fix: route rt reqs through same middleware as at reqs

This commit is contained in:
ae 2025-04-01 12:58:19 +03:00
parent 9324bb5321
commit b1e98fcf80
Signed by: ae
GPG Key ID: 995EFD5C1B532B3E
3 changed files with 38 additions and 26 deletions

View File

@ -15,7 +15,7 @@ func Run(conn *pgx.Conn, jwtSecret string, httpPort string) error {
q := data.New(conn) q := data.New(conn)
r := chi.NewRouter() r := chi.NewRouter()
tokenService := tokenService{ tokensRouter := tokensResource{
JWTSecret: jwtSecret, JWTSecret: jwtSecret,
Tokens: q, Tokens: q,
} }
@ -33,7 +33,7 @@ func Run(conn *pgx.Conn, jwtSecret string, httpPort string) error {
r.Use(middleware.AllowContentType("application/json")) r.Use(middleware.AllowContentType("application/json"))
// Routes grouped by functionality // Routes grouped by functionality
r.Post("/auth/refresh", tokenService.RefreshAccessToken) // POST /auth/refresh - new access token for refresh token r.Mount("/auth", tokensRouter.Routes())
r.Mount("/users", usersRouter.Routes()) r.Mount("/users", usersRouter.Routes())
r.Mount("/notes", notesRouter.Routes()) r.Mount("/notes", notesRouter.Routes())

View File

@ -11,6 +11,7 @@ import (
"time" "time"
"git.umbrella.haus/ae/notatest/pkg/data" "git.umbrella.haus/ae/notatest/pkg/data"
"github.com/go-chi/chi/v5"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"github.com/google/uuid" "github.com/google/uuid"
) )
@ -44,13 +45,24 @@ type TokenStore interface {
RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error
} }
type tokenService struct { type tokensResource struct {
JWTSecret string JWTSecret string
Tokens TokenStore Tokens TokenStore
} }
func (ts *tokenService) GenerateTokenPair(ctx context.Context, userID uuid.UUID, isAdmin bool) (*tokenPair, error) { func (rs tokensResource) Routes() chi.Router {
tokenPair, err := generateTokenPair(userID.String(), isAdmin, ts.JWTSecret) r := chi.NewRouter()
r.Group(func(r chi.Router) {
r.Use(requireRefreshToken(rs.JWTSecret))
r.Post("/refresh", rs.RefreshAccessToken) // POST /auth/refresh - convert refresh token to new token pair
})
return r
}
func (rs tokensResource) GenerateTokenPair(ctx context.Context, userID uuid.UUID, isAdmin bool) (*tokenPair, error) {
tokenPair, err := generateTokenPair(userID.String(), isAdmin, rs.JWTSecret)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -60,7 +72,7 @@ func (ts *tokenService) GenerateTokenPair(ctx context.Context, userID uuid.UUID,
// Store to DB with (almost) identical expiration timestamp // Store to DB with (almost) identical expiration timestamp
expiresAt := time.Now().Add(refreshTokenDuration) expiresAt := time.Now().Add(refreshTokenDuration)
_, err = ts.Tokens.CreateRefreshToken(ctx, data.CreateRefreshTokenParams{ _, err = rs.Tokens.CreateRefreshToken(ctx, data.CreateRefreshTokenParams{
UserID: userID, UserID: userID,
TokenHash: tokenHash, TokenHash: tokenHash,
ExpiresAt: expiresAt, ExpiresAt: expiresAt,
@ -72,17 +84,17 @@ func (ts *tokenService) GenerateTokenPair(ctx context.Context, userID uuid.UUID,
return tokenPair, nil return tokenPair, nil
} }
func (ts *tokenService) RevokeRefreshToken(ctx context.Context, token string) error { func (rs tokensResource) RevokeRefreshToken(ctx context.Context, token string) error {
hash := sha256.Sum256([]byte(token)) hash := sha256.Sum256([]byte(token))
tokenHash := hex.EncodeToString(hash[:]) tokenHash := hex.EncodeToString(hash[:])
return ts.Tokens.RevokeRefreshToken(ctx, tokenHash) return rs.Tokens.RevokeRefreshToken(ctx, tokenHash)
} }
func (ts *tokenService) ValidateRefreshToken(ctx context.Context, token string) (*data.RefreshToken, error) { func (rs tokensResource) ValidateRefreshToken(ctx context.Context, token string) (*data.RefreshToken, error) {
hash := sha256.Sum256([]byte(token)) hash := sha256.Sum256([]byte(token))
tokenHash := hex.EncodeToString(hash[:]) tokenHash := hex.EncodeToString(hash[:])
dbToken, err := ts.Tokens.GetRefreshTokenByHash(ctx, tokenHash) dbToken, err := rs.Tokens.GetRefreshTokenByHash(ctx, tokenHash)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -94,7 +106,7 @@ func (ts *tokenService) ValidateRefreshToken(ctx context.Context, token string)
return &dbToken, nil return &dbToken, nil
} }
func (ts *tokenService) RefreshAccessToken(w http.ResponseWriter, r *http.Request) { func (rs tokensResource) RefreshAccessToken(w http.ResponseWriter, r *http.Request) {
// Get claims from context // Get claims from context
claims, ok := r.Context().Value(userCtxKey{}).(*userClaims) claims, ok := r.Context().Value(userCtxKey{}).(*userClaims)
if !ok || claims.TokenType != "refresh" { if !ok || claims.TokenType != "refresh" {
@ -109,13 +121,13 @@ func (ts *tokenService) RefreshAccessToken(w http.ResponseWriter, r *http.Reques
} }
// Validate the refresh token in DB // Validate the refresh token in DB
if _, err := ts.ValidateRefreshToken(r.Context(), refreshToken); err != nil { if _, err := rs.ValidateRefreshToken(r.Context(), refreshToken); err != nil {
respondError(w, http.StatusUnauthorized, "Invalid refresh token") respondError(w, http.StatusUnauthorized, "Invalid refresh token")
return return
} }
// Revoke the used refresh token // Revoke the used refresh token
if err := ts.RevokeRefreshToken(r.Context(), refreshToken); err != nil { if err := rs.RevokeRefreshToken(r.Context(), refreshToken); err != nil {
respondError(w, http.StatusInternalServerError, "Failed to revoke token") respondError(w, http.StatusInternalServerError, "Failed to revoke token")
return return
} }
@ -127,7 +139,7 @@ func (ts *tokenService) RefreshAccessToken(w http.ResponseWriter, r *http.Reques
return return
} }
tokenPair, err := ts.GenerateTokenPair(r.Context(), userID, claims.Admin) tokenPair, err := rs.GenerateTokenPair(r.Context(), userID, claims.Admin)
if err != nil { if err != nil {
respondError(w, http.StatusInternalServerError, "Failed to generate tokens") respondError(w, http.StatusInternalServerError, "Failed to generate tokens")
return return
@ -136,7 +148,7 @@ func (ts *tokenService) RefreshAccessToken(w http.ResponseWriter, r *http.Reques
respondJSON(w, http.StatusOK, tokenPair) respondJSON(w, http.StatusOK, tokenPair)
} }
func (ts *tokenService) HandleLogout(w http.ResponseWriter, r *http.Request) { func (rs tokensResource) HandleLogout(w http.ResponseWriter, r *http.Request) {
claims, ok := r.Context().Value(userCtxKey{}).(*userClaims) claims, ok := r.Context().Value(userCtxKey{}).(*userClaims)
if !ok { if !ok {
respondError(w, http.StatusUnauthorized, "Not authenticated") respondError(w, http.StatusUnauthorized, "Not authenticated")
@ -149,7 +161,7 @@ func (ts *tokenService) HandleLogout(w http.ResponseWriter, r *http.Request) {
return return
} }
if err := ts.Tokens.RevokeAllUserRefreshTokens(r.Context(), userID); err != nil { if err := rs.Tokens.RevokeAllUserRefreshTokens(r.Context(), userID); err != nil {
respondError(w, http.StatusInternalServerError, "Failed to logout") respondError(w, http.StatusInternalServerError, "Failed to logout")
return return
} }

View File

@ -52,12 +52,12 @@ func TestGenerateTokenPair_Success(t *testing.T) {
}, },
} }
ts := tokenService{ rs := tokensResource{
JWTSecret: "test-secret", JWTSecret: "test-secret",
Tokens: mockStore, Tokens: mockStore,
} }
pair, err := ts.GenerateTokenPair(context.Background(), userID, false) pair, err := rs.GenerateTokenPair(context.Background(), userID, false)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, called) assert.True(t, called)
assert.NotEmpty(t, pair.AccessToken) assert.NotEmpty(t, pair.AccessToken)
@ -71,8 +71,8 @@ func TestGenerateTokenPair_DBError(t *testing.T) {
}, },
} }
ts := tokenService{Tokens: mockStore} rs := tokensResource{Tokens: mockStore}
_, err := ts.GenerateTokenPair(context.Background(), uuid.New(), false) _, err := rs.GenerateTokenPair(context.Background(), uuid.New(), false)
assert.ErrorContains(t, err, "db error") assert.ErrorContains(t, err, "db error")
} }
@ -88,8 +88,8 @@ func TestValidateRefreshToken_Valid(t *testing.T) {
}, },
} }
ts := tokenService{Tokens: mockStore} rs := tokensResource{Tokens: mockStore}
_, err := ts.ValidateRefreshToken(context.Background(), token) _, err := rs.ValidateRefreshToken(context.Background(), token)
assert.NoError(t, err) assert.NoError(t, err)
} }
@ -121,7 +121,7 @@ func TestRefreshAccessToken_Success(t *testing.T) {
}, },
} }
ts := tokenService{ rs := tokensResource{
JWTSecret: "test-secret", JWTSecret: "test-secret",
Tokens: mockStore, Tokens: mockStore,
} }
@ -141,7 +141,7 @@ func TestRefreshAccessToken_Success(t *testing.T) {
)) ))
w := httptest.NewRecorder() w := httptest.NewRecorder()
ts.RefreshAccessToken(w, req) rs.RefreshAccessToken(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "access_token", "refresh_token") assert.Contains(t, w.Body.String(), "access_token", "refresh_token")
@ -159,7 +159,7 @@ func TestHandleLogout_Success(t *testing.T) {
}, },
} }
ts := tokenService{Tokens: mockStore} rs := tokensResource{Tokens: mockStore}
req := httptest.NewRequest("POST", "/", nil) req := httptest.NewRequest("POST", "/", nil)
req = req.WithContext(context.WithValue( req = req.WithContext(context.WithValue(
@ -173,7 +173,7 @@ func TestHandleLogout_Success(t *testing.T) {
)) ))
w := httptest.NewRecorder() w := httptest.NewRecorder()
ts.HandleLogout(w, req) rs.HandleLogout(w, req)
assert.True(t, called) assert.True(t, called)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)