feat: modular api handlers for users/tokens (auth) incl. unit tests

This commit is contained in:
ae 2025-03-31 23:32:39 +03:00
parent 41d1336f58
commit 3257b19313
Signed by: ae
GPG Key ID: 995EFD5C1B532B3E
11 changed files with 1466 additions and 1 deletions

View File

@ -9,16 +9,22 @@ require (
github.com/google/uuid v1.6.0
github.com/jackc/pgx/v5 v5.7.4
github.com/rs/zerolog v1.34.0
github.com/stretchr/testify v1.10.0
github.com/wagslane/go-password-validator v0.3.0
golang.org/x/crypto v0.36.0
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/stretchr/testify v1.10.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.14.1 // indirect
github.com/stretchr/objx v0.5.2 // indirect
golang.org/x/sys v0.31.0 // indirect
golang.org/x/text v0.23.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@ -1,6 +1,7 @@
github.com/caarlos0/env v3.5.0+incompatible h1:Yy0UN8o9Wtr/jGHZDpCBLpNrzcFLLM2yixi/rBrKyJs=
github.com/caarlos0/env v3.5.0+incompatible/go.mod h1:tdCsowwCzMLdkqRYDlHpZCp2UooDD3MspDBjZ2AD02Y=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@ -19,6 +20,10 @@ github.com/jackc/pgx/v5 v5.7.4 h1:9wKznZrhWa2QiHL+NjTSPP6yjl3451BX3imWDnokYlg=
github.com/jackc/pgx/v5 v5.7.4/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
@ -29,10 +34,14 @@ github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
@ -51,6 +60,8 @@ golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -0,0 +1,157 @@
package service
import (
"context"
"fmt"
"net/http"
"time"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/rs/zerolog"
)
const (
panicRecoveryMsg = "panic recovered"
defaultLogMsg = "incoming request"
)
type userCtxKey struct{}
// Get JWT bearer from request's authorization header, parse it with custom user claims, and
// ensure its validity before attaching the claims to the request's context.
func authMiddleware(jwtSecret string, expectedType string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tokenString, err := getTokenFromRequest(r)
if err != nil {
respondError(w, http.StatusUnauthorized, fmt.Sprintf("Unauthorized: %s", err))
}
token, err := jwt.ParseWithClaims(tokenString, &userClaims{}, func(token *jwt.Token) (interface{}, error) {
return []byte(jwtSecret), nil
})
if err != nil || !token.Valid {
respondError(w, http.StatusUnauthorized, "Invalid token")
return
}
claims, ok := token.Claims.(*userClaims)
if !ok || claims.TokenType != expectedType {
respondError(w, http.StatusUnauthorized, "Invalid token type")
return
}
ctx := context.WithValue(r.Context(), userCtxKey{}, claims)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// JWT access token parsing, verification, and validation.
func requireAccessToken(jwtSecret string) func(http.Handler) http.Handler {
return authMiddleware(jwtSecret, "access")
}
// JWT refresh token parsing, verification, and validation.
func requireRefreshToken(jwtSecret string) func(http.Handler) http.Handler {
return authMiddleware(jwtSecret, "refresh")
}
// Ensure the current user is an administrator.
func adminOnlyMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
if !ok || !user.Admin {
respondError(w, http.StatusForbidden, "Forbidden")
return
}
next.ServeHTTP(w, r)
})
}
// Ensure the targeted resource is owned by the current user (i.e. current user's ID matches with
// the one stored into the resource).
func ownerOnlyMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
requestedID := chi.URLParam(r, "id")
if !ok || user.ID != requestedID {
respondError(w, http.StatusForbidden, "Forbidden")
return
}
next.ServeHTTP(w, r)
})
}
// Append user data into request's context based on user ID as a URL parameter.
func userCtx(store UserStore) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
userIDStr := chi.URLParam(r, "id")
userID, err := uuid.Parse(userIDStr)
if err != nil {
respondError(w, http.StatusNotFound, "Invalid user ID")
return
}
user, err := store.GetUserByID(r.Context(), userID)
if err != nil {
respondError(w, http.StatusNotFound, "User not found")
return
}
ctx := context.WithValue(r.Context(), userCtxKey{}, user)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// Zerolog compatible logger middleware. Automatically logs and recovers from errors with HTTP 500
// response, by default logs to INFO level.
func loggerMiddleware(log *zerolog.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
log := log.With().Logger()
ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
t1 := time.Now()
defer func() {
t2 := time.Now()
// Recover automatically and respond with HTTP 500
if rec := recover(); rec != nil {
log.Error().
Str("type", "error").
Timestamp().
Interface("recover_info", rec).
Msg(panicRecoveryMsg)
http.Error(ww, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}
// Log a regular HTTP request with some metadata
log.Info().
Str("type", "access").
Timestamp().
Fields(map[string]interface{}{
"remote_ip": r.RemoteAddr,
"url": r.URL.Path,
"proto": r.Proto,
"method": r.Method,
"user_agent": r.Header.Get("User-Agent"),
"status": ww.Status(),
"latency_ms": float64(t2.Sub(t1).Nanoseconds()) / 1000000.0,
"bytes_in": r.Header.Get("Content-Length"),
"bytes_out": ww.BytesWritten(),
}).
Msg(defaultLogMsg)
}()
next.ServeHTTP(ww, r)
}
return http.HandlerFunc(fn)
}
}

