package service import ( "context" "crypto/sha256" "encoding/hex" "errors" "fmt" "net/http" "net/http/httptest" "testing" "time" "git.umbrella.haus/ae/notatest/pkg/data" "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" "github.com/stretchr/testify/assert" ) type mockTokenStore struct { CreateRefreshTokenFunc func(context.Context, data.CreateRefreshTokenParams) (data.RefreshToken, error) GetRefreshTokenByHashFunc func(context.Context, string) (data.RefreshToken, error) RevokeRefreshTokenFunc func(context.Context, string) error RevokeAllUserRefreshTokensFunc func(context.Context, uuid.UUID) error } func (m *mockTokenStore) CreateRefreshToken(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) { return m.CreateRefreshTokenFunc(ctx, arg) } func (m *mockTokenStore) GetRefreshTokenByHash(ctx context.Context, tokenHash string) (data.RefreshToken, error) { return m.GetRefreshTokenByHashFunc(ctx, tokenHash) } func (m *mockTokenStore) RevokeRefreshToken(ctx context.Context, token string) error { return m.RevokeRefreshTokenFunc(ctx, token) } func (m *mockTokenStore) RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error { return m.RevokeAllUserRefreshTokensFunc(ctx, id) } func TestGenerateTokenPair_Success(t *testing.T) { userID := uuid.New() called := false mockStore := &mockTokenStore{ CreateRefreshTokenFunc: func(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) { called = true assert.Equal(t, userID, arg.UserID) return data.RefreshToken{}, nil }, } ts := tokenService{ JWTSecret: "test-secret", Tokens: mockStore, } pair, err := ts.GenerateTokenPair(context.Background(), userID, false) assert.NoError(t, err) assert.True(t, called) assert.NotEmpty(t, pair.AccessToken) assert.NotEmpty(t, pair.RefreshToken) } func TestGenerateTokenPair_DBError(t *testing.T) { mockStore := &mockTokenStore{ CreateRefreshTokenFunc: func(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) { return data.RefreshToken{}, errors.New("db error") }, } ts := tokenService{Tokens: mockStore} _, err := ts.GenerateTokenPair(context.Background(), uuid.New(), false) assert.ErrorContains(t, err, "db error") } func TestValidateRefreshToken_Valid(t *testing.T) { token := "valid-jwt-token" hash := sha256.Sum256([]byte(token)) expectedHash := hex.EncodeToString(hash[:]) mockStore := &mockTokenStore{ GetRefreshTokenByHashFunc: func(ctx context.Context, tokenHash string) (data.RefreshToken, error) { assert.Equal(t, expectedHash, tokenHash) return data.RefreshToken{ExpiresAt: time.Now().Add(1 * time.Hour)}, nil }, } ts := tokenService{Tokens: mockStore} _, err := ts.ValidateRefreshToken(context.Background(), token) assert.NoError(t, err) } func TestRefreshAccessToken_Success(t *testing.T) { userID := uuid.New() refreshToken := "valid-jwt-token" // Expected hash of the test token hash := sha256.Sum256([]byte(refreshToken)) expectedHash := hex.EncodeToString(hash[:]) mockStore := &mockTokenStore{ GetRefreshTokenByHashFunc: func(ctx context.Context, tokenHash string) (data.RefreshToken, error) { assert.Equal(t, expectedHash, tokenHash) // Must return an unrevoked token with future expiration return data.RefreshToken{ TokenHash: tokenHash, ExpiresAt: time.Now().Add(1 * time.Hour), Revoked: false, }, nil }, RevokeRefreshTokenFunc: func(ctx context.Context, tokenHash string) error { assert.Equal(t, expectedHash, tokenHash) return nil }, CreateRefreshTokenFunc: func(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) { return data.RefreshToken{}, nil }, } ts := tokenService{ JWTSecret: "test-secret", Tokens: mockStore, } req := httptest.NewRequest("POST", "/", nil) req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", refreshToken)) req = req.WithContext(context.WithValue( req.Context(), userCtxKey{}, &userClaims{ Admin: false, TokenType: "refresh", RegisteredClaims: jwt.RegisteredClaims{ Subject: userID.String(), }, }, )) w := httptest.NewRecorder() ts.RefreshAccessToken(w, req) assert.Equal(t, http.StatusOK, w.Code) assert.Contains(t, w.Body.String(), "access_token", "refresh_token") } func TestHandleLogout_Success(t *testing.T) { userID := uuid.New() called := false mockStore := &mockTokenStore{ RevokeAllUserRefreshTokensFunc: func(ctx context.Context, id uuid.UUID) error { called = true assert.Equal(t, userID, id) return nil }, } ts := tokenService{Tokens: mockStore} req := httptest.NewRequest("POST", "/", nil) req = req.WithContext(context.WithValue( req.Context(), userCtxKey{}, &userClaims{ RegisteredClaims: jwt.RegisteredClaims{ ID: userID.String(), }, }, )) w := httptest.NewRecorder() ts.HandleLogout(w, req) assert.True(t, called) assert.Equal(t, http.StatusOK, w.Code) }