206 lines
5.8 KiB
Go

package service
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"net/http"
"strings"
"time"
"git.umbrella.haus/ae/notatest/pkg/data"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
)
const (
accessTokenDuration = 15 * time.Minute
refreshTokenDuration = 7 * 24 * time.Hour
)
var (
ErrInvalidToken = errors.New("invalid token")
ErrAuthHeaderInvalid = errors.New("token couldn't be parsed from authentication header")
)
type userClaims struct {
Admin bool `json:"admin"`
TokenType string `json:"type"` // "access" or "refresh"
jwt.RegisteredClaims // User's UUID should be stored in the subject claim
}
type tokenPair struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
}
// Mockable database operations interface
type TokenStore interface {
CreateRefreshToken(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error)
GetRefreshTokenByHash(ctx context.Context, tokenHash string) (data.RefreshToken, error)
RevokeRefreshToken(ctx context.Context, tokenHash string) error
RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error
}
type tokenService struct {
JWTSecret string
Tokens TokenStore
}
func (ts *tokenService) GenerateTokenPair(ctx context.Context, userID uuid.UUID, isAdmin bool) (*tokenPair, error) {
tokenPair, err := generateTokenPair(userID.String(), isAdmin, ts.JWTSecret)
if err != nil {
return nil, err
}
hash := sha256.Sum256([]byte(tokenPair.RefreshToken))
tokenHash := hex.EncodeToString(hash[:])
// Store to DB with (almost) identical expiration timestamp
expiresAt := time.Now().Add(refreshTokenDuration)
_, err = ts.Tokens.CreateRefreshToken(ctx, data.CreateRefreshTokenParams{
UserID: userID,
TokenHash: tokenHash,
ExpiresAt: expiresAt,
})
if err != nil {
return nil, err
}
return tokenPair, nil
}
func (ts *tokenService) RevokeRefreshToken(ctx context.Context, token string) error {
hash := sha256.Sum256([]byte(token))
tokenHash := hex.EncodeToString(hash[:])
return ts.Tokens.RevokeRefreshToken(ctx, tokenHash)
}
func (ts *tokenService) ValidateRefreshToken(ctx context.Context, token string) (*data.RefreshToken, error) {
hash := sha256.Sum256([]byte(token))
tokenHash := hex.EncodeToString(hash[:])
dbToken, err := ts.Tokens.GetRefreshTokenByHash(ctx, tokenHash)
if err != nil {
return nil, err
}
if dbToken.Revoked || time.Now().After(dbToken.ExpiresAt) {
return nil, ErrInvalidToken
}
return &dbToken, nil
}
func (ts *tokenService) RefreshAccessToken(w http.ResponseWriter, r *http.Request) {
// Get claims from context
claims, ok := r.Context().Value(userCtxKey{}).(*userClaims)
if !ok || claims.TokenType != "refresh" {
respondError(w, http.StatusUnauthorized, "Invalid token")
return
}
// Attempt to get the token from Authentication header ("Bearer <token>")
refreshToken, err := getTokenFromRequest(r)
if err != nil {
respondError(w, http.StatusUnauthorized, fmt.Sprintf("Unauthorized: %s", err))
}
// Validate the refresh token in DB
if _, err := ts.ValidateRefreshToken(r.Context(), refreshToken); err != nil {
respondError(w, http.StatusUnauthorized, "Invalid refresh token")
return
}
// Revoke the used refresh token
if err := ts.RevokeRefreshToken(r.Context(), refreshToken); err != nil {
respondError(w, http.StatusInternalServerError, "Failed to revoke token")
return
}
// Generate a new pair (access & refresh tokens)
userID, err := uuid.Parse(claims.Subject)
if err != nil {
respondError(w, http.StatusInternalServerError, "Invalid user ID")
return
}
tokenPair, err := ts.GenerateTokenPair(r.Context(), userID, claims.Admin)
if err != nil {
respondError(w, http.StatusInternalServerError, "Failed to generate tokens")
return
}
respondJSON(w, http.StatusOK, tokenPair)
}
func (ts *tokenService) HandleLogout(w http.ResponseWriter, r *http.Request) {
claims, ok := r.Context().Value(userCtxKey{}).(*userClaims)
if !ok {
respondError(w, http.StatusUnauthorized, "Not authenticated")
return
}
userID, err := uuid.Parse(claims.ID)
if err != nil {
respondError(w, http.StatusInternalServerError, "Invalid user ID")
return
}
if err := ts.Tokens.RevokeAllUserRefreshTokens(r.Context(), userID); err != nil {
respondError(w, http.StatusInternalServerError, "Failed to logout")
return
}
respondJSON(w, http.StatusOK, map[string]string{"status": "logged out"})
}
func getTokenFromRequest(r *http.Request) (string, error) {
bearerToken := r.Header.Get("Authorization")
bearerFields := strings.Fields(bearerToken)
if len(bearerFields) == 2 && strings.ToLower(bearerFields[0]) == "bearer" {
return bearerFields[1], nil
}
return "", ErrAuthHeaderInvalid
}
func generateTokenPair(userID string, isAdmin bool, jwtSecret string) (*tokenPair, error) {
atClaims := userClaims{
Admin: isAdmin,
TokenType: "access",
RegisteredClaims: jwt.RegisteredClaims{
Subject: userID,
ExpiresAt: jwt.NewNumericDate(time.Now().Add(accessTokenDuration)),
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
},
}
accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, atClaims)
t, err := accessToken.SignedString([]byte(jwtSecret))
if err != nil {
return nil, err
}
rtClaims := userClaims{
Admin: isAdmin,
TokenType: "refresh",
RegisteredClaims: jwt.RegisteredClaims{
Subject: userID,
ExpiresAt: jwt.NewNumericDate(time.Now().Add(refreshTokenDuration)),
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
},
}
refreshToken := jwt.NewWithClaims(jwt.SigningMethodHS256, rtClaims)
rt, err := refreshToken.SignedString([]byte(jwtSecret))
if err != nil {
return nil, err
}
return &tokenPair{AccessToken: t, RefreshToken: rt}, nil
}