View File

@ -0,0 +1,19 @@
package service
import (
"github.com/go-chi/chi/v5"
)
// Mockable database operations interface
type NoteStore interface {
// TODO: implement
}
type notesResource struct {
Notes NoteStore
}
func (rs notesResource) Routes() chi.Router {
r := chi.NewRouter()
return r
}

View File

@ -0,0 +1,43 @@
package service
import (
"fmt"
"net/http"
"git.umbrella.haus/ae/notatest/pkg/data"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/jackc/pgx/v5"
"github.com/rs/zerolog/log"
)
func Run(conn *pgx.Conn, jwtSecret string, httpPort string) error {
q := data.New(conn)
r := chi.NewRouter()
tokenService := tokenService{
JWTSecret: jwtSecret,
Tokens: q,
}
usersRouter := usersResource{
JWTSecret: jwtSecret,
Users: q,
}
notesRouter := notesResource{}
// Global middlewares
r.Use(middleware.RequestID)
r.Use(middleware.RealIP)
r.Use(loggerMiddleware(&log.Logger))
r.Use(middleware.Recoverer)
r.Use(middleware.AllowContentType("application/json"))
// Routes grouped by functionality
r.Post("/auth/refresh", tokenService.RefreshAccessToken) // POST /auth/refresh - new access token for refresh token
r.Mount("/users", usersRouter.Routes())
r.Mount("/notes", notesRouter.Routes())
portStr := fmt.Sprintf(":%s", httpPort)
log.Info().Msgf("Starting server on %s", portStr)
return http.ListenAndServe(portStr, r)
}

View File

@ -0,0 +1,205 @@
package service
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"net/http"
"strings"
"time"
"git.umbrella.haus/ae/notatest/pkg/data"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
)
const (
accessTokenDuration = 15 * time.Minute
refreshTokenDuration = 7 * 24 * time.Hour
)
var (
ErrInvalidToken = errors.New("invalid token")
ErrAuthHeaderInvalid = errors.New("token couldn't be parsed from authentication header")
)
type userClaims struct {
Admin bool `json:"admin"`
TokenType string `json:"type"` // "access" or "refresh"
jwt.RegisteredClaims // User's UUID should be stored in the subject claim
}
type tokenPair struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
}
// Mockable database operations interface
type TokenStore interface {
CreateRefreshToken(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error)
GetRefreshTokenByHash(ctx context.Context, tokenHash string) (data.RefreshToken, error)
RevokeRefreshToken(ctx context.Context, tokenHash string) error
RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error
}
type tokenService struct {
JWTSecret string
Tokens TokenStore
}
func (ts *tokenService) GenerateTokenPair(ctx context.Context, userID uuid.UUID, isAdmin bool) (*tokenPair, error) {
tokenPair, err := generateTokenPair(userID.String(), isAdmin, ts.JWTSecret)
if err != nil {
return nil, err
}
hash := sha256.Sum256([]byte(tokenPair.RefreshToken))
tokenHash := hex.EncodeToString(hash[:])
// Store to DB with (almost) identical expiration timestamp
expiresAt := time.Now().Add(refreshTokenDuration)
_, err = ts.Tokens.CreateRefreshToken(ctx, data.CreateRefreshTokenParams{
UserID: userID,
TokenHash: tokenHash,
ExpiresAt: expiresAt,
})
if err != nil {
return nil, err
}
return tokenPair, nil
}
func (ts *tokenService) RevokeRefreshToken(ctx context.Context, token string) error {
hash := sha256.Sum256([]byte(token))
tokenHash := hex.EncodeToString(hash[:])
return ts.Tokens.RevokeRefreshToken(ctx, tokenHash)
}
func (ts *tokenService) ValidateRefreshToken(ctx context.Context, token string) (*data.RefreshToken, error) {
hash := sha256.Sum256([]byte(token))
tokenHash := hex.EncodeToString(hash[:])
dbToken, err := ts.Tokens.GetRefreshTokenByHash(ctx, tokenHash)
if err != nil {
return nil, err
}
if dbToken.Revoked || time.Now().After(dbToken.ExpiresAt) {
return nil, ErrInvalidToken
}
return &dbToken, nil
}
func (ts *tokenService) RefreshAccessToken(w http.ResponseWriter, r *http.Request) {
// Get claims from context
claims, ok := r.Context().Value(userCtxKey{}).(*userClaims)
if !ok || claims.TokenType != "refresh" {
respondError(w, http.StatusUnauthorized, "Invalid token")
return
}
// Attempt to get the token from Authentication header ("Bearer <token>")
refreshToken, err := getTokenFromRequest(r)
if err != nil {
respondError(w, http.StatusUnauthorized, fmt.Sprintf("Unauthorized: %s", err))
}
// Validate the refresh token in DB
if _, err := ts.ValidateRefreshToken(r.Context(), refreshToken); err != nil {
respondError(w, http.StatusUnauthorized, "Invalid refresh token")
return
}
// Revoke the used refresh token
if err := ts.RevokeRefreshToken(r.Context(), refreshToken); err != nil {
respondError(w, http.StatusInternalServerError, "Failed to revoke token")
return
}
// Generate a new pair (access & refresh tokens)
userID, err := uuid.Parse(claims.Subject)
if err != nil {
respondError(w, http.StatusInternalServerError, "Invalid user ID")
return
}
tokenPair, err := ts.GenerateTokenPair(r.Context(), userID, claims.Admin)
if err != nil {
respondError(w, http.StatusInternalServerError, "Failed to generate tokens")
return
}
respondJSON(w, http.StatusOK, tokenPair)
}
func (ts *tokenService) HandleLogout(w http.ResponseWriter, r *http.Request) {
claims, ok := r.Context().Value(userCtxKey{}).(*userClaims)
if !ok {
respondError(w, http.StatusUnauthorized, "Not authenticated")
return
}
userID, err := uuid.Parse(claims.ID)
if err != nil {
respondError(w, http.StatusInternalServerError, "Invalid user ID")
return
}
if err := ts.Tokens.RevokeAllUserRefreshTokens(r.Context(), userID); err != nil {
respondError(w, http.StatusInternalServerError, "Failed to logout")
return
}
respondJSON(w, http.StatusOK, map[string]string{"status": "logged out"})
}
func getTokenFromRequest(r *http.Request) (string, error) {
bearerToken := r.Header.Get("Authorization")
bearerFields := strings.Fields(bearerToken)
if len(bearerFields) == 2 && strings.ToLower(bearerFields[0]) == "bearer" {
return bearerFields[1], nil
}
return "", ErrAuthHeaderInvalid
}
func generateTokenPair(userID string, isAdmin bool, jwtSecret string) (*tokenPair, error) {
atClaims := userClaims{
Admin: isAdmin,
TokenType: "access",
RegisteredClaims: jwt.RegisteredClaims{
Subject: userID,
ExpiresAt: jwt.NewNumericDate(time.Now().Add(accessTokenDuration)),
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
},
}
accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, atClaims)
t, err := accessToken.SignedString([]byte(jwtSecret))
if err != nil {
return nil, err
}
rtClaims := userClaims{
Admin: isAdmin,
TokenType: "refresh",
RegisteredClaims: jwt.RegisteredClaims{
Subject: userID,
ExpiresAt: jwt.NewNumericDate(time.Now().Add(refreshTokenDuration)),
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
},
}
refreshToken := jwt.NewWithClaims(jwt.SigningMethodHS256, rtClaims)
rt, err := refreshToken.SignedString([]byte(jwtSecret))
if err != nil {
return nil, err
}
return &tokenPair{AccessToken: t, RefreshToken: rt}, nil
}

