fix: parse refresh_token from httpOnly cookie instead of header

This commit is contained in:
ae 2025-04-10 12:22:08 +03:00
parent e8b20d24fc
commit 24f4d8023e
Signed by: ae
GPG Key ID: 995EFD5C1B532B3E
3 changed files with 108 additions and 6 deletions

View File

@ -452,8 +452,8 @@ func (rs authResource) RefreshAccessToken(w http.ResponseWriter, r *http.Request
return return
} }
// Attempt to get the token from Authorization header (formatted as "Bearer <token>") // Attempt to get the token from httpOnly cookie
refreshToken, err := getTokenFromRequest(r) refreshToken, err := getTokenFromCookie(r)
if err != nil { if err != nil {
respondError(w, http.StatusUnauthorized, "Unauthorized") respondError(w, http.StatusUnauthorized, "Unauthorized")
return return
@ -578,8 +578,8 @@ func CreateAdminIfNotExists(ctx context.Context, q *data.Queries, username, pass
return nil return nil
} }
// Parse the JWT bearer token from the request's Authorization header. // Parse the JWT token from the request's Authorization header.
func getTokenFromRequest(r *http.Request) (string, error) { func getTokenFromHeader(r *http.Request) (string, error) {
bearerToken := r.Header.Get("Authorization") bearerToken := r.Header.Get("Authorization")
bearerFields := strings.Fields(bearerToken) bearerFields := strings.Fields(bearerToken)
if len(bearerFields) == 2 && strings.ToLower(bearerFields[0]) == "bearer" { if len(bearerFields) == 2 && strings.ToLower(bearerFields[0]) == "bearer" {
@ -588,6 +588,15 @@ func getTokenFromRequest(r *http.Request) (string, error) {
return "", ErrAuthHeaderInvalid return "", ErrAuthHeaderInvalid
} }
// Parse the JWT token from the request's cookies (httpOnly cookie).
func getTokenFromCookie(r *http.Request) (string, error) {
cookie, err := r.Cookie("refresh_token")
if err != nil {
return "", err
}
return cookie.Value, nil
}
// Helper function for generating a new JWT token pair with the given specifications. // Helper function for generating a new JWT token pair with the given specifications.
func generateTokenPair(userID string, isAdmin bool, jwtSecret string) (*tokenPair, error) { func generateTokenPair(userID string, isAdmin bool, jwtSecret string) (*tokenPair, error) {
atClaims := userClaims{ atClaims := userClaims{

View File

@ -32,7 +32,16 @@ type uuidCtxKey struct {
func authMiddleware(jwtSecret string, expectedType string) func(http.Handler) http.Handler { func authMiddleware(jwtSecret string, expectedType string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tokenString, err := getTokenFromRequest(r) var tokenString string
var err error
// Get token from appropriate source based on type (Authorization header or httpOnly cookie)
if expectedType == "refresh" {
tokenString, err = getTokenFromCookie(r)
} else {
tokenString, err = getTokenFromHeader(r)
}
if err != nil { if err != nil {
respondError(w, http.StatusUnauthorized, "Unauthorized") respondError(w, http.StatusUnauthorized, "Unauthorized")
return return

View File

@ -54,7 +54,91 @@ func (m *mockNoteStore) GetVersionHistory(ctx context.Context, arg data.GetVersi
return m.GetVersionHistoryFunc(ctx, arg) return m.GetVersionHistoryFunc(ctx, arg)
} }
func TestAuthMiddleware(t *testing.T) { func TestRequireRTMiddleware(t *testing.T) {
secret := "test-jwt-secret"
testUserID := uuid.New().String()
validRT := generateTestToken(t, secret, "refresh", testUserID, true)
validAT := generateTestToken(t, secret, "access", testUserID, true)
expiredRT := generateTestToken(t, secret, "refresh", testUserID, true, func(claims *userClaims) {
claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(-1 * time.Hour))
})
tests := []struct {
name string
token string
expectedErr string
statusCode int
}{
{
"no token",
"",
"Unauthorized",
http.StatusUnauthorized,
},
{
"invalid token",
"invalid",
"Invalid token",
http.StatusUnauthorized,
},
{
"expired token",
expiredRT,
"Invalid token",
http.StatusUnauthorized,
},
{
"wrong token type",
validAT,
"Invalid token type",
http.StatusUnauthorized,
},
{
"valid token",
validRT,
"",
http.StatusOK,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
rtAuthMiddleware := requireRefreshToken(secret)
// Mock request with cookie
req := httptest.NewRequest("GET", "/", nil)
if tc.token != "" {
req.AddCookie(&http.Cookie{
Name: "refresh_token",
Value: tc.token,
})
}
w := httptest.NewRecorder()
called := false
// Mock endpoint that the middleware protects
handler := rtAuthMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
claims, ok := r.Context().Value(userCtxKey{}).(*userClaims)
assert.True(t, ok)
assert.Equal(t, "refresh", claims.TokenType)
assert.Equal(t, testUserID, claims.Subject)
}))
handler.ServeHTTP(w, req)
assert.Equal(t, tc.statusCode, w.Code)
assert.Equal(t, tc.statusCode == http.StatusOK, called)
if tc.expectedErr != "" {
assert.Contains(t, w.Body.String(), tc.expectedErr)
}
})
}
}
func TestRequireATMiddleware(t *testing.T) {
secret := "test-jwt-secret" secret := "test-jwt-secret"
testUserID := uuid.New().String() testUserID := uuid.New().String()