fix: parse refresh_token from httpOnly cookie instead of header
This commit is contained in:
parent
e8b20d24fc
commit
24f4d8023e
@ -452,8 +452,8 @@ func (rs authResource) RefreshAccessToken(w http.ResponseWriter, r *http.Request
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Attempt to get the token from Authorization header (formatted as "Bearer <token>")
|
// Attempt to get the token from httpOnly cookie
|
||||||
refreshToken, err := getTokenFromRequest(r)
|
refreshToken, err := getTokenFromCookie(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
respondError(w, http.StatusUnauthorized, "Unauthorized")
|
respondError(w, http.StatusUnauthorized, "Unauthorized")
|
||||||
return
|
return
|
||||||
@ -578,8 +578,8 @@ func CreateAdminIfNotExists(ctx context.Context, q *data.Queries, username, pass
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse the JWT bearer token from the request's Authorization header.
|
// Parse the JWT token from the request's Authorization header.
|
||||||
func getTokenFromRequest(r *http.Request) (string, error) {
|
func getTokenFromHeader(r *http.Request) (string, error) {
|
||||||
bearerToken := r.Header.Get("Authorization")
|
bearerToken := r.Header.Get("Authorization")
|
||||||
bearerFields := strings.Fields(bearerToken)
|
bearerFields := strings.Fields(bearerToken)
|
||||||
if len(bearerFields) == 2 && strings.ToLower(bearerFields[0]) == "bearer" {
|
if len(bearerFields) == 2 && strings.ToLower(bearerFields[0]) == "bearer" {
|
||||||
@ -588,6 +588,15 @@ func getTokenFromRequest(r *http.Request) (string, error) {
|
|||||||
return "", ErrAuthHeaderInvalid
|
return "", ErrAuthHeaderInvalid
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Parse the JWT token from the request's cookies (httpOnly cookie).
|
||||||
|
func getTokenFromCookie(r *http.Request) (string, error) {
|
||||||
|
cookie, err := r.Cookie("refresh_token")
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return cookie.Value, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Helper function for generating a new JWT token pair with the given specifications.
|
// Helper function for generating a new JWT token pair with the given specifications.
|
||||||
func generateTokenPair(userID string, isAdmin bool, jwtSecret string) (*tokenPair, error) {
|
func generateTokenPair(userID string, isAdmin bool, jwtSecret string) (*tokenPair, error) {
|
||||||
atClaims := userClaims{
|
atClaims := userClaims{
|
||||||
|
@ -32,7 +32,16 @@ type uuidCtxKey struct {
|
|||||||
func authMiddleware(jwtSecret string, expectedType string) func(http.Handler) http.Handler {
|
func authMiddleware(jwtSecret string, expectedType string) func(http.Handler) http.Handler {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
tokenString, err := getTokenFromRequest(r)
|
var tokenString string
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// Get token from appropriate source based on type (Authorization header or httpOnly cookie)
|
||||||
|
if expectedType == "refresh" {
|
||||||
|
tokenString, err = getTokenFromCookie(r)
|
||||||
|
} else {
|
||||||
|
tokenString, err = getTokenFromHeader(r)
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
respondError(w, http.StatusUnauthorized, "Unauthorized")
|
respondError(w, http.StatusUnauthorized, "Unauthorized")
|
||||||
return
|
return
|
||||||
|
@ -54,7 +54,91 @@ func (m *mockNoteStore) GetVersionHistory(ctx context.Context, arg data.GetVersi
|
|||||||
return m.GetVersionHistoryFunc(ctx, arg)
|
return m.GetVersionHistoryFunc(ctx, arg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAuthMiddleware(t *testing.T) {
|
func TestRequireRTMiddleware(t *testing.T) {
|
||||||
|
secret := "test-jwt-secret"
|
||||||
|
testUserID := uuid.New().String()
|
||||||
|
|
||||||
|
validRT := generateTestToken(t, secret, "refresh", testUserID, true)
|
||||||
|
validAT := generateTestToken(t, secret, "access", testUserID, true)
|
||||||
|
expiredRT := generateTestToken(t, secret, "refresh", testUserID, true, func(claims *userClaims) {
|
||||||
|
claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(-1 * time.Hour))
|
||||||
|
})
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
token string
|
||||||
|
expectedErr string
|
||||||
|
statusCode int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"no token",
|
||||||
|
"",
|
||||||
|
"Unauthorized",
|
||||||
|
http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"invalid token",
|
||||||
|
"invalid",
|
||||||
|
"Invalid token",
|
||||||
|
http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"expired token",
|
||||||
|
expiredRT,
|
||||||
|
"Invalid token",
|
||||||
|
http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"wrong token type",
|
||||||
|
validAT,
|
||||||
|
"Invalid token type",
|
||||||
|
http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"valid token",
|
||||||
|
validRT,
|
||||||
|
"",
|
||||||
|
http.StatusOK,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
rtAuthMiddleware := requireRefreshToken(secret)
|
||||||
|
|
||||||
|
// Mock request with cookie
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
if tc.token != "" {
|
||||||
|
req.AddCookie(&http.Cookie{
|
||||||
|
Name: "refresh_token",
|
||||||
|
Value: tc.token,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
called := false
|
||||||
|
|
||||||
|
// Mock endpoint that the middleware protects
|
||||||
|
handler := rtAuthMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
called = true
|
||||||
|
claims, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, "refresh", claims.TokenType)
|
||||||
|
assert.Equal(t, testUserID, claims.Subject)
|
||||||
|
}))
|
||||||
|
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.statusCode, w.Code)
|
||||||
|
assert.Equal(t, tc.statusCode == http.StatusOK, called)
|
||||||
|
if tc.expectedErr != "" {
|
||||||
|
assert.Contains(t, w.Body.String(), tc.expectedErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequireATMiddleware(t *testing.T) {
|
||||||
secret := "test-jwt-secret"
|
secret := "test-jwt-secret"
|
||||||
testUserID := uuid.New().String()
|
testUserID := uuid.New().String()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user