View File

@ -0,0 +1,180 @@
package service
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"git.umbrella.haus/ae/notatest/pkg/data"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
type mockTokenStore struct {
CreateRefreshTokenFunc func(context.Context, data.CreateRefreshTokenParams) (data.RefreshToken, error)
GetRefreshTokenByHashFunc func(context.Context, string) (data.RefreshToken, error)
RevokeRefreshTokenFunc func(context.Context, string) error
RevokeAllUserRefreshTokensFunc func(context.Context, uuid.UUID) error
}
func (m *mockTokenStore) CreateRefreshToken(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) {
return m.CreateRefreshTokenFunc(ctx, arg)
}
func (m *mockTokenStore) GetRefreshTokenByHash(ctx context.Context, tokenHash string) (data.RefreshToken, error) {
return m.GetRefreshTokenByHashFunc(ctx, tokenHash)
}
func (m *mockTokenStore) RevokeRefreshToken(ctx context.Context, token string) error {
return m.RevokeRefreshTokenFunc(ctx, token)
}
func (m *mockTokenStore) RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error {
return m.RevokeAllUserRefreshTokensFunc(ctx, id)
}
func TestGenerateTokenPair_Success(t *testing.T) {
userID := uuid.New()
called := false
mockStore := &mockTokenStore{
CreateRefreshTokenFunc: func(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) {
called = true
assert.Equal(t, userID, arg.UserID)
return data.RefreshToken{}, nil
},
}
ts := tokenService{
JWTSecret: "test-secret",
Tokens: mockStore,
}
pair, err := ts.GenerateTokenPair(context.Background(), userID, false)
assert.NoError(t, err)
assert.True(t, called)
assert.NotEmpty(t, pair.AccessToken)
assert.NotEmpty(t, pair.RefreshToken)
}
func TestGenerateTokenPair_DBError(t *testing.T) {
mockStore := &mockTokenStore{
CreateRefreshTokenFunc: func(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) {
return data.RefreshToken{}, errors.New("db error")
},
}
ts := tokenService{Tokens: mockStore}
_, err := ts.GenerateTokenPair(context.Background(), uuid.New(), false)
assert.ErrorContains(t, err, "db error")
}
func TestValidateRefreshToken_Valid(t *testing.T) {
token := "valid-jwt-token"
hash := sha256.Sum256([]byte(token))
expectedHash := hex.EncodeToString(hash[:])
mockStore := &mockTokenStore{
GetRefreshTokenByHashFunc: func(ctx context.Context, tokenHash string) (data.RefreshToken, error) {
assert.Equal(t, expectedHash, tokenHash)
return data.RefreshToken{ExpiresAt: time.Now().Add(1 * time.Hour)}, nil
},
}
ts := tokenService{Tokens: mockStore}
_, err := ts.ValidateRefreshToken(context.Background(), token)
assert.NoError(t, err)
}
func TestRefreshAccessToken_Success(t *testing.T) {
userID := uuid.New()
refreshToken := "valid-jwt-token"
// Expected hash of the test token
hash := sha256.Sum256([]byte(refreshToken))
expectedHash := hex.EncodeToString(hash[:])
mockStore := &mockTokenStore{
GetRefreshTokenByHashFunc: func(ctx context.Context, tokenHash string) (data.RefreshToken, error) {
assert.Equal(t, expectedHash, tokenHash)
// Must return an unrevoked token with future expiration
return data.RefreshToken{
TokenHash: tokenHash,
ExpiresAt: time.Now().Add(1 * time.Hour),
Revoked: false,
}, nil
},
RevokeRefreshTokenFunc: func(ctx context.Context, tokenHash string) error {
assert.Equal(t, expectedHash, tokenHash)
return nil
},
CreateRefreshTokenFunc: func(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) {
return data.RefreshToken{}, nil
},
}
ts := tokenService{
JWTSecret: "test-secret",
Tokens: mockStore,
}
req := httptest.NewRequest("POST", "/", nil)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", refreshToken))
req = req.WithContext(context.WithValue(
req.Context(),
userCtxKey{},
&userClaims{
Admin: false,
TokenType: "refresh",
RegisteredClaims: jwt.RegisteredClaims{
Subject: userID.String(),
},
},
))
w := httptest.NewRecorder()
ts.RefreshAccessToken(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "access_token", "refresh_token")
}
func TestHandleLogout_Success(t *testing.T) {
userID := uuid.New()
called := false
mockStore := &mockTokenStore{
RevokeAllUserRefreshTokensFunc: func(ctx context.Context, id uuid.UUID) error {
called = true
assert.Equal(t, userID, id)
return nil
},
}
ts := tokenService{Tokens: mockStore}
req := httptest.NewRequest("POST", "/", nil)
req = req.WithContext(context.WithValue(
req.Context(),
userCtxKey{},
&userClaims{
RegisteredClaims: jwt.RegisteredClaims{
ID: userID.String(),
},
},
))
w := httptest.NewRecorder()
ts.HandleLogout(w, req)
assert.True(t, called)
assert.Equal(t, http.StatusOK, w.Code)
}

