notatest/server/pkg/service/tokens_test.go

181 lines
4.9 KiB
Go

package service
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"git.umbrella.haus/ae/notatest/pkg/data"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
type mockTokenStore struct {
CreateRefreshTokenFunc func(context.Context, data.CreateRefreshTokenParams) (data.RefreshToken, error)
GetRefreshTokenByHashFunc func(context.Context, string) (data.RefreshToken, error)
RevokeRefreshTokenFunc func(context.Context, string) error
RevokeAllUserRefreshTokensFunc func(context.Context, uuid.UUID) error
}
func (m *mockTokenStore) CreateRefreshToken(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) {
return m.CreateRefreshTokenFunc(ctx, arg)
}
func (m *mockTokenStore) GetRefreshTokenByHash(ctx context.Context, tokenHash string) (data.RefreshToken, error) {
return m.GetRefreshTokenByHashFunc(ctx, tokenHash)
}
func (m *mockTokenStore) RevokeRefreshToken(ctx context.Context, token string) error {
return m.RevokeRefreshTokenFunc(ctx, token)
}
func (m *mockTokenStore) RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error {
return m.RevokeAllUserRefreshTokensFunc(ctx, id)
}
func TestGenerateTokenPair_Success(t *testing.T) {
userID := uuid.New()
called := false
mockStore := &mockTokenStore{
CreateRefreshTokenFunc: func(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) {
called = true
assert.Equal(t, userID, arg.UserID)
return data.RefreshToken{}, nil
},
}
ts := tokenService{
JWTSecret: "test-secret",
Tokens: mockStore,
}
pair, err := ts.GenerateTokenPair(context.Background(), userID, false)
assert.NoError(t, err)
assert.True(t, called)
assert.NotEmpty(t, pair.AccessToken)
assert.NotEmpty(t, pair.RefreshToken)
}
func TestGenerateTokenPair_DBError(t *testing.T) {
mockStore := &mockTokenStore{
CreateRefreshTokenFunc: func(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) {
return data.RefreshToken{}, errors.New("db error")
},
}
ts := tokenService{Tokens: mockStore}
_, err := ts.GenerateTokenPair(context.Background(), uuid.New(), false)
assert.ErrorContains(t, err, "db error")
}
func TestValidateRefreshToken_Valid(t *testing.T) {
token := "valid-jwt-token"
hash := sha256.Sum256([]byte(token))
expectedHash := hex.EncodeToString(hash[:])
mockStore := &mockTokenStore{
GetRefreshTokenByHashFunc: func(ctx context.Context, tokenHash string) (data.RefreshToken, error) {
assert.Equal(t, expectedHash, tokenHash)
return data.RefreshToken{ExpiresAt: time.Now().Add(1 * time.Hour)}, nil
},
}
ts := tokenService{Tokens: mockStore}
_, err := ts.ValidateRefreshToken(context.Background(), token)
assert.NoError(t, err)
}
func TestRefreshAccessToken_Success(t *testing.T) {
userID := uuid.New()
refreshToken := "valid-jwt-token"
// Expected hash of the test token
hash := sha256.Sum256([]byte(refreshToken))
expectedHash := hex.EncodeToString(hash[:])
mockStore := &mockTokenStore{
GetRefreshTokenByHashFunc: func(ctx context.Context, tokenHash string) (data.RefreshToken, error) {
assert.Equal(t, expectedHash, tokenHash)
// Must return an unrevoked token with future expiration
return data.RefreshToken{
TokenHash: tokenHash,
ExpiresAt: time.Now().Add(1 * time.Hour),
Revoked: false,
}, nil
},
RevokeRefreshTokenFunc: func(ctx context.Context, tokenHash string) error {
assert.Equal(t, expectedHash, tokenHash)
return nil
},
CreateRefreshTokenFunc: func(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) {
return data.RefreshToken{}, nil
},
}
ts := tokenService{
JWTSecret: "test-secret",
Tokens: mockStore,
}
req := httptest.NewRequest("POST", "/", nil)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", refreshToken))
req = req.WithContext(context.WithValue(
req.Context(),
userCtxKey{},
&userClaims{
Admin: false,
TokenType: "refresh",
RegisteredClaims: jwt.RegisteredClaims{
Subject: userID.String(),
},
},
))
w := httptest.NewRecorder()
ts.RefreshAccessToken(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "access_token", "refresh_token")
}
func TestHandleLogout_Success(t *testing.T) {
userID := uuid.New()
called := false
mockStore := &mockTokenStore{
RevokeAllUserRefreshTokensFunc: func(ctx context.Context, id uuid.UUID) error {
called = true
assert.Equal(t, userID, id)
return nil
},
}
ts := tokenService{Tokens: mockStore}
req := httptest.NewRequest("POST", "/", nil)
req = req.WithContext(context.WithValue(
req.Context(),
userCtxKey{},
&userClaims{
RegisteredClaims: jwt.RegisteredClaims{
ID: userID.String(),
},
},
))
w := httptest.NewRecorder()
ts.HandleLogout(w, req)
assert.True(t, called)
assert.Equal(t, http.StatusOK, w.Code)
}