182 lines
5.0 KiB
Go
182 lines
5.0 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
"git.umbrella.haus/ae/notatest/pkg/data"
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/google/uuid"
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
type mockTokenStore struct {
|
|
CreateRefreshTokenFunc func(context.Context, data.CreateRefreshTokenParams) (data.RefreshToken, error)
|
|
GetRefreshTokenByHashFunc func(context.Context, string) (data.RefreshToken, error)
|
|
RevokeRefreshTokenFunc func(context.Context, string) error
|
|
RevokeAllUserRefreshTokensFunc func(context.Context, uuid.UUID) error
|
|
}
|
|
|
|
func (m *mockTokenStore) CreateRefreshToken(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) {
|
|
return m.CreateRefreshTokenFunc(ctx, arg)
|
|
}
|
|
|
|
func (m *mockTokenStore) GetRefreshTokenByHash(ctx context.Context, tokenHash string) (data.RefreshToken, error) {
|
|
return m.GetRefreshTokenByHashFunc(ctx, tokenHash)
|
|
}
|
|
|
|
func (m *mockTokenStore) RevokeRefreshToken(ctx context.Context, token string) error {
|
|
return m.RevokeRefreshTokenFunc(ctx, token)
|
|
}
|
|
|
|
func (m *mockTokenStore) RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error {
|
|
return m.RevokeAllUserRefreshTokensFunc(ctx, id)
|
|
}
|
|
|
|
func TestGenerateTokenPair_Success(t *testing.T) {
|
|
userID := uuid.New()
|
|
called := false
|
|
|
|
mockStore := &mockTokenStore{
|
|
CreateRefreshTokenFunc: func(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) {
|
|
called = true
|
|
assert.Equal(t, userID, arg.UserID)
|
|
return data.RefreshToken{}, nil
|
|
},
|
|
}
|
|
|
|
rs := tokensResource{
|
|
JWTSecret: "test-secret",
|
|
Tokens: mockStore,
|
|
}
|
|
|
|
pair, err := rs.GenerateTokenPair(context.Background(), userID, false)
|
|
assert.NoError(t, err)
|
|
assert.True(t, called)
|
|
assert.NotEmpty(t, pair.AccessToken)
|
|
assert.NotEmpty(t, pair.RefreshToken)
|
|
}
|
|
|
|
func TestGenerateTokenPair_DBError(t *testing.T) {
|
|
mockStore := &mockTokenStore{
|
|
CreateRefreshTokenFunc: func(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) {
|
|
return data.RefreshToken{}, errors.New("db error")
|
|
},
|
|
}
|
|
|
|
rs := tokensResource{Tokens: mockStore}
|
|
_, err := rs.GenerateTokenPair(context.Background(), uuid.New(), false)
|
|
assert.ErrorContains(t, err, "db error")
|
|
}
|
|
|
|
func TestValidateRefreshToken_Valid(t *testing.T) {
|
|
token := "valid-jwt-token"
|
|
hash := sha256.Sum256([]byte(token))
|
|
expectedHash := hex.EncodeToString(hash[:])
|
|
|
|
mockStore := &mockTokenStore{
|
|
GetRefreshTokenByHashFunc: func(ctx context.Context, tokenHash string) (data.RefreshToken, error) {
|
|
assert.Equal(t, expectedHash, tokenHash)
|
|
return data.RefreshToken{ExpiresAt: time.Now().Add(1 * time.Hour)}, nil
|
|
},
|
|
}
|
|
|
|
rs := tokensResource{Tokens: mockStore}
|
|
_, err := rs.ValidateRefreshToken(context.Background(), token)
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
func TestRefreshAccessToken_Success(t *testing.T) {
|
|
userID := uuid.New()
|
|
refreshToken := "valid-jwt-token"
|
|
|
|
// Expected hash of the test token
|
|
hash := sha256.Sum256([]byte(refreshToken))
|
|
expectedHash := hex.EncodeToString(hash[:])
|
|
|
|
mockStore := &mockTokenStore{
|
|
GetRefreshTokenByHashFunc: func(ctx context.Context, tokenHash string) (data.RefreshToken, error) {
|
|
assert.Equal(t, expectedHash, tokenHash)
|
|
|
|
// Must return an unrevoked token with future expiration
|
|
return data.RefreshToken{
|
|
TokenHash: tokenHash,
|
|
ExpiresAt: time.Now().Add(1 * time.Hour),
|
|
Revoked: false,
|
|
}, nil
|
|
},
|
|
RevokeRefreshTokenFunc: func(ctx context.Context, tokenHash string) error {
|
|
assert.Equal(t, expectedHash, tokenHash)
|
|
return nil
|
|
},
|
|
CreateRefreshTokenFunc: func(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) {
|
|
return data.RefreshToken{}, nil
|
|
},
|
|
}
|
|
|
|
rs := tokensResource{
|
|
JWTSecret: "test-secret",
|
|
Tokens: mockStore,
|
|
}
|
|
|
|
req := httptest.NewRequest("POST", "/", nil)
|
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", refreshToken))
|
|
req = req.WithContext(context.WithValue(
|
|
req.Context(),
|
|
userCtxKey{},
|
|
&userClaims{
|
|
Admin: false,
|
|
TokenType: "refresh",
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
Subject: userID.String(),
|
|
},
|
|
},
|
|
))
|
|
|
|
w := httptest.NewRecorder()
|
|
rs.RefreshAccessToken(w, req)
|
|
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
assert.Contains(t, w.Body.String(), "access_token")
|
|
assert.Contains(t, w.Result().Cookies()[0].Name, "refresh_token")
|
|
}
|
|
|
|
func TestHandleLogout_Success(t *testing.T) {
|
|
userID := uuid.New()
|
|
called := false
|
|
|
|
mockStore := &mockTokenStore{
|
|
RevokeAllUserRefreshTokensFunc: func(ctx context.Context, id uuid.UUID) error {
|
|
called = true
|
|
assert.Equal(t, userID, id)
|
|
return nil
|
|
},
|
|
}
|
|
|
|
rs := tokensResource{Tokens: mockStore}
|
|
|
|
req := httptest.NewRequest("POST", "/", nil)
|
|
req = req.WithContext(context.WithValue(
|
|
req.Context(),
|
|
userCtxKey{},
|
|
&userClaims{
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
ID: userID.String(),
|
|
},
|
|
},
|
|
))
|
|
|
|
w := httptest.NewRecorder()
|
|
rs.HandleLogout(w, req)
|
|
|
|
assert.True(t, called)
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
}
|