268
server/pkg/service/users.go Normal file
View File

@ -0,0 +1,268 @@
package service
import (
"context"
"encoding/json"
"errors"
"net/http"
"git.umbrella.haus/ae/notatest/pkg/data"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgconn"
"github.com/rs/zerolog/log"
"golang.org/x/crypto/bcrypt"
)
// Mockable database operations interface
type UserStore interface {
CreateUser(ctx context.Context, arg data.CreateUserParams) (data.User, error)
ListUsers(ctx context.Context) ([]data.User, error)
GetUserByID(ctx context.Context, id uuid.UUID) (data.User, error)
UpdatePassword(ctx context.Context, arg data.UpdatePasswordParams) error
DeleteUser(ctx context.Context, id uuid.UUID) error
RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error
}
type usersResource struct {
JWTSecret string
Users UserStore
}
func (rs usersResource) Routes() chi.Router {
r := chi.NewRouter()
// Public routes (no tokens required)
r.Post("/", rs.Create) // POST /users - registration/signup
// Protected routes (access token required)
r.Group(func(r chi.Router) {
r.Use(requireAccessToken(rs.JWTSecret))
// Admin only general routes
r.Group(func(r chi.Router) {
r.Use(adminOnlyMiddleware)
r.Get("/", rs.List) // GET /users - list all users
})
// User specific routes
r.Route("/{id}", func(r chi.Router) {
r.Use(userCtx(rs.Users)) // DB -> req. context
// Admin routes
r.Route("/admin", func(r chi.Router) {
r.Use(adminOnlyMiddleware)
r.Get("/", rs.Get) // GET /users/admin/{id} - get single user
r.Delete("/", rs.AdminDelete) // DELETE /users/admin/{id} - delete user
})
// Owner routes
r.Route("/owner", func(r chi.Router) {
r.Use(ownerOnlyMiddleware)
r.Put("/", rs.UpdatePassword) // PUT /users/owner/{id} - update user password
r.Delete("/", rs.OwnerDelete) // DELETE /users/owner/{id} - delete user
})
})
})
return r
}
func (rs usersResource) Create(w http.ResponseWriter, r *http.Request) {
type request struct {
Username string `json:"username"`
Password string `json:"password"`
}
var req request
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
respondError(w, http.StatusBadRequest, "Invalid request body")
return
}
normalizedUsername := normalizeUsername(req.Username)
if err := validateUsername(normalizedUsername); err != nil {
respondError(w, http.StatusBadRequest, err.Error())
return
}
if err := validatePassword(req.Password); err != nil {
respondError(w, http.StatusBadRequest, err.Error())
return
}
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
if err != nil {
respondError(w, http.StatusInternalServerError, "Failed to create user")
return
}
user, err := rs.Users.CreateUser(r.Context(), data.CreateUserParams{
Username: normalizedUsername,
PasswordHash: string(hashedPassword),
})
if err != nil {
if isDuplicateEntry(err) {
respondError(w, http.StatusConflict, "Username is already in use")
} else {
respondError(w, http.StatusInternalServerError, "Failed to create user")
}
return
}
respondJSON(w, http.StatusCreated, map[string]string{
"id": user.ID.String(),
"username": user.Username,
})
}
func (rs usersResource) List(w http.ResponseWriter, r *http.Request) {
users, err := rs.Users.ListUsers(r.Context())
if err != nil {
respondError(w, http.StatusInternalServerError, "Failed to retrieve users")
return
}
// Output sanitization
var output []map[string]interface{}
for _, user := range users {
output = append(output, map[string]interface{}{
"id": user.ID,
"username": user.Username,
"created_at": user.CreatedAt,
"updated_at": user.UpdatedAt,
})
}
respondJSON(w, http.StatusOK, output)
}
func (rs usersResource) Get(w http.ResponseWriter, r *http.Request) {
user, ok := r.Context().Value(userCtxKey{}).(data.User)
if !ok {
respondError(w, http.StatusNotFound, "User not found")
return
}
respondJSON(w, http.StatusOK, map[string]interface{}{
"id": user.ID,
"username": user.Username,
"created_at": user.CreatedAt,
"updated_at": user.UpdatedAt,
})
}
func (rs usersResource) UpdatePassword(w http.ResponseWriter, r *http.Request) {
type request struct {
OldPassword string `json:"old_password"`
NewPassword string `json:"new_password"`
}
var req request
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
respondError(w, http.StatusBadRequest, "Invalid request body")
return
}
user, ok := r.Context().Value(userCtxKey{}).(data.User)
if !ok {
respondError(w, http.StatusNotFound, "User not found")
return
}
// Verify the old password before allowing the update
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.OldPassword)); err != nil {
respondError(w, http.StatusUnauthorized, "Invalid credentials")
return
}
if err := validatePassword(req.NewPassword); err != nil {
respondError(w, http.StatusBadRequest, err.Error())
return
}
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.NewPassword), bcrypt.DefaultCost)
if err != nil {
respondError(w, http.StatusInternalServerError, "Failed to update password")
return
}
if err := rs.Users.UpdatePassword(r.Context(), data.UpdatePasswordParams{
ID: user.ID,
PasswordHash: string(hashedPassword),
}); err != nil {
respondError(w, http.StatusInternalServerError, "Failed to update password")
return
}
if err := rs.Users.RevokeAllUserRefreshTokens(r.Context(), user.ID); err != nil {
log.Error().Msgf("Failed to revoke refresh tokens: %s", err)
}
w.WriteHeader(http.StatusNoContent)
}
func (rs usersResource) AdminDelete(w http.ResponseWriter, r *http.Request) {
user, ok := r.Context().Value(userCtxKey{}).(data.User)
if !ok {
respondError(w, http.StatusNotFound, "User not found")
return
}
if err := rs.Users.DeleteUser(r.Context(), user.ID); err != nil {
respondError(w, http.StatusInternalServerError, "Failed to delete user")
return
}
if err := rs.Users.RevokeAllUserRefreshTokens(r.Context(), user.ID); err != nil {
log.Error().Msgf("Failed to revoke refresh tokens: %s", err)
}
w.WriteHeader(http.StatusNoContent)
}
func (rs usersResource) OwnerDelete(w http.ResponseWriter, r *http.Request) {
type request struct {
Password string `json:"password"`
}
var req request
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
respondError(w, http.StatusBadRequest, "Invalid request body")
return
}
user, ok := r.Context().Value(userCtxKey{}).(data.User)
if !ok {
respondError(w, http.StatusNotFound, "User not found")
return
}
// Verify the old password before allowing the deletion
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password)); err != nil {
respondError(w, http.StatusUnauthorized, "Invalid credentials")
return
}
err := rs.Users.DeleteUser(r.Context(), user.ID)
if err != nil {
respondError(w, http.StatusInternalServerError, "Failed to delete user")
return
}
if err := rs.Users.RevokeAllUserRefreshTokens(r.Context(), user.ID); err != nil {
log.Error().Msgf("Failed to revoke refresh tokens: %s", err)
}
w.WriteHeader(http.StatusNoContent)
}
// Check if the given error is a PostgreSQL error for `unique_violation`, i.e. whether an entry
// with the given details already exists in the database table.
func isDuplicateEntry(err error) bool {
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) {
return pgErr.Code == "23505"
}
return false
}

