package service import ( "context" "crypto/sha256" "encoding/hex" "errors" "fmt" "net/http" "strings" "time" "git.umbrella.haus/ae/notatest/pkg/data" "github.com/go-chi/chi/v5" "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" ) const ( accessTokenDuration = 15 * time.Minute refreshTokenDuration = 7 * 24 * time.Hour ) var ( ErrInvalidToken = errors.New("invalid token") ErrAuthHeaderInvalid = errors.New("token couldn't be parsed from authentication header") ) type userClaims struct { Admin bool `json:"admin"` TokenType string `json:"type"` // "access" or "refresh" jwt.RegisteredClaims // User's UUID should be stored in the subject claim } type tokenPair struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` } // Mockable database operations interface type TokenStore interface { CreateRefreshToken(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) GetRefreshTokenByHash(ctx context.Context, tokenHash string) (data.RefreshToken, error) RevokeRefreshToken(ctx context.Context, tokenHash string) error RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error } type tokensResource struct { JWTSecret string Tokens TokenStore } func (rs tokensResource) Routes() chi.Router { r := chi.NewRouter() r.Group(func(r chi.Router) { r.Use(requireRefreshToken(rs.JWTSecret)) r.Post("/refresh", rs.RefreshAccessToken) // POST /auth/refresh - convert refresh token to new token pair }) return r } func (rs tokensResource) GenerateTokenPair(ctx context.Context, userID uuid.UUID, isAdmin bool) (*tokenPair, error) { tokenPair, err := generateTokenPair(userID.String(), isAdmin, rs.JWTSecret) if err != nil { return nil, err } hash := sha256.Sum256([]byte(tokenPair.RefreshToken)) tokenHash := hex.EncodeToString(hash[:]) // Store to DB with (almost) identical expiration timestamp expiresAt := time.Now().Add(refreshTokenDuration) _, err = rs.Tokens.CreateRefreshToken(ctx, data.CreateRefreshTokenParams{ UserID: userID, TokenHash: tokenHash, ExpiresAt: expiresAt, }) if err != nil { return nil, err } return tokenPair, nil } func (rs tokensResource) RevokeRefreshToken(ctx context.Context, token string) error { hash := sha256.Sum256([]byte(token)) tokenHash := hex.EncodeToString(hash[:]) return rs.Tokens.RevokeRefreshToken(ctx, tokenHash) } func (rs tokensResource) ValidateRefreshToken(ctx context.Context, token string) (*data.RefreshToken, error) { hash := sha256.Sum256([]byte(token)) tokenHash := hex.EncodeToString(hash[:]) dbToken, err := rs.Tokens.GetRefreshTokenByHash(ctx, tokenHash) if err != nil { return nil, err } if dbToken.Revoked || time.Now().After(dbToken.ExpiresAt) { return nil, ErrInvalidToken } return &dbToken, nil } func (rs tokensResource) RefreshAccessToken(w http.ResponseWriter, r *http.Request) { // Get claims from context claims, ok := r.Context().Value(userCtxKey{}).(*userClaims) if !ok || claims.TokenType != "refresh" { respondError(w, http.StatusUnauthorized, "Invalid token") return } // Attempt to get the token from Authentication header ("Bearer ") refreshToken, err := getTokenFromRequest(r) if err != nil { respondError(w, http.StatusUnauthorized, fmt.Sprintf("Unauthorized: %s", err)) } // Validate the refresh token in DB if _, err := rs.ValidateRefreshToken(r.Context(), refreshToken); err != nil { respondError(w, http.StatusUnauthorized, "Invalid refresh token") return } // Revoke the used refresh token if err := rs.RevokeRefreshToken(r.Context(), refreshToken); err != nil { respondError(w, http.StatusInternalServerError, "Failed to revoke token") return } // Generate a new pair (access & refresh tokens) userID, err := uuid.Parse(claims.Subject) if err != nil { respondError(w, http.StatusInternalServerError, "Invalid user ID") return } tokenPair, err := rs.GenerateTokenPair(r.Context(), userID, claims.Admin) if err != nil { respondError(w, http.StatusInternalServerError, "Failed to generate tokens") return } respondJSON(w, http.StatusOK, tokenPair) } func (rs tokensResource) HandleLogout(w http.ResponseWriter, r *http.Request) { claims, ok := r.Context().Value(userCtxKey{}).(*userClaims) if !ok { respondError(w, http.StatusUnauthorized, "Not authenticated") return } userID, err := uuid.Parse(claims.ID) if err != nil { respondError(w, http.StatusInternalServerError, "Invalid user ID") return } if err := rs.Tokens.RevokeAllUserRefreshTokens(r.Context(), userID); err != nil { respondError(w, http.StatusInternalServerError, "Failed to logout") return } respondJSON(w, http.StatusOK, map[string]string{"status": "logged out"}) } func getTokenFromRequest(r *http.Request) (string, error) { bearerToken := r.Header.Get("Authorization") bearerFields := strings.Fields(bearerToken) if len(bearerFields) == 2 && strings.ToLower(bearerFields[0]) == "bearer" { return bearerFields[1], nil } return "", ErrAuthHeaderInvalid } func generateTokenPair(userID string, isAdmin bool, jwtSecret string) (*tokenPair, error) { atClaims := userClaims{ Admin: isAdmin, TokenType: "access", RegisteredClaims: jwt.RegisteredClaims{ Subject: userID, ExpiresAt: jwt.NewNumericDate(time.Now().Add(accessTokenDuration)), IssuedAt: jwt.NewNumericDate(time.Now()), NotBefore: jwt.NewNumericDate(time.Now()), }, } accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, atClaims) t, err := accessToken.SignedString([]byte(jwtSecret)) if err != nil { return nil, err } rtClaims := userClaims{ Admin: isAdmin, TokenType: "refresh", RegisteredClaims: jwt.RegisteredClaims{ Subject: userID, ExpiresAt: jwt.NewNumericDate(time.Now().Add(refreshTokenDuration)), IssuedAt: jwt.NewNumericDate(time.Now()), NotBefore: jwt.NewNumericDate(time.Now()), }, } refreshToken := jwt.NewWithClaims(jwt.SigningMethodHS256, rtClaims) rt, err := refreshToken.SignedString([]byte(jwtSecret)) if err != nil { return nil, err } return &tokenPair{AccessToken: t, RefreshToken: rt}, nil }