diff --git a/server/internal/service/auth.go b/server/internal/service/auth.go index 6cc5ccb..2577dec 100644 --- a/server/internal/service/auth.go +++ b/server/internal/service/auth.go @@ -452,8 +452,8 @@ func (rs authResource) RefreshAccessToken(w http.ResponseWriter, r *http.Request return } - // Attempt to get the token from Authorization header (formatted as "Bearer ") - refreshToken, err := getTokenFromRequest(r) + // Attempt to get the token from httpOnly cookie + refreshToken, err := getTokenFromCookie(r) if err != nil { respondError(w, http.StatusUnauthorized, "Unauthorized") return @@ -578,8 +578,8 @@ func CreateAdminIfNotExists(ctx context.Context, q *data.Queries, username, pass return nil } -// Parse the JWT bearer token from the request's Authorization header. -func getTokenFromRequest(r *http.Request) (string, error) { +// Parse the JWT token from the request's Authorization header. +func getTokenFromHeader(r *http.Request) (string, error) { bearerToken := r.Header.Get("Authorization") bearerFields := strings.Fields(bearerToken) if len(bearerFields) == 2 && strings.ToLower(bearerFields[0]) == "bearer" { @@ -588,6 +588,15 @@ func getTokenFromRequest(r *http.Request) (string, error) { return "", ErrAuthHeaderInvalid } +// Parse the JWT token from the request's cookies (httpOnly cookie). +func getTokenFromCookie(r *http.Request) (string, error) { + cookie, err := r.Cookie("refresh_token") + if err != nil { + return "", err + } + return cookie.Value, nil +} + // Helper function for generating a new JWT token pair with the given specifications. func generateTokenPair(userID string, isAdmin bool, jwtSecret string) (*tokenPair, error) { atClaims := userClaims{ diff --git a/server/internal/service/middleware.go b/server/internal/service/middleware.go index c01fbe4..de6c177 100644 --- a/server/internal/service/middleware.go +++ b/server/internal/service/middleware.go @@ -32,7 +32,16 @@ type uuidCtxKey struct { func authMiddleware(jwtSecret string, expectedType string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - tokenString, err := getTokenFromRequest(r) + var tokenString string + var err error + + // Get token from appropriate source based on type (Authorization header or httpOnly cookie) + if expectedType == "refresh" { + tokenString, err = getTokenFromCookie(r) + } else { + tokenString, err = getTokenFromHeader(r) + } + if err != nil { respondError(w, http.StatusUnauthorized, "Unauthorized") return diff --git a/server/internal/service/middleware_test.go b/server/internal/service/middleware_test.go index ce11f82..22690c1 100644 --- a/server/internal/service/middleware_test.go +++ b/server/internal/service/middleware_test.go @@ -54,7 +54,91 @@ func (m *mockNoteStore) GetVersionHistory(ctx context.Context, arg data.GetVersi return m.GetVersionHistoryFunc(ctx, arg) } -func TestAuthMiddleware(t *testing.T) { +func TestRequireRTMiddleware(t *testing.T) { + secret := "test-jwt-secret" + testUserID := uuid.New().String() + + validRT := generateTestToken(t, secret, "refresh", testUserID, true) + validAT := generateTestToken(t, secret, "access", testUserID, true) + expiredRT := generateTestToken(t, secret, "refresh", testUserID, true, func(claims *userClaims) { + claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(-1 * time.Hour)) + }) + + tests := []struct { + name string + token string + expectedErr string + statusCode int + }{ + { + "no token", + "", + "Unauthorized", + http.StatusUnauthorized, + }, + { + "invalid token", + "invalid", + "Invalid token", + http.StatusUnauthorized, + }, + { + "expired token", + expiredRT, + "Invalid token", + http.StatusUnauthorized, + }, + { + "wrong token type", + validAT, + "Invalid token type", + http.StatusUnauthorized, + }, + { + "valid token", + validRT, + "", + http.StatusOK, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + rtAuthMiddleware := requireRefreshToken(secret) + + // Mock request with cookie + req := httptest.NewRequest("GET", "/", nil) + if tc.token != "" { + req.AddCookie(&http.Cookie{ + Name: "refresh_token", + Value: tc.token, + }) + } + + w := httptest.NewRecorder() + called := false + + // Mock endpoint that the middleware protects + handler := rtAuthMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + claims, ok := r.Context().Value(userCtxKey{}).(*userClaims) + assert.True(t, ok) + assert.Equal(t, "refresh", claims.TokenType) + assert.Equal(t, testUserID, claims.Subject) + })) + + handler.ServeHTTP(w, req) + + assert.Equal(t, tc.statusCode, w.Code) + assert.Equal(t, tc.statusCode == http.StatusOK, called) + if tc.expectedErr != "" { + assert.Contains(t, w.Body.String(), tc.expectedErr) + } + }) + } +} + +func TestRequireATMiddleware(t *testing.T) { secret := "test-jwt-secret" testUserID := uuid.New().String()