View File

@ -0,0 +1,259 @@
package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"git.umbrella.haus/ae/notatest/pkg/data"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgconn"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/bcrypt"
)
type mockUserStore struct {
CreateUserFunc func(context.Context, data.CreateUserParams) (data.User, error)
ListUsersFunc func(context.Context) ([]data.User, error)
GetUserByIDFunc func(context.Context, uuid.UUID) (data.User, error)
UpdatePasswordFunc func(context.Context, data.UpdatePasswordParams) error
DeleteUserFunc func(context.Context, uuid.UUID) error
RevokeAllUserRefreshTokensFunc func(context.Context, uuid.UUID) error
}
func (m *mockUserStore) CreateUser(ctx context.Context, arg data.CreateUserParams) (data.User, error) {
return m.CreateUserFunc(ctx, arg)
}
func (m *mockUserStore) ListUsers(ctx context.Context) ([]data.User, error) {
return m.ListUsersFunc(ctx)
}
func (m *mockUserStore) GetUserByID(ctx context.Context, id uuid.UUID) (data.User, error) {
return m.GetUserByIDFunc(ctx, id)
}
func (m *mockUserStore) UpdatePassword(ctx context.Context, arg data.UpdatePasswordParams) error {
return m.UpdatePasswordFunc(ctx, arg)
}
func (m *mockUserStore) DeleteUser(ctx context.Context, id uuid.UUID) error {
return m.DeleteUserFunc(ctx, id)
}
func (m *mockUserStore) RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error {
return m.RevokeAllUserRefreshTokensFunc(ctx, id)
}
func TestCreateUser_Duplicate(t *testing.T) {
mockStore := &mockUserStore{
CreateUserFunc: func(ctx context.Context, arg data.CreateUserParams) (data.User, error) {
return data.User{}, &pgconn.PgError{Code: "23505"}
},
}
rs := usersResource{Users: mockStore, JWTSecret: "test-secret"}
reqBody := `{"username": "existing", "password": "validPass123!"}`
req := httptest.NewRequest("POST", "/", strings.NewReader(reqBody))
w := httptest.NewRecorder()
rs.Create(w, req)
assert.Equal(t, http.StatusConflict, w.Code)
assert.Contains(t, w.Body.String(), "Username is already in use")
}
func TestCreateUser_InvalidUsername(t *testing.T) {
mockStore := &mockUserStore{} // No DB calls expected
rs := usersResource{Users: mockStore, JWTSecret: "test-secret"}
// Test various invalid usernames
tests := []struct {
name string
body string
}{
{"Too short", `{"username": "a", "password": "validPass123!"}`},
{"Invalid chars", `{"username": "user@name", "password": "validPass123!"}`},
{"Empty", `{"username": "", "password": "validPass123!"}`},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest("POST", "/", strings.NewReader(tc.body))
w := httptest.NewRecorder()
rs.Create(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
})
}
}
func TestCreateUser_InvalidPassword(t *testing.T) {
mockStore := &mockUserStore{}
rs := usersResource{Users: mockStore, JWTSecret: "test-secret"}
tests := []struct {
name string
body string
}{
{"too short", fmt.Sprintf(`{"username": "valid", "password": "%s"}`, strings.Repeat("a", minPasswordLength-1))},
{"too long", fmt.Sprintf(`{"username": "valid", "password": "%s"}`, strings.Repeat("a", maxPasswordLength+1))},
{"low entropy", fmt.Sprintf(`{"username": "valid", "password": "%s"}`, strings.Repeat("a", minPasswordLength))},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest("POST", "/", strings.NewReader(tc.body))
w := httptest.NewRecorder()
rs.Create(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
})
}
}
func TestListUsers_Success(t *testing.T) {
testUsers := []data.User{
{ID: uuid.New(), Username: "user1"},
{ID: uuid.New(), Username: "user2"},
}
mockStore := &mockUserStore{
ListUsersFunc: func(ctx context.Context) ([]data.User, error) {
return testUsers, nil
},
}
rs := usersResource{Users: mockStore}
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
rs.List(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var response []map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatal(err)
}
assert.Len(t, response, 2)
assert.Equal(t, "user1", response[0]["username"])
assert.NotContains(t, response[0], "password_hash")
}
func TestUpdatePassword_InvalidOldPassword(t *testing.T) {
// User with password hash that won't match "wrongpassword"
user := data.User{
ID: uuid.New(),
PasswordHash: "$2a$10$PHhno.bZBF8IEINdFRZAPujMxIN65msElATgJG6FIxZdeWYVLSfFi", // Hash of "correctpassword"
}
mockStore := &mockUserStore{}
rs := usersResource{Users: mockStore}
reqBody := `{"old_password": "wrongpassword", "new_password": "NewValidPass321!"}`
req := httptest.NewRequest("PUT", "/", strings.NewReader(reqBody))
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, user))
w := httptest.NewRecorder()
rs.UpdatePassword(w, req)
assert.Equal(t, http.StatusUnauthorized, w.Code)
}
func TestAdminDelete_Success(t *testing.T) {
user := data.User{ID: uuid.New()}
deleteCalled := false
revokeCalled := false
mockStore := &mockUserStore{
DeleteUserFunc: func(ctx context.Context, id uuid.UUID) error {
deleteCalled = true
assert.Equal(t, user.ID, id)
return nil
},
RevokeAllUserRefreshTokensFunc: func(ctx context.Context, id uuid.UUID) error {
revokeCalled = true
assert.Equal(t, user.ID, id)
return nil
},
}
rs := usersResource{Users: mockStore}
req := httptest.NewRequest("DELETE", "/", nil)
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, user))
w := httptest.NewRecorder()
rs.AdminDelete(w, req)
assert.Equal(t, http.StatusNoContent, w.Code)
assert.True(t, deleteCalled)
assert.True(t, revokeCalled)
}
func TestOwnerDelete_InvalidCredentials(t *testing.T) {
// Create user with known password hash
correctPassword := "CorrectPass123!"
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(correctPassword), bcrypt.DefaultCost)
user := data.User{
ID: uuid.New(),
PasswordHash: string(hashedPassword),
}
mockStore := &mockUserStore{}
rs := usersResource{Users: mockStore}
reqBody := `{"password": "wrongpassword"}`
req := httptest.NewRequest("DELETE", "/", strings.NewReader(reqBody))
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, user))
w := httptest.NewRecorder()
rs.OwnerDelete(w, req)
assert.Equal(t, http.StatusUnauthorized, w.Code)
}
func TestGetUser_NotFound(t *testing.T) {
mockStore := &mockUserStore{}
rs := usersResource{Users: mockStore}
// No user in context
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
rs.Get(w, req)
assert.Equal(t, http.StatusNotFound, w.Code)
}
func TestUpdatePassword_DatabaseError(t *testing.T) {
// Add user with a valid password to the context
oldPassword := "OldValidPass321!"
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(oldPassword), bcrypt.DefaultCost)
user := data.User{
ID: uuid.New(),
PasswordHash: string(hashedPassword),
}
mockStore := &mockUserStore{
UpdatePasswordFunc: func(ctx context.Context, arg data.UpdatePasswordParams) error {
return errors.New("database error")
},
}
rs := usersResource{Users: mockStore}
reqBody := fmt.Sprintf(`{"old_password": "%s", "new_password": "NewValidPass123!"}`, oldPassword)
req := httptest.NewRequest("PUT", "/", strings.NewReader(reqBody))
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, user))
w := httptest.NewRecorder()
rs.UpdatePassword(w, req)
assert.Equal(t, http.StatusInternalServerError, w.Code)
assert.Contains(t, w.Body.String(), "Failed to update password")
}

