218 lines
6.0 KiB
Go
218 lines
6.0 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"git.umbrella.haus/ae/notatest/pkg/data"
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
const (
|
|
accessTokenDuration = 15 * time.Minute
|
|
refreshTokenDuration = 7 * 24 * time.Hour
|
|
)
|
|
|
|
var (
|
|
ErrInvalidToken = errors.New("invalid token")
|
|
ErrAuthHeaderInvalid = errors.New("token couldn't be parsed from authentication header")
|
|
)
|
|
|
|
type userClaims struct {
|
|
Admin bool `json:"admin"`
|
|
TokenType string `json:"type"` // "access" or "refresh"
|
|
jwt.RegisteredClaims // User's UUID should be stored in the subject claim
|
|
}
|
|
|
|
type tokenPair struct {
|
|
AccessToken string `json:"access_token"`
|
|
RefreshToken string `json:"refresh_token"`
|
|
}
|
|
|
|
// Mockable database operations interface
|
|
type TokenStore interface {
|
|
CreateRefreshToken(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error)
|
|
GetRefreshTokenByHash(ctx context.Context, tokenHash string) (data.RefreshToken, error)
|
|
RevokeRefreshToken(ctx context.Context, tokenHash string) error
|
|
RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error
|
|
}
|
|
|
|
type tokensResource struct {
|
|
JWTSecret string
|
|
Tokens TokenStore
|
|
}
|
|
|
|
func (rs tokensResource) Routes() chi.Router {
|
|
r := chi.NewRouter()
|
|
|
|
r.Group(func(r chi.Router) {
|
|
r.Use(requireRefreshToken(rs.JWTSecret))
|
|
r.Post("/refresh", rs.RefreshAccessToken) // POST /auth/refresh - convert refresh token to new token pair
|
|
})
|
|
|
|
return r
|
|
}
|
|
|
|
func (rs tokensResource) GenerateTokenPair(ctx context.Context, userID uuid.UUID, isAdmin bool) (*tokenPair, error) {
|
|
tokenPair, err := generateTokenPair(userID.String(), isAdmin, rs.JWTSecret)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
hash := sha256.Sum256([]byte(tokenPair.RefreshToken))
|
|
tokenHash := hex.EncodeToString(hash[:])
|
|
|
|
// Store to DB with (almost) identical expiration timestamp
|
|
expiresAt := time.Now().Add(refreshTokenDuration)
|
|
_, err = rs.Tokens.CreateRefreshToken(ctx, data.CreateRefreshTokenParams{
|
|
UserID: userID,
|
|
TokenHash: tokenHash,
|
|
ExpiresAt: expiresAt,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return tokenPair, nil
|
|
}
|
|
|
|
func (rs tokensResource) RevokeRefreshToken(ctx context.Context, token string) error {
|
|
hash := sha256.Sum256([]byte(token))
|
|
tokenHash := hex.EncodeToString(hash[:])
|
|
return rs.Tokens.RevokeRefreshToken(ctx, tokenHash)
|
|
}
|
|
|
|
func (rs tokensResource) ValidateRefreshToken(ctx context.Context, token string) (*data.RefreshToken, error) {
|
|
hash := sha256.Sum256([]byte(token))
|
|
tokenHash := hex.EncodeToString(hash[:])
|
|
|
|
dbToken, err := rs.Tokens.GetRefreshTokenByHash(ctx, tokenHash)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if dbToken.Revoked || time.Now().After(dbToken.ExpiresAt) {
|
|
return nil, ErrInvalidToken
|
|
}
|
|
|
|
return &dbToken, nil
|
|
}
|
|
|
|
func (rs tokensResource) RefreshAccessToken(w http.ResponseWriter, r *http.Request) {
|
|
// Get claims from context
|
|
claims, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
|
if !ok || claims.TokenType != "refresh" {
|
|
respondError(w, http.StatusUnauthorized, "Invalid token")
|
|
return
|
|
}
|
|
|
|
// Attempt to get the token from Authentication header ("Bearer <token>")
|
|
refreshToken, err := getTokenFromRequest(r)
|
|
if err != nil {
|
|
respondError(w, http.StatusUnauthorized, fmt.Sprintf("Unauthorized: %s", err))
|
|
}
|
|
|
|
// Validate the refresh token in DB
|
|
if _, err := rs.ValidateRefreshToken(r.Context(), refreshToken); err != nil {
|
|
respondError(w, http.StatusUnauthorized, "Invalid refresh token")
|
|
return
|
|
}
|
|
|
|
// Revoke the used refresh token
|
|
if err := rs.RevokeRefreshToken(r.Context(), refreshToken); err != nil {
|
|
respondError(w, http.StatusInternalServerError, "Failed to revoke token")
|
|
return
|
|
}
|
|
|
|
// Generate a new pair (access & refresh tokens)
|
|
userID, err := uuid.Parse(claims.Subject)
|
|
if err != nil {
|
|
respondError(w, http.StatusInternalServerError, "Invalid user ID")
|
|
return
|
|
}
|
|
|
|
tokenPair, err := rs.GenerateTokenPair(r.Context(), userID, claims.Admin)
|
|
if err != nil {
|
|
respondError(w, http.StatusInternalServerError, "Failed to generate tokens")
|
|
return
|
|
}
|
|
|
|
respondJSON(w, http.StatusOK, tokenPair)
|
|
}
|
|
|
|
func (rs tokensResource) HandleLogout(w http.ResponseWriter, r *http.Request) {
|
|
claims, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
|
if !ok {
|
|
respondError(w, http.StatusUnauthorized, "Not authenticated")
|
|
return
|
|
}
|
|
|
|
userID, err := uuid.Parse(claims.ID)
|
|
if err != nil {
|
|
respondError(w, http.StatusInternalServerError, "Invalid user ID")
|
|
return
|
|
}
|
|
|
|
if err := rs.Tokens.RevokeAllUserRefreshTokens(r.Context(), userID); err != nil {
|
|
respondError(w, http.StatusInternalServerError, "Failed to logout")
|
|
return
|
|
}
|
|
|
|
respondJSON(w, http.StatusOK, map[string]string{"status": "logged out"})
|
|
}
|
|
|
|
func getTokenFromRequest(r *http.Request) (string, error) {
|
|
bearerToken := r.Header.Get("Authorization")
|
|
bearerFields := strings.Fields(bearerToken)
|
|
if len(bearerFields) == 2 && strings.ToLower(bearerFields[0]) == "bearer" {
|
|
return bearerFields[1], nil
|
|
}
|
|
return "", ErrAuthHeaderInvalid
|
|
}
|
|
|
|
func generateTokenPair(userID string, isAdmin bool, jwtSecret string) (*tokenPair, error) {
|
|
atClaims := userClaims{
|
|
Admin: isAdmin,
|
|
TokenType: "access",
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
Subject: userID,
|
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(accessTokenDuration)),
|
|
IssuedAt: jwt.NewNumericDate(time.Now()),
|
|
NotBefore: jwt.NewNumericDate(time.Now()),
|
|
},
|
|
}
|
|
accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, atClaims)
|
|
|
|
t, err := accessToken.SignedString([]byte(jwtSecret))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
rtClaims := userClaims{
|
|
Admin: isAdmin,
|
|
TokenType: "refresh",
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
Subject: userID,
|
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(refreshTokenDuration)),
|
|
IssuedAt: jwt.NewNumericDate(time.Now()),
|
|
NotBefore: jwt.NewNumericDate(time.Now()),
|
|
},
|
|
}
|
|
refreshToken := jwt.NewWithClaims(jwt.SigningMethodHS256, rtClaims)
|
|
|
|
rt, err := refreshToken.SignedString([]byte(jwtSecret))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &tokenPair{AccessToken: t, RefreshToken: rt}, nil
|
|
}
|