fix: route rt reqs through same middleware as at reqs
This commit is contained in:
parent
9324bb5321
commit
b1e98fcf80
@ -15,7 +15,7 @@ func Run(conn *pgx.Conn, jwtSecret string, httpPort string) error {
|
|||||||
q := data.New(conn)
|
q := data.New(conn)
|
||||||
r := chi.NewRouter()
|
r := chi.NewRouter()
|
||||||
|
|
||||||
tokenService := tokenService{
|
tokensRouter := tokensResource{
|
||||||
JWTSecret: jwtSecret,
|
JWTSecret: jwtSecret,
|
||||||
Tokens: q,
|
Tokens: q,
|
||||||
}
|
}
|
||||||
@ -33,7 +33,7 @@ func Run(conn *pgx.Conn, jwtSecret string, httpPort string) error {
|
|||||||
r.Use(middleware.AllowContentType("application/json"))
|
r.Use(middleware.AllowContentType("application/json"))
|
||||||
|
|
||||||
// Routes grouped by functionality
|
// Routes grouped by functionality
|
||||||
r.Post("/auth/refresh", tokenService.RefreshAccessToken) // POST /auth/refresh - new access token for refresh token
|
r.Mount("/auth", tokensRouter.Routes())
|
||||||
r.Mount("/users", usersRouter.Routes())
|
r.Mount("/users", usersRouter.Routes())
|
||||||
r.Mount("/notes", notesRouter.Routes())
|
r.Mount("/notes", notesRouter.Routes())
|
||||||
|
|
||||||
|
@ -11,6 +11,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.umbrella.haus/ae/notatest/pkg/data"
|
"git.umbrella.haus/ae/notatest/pkg/data"
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
@ -44,13 +45,24 @@ type TokenStore interface {
|
|||||||
RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error
|
RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type tokenService struct {
|
type tokensResource struct {
|
||||||
JWTSecret string
|
JWTSecret string
|
||||||
Tokens TokenStore
|
Tokens TokenStore
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ts *tokenService) GenerateTokenPair(ctx context.Context, userID uuid.UUID, isAdmin bool) (*tokenPair, error) {
|
func (rs tokensResource) Routes() chi.Router {
|
||||||
tokenPair, err := generateTokenPair(userID.String(), isAdmin, ts.JWTSecret)
|
r := chi.NewRouter()
|
||||||
|
|
||||||
|
r.Group(func(r chi.Router) {
|
||||||
|
r.Use(requireRefreshToken(rs.JWTSecret))
|
||||||
|
r.Post("/refresh", rs.RefreshAccessToken) // POST /auth/refresh - convert refresh token to new token pair
|
||||||
|
})
|
||||||
|
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rs tokensResource) GenerateTokenPair(ctx context.Context, userID uuid.UUID, isAdmin bool) (*tokenPair, error) {
|
||||||
|
tokenPair, err := generateTokenPair(userID.String(), isAdmin, rs.JWTSecret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -60,7 +72,7 @@ func (ts *tokenService) GenerateTokenPair(ctx context.Context, userID uuid.UUID,
|
|||||||
|
|
||||||
// Store to DB with (almost) identical expiration timestamp
|
// Store to DB with (almost) identical expiration timestamp
|
||||||
expiresAt := time.Now().Add(refreshTokenDuration)
|
expiresAt := time.Now().Add(refreshTokenDuration)
|
||||||
_, err = ts.Tokens.CreateRefreshToken(ctx, data.CreateRefreshTokenParams{
|
_, err = rs.Tokens.CreateRefreshToken(ctx, data.CreateRefreshTokenParams{
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
TokenHash: tokenHash,
|
TokenHash: tokenHash,
|
||||||
ExpiresAt: expiresAt,
|
ExpiresAt: expiresAt,
|
||||||
@ -72,17 +84,17 @@ func (ts *tokenService) GenerateTokenPair(ctx context.Context, userID uuid.UUID,
|
|||||||
return tokenPair, nil
|
return tokenPair, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ts *tokenService) RevokeRefreshToken(ctx context.Context, token string) error {
|
func (rs tokensResource) RevokeRefreshToken(ctx context.Context, token string) error {
|
||||||
hash := sha256.Sum256([]byte(token))
|
hash := sha256.Sum256([]byte(token))
|
||||||
tokenHash := hex.EncodeToString(hash[:])
|
tokenHash := hex.EncodeToString(hash[:])
|
||||||
return ts.Tokens.RevokeRefreshToken(ctx, tokenHash)
|
return rs.Tokens.RevokeRefreshToken(ctx, tokenHash)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ts *tokenService) ValidateRefreshToken(ctx context.Context, token string) (*data.RefreshToken, error) {
|
func (rs tokensResource) ValidateRefreshToken(ctx context.Context, token string) (*data.RefreshToken, error) {
|
||||||
hash := sha256.Sum256([]byte(token))
|
hash := sha256.Sum256([]byte(token))
|
||||||
tokenHash := hex.EncodeToString(hash[:])
|
tokenHash := hex.EncodeToString(hash[:])
|
||||||
|
|
||||||
dbToken, err := ts.Tokens.GetRefreshTokenByHash(ctx, tokenHash)
|
dbToken, err := rs.Tokens.GetRefreshTokenByHash(ctx, tokenHash)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -94,7 +106,7 @@ func (ts *tokenService) ValidateRefreshToken(ctx context.Context, token string)
|
|||||||
return &dbToken, nil
|
return &dbToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ts *tokenService) RefreshAccessToken(w http.ResponseWriter, r *http.Request) {
|
func (rs tokensResource) RefreshAccessToken(w http.ResponseWriter, r *http.Request) {
|
||||||
// Get claims from context
|
// Get claims from context
|
||||||
claims, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
claims, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
||||||
if !ok || claims.TokenType != "refresh" {
|
if !ok || claims.TokenType != "refresh" {
|
||||||
@ -109,13 +121,13 @@ func (ts *tokenService) RefreshAccessToken(w http.ResponseWriter, r *http.Reques
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Validate the refresh token in DB
|
// Validate the refresh token in DB
|
||||||
if _, err := ts.ValidateRefreshToken(r.Context(), refreshToken); err != nil {
|
if _, err := rs.ValidateRefreshToken(r.Context(), refreshToken); err != nil {
|
||||||
respondError(w, http.StatusUnauthorized, "Invalid refresh token")
|
respondError(w, http.StatusUnauthorized, "Invalid refresh token")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Revoke the used refresh token
|
// Revoke the used refresh token
|
||||||
if err := ts.RevokeRefreshToken(r.Context(), refreshToken); err != nil {
|
if err := rs.RevokeRefreshToken(r.Context(), refreshToken); err != nil {
|
||||||
respondError(w, http.StatusInternalServerError, "Failed to revoke token")
|
respondError(w, http.StatusInternalServerError, "Failed to revoke token")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -127,7 +139,7 @@ func (ts *tokenService) RefreshAccessToken(w http.ResponseWriter, r *http.Reques
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenPair, err := ts.GenerateTokenPair(r.Context(), userID, claims.Admin)
|
tokenPair, err := rs.GenerateTokenPair(r.Context(), userID, claims.Admin)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
respondError(w, http.StatusInternalServerError, "Failed to generate tokens")
|
respondError(w, http.StatusInternalServerError, "Failed to generate tokens")
|
||||||
return
|
return
|
||||||
@ -136,7 +148,7 @@ func (ts *tokenService) RefreshAccessToken(w http.ResponseWriter, r *http.Reques
|
|||||||
respondJSON(w, http.StatusOK, tokenPair)
|
respondJSON(w, http.StatusOK, tokenPair)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ts *tokenService) HandleLogout(w http.ResponseWriter, r *http.Request) {
|
func (rs tokensResource) HandleLogout(w http.ResponseWriter, r *http.Request) {
|
||||||
claims, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
claims, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
||||||
if !ok {
|
if !ok {
|
||||||
respondError(w, http.StatusUnauthorized, "Not authenticated")
|
respondError(w, http.StatusUnauthorized, "Not authenticated")
|
||||||
@ -149,7 +161,7 @@ func (ts *tokenService) HandleLogout(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := ts.Tokens.RevokeAllUserRefreshTokens(r.Context(), userID); err != nil {
|
if err := rs.Tokens.RevokeAllUserRefreshTokens(r.Context(), userID); err != nil {
|
||||||
respondError(w, http.StatusInternalServerError, "Failed to logout")
|
respondError(w, http.StatusInternalServerError, "Failed to logout")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -52,12 +52,12 @@ func TestGenerateTokenPair_Success(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
ts := tokenService{
|
rs := tokensResource{
|
||||||
JWTSecret: "test-secret",
|
JWTSecret: "test-secret",
|
||||||
Tokens: mockStore,
|
Tokens: mockStore,
|
||||||
}
|
}
|
||||||
|
|
||||||
pair, err := ts.GenerateTokenPair(context.Background(), userID, false)
|
pair, err := rs.GenerateTokenPair(context.Background(), userID, false)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.True(t, called)
|
assert.True(t, called)
|
||||||
assert.NotEmpty(t, pair.AccessToken)
|
assert.NotEmpty(t, pair.AccessToken)
|
||||||
@ -71,8 +71,8 @@ func TestGenerateTokenPair_DBError(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
ts := tokenService{Tokens: mockStore}
|
rs := tokensResource{Tokens: mockStore}
|
||||||
_, err := ts.GenerateTokenPair(context.Background(), uuid.New(), false)
|
_, err := rs.GenerateTokenPair(context.Background(), uuid.New(), false)
|
||||||
assert.ErrorContains(t, err, "db error")
|
assert.ErrorContains(t, err, "db error")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -88,8 +88,8 @@ func TestValidateRefreshToken_Valid(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
ts := tokenService{Tokens: mockStore}
|
rs := tokensResource{Tokens: mockStore}
|
||||||
_, err := ts.ValidateRefreshToken(context.Background(), token)
|
_, err := rs.ValidateRefreshToken(context.Background(), token)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -121,7 +121,7 @@ func TestRefreshAccessToken_Success(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
ts := tokenService{
|
rs := tokensResource{
|
||||||
JWTSecret: "test-secret",
|
JWTSecret: "test-secret",
|
||||||
Tokens: mockStore,
|
Tokens: mockStore,
|
||||||
}
|
}
|
||||||
@ -141,7 +141,7 @@ func TestRefreshAccessToken_Success(t *testing.T) {
|
|||||||
))
|
))
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
ts.RefreshAccessToken(w, req)
|
rs.RefreshAccessToken(w, req)
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
assert.Contains(t, w.Body.String(), "access_token", "refresh_token")
|
assert.Contains(t, w.Body.String(), "access_token", "refresh_token")
|
||||||
@ -159,7 +159,7 @@ func TestHandleLogout_Success(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
ts := tokenService{Tokens: mockStore}
|
rs := tokensResource{Tokens: mockStore}
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/", nil)
|
req := httptest.NewRequest("POST", "/", nil)
|
||||||
req = req.WithContext(context.WithValue(
|
req = req.WithContext(context.WithValue(
|
||||||
@ -173,7 +173,7 @@ func TestHandleLogout_Success(t *testing.T) {
|
|||||||
))
|
))
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
ts.HandleLogout(w, req)
|
rs.HandleLogout(w, req)
|
||||||
|
|
||||||
assert.True(t, called)
|
assert.True(t, called)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user