133
server/pkg/service/util.go Normal file
View File

@ -0,0 +1,133 @@
package service
import (
"crypto/sha1"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"regexp"
"strings"
"unicode/utf8"
passwordvalidator "github.com/wagslane/go-password-validator"
)
const (
minPasswordLength = 12 // Entropy checks prevent short passwords anyway
maxPasswordLength = 72 // Limitation of bcrypt
minPasswordEntropy = 60.0
minUsernameLength = 3
maxUsernameLength = 20
hibpAPI = "https://api.pwnedpasswords.com/range" // Doesn't require an API key
)
var (
usernameRegex = regexp.MustCompile("^[a-z0-9_]+$")
)
func respondJSON(w http.ResponseWriter, status int, data interface{}) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(data)
}
func respondError(w http.ResponseWriter, status int, message string) {
respondJSON(w, status, map[string]string{"error": message})
}
/*
Client-side check:
```
function estimateEntropy(password: string): number {
const pool: number = getCharsetSize(password); // Character diversity (R)
const entropy: number = password.length * Math.log2(pool); // E = L * log_2(R)
return entropy; // Value (E) that can be compared against a hardcoded threshold (e.g. 60)
}
```
*/
// Validate the given password using a hybrid approach: length (max. set due to bcrypt's input
// upper limit of 72 bytes), entropy, and HIBP API.
func validatePassword(password string) error {
if len(password) < minPasswordLength {
return fmt.Errorf("password must be at least %d characters", minPasswordLength)
}
if len(password) > maxPasswordLength {
return fmt.Errorf("password cannot be longer than %d characters", maxPasswordLength)
}
// Formatted error message will contain tips to increase the password strength (safe to show)
err := passwordvalidator.Validate(password, minPasswordEntropy)
if err != nil {
return err
}
if compromised, _ := isPasswordCompromised(password); compromised {
return errors.New("password is compromised")
}
return nil
}
// Send the first five bytes of the password's SHA-1 hash to HIBP API, then check if the rest of
// the hash is present in the APi's response data (k-Anonymity model).
func isPasswordCompromised(password string) (bool, error) {
hash := sha1.Sum([]byte(password))
hashStr := strings.ToUpper(hex.EncodeToString(hash[:]))
prefix, suffix := hashStr[:5], hashStr[5:]
resp, err := http.Get(fmt.Sprintf("%s/%s", hibpAPI, prefix))
if err != nil {
return false, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return false, err
}
return strings.Contains(strings.ToUpper(string(body)), suffix), nil
}
// Normalize the username by making it lowercase and trimming any leading or trailing whitespace.
func normalizeUsername(username string) string {
return strings.ToLower(strings.TrimSpace(username))
}
/*
Client-side check (additionally input should automatically perform the normalization steps):
```
function validateUsername(username: string): string {
const min: number = 3, max: number = 20;
if (username.length < min) return "Too short";
if (username.length > max) return "Too long";
if (!/^[a-zA-Z0-9_]+$/.test(username)) return "Invalid characters";
return "Valid";
}
```
*/
// Validate the given username by making sure it only contains alphanumeric characters or
// underscores and adheres the hardcoded minimum and maximum length rules.
func validateUsername(username string) error {
if utf8.RuneCountInString(username) < minUsernameLength {
return fmt.Errorf("username must be at least %d characters", minUsernameLength)
}
if utf8.RuneCountInString(username) > maxUsernameLength {
return fmt.Errorf("username cannot be longer than %d characters", maxUsernameLength)
}
if !usernameRegex.MatchString(username) {
return errors.New("username can only contain numbers, letters, and underscores")
}
return nil
}

