From 3257b193137abbcad67b71c5f3c2955b50aa2d9a Mon Sep 17 00:00:00 2001 From: ae Date: Mon, 31 Mar 2025 23:32:39 +0300 Subject: [PATCH] feat: modular api handlers for users/tokens (auth) incl. unit tests --- server/go.mod | 8 +- server/go.sum | 11 ++ server/pkg/service/middleware.go | 157 +++++++++++++++++ server/pkg/service/notes.go | 19 +++ server/pkg/service/server.go | 43 +++++ server/pkg/service/tokens.go | 205 +++++++++++++++++++++++ server/pkg/service/tokens_test.go | 180 ++++++++++++++++++++ server/pkg/service/users.go | 268 ++++++++++++++++++++++++++++++ server/pkg/service/users_test.go | 259 +++++++++++++++++++++++++++++ server/pkg/service/util.go | 133 +++++++++++++++ server/pkg/service/util_test.go | 184 ++++++++++++++++++++ 11 files changed, 1466 insertions(+), 1 deletion(-) create mode 100644 server/pkg/service/middleware.go create mode 100644 server/pkg/service/notes.go create mode 100644 server/pkg/service/server.go create mode 100644 server/pkg/service/tokens.go create mode 100644 server/pkg/service/tokens_test.go create mode 100644 server/pkg/service/users.go create mode 100644 server/pkg/service/users_test.go create mode 100644 server/pkg/service/util.go create mode 100644 server/pkg/service/util_test.go diff --git a/server/go.mod b/server/go.mod index d0f0e30..38651b5 100644 --- a/server/go.mod +++ b/server/go.mod @@ -9,16 +9,22 @@ require ( github.com/google/uuid v1.6.0 github.com/jackc/pgx/v5 v5.7.4 github.com/rs/zerolog v1.34.0 + github.com/stretchr/testify v1.10.0 github.com/wagslane/go-password-validator v0.3.0 golang.org/x/crypto v0.36.0 ) require ( + github.com/davecgh/go-spew v1.1.1 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/kr/text v0.2.0 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect - github.com/stretchr/testify v1.10.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/rogpeppe/go-internal v1.14.1 // indirect + github.com/stretchr/objx v0.5.2 // indirect golang.org/x/sys v0.31.0 // indirect golang.org/x/text v0.23.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/server/go.sum b/server/go.sum index 22c8694..675d802 100644 --- a/server/go.sum +++ b/server/go.sum @@ -1,6 +1,7 @@ github.com/caarlos0/env v3.5.0+incompatible h1:Yy0UN8o9Wtr/jGHZDpCBLpNrzcFLLM2yixi/rBrKyJs= github.com/caarlos0/env v3.5.0+incompatible/go.mod h1:tdCsowwCzMLdkqRYDlHpZCp2UooDD3MspDBjZ2AD02Y= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -19,6 +20,10 @@ github.com/jackc/pgx/v5 v5.7.4 h1:9wKznZrhWa2QiHL+NjTSPP6yjl3451BX3imWDnokYlg= github.com/jackc/pgx/v5 v5.7.4/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= @@ -29,10 +34,14 @@ github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= @@ -51,6 +60,8 @@ golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/server/pkg/service/middleware.go b/server/pkg/service/middleware.go new file mode 100644 index 0000000..35fd2ee --- /dev/null +++ b/server/pkg/service/middleware.go @@ -0,0 +1,157 @@ +package service + +import ( + "context" + "fmt" + "net/http" + "time" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" + "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" + "github.com/rs/zerolog" +) + +const ( + panicRecoveryMsg = "panic recovered" + defaultLogMsg = "incoming request" +) + +type userCtxKey struct{} + +// Get JWT bearer from request's authorization header, parse it with custom user claims, and +// ensure its validity before attaching the claims to the request's context. +func authMiddleware(jwtSecret string, expectedType string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tokenString, err := getTokenFromRequest(r) + if err != nil { + respondError(w, http.StatusUnauthorized, fmt.Sprintf("Unauthorized: %s", err)) + } + + token, err := jwt.ParseWithClaims(tokenString, &userClaims{}, func(token *jwt.Token) (interface{}, error) { + return []byte(jwtSecret), nil + }) + if err != nil || !token.Valid { + respondError(w, http.StatusUnauthorized, "Invalid token") + return + } + + claims, ok := token.Claims.(*userClaims) + if !ok || claims.TokenType != expectedType { + respondError(w, http.StatusUnauthorized, "Invalid token type") + return + } + + ctx := context.WithValue(r.Context(), userCtxKey{}, claims) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + +// JWT access token parsing, verification, and validation. +func requireAccessToken(jwtSecret string) func(http.Handler) http.Handler { + return authMiddleware(jwtSecret, "access") +} + +// JWT refresh token parsing, verification, and validation. +func requireRefreshToken(jwtSecret string) func(http.Handler) http.Handler { + return authMiddleware(jwtSecret, "refresh") +} + +// Ensure the current user is an administrator. +func adminOnlyMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + user, ok := r.Context().Value(userCtxKey{}).(*userClaims) + if !ok || !user.Admin { + respondError(w, http.StatusForbidden, "Forbidden") + return + } + next.ServeHTTP(w, r) + }) +} + +// Ensure the targeted resource is owned by the current user (i.e. current user's ID matches with +// the one stored into the resource). +func ownerOnlyMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + user, ok := r.Context().Value(userCtxKey{}).(*userClaims) + requestedID := chi.URLParam(r, "id") + if !ok || user.ID != requestedID { + respondError(w, http.StatusForbidden, "Forbidden") + return + } + next.ServeHTTP(w, r) + }) +} + +// Append user data into request's context based on user ID as a URL parameter. +func userCtx(store UserStore) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + userIDStr := chi.URLParam(r, "id") + userID, err := uuid.Parse(userIDStr) + if err != nil { + respondError(w, http.StatusNotFound, "Invalid user ID") + return + } + + user, err := store.GetUserByID(r.Context(), userID) + if err != nil { + respondError(w, http.StatusNotFound, "User not found") + return + } + + ctx := context.WithValue(r.Context(), userCtxKey{}, user) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + +// Zerolog compatible logger middleware. Automatically logs and recovers from errors with HTTP 500 +// response, by default logs to INFO level. +func loggerMiddleware(log *zerolog.Logger) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + log := log.With().Logger() + + ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor) + + t1 := time.Now() + defer func() { + t2 := time.Now() + + // Recover automatically and respond with HTTP 500 + if rec := recover(); rec != nil { + log.Error(). + Str("type", "error"). + Timestamp(). + Interface("recover_info", rec). + Msg(panicRecoveryMsg) + http.Error(ww, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + } + + // Log a regular HTTP request with some metadata + log.Info(). + Str("type", "access"). + Timestamp(). + Fields(map[string]interface{}{ + "remote_ip": r.RemoteAddr, + "url": r.URL.Path, + "proto": r.Proto, + "method": r.Method, + "user_agent": r.Header.Get("User-Agent"), + "status": ww.Status(), + "latency_ms": float64(t2.Sub(t1).Nanoseconds()) / 1000000.0, + "bytes_in": r.Header.Get("Content-Length"), + "bytes_out": ww.BytesWritten(), + }). + Msg(defaultLogMsg) + }() + + next.ServeHTTP(ww, r) + } + return http.HandlerFunc(fn) + } +} diff --git a/server/pkg/service/notes.go b/server/pkg/service/notes.go new file mode 100644 index 0000000..37b117d --- /dev/null +++ b/server/pkg/service/notes.go @@ -0,0 +1,19 @@ +package service + +import ( + "github.com/go-chi/chi/v5" +) + +// Mockable database operations interface +type NoteStore interface { + // TODO: implement +} + +type notesResource struct { + Notes NoteStore +} + +func (rs notesResource) Routes() chi.Router { + r := chi.NewRouter() + return r +} diff --git a/server/pkg/service/server.go b/server/pkg/service/server.go new file mode 100644 index 0000000..a1105f9 --- /dev/null +++ b/server/pkg/service/server.go @@ -0,0 +1,43 @@ +package service + +import ( + "fmt" + "net/http" + + "git.umbrella.haus/ae/notatest/pkg/data" + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" + "github.com/jackc/pgx/v5" + "github.com/rs/zerolog/log" +) + +func Run(conn *pgx.Conn, jwtSecret string, httpPort string) error { + q := data.New(conn) + r := chi.NewRouter() + + tokenService := tokenService{ + JWTSecret: jwtSecret, + Tokens: q, + } + usersRouter := usersResource{ + JWTSecret: jwtSecret, + Users: q, + } + notesRouter := notesResource{} + + // Global middlewares + r.Use(middleware.RequestID) + r.Use(middleware.RealIP) + r.Use(loggerMiddleware(&log.Logger)) + r.Use(middleware.Recoverer) + r.Use(middleware.AllowContentType("application/json")) + + // Routes grouped by functionality + r.Post("/auth/refresh", tokenService.RefreshAccessToken) // POST /auth/refresh - new access token for refresh token + r.Mount("/users", usersRouter.Routes()) + r.Mount("/notes", notesRouter.Routes()) + + portStr := fmt.Sprintf(":%s", httpPort) + log.Info().Msgf("Starting server on %s", portStr) + return http.ListenAndServe(portStr, r) +} diff --git a/server/pkg/service/tokens.go b/server/pkg/service/tokens.go new file mode 100644 index 0000000..3dc2fe5 --- /dev/null +++ b/server/pkg/service/tokens.go @@ -0,0 +1,205 @@ +package service + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "net/http" + "strings" + "time" + + "git.umbrella.haus/ae/notatest/pkg/data" + "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" +) + +const ( + accessTokenDuration = 15 * time.Minute + refreshTokenDuration = 7 * 24 * time.Hour +) + +var ( + ErrInvalidToken = errors.New("invalid token") + ErrAuthHeaderInvalid = errors.New("token couldn't be parsed from authentication header") +) + +type userClaims struct { + Admin bool `json:"admin"` + TokenType string `json:"type"` // "access" or "refresh" + jwt.RegisteredClaims // User's UUID should be stored in the subject claim +} + +type tokenPair struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` +} + +// Mockable database operations interface +type TokenStore interface { + CreateRefreshToken(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) + GetRefreshTokenByHash(ctx context.Context, tokenHash string) (data.RefreshToken, error) + RevokeRefreshToken(ctx context.Context, tokenHash string) error + RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error +} + +type tokenService struct { + JWTSecret string + Tokens TokenStore +} + +func (ts *tokenService) GenerateTokenPair(ctx context.Context, userID uuid.UUID, isAdmin bool) (*tokenPair, error) { + tokenPair, err := generateTokenPair(userID.String(), isAdmin, ts.JWTSecret) + if err != nil { + return nil, err + } + + hash := sha256.Sum256([]byte(tokenPair.RefreshToken)) + tokenHash := hex.EncodeToString(hash[:]) + + // Store to DB with (almost) identical expiration timestamp + expiresAt := time.Now().Add(refreshTokenDuration) + _, err = ts.Tokens.CreateRefreshToken(ctx, data.CreateRefreshTokenParams{ + UserID: userID, + TokenHash: tokenHash, + ExpiresAt: expiresAt, + }) + if err != nil { + return nil, err + } + + return tokenPair, nil +} + +func (ts *tokenService) RevokeRefreshToken(ctx context.Context, token string) error { + hash := sha256.Sum256([]byte(token)) + tokenHash := hex.EncodeToString(hash[:]) + return ts.Tokens.RevokeRefreshToken(ctx, tokenHash) +} + +func (ts *tokenService) ValidateRefreshToken(ctx context.Context, token string) (*data.RefreshToken, error) { + hash := sha256.Sum256([]byte(token)) + tokenHash := hex.EncodeToString(hash[:]) + + dbToken, err := ts.Tokens.GetRefreshTokenByHash(ctx, tokenHash) + if err != nil { + return nil, err + } + + if dbToken.Revoked || time.Now().After(dbToken.ExpiresAt) { + return nil, ErrInvalidToken + } + + return &dbToken, nil +} + +func (ts *tokenService) RefreshAccessToken(w http.ResponseWriter, r *http.Request) { + // Get claims from context + claims, ok := r.Context().Value(userCtxKey{}).(*userClaims) + if !ok || claims.TokenType != "refresh" { + respondError(w, http.StatusUnauthorized, "Invalid token") + return + } + + // Attempt to get the token from Authentication header ("Bearer ") + refreshToken, err := getTokenFromRequest(r) + if err != nil { + respondError(w, http.StatusUnauthorized, fmt.Sprintf("Unauthorized: %s", err)) + } + + // Validate the refresh token in DB + if _, err := ts.ValidateRefreshToken(r.Context(), refreshToken); err != nil { + respondError(w, http.StatusUnauthorized, "Invalid refresh token") + return + } + + // Revoke the used refresh token + if err := ts.RevokeRefreshToken(r.Context(), refreshToken); err != nil { + respondError(w, http.StatusInternalServerError, "Failed to revoke token") + return + } + + // Generate a new pair (access & refresh tokens) + userID, err := uuid.Parse(claims.Subject) + if err != nil { + respondError(w, http.StatusInternalServerError, "Invalid user ID") + return + } + + tokenPair, err := ts.GenerateTokenPair(r.Context(), userID, claims.Admin) + if err != nil { + respondError(w, http.StatusInternalServerError, "Failed to generate tokens") + return + } + + respondJSON(w, http.StatusOK, tokenPair) +} + +func (ts *tokenService) HandleLogout(w http.ResponseWriter, r *http.Request) { + claims, ok := r.Context().Value(userCtxKey{}).(*userClaims) + if !ok { + respondError(w, http.StatusUnauthorized, "Not authenticated") + return + } + + userID, err := uuid.Parse(claims.ID) + if err != nil { + respondError(w, http.StatusInternalServerError, "Invalid user ID") + return + } + + if err := ts.Tokens.RevokeAllUserRefreshTokens(r.Context(), userID); err != nil { + respondError(w, http.StatusInternalServerError, "Failed to logout") + return + } + + respondJSON(w, http.StatusOK, map[string]string{"status": "logged out"}) +} + +func getTokenFromRequest(r *http.Request) (string, error) { + bearerToken := r.Header.Get("Authorization") + bearerFields := strings.Fields(bearerToken) + if len(bearerFields) == 2 && strings.ToLower(bearerFields[0]) == "bearer" { + return bearerFields[1], nil + } + return "", ErrAuthHeaderInvalid +} + +func generateTokenPair(userID string, isAdmin bool, jwtSecret string) (*tokenPair, error) { + atClaims := userClaims{ + Admin: isAdmin, + TokenType: "access", + RegisteredClaims: jwt.RegisteredClaims{ + Subject: userID, + ExpiresAt: jwt.NewNumericDate(time.Now().Add(accessTokenDuration)), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now()), + }, + } + accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, atClaims) + + t, err := accessToken.SignedString([]byte(jwtSecret)) + if err != nil { + return nil, err + } + + rtClaims := userClaims{ + Admin: isAdmin, + TokenType: "refresh", + RegisteredClaims: jwt.RegisteredClaims{ + Subject: userID, + ExpiresAt: jwt.NewNumericDate(time.Now().Add(refreshTokenDuration)), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now()), + }, + } + refreshToken := jwt.NewWithClaims(jwt.SigningMethodHS256, rtClaims) + + rt, err := refreshToken.SignedString([]byte(jwtSecret)) + if err != nil { + return nil, err + } + + return &tokenPair{AccessToken: t, RefreshToken: rt}, nil +} diff --git a/server/pkg/service/tokens_test.go b/server/pkg/service/tokens_test.go new file mode 100644 index 0000000..b72196b --- /dev/null +++ b/server/pkg/service/tokens_test.go @@ -0,0 +1,180 @@ +package service + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "git.umbrella.haus/ae/notatest/pkg/data" + "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" +) + +type mockTokenStore struct { + CreateRefreshTokenFunc func(context.Context, data.CreateRefreshTokenParams) (data.RefreshToken, error) + GetRefreshTokenByHashFunc func(context.Context, string) (data.RefreshToken, error) + RevokeRefreshTokenFunc func(context.Context, string) error + RevokeAllUserRefreshTokensFunc func(context.Context, uuid.UUID) error +} + +func (m *mockTokenStore) CreateRefreshToken(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) { + return m.CreateRefreshTokenFunc(ctx, arg) +} + +func (m *mockTokenStore) GetRefreshTokenByHash(ctx context.Context, tokenHash string) (data.RefreshToken, error) { + return m.GetRefreshTokenByHashFunc(ctx, tokenHash) +} + +func (m *mockTokenStore) RevokeRefreshToken(ctx context.Context, token string) error { + return m.RevokeRefreshTokenFunc(ctx, token) +} + +func (m *mockTokenStore) RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error { + return m.RevokeAllUserRefreshTokensFunc(ctx, id) +} + +func TestGenerateTokenPair_Success(t *testing.T) { + userID := uuid.New() + called := false + + mockStore := &mockTokenStore{ + CreateRefreshTokenFunc: func(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) { + called = true + assert.Equal(t, userID, arg.UserID) + return data.RefreshToken{}, nil + }, + } + + ts := tokenService{ + JWTSecret: "test-secret", + Tokens: mockStore, + } + + pair, err := ts.GenerateTokenPair(context.Background(), userID, false) + assert.NoError(t, err) + assert.True(t, called) + assert.NotEmpty(t, pair.AccessToken) + assert.NotEmpty(t, pair.RefreshToken) +} + +func TestGenerateTokenPair_DBError(t *testing.T) { + mockStore := &mockTokenStore{ + CreateRefreshTokenFunc: func(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) { + return data.RefreshToken{}, errors.New("db error") + }, + } + + ts := tokenService{Tokens: mockStore} + _, err := ts.GenerateTokenPair(context.Background(), uuid.New(), false) + assert.ErrorContains(t, err, "db error") +} + +func TestValidateRefreshToken_Valid(t *testing.T) { + token := "valid-jwt-token" + hash := sha256.Sum256([]byte(token)) + expectedHash := hex.EncodeToString(hash[:]) + + mockStore := &mockTokenStore{ + GetRefreshTokenByHashFunc: func(ctx context.Context, tokenHash string) (data.RefreshToken, error) { + assert.Equal(t, expectedHash, tokenHash) + return data.RefreshToken{ExpiresAt: time.Now().Add(1 * time.Hour)}, nil + }, + } + + ts := tokenService{Tokens: mockStore} + _, err := ts.ValidateRefreshToken(context.Background(), token) + assert.NoError(t, err) +} + +func TestRefreshAccessToken_Success(t *testing.T) { + userID := uuid.New() + refreshToken := "valid-jwt-token" + + // Expected hash of the test token + hash := sha256.Sum256([]byte(refreshToken)) + expectedHash := hex.EncodeToString(hash[:]) + + mockStore := &mockTokenStore{ + GetRefreshTokenByHashFunc: func(ctx context.Context, tokenHash string) (data.RefreshToken, error) { + assert.Equal(t, expectedHash, tokenHash) + + // Must return an unrevoked token with future expiration + return data.RefreshToken{ + TokenHash: tokenHash, + ExpiresAt: time.Now().Add(1 * time.Hour), + Revoked: false, + }, nil + }, + RevokeRefreshTokenFunc: func(ctx context.Context, tokenHash string) error { + assert.Equal(t, expectedHash, tokenHash) + return nil + }, + CreateRefreshTokenFunc: func(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) { + return data.RefreshToken{}, nil + }, + } + + ts := tokenService{ + JWTSecret: "test-secret", + Tokens: mockStore, + } + + req := httptest.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", refreshToken)) + req = req.WithContext(context.WithValue( + req.Context(), + userCtxKey{}, + &userClaims{ + Admin: false, + TokenType: "refresh", + RegisteredClaims: jwt.RegisteredClaims{ + Subject: userID.String(), + }, + }, + )) + + w := httptest.NewRecorder() + ts.RefreshAccessToken(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Contains(t, w.Body.String(), "access_token", "refresh_token") +} + +func TestHandleLogout_Success(t *testing.T) { + userID := uuid.New() + called := false + + mockStore := &mockTokenStore{ + RevokeAllUserRefreshTokensFunc: func(ctx context.Context, id uuid.UUID) error { + called = true + assert.Equal(t, userID, id) + return nil + }, + } + + ts := tokenService{Tokens: mockStore} + + req := httptest.NewRequest("POST", "/", nil) + req = req.WithContext(context.WithValue( + req.Context(), + userCtxKey{}, + &userClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + ID: userID.String(), + }, + }, + )) + + w := httptest.NewRecorder() + ts.HandleLogout(w, req) + + assert.True(t, called) + assert.Equal(t, http.StatusOK, w.Code) +} diff --git a/server/pkg/service/users.go b/server/pkg/service/users.go new file mode 100644 index 0000000..4e4c6d2 --- /dev/null +++ b/server/pkg/service/users.go @@ -0,0 +1,268 @@ +package service + +import ( + "context" + "encoding/json" + "errors" + "net/http" + + "git.umbrella.haus/ae/notatest/pkg/data" + "github.com/go-chi/chi/v5" + "github.com/google/uuid" + "github.com/jackc/pgx/v5/pgconn" + "github.com/rs/zerolog/log" + "golang.org/x/crypto/bcrypt" +) + +// Mockable database operations interface +type UserStore interface { + CreateUser(ctx context.Context, arg data.CreateUserParams) (data.User, error) + ListUsers(ctx context.Context) ([]data.User, error) + GetUserByID(ctx context.Context, id uuid.UUID) (data.User, error) + UpdatePassword(ctx context.Context, arg data.UpdatePasswordParams) error + DeleteUser(ctx context.Context, id uuid.UUID) error + RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error +} + +type usersResource struct { + JWTSecret string + Users UserStore +} + +func (rs usersResource) Routes() chi.Router { + r := chi.NewRouter() + + // Public routes (no tokens required) + r.Post("/", rs.Create) // POST /users - registration/signup + + // Protected routes (access token required) + r.Group(func(r chi.Router) { + r.Use(requireAccessToken(rs.JWTSecret)) + + // Admin only general routes + r.Group(func(r chi.Router) { + r.Use(adminOnlyMiddleware) + r.Get("/", rs.List) // GET /users - list all users + }) + + // User specific routes + r.Route("/{id}", func(r chi.Router) { + r.Use(userCtx(rs.Users)) // DB -> req. context + + // Admin routes + r.Route("/admin", func(r chi.Router) { + r.Use(adminOnlyMiddleware) + r.Get("/", rs.Get) // GET /users/admin/{id} - get single user + r.Delete("/", rs.AdminDelete) // DELETE /users/admin/{id} - delete user + }) + + // Owner routes + r.Route("/owner", func(r chi.Router) { + r.Use(ownerOnlyMiddleware) + r.Put("/", rs.UpdatePassword) // PUT /users/owner/{id} - update user password + r.Delete("/", rs.OwnerDelete) // DELETE /users/owner/{id} - delete user + }) + }) + }) + + return r +} + +func (rs usersResource) Create(w http.ResponseWriter, r *http.Request) { + type request struct { + Username string `json:"username"` + Password string `json:"password"` + } + + var req request + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + respondError(w, http.StatusBadRequest, "Invalid request body") + return + } + + normalizedUsername := normalizeUsername(req.Username) + if err := validateUsername(normalizedUsername); err != nil { + respondError(w, http.StatusBadRequest, err.Error()) + return + } + + if err := validatePassword(req.Password); err != nil { + respondError(w, http.StatusBadRequest, err.Error()) + return + } + + hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost) + if err != nil { + respondError(w, http.StatusInternalServerError, "Failed to create user") + return + } + + user, err := rs.Users.CreateUser(r.Context(), data.CreateUserParams{ + Username: normalizedUsername, + PasswordHash: string(hashedPassword), + }) + if err != nil { + if isDuplicateEntry(err) { + respondError(w, http.StatusConflict, "Username is already in use") + } else { + respondError(w, http.StatusInternalServerError, "Failed to create user") + } + return + } + + respondJSON(w, http.StatusCreated, map[string]string{ + "id": user.ID.String(), + "username": user.Username, + }) +} + +func (rs usersResource) List(w http.ResponseWriter, r *http.Request) { + users, err := rs.Users.ListUsers(r.Context()) + if err != nil { + respondError(w, http.StatusInternalServerError, "Failed to retrieve users") + return + } + + // Output sanitization + var output []map[string]interface{} + for _, user := range users { + output = append(output, map[string]interface{}{ + "id": user.ID, + "username": user.Username, + "created_at": user.CreatedAt, + "updated_at": user.UpdatedAt, + }) + } + + respondJSON(w, http.StatusOK, output) +} + +func (rs usersResource) Get(w http.ResponseWriter, r *http.Request) { + user, ok := r.Context().Value(userCtxKey{}).(data.User) + if !ok { + respondError(w, http.StatusNotFound, "User not found") + return + } + + respondJSON(w, http.StatusOK, map[string]interface{}{ + "id": user.ID, + "username": user.Username, + "created_at": user.CreatedAt, + "updated_at": user.UpdatedAt, + }) +} + +func (rs usersResource) UpdatePassword(w http.ResponseWriter, r *http.Request) { + type request struct { + OldPassword string `json:"old_password"` + NewPassword string `json:"new_password"` + } + + var req request + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + respondError(w, http.StatusBadRequest, "Invalid request body") + return + } + + user, ok := r.Context().Value(userCtxKey{}).(data.User) + if !ok { + respondError(w, http.StatusNotFound, "User not found") + return + } + + // Verify the old password before allowing the update + if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.OldPassword)); err != nil { + respondError(w, http.StatusUnauthorized, "Invalid credentials") + return + } + + if err := validatePassword(req.NewPassword); err != nil { + respondError(w, http.StatusBadRequest, err.Error()) + return + } + + hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.NewPassword), bcrypt.DefaultCost) + if err != nil { + respondError(w, http.StatusInternalServerError, "Failed to update password") + return + } + + if err := rs.Users.UpdatePassword(r.Context(), data.UpdatePasswordParams{ + ID: user.ID, + PasswordHash: string(hashedPassword), + }); err != nil { + respondError(w, http.StatusInternalServerError, "Failed to update password") + return + } + + if err := rs.Users.RevokeAllUserRefreshTokens(r.Context(), user.ID); err != nil { + log.Error().Msgf("Failed to revoke refresh tokens: %s", err) + } + + w.WriteHeader(http.StatusNoContent) +} + +func (rs usersResource) AdminDelete(w http.ResponseWriter, r *http.Request) { + user, ok := r.Context().Value(userCtxKey{}).(data.User) + if !ok { + respondError(w, http.StatusNotFound, "User not found") + return + } + + if err := rs.Users.DeleteUser(r.Context(), user.ID); err != nil { + respondError(w, http.StatusInternalServerError, "Failed to delete user") + return + } + + if err := rs.Users.RevokeAllUserRefreshTokens(r.Context(), user.ID); err != nil { + log.Error().Msgf("Failed to revoke refresh tokens: %s", err) + } + + w.WriteHeader(http.StatusNoContent) +} + +func (rs usersResource) OwnerDelete(w http.ResponseWriter, r *http.Request) { + type request struct { + Password string `json:"password"` + } + + var req request + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + respondError(w, http.StatusBadRequest, "Invalid request body") + return + } + + user, ok := r.Context().Value(userCtxKey{}).(data.User) + if !ok { + respondError(w, http.StatusNotFound, "User not found") + return + } + + // Verify the old password before allowing the deletion + if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password)); err != nil { + respondError(w, http.StatusUnauthorized, "Invalid credentials") + return + } + + err := rs.Users.DeleteUser(r.Context(), user.ID) + if err != nil { + respondError(w, http.StatusInternalServerError, "Failed to delete user") + return + } + + if err := rs.Users.RevokeAllUserRefreshTokens(r.Context(), user.ID); err != nil { + log.Error().Msgf("Failed to revoke refresh tokens: %s", err) + } + + w.WriteHeader(http.StatusNoContent) +} + +// Check if the given error is a PostgreSQL error for `unique_violation`, i.e. whether an entry +// with the given details already exists in the database table. +func isDuplicateEntry(err error) bool { + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) { + return pgErr.Code == "23505" + } + return false +} diff --git a/server/pkg/service/users_test.go b/server/pkg/service/users_test.go new file mode 100644 index 0000000..9e7fbd3 --- /dev/null +++ b/server/pkg/service/users_test.go @@ -0,0 +1,259 @@ +package service + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "git.umbrella.haus/ae/notatest/pkg/data" + "github.com/google/uuid" + "github.com/jackc/pgx/v5/pgconn" + "github.com/stretchr/testify/assert" + "golang.org/x/crypto/bcrypt" +) + +type mockUserStore struct { + CreateUserFunc func(context.Context, data.CreateUserParams) (data.User, error) + ListUsersFunc func(context.Context) ([]data.User, error) + GetUserByIDFunc func(context.Context, uuid.UUID) (data.User, error) + UpdatePasswordFunc func(context.Context, data.UpdatePasswordParams) error + DeleteUserFunc func(context.Context, uuid.UUID) error + RevokeAllUserRefreshTokensFunc func(context.Context, uuid.UUID) error +} + +func (m *mockUserStore) CreateUser(ctx context.Context, arg data.CreateUserParams) (data.User, error) { + return m.CreateUserFunc(ctx, arg) +} + +func (m *mockUserStore) ListUsers(ctx context.Context) ([]data.User, error) { + return m.ListUsersFunc(ctx) +} + +func (m *mockUserStore) GetUserByID(ctx context.Context, id uuid.UUID) (data.User, error) { + return m.GetUserByIDFunc(ctx, id) +} + +func (m *mockUserStore) UpdatePassword(ctx context.Context, arg data.UpdatePasswordParams) error { + return m.UpdatePasswordFunc(ctx, arg) +} + +func (m *mockUserStore) DeleteUser(ctx context.Context, id uuid.UUID) error { + return m.DeleteUserFunc(ctx, id) +} + +func (m *mockUserStore) RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error { + return m.RevokeAllUserRefreshTokensFunc(ctx, id) +} + +func TestCreateUser_Duplicate(t *testing.T) { + mockStore := &mockUserStore{ + CreateUserFunc: func(ctx context.Context, arg data.CreateUserParams) (data.User, error) { + return data.User{}, &pgconn.PgError{Code: "23505"} + }, + } + + rs := usersResource{Users: mockStore, JWTSecret: "test-secret"} + + reqBody := `{"username": "existing", "password": "validPass123!"}` + req := httptest.NewRequest("POST", "/", strings.NewReader(reqBody)) + w := httptest.NewRecorder() + + rs.Create(w, req) + + assert.Equal(t, http.StatusConflict, w.Code) + assert.Contains(t, w.Body.String(), "Username is already in use") +} + +func TestCreateUser_InvalidUsername(t *testing.T) { + mockStore := &mockUserStore{} // No DB calls expected + rs := usersResource{Users: mockStore, JWTSecret: "test-secret"} + + // Test various invalid usernames + tests := []struct { + name string + body string + }{ + {"Too short", `{"username": "a", "password": "validPass123!"}`}, + {"Invalid chars", `{"username": "user@name", "password": "validPass123!"}`}, + {"Empty", `{"username": "", "password": "validPass123!"}`}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest("POST", "/", strings.NewReader(tc.body)) + w := httptest.NewRecorder() + rs.Create(w, req) + assert.Equal(t, http.StatusBadRequest, w.Code) + }) + } +} + +func TestCreateUser_InvalidPassword(t *testing.T) { + mockStore := &mockUserStore{} + rs := usersResource{Users: mockStore, JWTSecret: "test-secret"} + + tests := []struct { + name string + body string + }{ + {"too short", fmt.Sprintf(`{"username": "valid", "password": "%s"}`, strings.Repeat("a", minPasswordLength-1))}, + {"too long", fmt.Sprintf(`{"username": "valid", "password": "%s"}`, strings.Repeat("a", maxPasswordLength+1))}, + {"low entropy", fmt.Sprintf(`{"username": "valid", "password": "%s"}`, strings.Repeat("a", minPasswordLength))}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest("POST", "/", strings.NewReader(tc.body)) + w := httptest.NewRecorder() + rs.Create(w, req) + assert.Equal(t, http.StatusBadRequest, w.Code) + }) + } +} + +func TestListUsers_Success(t *testing.T) { + testUsers := []data.User{ + {ID: uuid.New(), Username: "user1"}, + {ID: uuid.New(), Username: "user2"}, + } + + mockStore := &mockUserStore{ + ListUsersFunc: func(ctx context.Context) ([]data.User, error) { + return testUsers, nil + }, + } + + rs := usersResource{Users: mockStore} + req := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + + rs.List(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + + var response []map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil { + t.Fatal(err) + } + + assert.Len(t, response, 2) + assert.Equal(t, "user1", response[0]["username"]) + assert.NotContains(t, response[0], "password_hash") +} + +func TestUpdatePassword_InvalidOldPassword(t *testing.T) { + // User with password hash that won't match "wrongpassword" + user := data.User{ + ID: uuid.New(), + PasswordHash: "$2a$10$PHhno.bZBF8IEINdFRZAPujMxIN65msElATgJG6FIxZdeWYVLSfFi", // Hash of "correctpassword" + } + + mockStore := &mockUserStore{} + rs := usersResource{Users: mockStore} + + reqBody := `{"old_password": "wrongpassword", "new_password": "NewValidPass321!"}` + req := httptest.NewRequest("PUT", "/", strings.NewReader(reqBody)) + req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, user)) + w := httptest.NewRecorder() + + rs.UpdatePassword(w, req) + + assert.Equal(t, http.StatusUnauthorized, w.Code) +} + +func TestAdminDelete_Success(t *testing.T) { + user := data.User{ID: uuid.New()} + deleteCalled := false + revokeCalled := false + + mockStore := &mockUserStore{ + DeleteUserFunc: func(ctx context.Context, id uuid.UUID) error { + deleteCalled = true + assert.Equal(t, user.ID, id) + return nil + }, + RevokeAllUserRefreshTokensFunc: func(ctx context.Context, id uuid.UUID) error { + revokeCalled = true + assert.Equal(t, user.ID, id) + return nil + }, + } + + rs := usersResource{Users: mockStore} + req := httptest.NewRequest("DELETE", "/", nil) + req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, user)) + w := httptest.NewRecorder() + + rs.AdminDelete(w, req) + + assert.Equal(t, http.StatusNoContent, w.Code) + assert.True(t, deleteCalled) + assert.True(t, revokeCalled) +} + +func TestOwnerDelete_InvalidCredentials(t *testing.T) { + // Create user with known password hash + correctPassword := "CorrectPass123!" + hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(correctPassword), bcrypt.DefaultCost) + user := data.User{ + ID: uuid.New(), + PasswordHash: string(hashedPassword), + } + + mockStore := &mockUserStore{} + rs := usersResource{Users: mockStore} + + reqBody := `{"password": "wrongpassword"}` + req := httptest.NewRequest("DELETE", "/", strings.NewReader(reqBody)) + req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, user)) + w := httptest.NewRecorder() + + rs.OwnerDelete(w, req) + + assert.Equal(t, http.StatusUnauthorized, w.Code) +} + +func TestGetUser_NotFound(t *testing.T) { + mockStore := &mockUserStore{} + rs := usersResource{Users: mockStore} + + // No user in context + req := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + + rs.Get(w, req) + + assert.Equal(t, http.StatusNotFound, w.Code) +} + +func TestUpdatePassword_DatabaseError(t *testing.T) { + // Add user with a valid password to the context + oldPassword := "OldValidPass321!" + hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(oldPassword), bcrypt.DefaultCost) + user := data.User{ + ID: uuid.New(), + PasswordHash: string(hashedPassword), + } + mockStore := &mockUserStore{ + UpdatePasswordFunc: func(ctx context.Context, arg data.UpdatePasswordParams) error { + return errors.New("database error") + }, + } + + rs := usersResource{Users: mockStore} + + reqBody := fmt.Sprintf(`{"old_password": "%s", "new_password": "NewValidPass123!"}`, oldPassword) + req := httptest.NewRequest("PUT", "/", strings.NewReader(reqBody)) + req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, user)) + w := httptest.NewRecorder() + + rs.UpdatePassword(w, req) + + assert.Equal(t, http.StatusInternalServerError, w.Code) + assert.Contains(t, w.Body.String(), "Failed to update password") +} diff --git a/server/pkg/service/util.go b/server/pkg/service/util.go new file mode 100644 index 0000000..8add2ac --- /dev/null +++ b/server/pkg/service/util.go @@ -0,0 +1,133 @@ +package service + +import ( + "crypto/sha1" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "regexp" + "strings" + "unicode/utf8" + + passwordvalidator "github.com/wagslane/go-password-validator" +) + +const ( + minPasswordLength = 12 // Entropy checks prevent short passwords anyway + maxPasswordLength = 72 // Limitation of bcrypt + minPasswordEntropy = 60.0 + + minUsernameLength = 3 + maxUsernameLength = 20 + + hibpAPI = "https://api.pwnedpasswords.com/range" // Doesn't require an API key +) + +var ( + usernameRegex = regexp.MustCompile("^[a-z0-9_]+$") +) + +func respondJSON(w http.ResponseWriter, status int, data interface{}) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + json.NewEncoder(w).Encode(data) +} + +func respondError(w http.ResponseWriter, status int, message string) { + respondJSON(w, status, map[string]string{"error": message}) +} + +/* + Client-side check: + + ``` + function estimateEntropy(password: string): number { + const pool: number = getCharsetSize(password); // Character diversity (R) + const entropy: number = password.length * Math.log2(pool); // E = L * log_2(R) + return entropy; // Value (E) that can be compared against a hardcoded threshold (e.g. 60) + } + ``` +*/ + +// Validate the given password using a hybrid approach: length (max. set due to bcrypt's input +// upper limit of 72 bytes), entropy, and HIBP API. +func validatePassword(password string) error { + if len(password) < minPasswordLength { + return fmt.Errorf("password must be at least %d characters", minPasswordLength) + } + if len(password) > maxPasswordLength { + return fmt.Errorf("password cannot be longer than %d characters", maxPasswordLength) + } + + // Formatted error message will contain tips to increase the password strength (safe to show) + err := passwordvalidator.Validate(password, minPasswordEntropy) + if err != nil { + return err + } + + if compromised, _ := isPasswordCompromised(password); compromised { + return errors.New("password is compromised") + } + + return nil +} + +// Send the first five bytes of the password's SHA-1 hash to HIBP API, then check if the rest of +// the hash is present in the APi's response data (k-Anonymity model). +func isPasswordCompromised(password string) (bool, error) { + hash := sha1.Sum([]byte(password)) + hashStr := strings.ToUpper(hex.EncodeToString(hash[:])) + prefix, suffix := hashStr[:5], hashStr[5:] + + resp, err := http.Get(fmt.Sprintf("%s/%s", hibpAPI, prefix)) + if err != nil { + return false, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return false, err + } + + return strings.Contains(strings.ToUpper(string(body)), suffix), nil +} + +// Normalize the username by making it lowercase and trimming any leading or trailing whitespace. +func normalizeUsername(username string) string { + return strings.ToLower(strings.TrimSpace(username)) +} + +/* + Client-side check (additionally input should automatically perform the normalization steps): + + ``` + function validateUsername(username: string): string { + const min: number = 3, max: number = 20; + if (username.length < min) return "Too short"; + if (username.length > max) return "Too long"; + if (!/^[a-zA-Z0-9_]+$/.test(username)) return "Invalid characters"; + return "Valid"; + } + ``` +*/ + +// Validate the given username by making sure it only contains alphanumeric characters or +// underscores and adheres the hardcoded minimum and maximum length rules. +func validateUsername(username string) error { + if utf8.RuneCountInString(username) < minUsernameLength { + return fmt.Errorf("username must be at least %d characters", minUsernameLength) + } + if utf8.RuneCountInString(username) > maxUsernameLength { + return fmt.Errorf("username cannot be longer than %d characters", maxUsernameLength) + } + + if !usernameRegex.MatchString(username) { + return errors.New("username can only contain numbers, letters, and underscores") + } + + return nil +} diff --git a/server/pkg/service/util_test.go b/server/pkg/service/util_test.go new file mode 100644 index 0000000..8f17882 --- /dev/null +++ b/server/pkg/service/util_test.go @@ -0,0 +1,184 @@ +package service + +import ( + "fmt" + "math/rand" + "net/http" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + passwordvalidator "github.com/wagslane/go-password-validator" +) + +type MockHTTPClient struct { + mock.Mock +} + +func (m *MockHTTPClient) Get(url string) (*http.Response, error) { + args := m.Called(url) + return args.Get(0).(*http.Response), args.Error(1) +} + +func TestValidatePassword(t *testing.T) { + tests := []struct { + name string + password string + wantErr string + mockHTTP func(*MockHTTPClient) + }{ + { + name: "too short", + password: strings.Repeat("a", minPasswordLength-1), + wantErr: fmt.Sprintf("password must be at least %d characters", minPasswordLength), + }, + { + name: "too long", + password: strings.Repeat("a", maxPasswordLength+1), + wantErr: fmt.Sprintf("password cannot be longer than %d characters", maxPasswordLength), + }, + { + name: "low entropy", + password: strings.Repeat("a", minPasswordLength), + wantErr: "insecure password", // Error produced by wagslane/go-password-validator + }, + { + name: "valid password", + password: "SecurePassw0rd!123", + wantErr: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Mock HTTP client if needed + if tt.mockHTTP != nil { + mockClient := new(MockHTTPClient) + tt.mockHTTP(mockClient) + } + + err := validatePassword(tt.password) + if tt.wantErr == "" { + assert.NoError(t, err) + } else { + assert.ErrorContains(t, err, tt.wantErr) + } + }) + } +} + +func TestIsPasswordCompromised(t *testing.T) { + t.Run("known compromised", func(t *testing.T) { + compromised, err := isPasswordCompromised("password123456") + assert.NoError(t, err) + assert.True(t, compromised) + }) + + t.Run("randomly generated", func(t *testing.T) { + randomStr := genRandomString(12) + compromised, err := isPasswordCompromised(randomStr) + assert.NoError(t, err) + assert.False(t, compromised) + }) +} + +func TestPasswordEntropyCalculation(t *testing.T) { + tests := []struct { + password string + entropy float64 + }{ + {"password", 37.6}, + {"SecurePassw0rd!123", 103.12}, + {"aaaaaaaaaaaaaaaa", 9.5}, + } + + for _, tt := range tests { + t.Run(tt.password, func(t *testing.T) { + entropy := passwordvalidator.GetEntropy(tt.password) + assert.InDelta(t, tt.entropy, entropy, 1.0) + }) + } +} + +func TestValidateUsername(t *testing.T) { + tests := []struct { + name string + input string + wantErr string + }{ + { + name: "too short", + input: strings.Repeat("a", minUsernameLength-1), + wantErr: fmt.Sprintf("username must be at least %d characters", minUsernameLength), + }, + { + name: "too long", + input: strings.Repeat("a", maxUsernameLength+1), + wantErr: fmt.Sprintf("username cannot be longer than %d characters", maxUsernameLength), + }, + { + name: "invalid characters", + input: "user@name", + wantErr: "username can only contain numbers, letters, and underscores", + }, + { + name: "valid username", + input: "valid_user123", + wantErr: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateUsername(tt.input) + if tt.wantErr == "" { + assert.NoError(t, err) + } else { + assert.ErrorContains(t, err, tt.wantErr) + } + }) + } +} + +func TestNormalizeUsername(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + { + name: "trim whitespace", + input: " test_user ", + want: "test_user", + }, + { + name: "lowercase", + input: "TestUser", + want: "testuser", + }, + { + name: "mixed case and spaces", + input: " UserName123 ", + want: "username123", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := normalizeUsername(tt.input) + assert.Equal(t, tt.want, got) + }) + } +} + +func genRandomString(length int) string { + const charset = "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + seededRand := rand.New(rand.NewSource(time.Now().UnixNano())) + b := make([]byte, length) + for i := range b { + b[i] = charset[seededRand.Intn(len(charset))] + } + return string(b) +}