fix: parse refresh_token from httpOnly cookie instead of header
This commit is contained in:
parent
e8b20d24fc
commit
24f4d8023e
@ -452,8 +452,8 @@ func (rs authResource) RefreshAccessToken(w http.ResponseWriter, r *http.Request
|
||||
return
|
||||
}
|
||||
|
||||
// Attempt to get the token from Authorization header (formatted as "Bearer <token>")
|
||||
refreshToken, err := getTokenFromRequest(r)
|
||||
// Attempt to get the token from httpOnly cookie
|
||||
refreshToken, err := getTokenFromCookie(r)
|
||||
if err != nil {
|
||||
respondError(w, http.StatusUnauthorized, "Unauthorized")
|
||||
return
|
||||
@ -578,8 +578,8 @@ func CreateAdminIfNotExists(ctx context.Context, q *data.Queries, username, pass
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse the JWT bearer token from the request's Authorization header.
|
||||
func getTokenFromRequest(r *http.Request) (string, error) {
|
||||
// Parse the JWT token from the request's Authorization header.
|
||||
func getTokenFromHeader(r *http.Request) (string, error) {
|
||||
bearerToken := r.Header.Get("Authorization")
|
||||
bearerFields := strings.Fields(bearerToken)
|
||||
if len(bearerFields) == 2 && strings.ToLower(bearerFields[0]) == "bearer" {
|
||||
@ -588,6 +588,15 @@ func getTokenFromRequest(r *http.Request) (string, error) {
|
||||
return "", ErrAuthHeaderInvalid
|
||||
}
|
||||
|
||||
// Parse the JWT token from the request's cookies (httpOnly cookie).
|
||||
func getTokenFromCookie(r *http.Request) (string, error) {
|
||||
cookie, err := r.Cookie("refresh_token")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return cookie.Value, nil
|
||||
}
|
||||
|
||||
// Helper function for generating a new JWT token pair with the given specifications.
|
||||
func generateTokenPair(userID string, isAdmin bool, jwtSecret string) (*tokenPair, error) {
|
||||
atClaims := userClaims{
|
||||
|
@ -32,7 +32,16 @@ type uuidCtxKey struct {
|
||||
func authMiddleware(jwtSecret string, expectedType string) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
tokenString, err := getTokenFromRequest(r)
|
||||
var tokenString string
|
||||
var err error
|
||||
|
||||
// Get token from appropriate source based on type (Authorization header or httpOnly cookie)
|
||||
if expectedType == "refresh" {
|
||||
tokenString, err = getTokenFromCookie(r)
|
||||
} else {
|
||||
tokenString, err = getTokenFromHeader(r)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
respondError(w, http.StatusUnauthorized, "Unauthorized")
|
||||
return
|
||||
|
@ -54,7 +54,91 @@ func (m *mockNoteStore) GetVersionHistory(ctx context.Context, arg data.GetVersi
|
||||
return m.GetVersionHistoryFunc(ctx, arg)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware(t *testing.T) {
|
||||
func TestRequireRTMiddleware(t *testing.T) {
|
||||
secret := "test-jwt-secret"
|
||||
testUserID := uuid.New().String()
|
||||
|
||||
validRT := generateTestToken(t, secret, "refresh", testUserID, true)
|
||||
validAT := generateTestToken(t, secret, "access", testUserID, true)
|
||||
expiredRT := generateTestToken(t, secret, "refresh", testUserID, true, func(claims *userClaims) {
|
||||
claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(-1 * time.Hour))
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
expectedErr string
|
||||
statusCode int
|
||||
}{
|
||||
{
|
||||
"no token",
|
||||
"",
|
||||
"Unauthorized",
|
||||
http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
"invalid token",
|
||||
"invalid",
|
||||
"Invalid token",
|
||||
http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
"expired token",
|
||||
expiredRT,
|
||||
"Invalid token",
|
||||
http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
"wrong token type",
|
||||
validAT,
|
||||
"Invalid token type",
|
||||
http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
"valid token",
|
||||
validRT,
|
||||
"",
|
||||
http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
rtAuthMiddleware := requireRefreshToken(secret)
|
||||
|
||||
// Mock request with cookie
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
if tc.token != "" {
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: "refresh_token",
|
||||
Value: tc.token,
|
||||
})
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
called := false
|
||||
|
||||
// Mock endpoint that the middleware protects
|
||||
handler := rtAuthMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
called = true
|
||||
claims, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "refresh", claims.TokenType)
|
||||
assert.Equal(t, testUserID, claims.Subject)
|
||||
}))
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, tc.statusCode, w.Code)
|
||||
assert.Equal(t, tc.statusCode == http.StatusOK, called)
|
||||
if tc.expectedErr != "" {
|
||||
assert.Contains(t, w.Body.String(), tc.expectedErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireATMiddleware(t *testing.T) {
|
||||
secret := "test-jwt-secret"
|
||||
testUserID := uuid.New().String()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user