View File

@ -0,0 +1,184 @@
package service
import (
"fmt"
"math/rand"
"net/http"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
passwordvalidator "github.com/wagslane/go-password-validator"
)
type MockHTTPClient struct {
mock.Mock
}
func (m *MockHTTPClient) Get(url string) (*http.Response, error) {
args := m.Called(url)
return args.Get(0).(*http.Response), args.Error(1)
}
func TestValidatePassword(t *testing.T) {
tests := []struct {
name string
password string
wantErr string
mockHTTP func(*MockHTTPClient)
}{
{
name: "too short",
password: strings.Repeat("a", minPasswordLength-1),
wantErr: fmt.Sprintf("password must be at least %d characters", minPasswordLength),
},
{
name: "too long",
password: strings.Repeat("a", maxPasswordLength+1),
wantErr: fmt.Sprintf("password cannot be longer than %d characters", maxPasswordLength),
},
{
name: "low entropy",
password: strings.Repeat("a", minPasswordLength),
wantErr: "insecure password", // Error produced by wagslane/go-password-validator
},
{
name: "valid password",
password: "SecurePassw0rd!123",
wantErr: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Mock HTTP client if needed
if tt.mockHTTP != nil {
mockClient := new(MockHTTPClient)
tt.mockHTTP(mockClient)
}
err := validatePassword(tt.password)
if tt.wantErr == "" {
assert.NoError(t, err)
} else {
assert.ErrorContains(t, err, tt.wantErr)
}
})
}
}
func TestIsPasswordCompromised(t *testing.T) {
t.Run("known compromised", func(t *testing.T) {
compromised, err := isPasswordCompromised("password123456")
assert.NoError(t, err)
assert.True(t, compromised)
})
t.Run("randomly generated", func(t *testing.T) {
randomStr := genRandomString(12)
compromised, err := isPasswordCompromised(randomStr)
assert.NoError(t, err)
assert.False(t, compromised)
})
}
func TestPasswordEntropyCalculation(t *testing.T) {
tests := []struct {
password string
entropy float64
}{
{"password", 37.6},
{"SecurePassw0rd!123", 103.12},
{"aaaaaaaaaaaaaaaa", 9.5},
}
for _, tt := range tests {
t.Run(tt.password, func(t *testing.T) {
entropy := passwordvalidator.GetEntropy(tt.password)
assert.InDelta(t, tt.entropy, entropy, 1.0)
})
}
}
func TestValidateUsername(t *testing.T) {
tests := []struct {
name string
input string
wantErr string
}{
{
name: "too short",
input: strings.Repeat("a", minUsernameLength-1),
wantErr: fmt.Sprintf("username must be at least %d characters", minUsernameLength),
},
{
name: "too long",
input: strings.Repeat("a", maxUsernameLength+1),
wantErr: fmt.Sprintf("username cannot be longer than %d characters", maxUsernameLength),
},
{
name: "invalid characters",
input: "user@name",
wantErr: "username can only contain numbers, letters, and underscores",
},
{
name: "valid username",
input: "valid_user123",
wantErr: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateUsername(tt.input)
if tt.wantErr == "" {
assert.NoError(t, err)
} else {
assert.ErrorContains(t, err, tt.wantErr)
}
})
}
}
func TestNormalizeUsername(t *testing.T) {
tests := []struct {
name string
input string
want string
}{
{
name: "trim whitespace",
input: " test_user ",
want: "test_user",
},
{
name: "lowercase",
input: "TestUser",
want: "testuser",
},
{
name: "mixed case and spaces",
input: " UserName123 ",
want: "username123",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := normalizeUsername(tt.input)
assert.Equal(t, tt.want, got)
})
}
}
func genRandomString(length int) string {
const charset = "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
seededRand := rand.New(rand.NewSource(time.Now().UnixNano()))
b := make([]byte, length)
for i := range b {
b[i] = charset[seededRand.Intn(len(charset))]
}
return string(b)
}