From b1e98fcf807ef7f7af4a0f64eea2b33774d84ff6 Mon Sep 17 00:00:00 2001 From: ae Date: Tue, 1 Apr 2025 12:58:19 +0300 Subject: [PATCH] fix: route rt reqs through same middleware as at reqs --- server/pkg/service/server.go | 4 ++-- server/pkg/service/tokens.go | 40 ++++++++++++++++++++----------- server/pkg/service/tokens_test.go | 20 ++++++++-------- 3 files changed, 38 insertions(+), 26 deletions(-) diff --git a/server/pkg/service/server.go b/server/pkg/service/server.go index a1105f9..011a6ad 100644 --- a/server/pkg/service/server.go +++ b/server/pkg/service/server.go @@ -15,7 +15,7 @@ func Run(conn *pgx.Conn, jwtSecret string, httpPort string) error { q := data.New(conn) r := chi.NewRouter() - tokenService := tokenService{ + tokensRouter := tokensResource{ JWTSecret: jwtSecret, Tokens: q, } @@ -33,7 +33,7 @@ func Run(conn *pgx.Conn, jwtSecret string, httpPort string) error { r.Use(middleware.AllowContentType("application/json")) // Routes grouped by functionality - r.Post("/auth/refresh", tokenService.RefreshAccessToken) // POST /auth/refresh - new access token for refresh token + r.Mount("/auth", tokensRouter.Routes()) r.Mount("/users", usersRouter.Routes()) r.Mount("/notes", notesRouter.Routes()) diff --git a/server/pkg/service/tokens.go b/server/pkg/service/tokens.go index 3dc2fe5..03d0c90 100644 --- a/server/pkg/service/tokens.go +++ b/server/pkg/service/tokens.go @@ -11,6 +11,7 @@ import ( "time" "git.umbrella.haus/ae/notatest/pkg/data" + "github.com/go-chi/chi/v5" "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" ) @@ -44,13 +45,24 @@ type TokenStore interface { RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error } -type tokenService struct { +type tokensResource struct { JWTSecret string Tokens TokenStore } -func (ts *tokenService) GenerateTokenPair(ctx context.Context, userID uuid.UUID, isAdmin bool) (*tokenPair, error) { - tokenPair, err := generateTokenPair(userID.String(), isAdmin, ts.JWTSecret) +func (rs tokensResource) Routes() chi.Router { + r := chi.NewRouter() + + r.Group(func(r chi.Router) { + r.Use(requireRefreshToken(rs.JWTSecret)) + r.Post("/refresh", rs.RefreshAccessToken) // POST /auth/refresh - convert refresh token to new token pair + }) + + return r +} + +func (rs tokensResource) GenerateTokenPair(ctx context.Context, userID uuid.UUID, isAdmin bool) (*tokenPair, error) { + tokenPair, err := generateTokenPair(userID.String(), isAdmin, rs.JWTSecret) if err != nil { return nil, err } @@ -60,7 +72,7 @@ func (ts *tokenService) GenerateTokenPair(ctx context.Context, userID uuid.UUID, // Store to DB with (almost) identical expiration timestamp expiresAt := time.Now().Add(refreshTokenDuration) - _, err = ts.Tokens.CreateRefreshToken(ctx, data.CreateRefreshTokenParams{ + _, err = rs.Tokens.CreateRefreshToken(ctx, data.CreateRefreshTokenParams{ UserID: userID, TokenHash: tokenHash, ExpiresAt: expiresAt, @@ -72,17 +84,17 @@ func (ts *tokenService) GenerateTokenPair(ctx context.Context, userID uuid.UUID, return tokenPair, nil } -func (ts *tokenService) RevokeRefreshToken(ctx context.Context, token string) error { +func (rs tokensResource) RevokeRefreshToken(ctx context.Context, token string) error { hash := sha256.Sum256([]byte(token)) tokenHash := hex.EncodeToString(hash[:]) - return ts.Tokens.RevokeRefreshToken(ctx, tokenHash) + return rs.Tokens.RevokeRefreshToken(ctx, tokenHash) } -func (ts *tokenService) ValidateRefreshToken(ctx context.Context, token string) (*data.RefreshToken, error) { +func (rs tokensResource) ValidateRefreshToken(ctx context.Context, token string) (*data.RefreshToken, error) { hash := sha256.Sum256([]byte(token)) tokenHash := hex.EncodeToString(hash[:]) - dbToken, err := ts.Tokens.GetRefreshTokenByHash(ctx, tokenHash) + dbToken, err := rs.Tokens.GetRefreshTokenByHash(ctx, tokenHash) if err != nil { return nil, err } @@ -94,7 +106,7 @@ func (ts *tokenService) ValidateRefreshToken(ctx context.Context, token string) return &dbToken, nil } -func (ts *tokenService) RefreshAccessToken(w http.ResponseWriter, r *http.Request) { +func (rs tokensResource) RefreshAccessToken(w http.ResponseWriter, r *http.Request) { // Get claims from context claims, ok := r.Context().Value(userCtxKey{}).(*userClaims) if !ok || claims.TokenType != "refresh" { @@ -109,13 +121,13 @@ func (ts *tokenService) RefreshAccessToken(w http.ResponseWriter, r *http.Reques } // Validate the refresh token in DB - if _, err := ts.ValidateRefreshToken(r.Context(), refreshToken); err != nil { + if _, err := rs.ValidateRefreshToken(r.Context(), refreshToken); err != nil { respondError(w, http.StatusUnauthorized, "Invalid refresh token") return } // Revoke the used refresh token - if err := ts.RevokeRefreshToken(r.Context(), refreshToken); err != nil { + if err := rs.RevokeRefreshToken(r.Context(), refreshToken); err != nil { respondError(w, http.StatusInternalServerError, "Failed to revoke token") return } @@ -127,7 +139,7 @@ func (ts *tokenService) RefreshAccessToken(w http.ResponseWriter, r *http.Reques return } - tokenPair, err := ts.GenerateTokenPair(r.Context(), userID, claims.Admin) + tokenPair, err := rs.GenerateTokenPair(r.Context(), userID, claims.Admin) if err != nil { respondError(w, http.StatusInternalServerError, "Failed to generate tokens") return @@ -136,7 +148,7 @@ func (ts *tokenService) RefreshAccessToken(w http.ResponseWriter, r *http.Reques respondJSON(w, http.StatusOK, tokenPair) } -func (ts *tokenService) HandleLogout(w http.ResponseWriter, r *http.Request) { +func (rs tokensResource) HandleLogout(w http.ResponseWriter, r *http.Request) { claims, ok := r.Context().Value(userCtxKey{}).(*userClaims) if !ok { respondError(w, http.StatusUnauthorized, "Not authenticated") @@ -149,7 +161,7 @@ func (ts *tokenService) HandleLogout(w http.ResponseWriter, r *http.Request) { return } - if err := ts.Tokens.RevokeAllUserRefreshTokens(r.Context(), userID); err != nil { + if err := rs.Tokens.RevokeAllUserRefreshTokens(r.Context(), userID); err != nil { respondError(w, http.StatusInternalServerError, "Failed to logout") return } diff --git a/server/pkg/service/tokens_test.go b/server/pkg/service/tokens_test.go index b72196b..f1b546d 100644 --- a/server/pkg/service/tokens_test.go +++ b/server/pkg/service/tokens_test.go @@ -52,12 +52,12 @@ func TestGenerateTokenPair_Success(t *testing.T) { }, } - ts := tokenService{ + rs := tokensResource{ JWTSecret: "test-secret", Tokens: mockStore, } - pair, err := ts.GenerateTokenPair(context.Background(), userID, false) + pair, err := rs.GenerateTokenPair(context.Background(), userID, false) assert.NoError(t, err) assert.True(t, called) assert.NotEmpty(t, pair.AccessToken) @@ -71,8 +71,8 @@ func TestGenerateTokenPair_DBError(t *testing.T) { }, } - ts := tokenService{Tokens: mockStore} - _, err := ts.GenerateTokenPair(context.Background(), uuid.New(), false) + rs := tokensResource{Tokens: mockStore} + _, err := rs.GenerateTokenPair(context.Background(), uuid.New(), false) assert.ErrorContains(t, err, "db error") } @@ -88,8 +88,8 @@ func TestValidateRefreshToken_Valid(t *testing.T) { }, } - ts := tokenService{Tokens: mockStore} - _, err := ts.ValidateRefreshToken(context.Background(), token) + rs := tokensResource{Tokens: mockStore} + _, err := rs.ValidateRefreshToken(context.Background(), token) assert.NoError(t, err) } @@ -121,7 +121,7 @@ func TestRefreshAccessToken_Success(t *testing.T) { }, } - ts := tokenService{ + rs := tokensResource{ JWTSecret: "test-secret", Tokens: mockStore, } @@ -141,7 +141,7 @@ func TestRefreshAccessToken_Success(t *testing.T) { )) w := httptest.NewRecorder() - ts.RefreshAccessToken(w, req) + rs.RefreshAccessToken(w, req) assert.Equal(t, http.StatusOK, w.Code) assert.Contains(t, w.Body.String(), "access_token", "refresh_token") @@ -159,7 +159,7 @@ func TestHandleLogout_Success(t *testing.T) { }, } - ts := tokenService{Tokens: mockStore} + rs := tokensResource{Tokens: mockStore} req := httptest.NewRequest("POST", "/", nil) req = req.WithContext(context.WithValue( @@ -173,7 +173,7 @@ func TestHandleLogout_Success(t *testing.T) { )) w := httptest.NewRecorder() - ts.HandleLogout(w, req) + rs.HandleLogout(w, req) assert.True(t, called) assert.Equal(t, http.StatusOK, w.Code)