feat: modular api handlers for users/tokens (auth) incl. unit tests
This commit is contained in:
parent
41d1336f58
commit
3257b19313
@ -9,16 +9,22 @@ require (
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/jackc/pgx/v5 v5.7.4
|
||||
github.com/rs/zerolog v1.34.0
|
||||
github.com/stretchr/testify v1.10.0
|
||||
github.com/wagslane/go-password-validator v0.3.0
|
||||
golang.org/x/crypto v0.36.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/stretchr/testify v1.10.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rogpeppe/go-internal v1.14.1 // indirect
|
||||
github.com/stretchr/objx v0.5.2 // indirect
|
||||
golang.org/x/sys v0.31.0 // indirect
|
||||
golang.org/x/text v0.23.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
@ -1,6 +1,7 @@
|
||||
github.com/caarlos0/env v3.5.0+incompatible h1:Yy0UN8o9Wtr/jGHZDpCBLpNrzcFLLM2yixi/rBrKyJs=
|
||||
github.com/caarlos0/env v3.5.0+incompatible/go.mod h1:tdCsowwCzMLdkqRYDlHpZCp2UooDD3MspDBjZ2AD02Y=
|
||||
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
@ -19,6 +20,10 @@ github.com/jackc/pgx/v5 v5.7.4 h1:9wKznZrhWa2QiHL+NjTSPP6yjl3451BX3imWDnokYlg=
|
||||
github.com/jackc/pgx/v5 v5.7.4/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ=
|
||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
|
||||
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
|
||||
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
|
||||
@ -29,10 +34,14 @@ github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
|
||||
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
|
||||
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
|
||||
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
@ -51,6 +60,8 @@ golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
|
||||
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
157
server/pkg/service/middleware.go
Normal file
157
server/pkg/service/middleware.go
Normal file
@ -0,0 +1,157 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
const (
|
||||
panicRecoveryMsg = "panic recovered"
|
||||
defaultLogMsg = "incoming request"
|
||||
)
|
||||
|
||||
type userCtxKey struct{}
|
||||
|
||||
// Get JWT bearer from request's authorization header, parse it with custom user claims, and
|
||||
// ensure its validity before attaching the claims to the request's context.
|
||||
func authMiddleware(jwtSecret string, expectedType string) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
tokenString, err := getTokenFromRequest(r)
|
||||
if err != nil {
|
||||
respondError(w, http.StatusUnauthorized, fmt.Sprintf("Unauthorized: %s", err))
|
||||
}
|
||||
|
||||
token, err := jwt.ParseWithClaims(tokenString, &userClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte(jwtSecret), nil
|
||||
})
|
||||
if err != nil || !token.Valid {
|
||||
respondError(w, http.StatusUnauthorized, "Invalid token")
|
||||
return
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*userClaims)
|
||||
if !ok || claims.TokenType != expectedType {
|
||||
respondError(w, http.StatusUnauthorized, "Invalid token type")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), userCtxKey{}, claims)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// JWT access token parsing, verification, and validation.
|
||||
func requireAccessToken(jwtSecret string) func(http.Handler) http.Handler {
|
||||
return authMiddleware(jwtSecret, "access")
|
||||
}
|
||||
|
||||
// JWT refresh token parsing, verification, and validation.
|
||||
func requireRefreshToken(jwtSecret string) func(http.Handler) http.Handler {
|
||||
return authMiddleware(jwtSecret, "refresh")
|
||||
}
|
||||
|
||||
// Ensure the current user is an administrator.
|
||||
func adminOnlyMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
||||
if !ok || !user.Admin {
|
||||
respondError(w, http.StatusForbidden, "Forbidden")
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// Ensure the targeted resource is owned by the current user (i.e. current user's ID matches with
|
||||
// the one stored into the resource).
|
||||
func ownerOnlyMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
||||
requestedID := chi.URLParam(r, "id")
|
||||
if !ok || user.ID != requestedID {
|
||||
respondError(w, http.StatusForbidden, "Forbidden")
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// Append user data into request's context based on user ID as a URL parameter.
|
||||
func userCtx(store UserStore) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
userIDStr := chi.URLParam(r, "id")
|
||||
userID, err := uuid.Parse(userIDStr)
|
||||
if err != nil {
|
||||
respondError(w, http.StatusNotFound, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
user, err := store.GetUserByID(r.Context(), userID)
|
||||
if err != nil {
|
||||
respondError(w, http.StatusNotFound, "User not found")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), userCtxKey{}, user)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Zerolog compatible logger middleware. Automatically logs and recovers from errors with HTTP 500
|
||||
// response, by default logs to INFO level.
|
||||
func loggerMiddleware(log *zerolog.Logger) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||
log := log.With().Logger()
|
||||
|
||||
ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
|
||||
|
||||
t1 := time.Now()
|
||||
defer func() {
|
||||
t2 := time.Now()
|
||||
|
||||
// Recover automatically and respond with HTTP 500
|
||||
if rec := recover(); rec != nil {
|
||||
log.Error().
|
||||
Str("type", "error").
|
||||
Timestamp().
|
||||
Interface("recover_info", rec).
|
||||
Msg(panicRecoveryMsg)
|
||||
http.Error(ww, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// Log a regular HTTP request with some metadata
|
||||
log.Info().
|
||||
Str("type", "access").
|
||||
Timestamp().
|
||||
Fields(map[string]interface{}{
|
||||
"remote_ip": r.RemoteAddr,
|
||||
"url": r.URL.Path,
|
||||
"proto": r.Proto,
|
||||
"method": r.Method,
|
||||
"user_agent": r.Header.Get("User-Agent"),
|
||||
"status": ww.Status(),
|
||||
"latency_ms": float64(t2.Sub(t1).Nanoseconds()) / 1000000.0,
|
||||
"bytes_in": r.Header.Get("Content-Length"),
|
||||
"bytes_out": ww.BytesWritten(),
|
||||
}).
|
||||
Msg(defaultLogMsg)
|
||||
}()
|
||||
|
||||
next.ServeHTTP(ww, r)
|
||||
}
|
||||
return http.HandlerFunc(fn)
|
||||
}
|
||||
}
|
19
server/pkg/service/notes.go
Normal file
19
server/pkg/service/notes.go
Normal file
@ -0,0 +1,19 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
// Mockable database operations interface
|
||||
type NoteStore interface {
|
||||
// TODO: implement
|
||||
}
|
||||
|
||||
type notesResource struct {
|
||||
Notes NoteStore
|
||||
}
|
||||
|
||||
func (rs notesResource) Routes() chi.Router {
|
||||
r := chi.NewRouter()
|
||||
return r
|
||||
}
|
43
server/pkg/service/server.go
Normal file
43
server/pkg/service/server.go
Normal file
@ -0,0 +1,43 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"git.umbrella.haus/ae/notatest/pkg/data"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func Run(conn *pgx.Conn, jwtSecret string, httpPort string) error {
|
||||
q := data.New(conn)
|
||||
r := chi.NewRouter()
|
||||
|
||||
tokenService := tokenService{
|
||||
JWTSecret: jwtSecret,
|
||||
Tokens: q,
|
||||
}
|
||||
usersRouter := usersResource{
|
||||
JWTSecret: jwtSecret,
|
||||
Users: q,
|
||||
}
|
||||
notesRouter := notesResource{}
|
||||
|
||||
// Global middlewares
|
||||
r.Use(middleware.RequestID)
|
||||
r.Use(middleware.RealIP)
|
||||
r.Use(loggerMiddleware(&log.Logger))
|
||||
r.Use(middleware.Recoverer)
|
||||
r.Use(middleware.AllowContentType("application/json"))
|
||||
|
||||
// Routes grouped by functionality
|
||||
r.Post("/auth/refresh", tokenService.RefreshAccessToken) // POST /auth/refresh - new access token for refresh token
|
||||
r.Mount("/users", usersRouter.Routes())
|
||||
r.Mount("/notes", notesRouter.Routes())
|
||||
|
||||
portStr := fmt.Sprintf(":%s", httpPort)
|
||||
log.Info().Msgf("Starting server on %s", portStr)
|
||||
return http.ListenAndServe(portStr, r)
|
||||
}
|
205
server/pkg/service/tokens.go
Normal file
205
server/pkg/service/tokens.go
Normal file
@ -0,0 +1,205 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.umbrella.haus/ae/notatest/pkg/data"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
accessTokenDuration = 15 * time.Minute
|
||||
refreshTokenDuration = 7 * 24 * time.Hour
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidToken = errors.New("invalid token")
|
||||
ErrAuthHeaderInvalid = errors.New("token couldn't be parsed from authentication header")
|
||||
)
|
||||
|
||||
type userClaims struct {
|
||||
Admin bool `json:"admin"`
|
||||
TokenType string `json:"type"` // "access" or "refresh"
|
||||
jwt.RegisteredClaims // User's UUID should be stored in the subject claim
|
||||
}
|
||||
|
||||
type tokenPair struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
|
||||
// Mockable database operations interface
|
||||
type TokenStore interface {
|
||||
CreateRefreshToken(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error)
|
||||
GetRefreshTokenByHash(ctx context.Context, tokenHash string) (data.RefreshToken, error)
|
||||
RevokeRefreshToken(ctx context.Context, tokenHash string) error
|
||||
RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error
|
||||
}
|
||||
|
||||
type tokenService struct {
|
||||
JWTSecret string
|
||||
Tokens TokenStore
|
||||
}
|
||||
|
||||
func (ts *tokenService) GenerateTokenPair(ctx context.Context, userID uuid.UUID, isAdmin bool) (*tokenPair, error) {
|
||||
tokenPair, err := generateTokenPair(userID.String(), isAdmin, ts.JWTSecret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hash := sha256.Sum256([]byte(tokenPair.RefreshToken))
|
||||
tokenHash := hex.EncodeToString(hash[:])
|
||||
|
||||
// Store to DB with (almost) identical expiration timestamp
|
||||
expiresAt := time.Now().Add(refreshTokenDuration)
|
||||
_, err = ts.Tokens.CreateRefreshToken(ctx, data.CreateRefreshTokenParams{
|
||||
UserID: userID,
|
||||
TokenHash: tokenHash,
|
||||
ExpiresAt: expiresAt,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tokenPair, nil
|
||||
}
|
||||
|
||||
func (ts *tokenService) RevokeRefreshToken(ctx context.Context, token string) error {
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
tokenHash := hex.EncodeToString(hash[:])
|
||||
return ts.Tokens.RevokeRefreshToken(ctx, tokenHash)
|
||||
}
|
||||
|
||||
func (ts *tokenService) ValidateRefreshToken(ctx context.Context, token string) (*data.RefreshToken, error) {
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
tokenHash := hex.EncodeToString(hash[:])
|
||||
|
||||
dbToken, err := ts.Tokens.GetRefreshTokenByHash(ctx, tokenHash)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if dbToken.Revoked || time.Now().After(dbToken.ExpiresAt) {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
return &dbToken, nil
|
||||
}
|
||||
|
||||
func (ts *tokenService) RefreshAccessToken(w http.ResponseWriter, r *http.Request) {
|
||||
// Get claims from context
|
||||
claims, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
||||
if !ok || claims.TokenType != "refresh" {
|
||||
respondError(w, http.StatusUnauthorized, "Invalid token")
|
||||
return
|
||||
}
|
||||
|
||||
// Attempt to get the token from Authentication header ("Bearer <token>")
|
||||
refreshToken, err := getTokenFromRequest(r)
|
||||
if err != nil {
|
||||
respondError(w, http.StatusUnauthorized, fmt.Sprintf("Unauthorized: %s", err))
|
||||
}
|
||||
|
||||
// Validate the refresh token in DB
|
||||
if _, err := ts.ValidateRefreshToken(r.Context(), refreshToken); err != nil {
|
||||
respondError(w, http.StatusUnauthorized, "Invalid refresh token")
|
||||
return
|
||||
}
|
||||
|
||||
// Revoke the used refresh token
|
||||
if err := ts.RevokeRefreshToken(r.Context(), refreshToken); err != nil {
|
||||
respondError(w, http.StatusInternalServerError, "Failed to revoke token")
|
||||
return
|
||||
}
|
||||
|
||||
// Generate a new pair (access & refresh tokens)
|
||||
userID, err := uuid.Parse(claims.Subject)
|
||||
if err != nil {
|
||||
respondError(w, http.StatusInternalServerError, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
tokenPair, err := ts.GenerateTokenPair(r.Context(), userID, claims.Admin)
|
||||
if err != nil {
|
||||
respondError(w, http.StatusInternalServerError, "Failed to generate tokens")
|
||||
return
|
||||
}
|
||||
|
||||
respondJSON(w, http.StatusOK, tokenPair)
|
||||
}
|
||||
|
||||
func (ts *tokenService) HandleLogout(w http.ResponseWriter, r *http.Request) {
|
||||
claims, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
||||
if !ok {
|
||||
respondError(w, http.StatusUnauthorized, "Not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
userID, err := uuid.Parse(claims.ID)
|
||||
if err != nil {
|
||||
respondError(w, http.StatusInternalServerError, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := ts.Tokens.RevokeAllUserRefreshTokens(r.Context(), userID); err != nil {
|
||||
respondError(w, http.StatusInternalServerError, "Failed to logout")
|
||||
return
|
||||
}
|
||||
|
||||
respondJSON(w, http.StatusOK, map[string]string{"status": "logged out"})
|
||||
}
|
||||
|
||||
func getTokenFromRequest(r *http.Request) (string, error) {
|
||||
bearerToken := r.Header.Get("Authorization")
|
||||
bearerFields := strings.Fields(bearerToken)
|
||||
if len(bearerFields) == 2 && strings.ToLower(bearerFields[0]) == "bearer" {
|
||||
return bearerFields[1], nil
|
||||
}
|
||||
return "", ErrAuthHeaderInvalid
|
||||
}
|
||||
|
||||
func generateTokenPair(userID string, isAdmin bool, jwtSecret string) (*tokenPair, error) {
|
||||
atClaims := userClaims{
|
||||
Admin: isAdmin,
|
||||
TokenType: "access",
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: userID,
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(accessTokenDuration)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
NotBefore: jwt.NewNumericDate(time.Now()),
|
||||
},
|
||||
}
|
||||
accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, atClaims)
|
||||
|
||||
t, err := accessToken.SignedString([]byte(jwtSecret))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rtClaims := userClaims{
|
||||
Admin: isAdmin,
|
||||
TokenType: "refresh",
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: userID,
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(refreshTokenDuration)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
NotBefore: jwt.NewNumericDate(time.Now()),
|
||||
},
|
||||
}
|
||||
refreshToken := jwt.NewWithClaims(jwt.SigningMethodHS256, rtClaims)
|
||||
|
||||
rt, err := refreshToken.SignedString([]byte(jwtSecret))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &tokenPair{AccessToken: t, RefreshToken: rt}, nil
|
||||
}
|
180
server/pkg/service/tokens_test.go
Normal file
180
server/pkg/service/tokens_test.go
Normal file
@ -0,0 +1,180 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.umbrella.haus/ae/notatest/pkg/data"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type mockTokenStore struct {
|
||||
CreateRefreshTokenFunc func(context.Context, data.CreateRefreshTokenParams) (data.RefreshToken, error)
|
||||
GetRefreshTokenByHashFunc func(context.Context, string) (data.RefreshToken, error)
|
||||
RevokeRefreshTokenFunc func(context.Context, string) error
|
||||
RevokeAllUserRefreshTokensFunc func(context.Context, uuid.UUID) error
|
||||
}
|
||||
|
||||
func (m *mockTokenStore) CreateRefreshToken(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) {
|
||||
return m.CreateRefreshTokenFunc(ctx, arg)
|
||||
}
|
||||
|
||||
func (m *mockTokenStore) GetRefreshTokenByHash(ctx context.Context, tokenHash string) (data.RefreshToken, error) {
|
||||
return m.GetRefreshTokenByHashFunc(ctx, tokenHash)
|
||||
}
|
||||
|
||||
func (m *mockTokenStore) RevokeRefreshToken(ctx context.Context, token string) error {
|
||||
return m.RevokeRefreshTokenFunc(ctx, token)
|
||||
}
|
||||
|
||||
func (m *mockTokenStore) RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error {
|
||||
return m.RevokeAllUserRefreshTokensFunc(ctx, id)
|
||||
}
|
||||
|
||||
func TestGenerateTokenPair_Success(t *testing.T) {
|
||||
userID := uuid.New()
|
||||
called := false
|
||||
|
||||
mockStore := &mockTokenStore{
|
||||
CreateRefreshTokenFunc: func(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) {
|
||||
called = true
|
||||
assert.Equal(t, userID, arg.UserID)
|
||||
return data.RefreshToken{}, nil
|
||||
},
|
||||
}
|
||||
|
||||
ts := tokenService{
|
||||
JWTSecret: "test-secret",
|
||||
Tokens: mockStore,
|
||||
}
|
||||
|
||||
pair, err := ts.GenerateTokenPair(context.Background(), userID, false)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, called)
|
||||
assert.NotEmpty(t, pair.AccessToken)
|
||||
assert.NotEmpty(t, pair.RefreshToken)
|
||||
}
|
||||
|
||||
func TestGenerateTokenPair_DBError(t *testing.T) {
|
||||
mockStore := &mockTokenStore{
|
||||
CreateRefreshTokenFunc: func(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) {
|
||||
return data.RefreshToken{}, errors.New("db error")
|
||||
},
|
||||
}
|
||||
|
||||
ts := tokenService{Tokens: mockStore}
|
||||
_, err := ts.GenerateTokenPair(context.Background(), uuid.New(), false)
|
||||
assert.ErrorContains(t, err, "db error")
|
||||
}
|
||||
|
||||
func TestValidateRefreshToken_Valid(t *testing.T) {
|
||||
token := "valid-jwt-token"
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
expectedHash := hex.EncodeToString(hash[:])
|
||||
|
||||
mockStore := &mockTokenStore{
|
||||
GetRefreshTokenByHashFunc: func(ctx context.Context, tokenHash string) (data.RefreshToken, error) {
|
||||
assert.Equal(t, expectedHash, tokenHash)
|
||||
return data.RefreshToken{ExpiresAt: time.Now().Add(1 * time.Hour)}, nil
|
||||
},
|
||||
}
|
||||
|
||||
ts := tokenService{Tokens: mockStore}
|
||||
_, err := ts.ValidateRefreshToken(context.Background(), token)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestRefreshAccessToken_Success(t *testing.T) {
|
||||
userID := uuid.New()
|
||||
refreshToken := "valid-jwt-token"
|
||||
|
||||
// Expected hash of the test token
|
||||
hash := sha256.Sum256([]byte(refreshToken))
|
||||
expectedHash := hex.EncodeToString(hash[:])
|
||||
|
||||
mockStore := &mockTokenStore{
|
||||
GetRefreshTokenByHashFunc: func(ctx context.Context, tokenHash string) (data.RefreshToken, error) {
|
||||
assert.Equal(t, expectedHash, tokenHash)
|
||||
|
||||
// Must return an unrevoked token with future expiration
|
||||
return data.RefreshToken{
|
||||
TokenHash: tokenHash,
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
Revoked: false,
|
||||
}, nil
|
||||
},
|
||||
RevokeRefreshTokenFunc: func(ctx context.Context, tokenHash string) error {
|
||||
assert.Equal(t, expectedHash, tokenHash)
|
||||
return nil
|
||||
},
|
||||
CreateRefreshTokenFunc: func(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) {
|
||||
return data.RefreshToken{}, nil
|
||||
},
|
||||
}
|
||||
|
||||
ts := tokenService{
|
||||
JWTSecret: "test-secret",
|
||||
Tokens: mockStore,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/", nil)
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", refreshToken))
|
||||
req = req.WithContext(context.WithValue(
|
||||
req.Context(),
|
||||
userCtxKey{},
|
||||
&userClaims{
|
||||
Admin: false,
|
||||
TokenType: "refresh",
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: userID.String(),
|
||||
},
|
||||
},
|
||||
))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
ts.RefreshAccessToken(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "access_token", "refresh_token")
|
||||
}
|
||||
|
||||
func TestHandleLogout_Success(t *testing.T) {
|
||||
userID := uuid.New()
|
||||
called := false
|
||||
|
||||
mockStore := &mockTokenStore{
|
||||
RevokeAllUserRefreshTokensFunc: func(ctx context.Context, id uuid.UUID) error {
|
||||
called = true
|
||||
assert.Equal(t, userID, id)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
ts := tokenService{Tokens: mockStore}
|
||||
|
||||
req := httptest.NewRequest("POST", "/", nil)
|
||||
req = req.WithContext(context.WithValue(
|
||||
req.Context(),
|
||||
userCtxKey{},
|
||||
&userClaims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ID: userID.String(),
|
||||
},
|
||||
},
|
||||
))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
ts.HandleLogout(w, req)
|
||||
|
||||
assert.True(t, called)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
268
server/pkg/service/users.go
Normal file
268
server/pkg/service/users.go
Normal file
@ -0,0 +1,268 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"git.umbrella.haus/ae/notatest/pkg/data"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// Mockable database operations interface
|
||||
type UserStore interface {
|
||||
CreateUser(ctx context.Context, arg data.CreateUserParams) (data.User, error)
|
||||
ListUsers(ctx context.Context) ([]data.User, error)
|
||||
GetUserByID(ctx context.Context, id uuid.UUID) (data.User, error)
|
||||
UpdatePassword(ctx context.Context, arg data.UpdatePasswordParams) error
|
||||
DeleteUser(ctx context.Context, id uuid.UUID) error
|
||||
RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error
|
||||
}
|
||||
|
||||
type usersResource struct {
|
||||
JWTSecret string
|
||||
Users UserStore
|
||||
}
|
||||
|
||||
func (rs usersResource) Routes() chi.Router {
|
||||
r := chi.NewRouter()
|
||||
|
||||
// Public routes (no tokens required)
|
||||
r.Post("/", rs.Create) // POST /users - registration/signup
|
||||
|
||||
// Protected routes (access token required)
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(requireAccessToken(rs.JWTSecret))
|
||||
|
||||
// Admin only general routes
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(adminOnlyMiddleware)
|
||||
r.Get("/", rs.List) // GET /users - list all users
|
||||
})
|
||||
|
||||
// User specific routes
|
||||
r.Route("/{id}", func(r chi.Router) {
|
||||
r.Use(userCtx(rs.Users)) // DB -> req. context
|
||||
|
||||
// Admin routes
|
||||
r.Route("/admin", func(r chi.Router) {
|
||||
r.Use(adminOnlyMiddleware)
|
||||
r.Get("/", rs.Get) // GET /users/admin/{id} - get single user
|
||||
r.Delete("/", rs.AdminDelete) // DELETE /users/admin/{id} - delete user
|
||||
})
|
||||
|
||||
// Owner routes
|
||||
r.Route("/owner", func(r chi.Router) {
|
||||
r.Use(ownerOnlyMiddleware)
|
||||
r.Put("/", rs.UpdatePassword) // PUT /users/owner/{id} - update user password
|
||||
r.Delete("/", rs.OwnerDelete) // DELETE /users/owner/{id} - delete user
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func (rs usersResource) Create(w http.ResponseWriter, r *http.Request) {
|
||||
type request struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
var req request
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
respondError(w, http.StatusBadRequest, "Invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
normalizedUsername := normalizeUsername(req.Username)
|
||||
if err := validateUsername(normalizedUsername); err != nil {
|
||||
respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := validatePassword(req.Password); err != nil {
|
||||
respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
respondError(w, http.StatusInternalServerError, "Failed to create user")
|
||||
return
|
||||
}
|
||||
|
||||
user, err := rs.Users.CreateUser(r.Context(), data.CreateUserParams{
|
||||
Username: normalizedUsername,
|
||||
PasswordHash: string(hashedPassword),
|
||||
})
|
||||
if err != nil {
|
||||
if isDuplicateEntry(err) {
|
||||
respondError(w, http.StatusConflict, "Username is already in use")
|
||||
} else {
|
||||
respondError(w, http.StatusInternalServerError, "Failed to create user")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
respondJSON(w, http.StatusCreated, map[string]string{
|
||||
"id": user.ID.String(),
|
||||
"username": user.Username,
|
||||
})
|
||||
}
|
||||
|
||||
func (rs usersResource) List(w http.ResponseWriter, r *http.Request) {
|
||||
users, err := rs.Users.ListUsers(r.Context())
|
||||
if err != nil {
|
||||
respondError(w, http.StatusInternalServerError, "Failed to retrieve users")
|
||||
return
|
||||
}
|
||||
|
||||
// Output sanitization
|
||||
var output []map[string]interface{}
|
||||
for _, user := range users {
|
||||
output = append(output, map[string]interface{}{
|
||||
"id": user.ID,
|
||||
"username": user.Username,
|
||||
"created_at": user.CreatedAt,
|
||||
"updated_at": user.UpdatedAt,
|
||||
})
|
||||
}
|
||||
|
||||
respondJSON(w, http.StatusOK, output)
|
||||
}
|
||||
|
||||
func (rs usersResource) Get(w http.ResponseWriter, r *http.Request) {
|
||||
user, ok := r.Context().Value(userCtxKey{}).(data.User)
|
||||
if !ok {
|
||||
respondError(w, http.StatusNotFound, "User not found")
|
||||
return
|
||||
}
|
||||
|
||||
respondJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"id": user.ID,
|
||||
"username": user.Username,
|
||||
"created_at": user.CreatedAt,
|
||||
"updated_at": user.UpdatedAt,
|
||||
})
|
||||
}
|
||||
|
||||
func (rs usersResource) UpdatePassword(w http.ResponseWriter, r *http.Request) {
|
||||
type request struct {
|
||||
OldPassword string `json:"old_password"`
|
||||
NewPassword string `json:"new_password"`
|
||||
}
|
||||
|
||||
var req request
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
respondError(w, http.StatusBadRequest, "Invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := r.Context().Value(userCtxKey{}).(data.User)
|
||||
if !ok {
|
||||
respondError(w, http.StatusNotFound, "User not found")
|
||||
return
|
||||
}
|
||||
|
||||
// Verify the old password before allowing the update
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.OldPassword)); err != nil {
|
||||
respondError(w, http.StatusUnauthorized, "Invalid credentials")
|
||||
return
|
||||
}
|
||||
|
||||
if err := validatePassword(req.NewPassword); err != nil {
|
||||
respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.NewPassword), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
respondError(w, http.StatusInternalServerError, "Failed to update password")
|
||||
return
|
||||
}
|
||||
|
||||
if err := rs.Users.UpdatePassword(r.Context(), data.UpdatePasswordParams{
|
||||
ID: user.ID,
|
||||
PasswordHash: string(hashedPassword),
|
||||
}); err != nil {
|
||||
respondError(w, http.StatusInternalServerError, "Failed to update password")
|
||||
return
|
||||
}
|
||||
|
||||
if err := rs.Users.RevokeAllUserRefreshTokens(r.Context(), user.ID); err != nil {
|
||||
log.Error().Msgf("Failed to revoke refresh tokens: %s", err)
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (rs usersResource) AdminDelete(w http.ResponseWriter, r *http.Request) {
|
||||
user, ok := r.Context().Value(userCtxKey{}).(data.User)
|
||||
if !ok {
|
||||
respondError(w, http.StatusNotFound, "User not found")
|
||||
return
|
||||
}
|
||||
|
||||
if err := rs.Users.DeleteUser(r.Context(), user.ID); err != nil {
|
||||
respondError(w, http.StatusInternalServerError, "Failed to delete user")
|
||||
return
|
||||
}
|
||||
|
||||
if err := rs.Users.RevokeAllUserRefreshTokens(r.Context(), user.ID); err != nil {
|
||||
log.Error().Msgf("Failed to revoke refresh tokens: %s", err)
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (rs usersResource) OwnerDelete(w http.ResponseWriter, r *http.Request) {
|
||||
type request struct {
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
var req request
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
respondError(w, http.StatusBadRequest, "Invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := r.Context().Value(userCtxKey{}).(data.User)
|
||||
if !ok {
|
||||
respondError(w, http.StatusNotFound, "User not found")
|
||||
return
|
||||
}
|
||||
|
||||
// Verify the old password before allowing the deletion
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password)); err != nil {
|
||||
respondError(w, http.StatusUnauthorized, "Invalid credentials")
|
||||
return
|
||||
}
|
||||
|
||||
err := rs.Users.DeleteUser(r.Context(), user.ID)
|
||||
if err != nil {
|
||||
respondError(w, http.StatusInternalServerError, "Failed to delete user")
|
||||
return
|
||||
}
|
||||
|
||||
if err := rs.Users.RevokeAllUserRefreshTokens(r.Context(), user.ID); err != nil {
|
||||
log.Error().Msgf("Failed to revoke refresh tokens: %s", err)
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// Check if the given error is a PostgreSQL error for `unique_violation`, i.e. whether an entry
|
||||
// with the given details already exists in the database table.
|
||||
func isDuplicateEntry(err error) bool {
|
||||
var pgErr *pgconn.PgError
|
||||
if errors.As(err, &pgErr) {
|
||||
return pgErr.Code == "23505"
|
||||
}
|
||||
return false
|
||||
}
|
259
server/pkg/service/users_test.go
Normal file
259
server/pkg/service/users_test.go
Normal file
@ -0,0 +1,259 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"git.umbrella.haus/ae/notatest/pkg/data"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
type mockUserStore struct {
|
||||
CreateUserFunc func(context.Context, data.CreateUserParams) (data.User, error)
|
||||
ListUsersFunc func(context.Context) ([]data.User, error)
|
||||
GetUserByIDFunc func(context.Context, uuid.UUID) (data.User, error)
|
||||
UpdatePasswordFunc func(context.Context, data.UpdatePasswordParams) error
|
||||
DeleteUserFunc func(context.Context, uuid.UUID) error
|
||||
RevokeAllUserRefreshTokensFunc func(context.Context, uuid.UUID) error
|
||||
}
|
||||
|
||||
func (m *mockUserStore) CreateUser(ctx context.Context, arg data.CreateUserParams) (data.User, error) {
|
||||
return m.CreateUserFunc(ctx, arg)
|
||||
}
|
||||
|
||||
func (m *mockUserStore) ListUsers(ctx context.Context) ([]data.User, error) {
|
||||
return m.ListUsersFunc(ctx)
|
||||
}
|
||||
|
||||
func (m *mockUserStore) GetUserByID(ctx context.Context, id uuid.UUID) (data.User, error) {
|
||||
return m.GetUserByIDFunc(ctx, id)
|
||||
}
|
||||
|
||||
func (m *mockUserStore) UpdatePassword(ctx context.Context, arg data.UpdatePasswordParams) error {
|
||||
return m.UpdatePasswordFunc(ctx, arg)
|
||||
}
|
||||
|
||||
func (m *mockUserStore) DeleteUser(ctx context.Context, id uuid.UUID) error {
|
||||
return m.DeleteUserFunc(ctx, id)
|
||||
}
|
||||
|
||||
func (m *mockUserStore) RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error {
|
||||
return m.RevokeAllUserRefreshTokensFunc(ctx, id)
|
||||
}
|
||||
|
||||
func TestCreateUser_Duplicate(t *testing.T) {
|
||||
mockStore := &mockUserStore{
|
||||
CreateUserFunc: func(ctx context.Context, arg data.CreateUserParams) (data.User, error) {
|
||||
return data.User{}, &pgconn.PgError{Code: "23505"}
|
||||
},
|
||||
}
|
||||
|
||||
rs := usersResource{Users: mockStore, JWTSecret: "test-secret"}
|
||||
|
||||
reqBody := `{"username": "existing", "password": "validPass123!"}`
|
||||
req := httptest.NewRequest("POST", "/", strings.NewReader(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
rs.Create(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusConflict, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "Username is already in use")
|
||||
}
|
||||
|
||||
func TestCreateUser_InvalidUsername(t *testing.T) {
|
||||
mockStore := &mockUserStore{} // No DB calls expected
|
||||
rs := usersResource{Users: mockStore, JWTSecret: "test-secret"}
|
||||
|
||||
// Test various invalid usernames
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
}{
|
||||
{"Too short", `{"username": "a", "password": "validPass123!"}`},
|
||||
{"Invalid chars", `{"username": "user@name", "password": "validPass123!"}`},
|
||||
{"Empty", `{"username": "", "password": "validPass123!"}`},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/", strings.NewReader(tc.body))
|
||||
w := httptest.NewRecorder()
|
||||
rs.Create(w, req)
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateUser_InvalidPassword(t *testing.T) {
|
||||
mockStore := &mockUserStore{}
|
||||
rs := usersResource{Users: mockStore, JWTSecret: "test-secret"}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
}{
|
||||
{"too short", fmt.Sprintf(`{"username": "valid", "password": "%s"}`, strings.Repeat("a", minPasswordLength-1))},
|
||||
{"too long", fmt.Sprintf(`{"username": "valid", "password": "%s"}`, strings.Repeat("a", maxPasswordLength+1))},
|
||||
{"low entropy", fmt.Sprintf(`{"username": "valid", "password": "%s"}`, strings.Repeat("a", minPasswordLength))},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/", strings.NewReader(tc.body))
|
||||
w := httptest.NewRecorder()
|
||||
rs.Create(w, req)
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestListUsers_Success(t *testing.T) {
|
||||
testUsers := []data.User{
|
||||
{ID: uuid.New(), Username: "user1"},
|
||||
{ID: uuid.New(), Username: "user2"},
|
||||
}
|
||||
|
||||
mockStore := &mockUserStore{
|
||||
ListUsersFunc: func(ctx context.Context) ([]data.User, error) {
|
||||
return testUsers, nil
|
||||
},
|
||||
}
|
||||
|
||||
rs := usersResource{Users: mockStore}
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
rs.List(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var response []map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
assert.Len(t, response, 2)
|
||||
assert.Equal(t, "user1", response[0]["username"])
|
||||
assert.NotContains(t, response[0], "password_hash")
|
||||
}
|
||||
|
||||
func TestUpdatePassword_InvalidOldPassword(t *testing.T) {
|
||||
// User with password hash that won't match "wrongpassword"
|
||||
user := data.User{
|
||||
ID: uuid.New(),
|
||||
PasswordHash: "$2a$10$PHhno.bZBF8IEINdFRZAPujMxIN65msElATgJG6FIxZdeWYVLSfFi", // Hash of "correctpassword"
|
||||
}
|
||||
|
||||
mockStore := &mockUserStore{}
|
||||
rs := usersResource{Users: mockStore}
|
||||
|
||||
reqBody := `{"old_password": "wrongpassword", "new_password": "NewValidPass321!"}`
|
||||
req := httptest.NewRequest("PUT", "/", strings.NewReader(reqBody))
|
||||
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, user))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
rs.UpdatePassword(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
}
|
||||
|
||||
func TestAdminDelete_Success(t *testing.T) {
|
||||
user := data.User{ID: uuid.New()}
|
||||
deleteCalled := false
|
||||
revokeCalled := false
|
||||
|
||||
mockStore := &mockUserStore{
|
||||
DeleteUserFunc: func(ctx context.Context, id uuid.UUID) error {
|
||||
deleteCalled = true
|
||||
assert.Equal(t, user.ID, id)
|
||||
return nil
|
||||
},
|
||||
RevokeAllUserRefreshTokensFunc: func(ctx context.Context, id uuid.UUID) error {
|
||||
revokeCalled = true
|
||||
assert.Equal(t, user.ID, id)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
rs := usersResource{Users: mockStore}
|
||||
req := httptest.NewRequest("DELETE", "/", nil)
|
||||
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, user))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
rs.AdminDelete(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusNoContent, w.Code)
|
||||
assert.True(t, deleteCalled)
|
||||
assert.True(t, revokeCalled)
|
||||
}
|
||||
|
||||
func TestOwnerDelete_InvalidCredentials(t *testing.T) {
|
||||
// Create user with known password hash
|
||||
correctPassword := "CorrectPass123!"
|
||||
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(correctPassword), bcrypt.DefaultCost)
|
||||
user := data.User{
|
||||
ID: uuid.New(),
|
||||
PasswordHash: string(hashedPassword),
|
||||
}
|
||||
|
||||
mockStore := &mockUserStore{}
|
||||
rs := usersResource{Users: mockStore}
|
||||
|
||||
reqBody := `{"password": "wrongpassword"}`
|
||||
req := httptest.NewRequest("DELETE", "/", strings.NewReader(reqBody))
|
||||
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, user))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
rs.OwnerDelete(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
}
|
||||
|
||||
func TestGetUser_NotFound(t *testing.T) {
|
||||
mockStore := &mockUserStore{}
|
||||
rs := usersResource{Users: mockStore}
|
||||
|
||||
// No user in context
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
rs.Get(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
}
|
||||
|
||||
func TestUpdatePassword_DatabaseError(t *testing.T) {
|
||||
// Add user with a valid password to the context
|
||||
oldPassword := "OldValidPass321!"
|
||||
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(oldPassword), bcrypt.DefaultCost)
|
||||
user := data.User{
|
||||
ID: uuid.New(),
|
||||
PasswordHash: string(hashedPassword),
|
||||
}
|
||||
mockStore := &mockUserStore{
|
||||
UpdatePasswordFunc: func(ctx context.Context, arg data.UpdatePasswordParams) error {
|
||||
return errors.New("database error")
|
||||
},
|
||||
}
|
||||
|
||||
rs := usersResource{Users: mockStore}
|
||||
|
||||
reqBody := fmt.Sprintf(`{"old_password": "%s", "new_password": "NewValidPass123!"}`, oldPassword)
|
||||
req := httptest.NewRequest("PUT", "/", strings.NewReader(reqBody))
|
||||
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, user))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
rs.UpdatePassword(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "Failed to update password")
|
||||
}
|
133
server/pkg/service/util.go
Normal file
133
server/pkg/service/util.go
Normal file
@ -0,0 +1,133 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
passwordvalidator "github.com/wagslane/go-password-validator"
|
||||
)
|
||||
|
||||
const (
|
||||
minPasswordLength = 12 // Entropy checks prevent short passwords anyway
|
||||
maxPasswordLength = 72 // Limitation of bcrypt
|
||||
minPasswordEntropy = 60.0
|
||||
|
||||
minUsernameLength = 3
|
||||
maxUsernameLength = 20
|
||||
|
||||
hibpAPI = "https://api.pwnedpasswords.com/range" // Doesn't require an API key
|
||||
)
|
||||
|
||||
var (
|
||||
usernameRegex = regexp.MustCompile("^[a-z0-9_]+$")
|
||||
)
|
||||
|
||||
func respondJSON(w http.ResponseWriter, status int, data interface{}) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
json.NewEncoder(w).Encode(data)
|
||||
}
|
||||
|
||||
func respondError(w http.ResponseWriter, status int, message string) {
|
||||
respondJSON(w, status, map[string]string{"error": message})
|
||||
}
|
||||
|
||||
/*
|
||||
Client-side check:
|
||||
|
||||
```
|
||||
function estimateEntropy(password: string): number {
|
||||
const pool: number = getCharsetSize(password); // Character diversity (R)
|
||||
const entropy: number = password.length * Math.log2(pool); // E = L * log_2(R)
|
||||
return entropy; // Value (E) that can be compared against a hardcoded threshold (e.g. 60)
|
||||
}
|
||||
```
|
||||
*/
|
||||
|
||||
// Validate the given password using a hybrid approach: length (max. set due to bcrypt's input
|
||||
// upper limit of 72 bytes), entropy, and HIBP API.
|
||||
func validatePassword(password string) error {
|
||||
if len(password) < minPasswordLength {
|
||||
return fmt.Errorf("password must be at least %d characters", minPasswordLength)
|
||||
}
|
||||
if len(password) > maxPasswordLength {
|
||||
return fmt.Errorf("password cannot be longer than %d characters", maxPasswordLength)
|
||||
}
|
||||
|
||||
// Formatted error message will contain tips to increase the password strength (safe to show)
|
||||
err := passwordvalidator.Validate(password, minPasswordEntropy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if compromised, _ := isPasswordCompromised(password); compromised {
|
||||
return errors.New("password is compromised")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send the first five bytes of the password's SHA-1 hash to HIBP API, then check if the rest of
|
||||
// the hash is present in the APi's response data (k-Anonymity model).
|
||||
func isPasswordCompromised(password string) (bool, error) {
|
||||
hash := sha1.Sum([]byte(password))
|
||||
hashStr := strings.ToUpper(hex.EncodeToString(hash[:]))
|
||||
prefix, suffix := hashStr[:5], hashStr[5:]
|
||||
|
||||
resp, err := http.Get(fmt.Sprintf("%s/%s", hibpAPI, prefix))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return strings.Contains(strings.ToUpper(string(body)), suffix), nil
|
||||
}
|
||||
|
||||
// Normalize the username by making it lowercase and trimming any leading or trailing whitespace.
|
||||
func normalizeUsername(username string) string {
|
||||
return strings.ToLower(strings.TrimSpace(username))
|
||||
}
|
||||
|
||||
/*
|
||||
Client-side check (additionally input should automatically perform the normalization steps):
|
||||
|
||||
```
|
||||
function validateUsername(username: string): string {
|
||||
const min: number = 3, max: number = 20;
|
||||
if (username.length < min) return "Too short";
|
||||
if (username.length > max) return "Too long";
|
||||
if (!/^[a-zA-Z0-9_]+$/.test(username)) return "Invalid characters";
|
||||
return "Valid";
|
||||
}
|
||||
```
|
||||
*/
|
||||
|
||||
// Validate the given username by making sure it only contains alphanumeric characters or
|
||||
// underscores and adheres the hardcoded minimum and maximum length rules.
|
||||
func validateUsername(username string) error {
|
||||
if utf8.RuneCountInString(username) < minUsernameLength {
|
||||
return fmt.Errorf("username must be at least %d characters", minUsernameLength)
|
||||
}
|
||||
if utf8.RuneCountInString(username) > maxUsernameLength {
|
||||
return fmt.Errorf("username cannot be longer than %d characters", maxUsernameLength)
|
||||
}
|
||||
|
||||
if !usernameRegex.MatchString(username) {
|
||||
return errors.New("username can only contain numbers, letters, and underscores")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
184
server/pkg/service/util_test.go
Normal file
184
server/pkg/service/util_test.go
Normal file
@ -0,0 +1,184 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
passwordvalidator "github.com/wagslane/go-password-validator"
|
||||
)
|
||||
|
||||
type MockHTTPClient struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockHTTPClient) Get(url string) (*http.Response, error) {
|
||||
args := m.Called(url)
|
||||
return args.Get(0).(*http.Response), args.Error(1)
|
||||
}
|
||||
|
||||
func TestValidatePassword(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
password string
|
||||
wantErr string
|
||||
mockHTTP func(*MockHTTPClient)
|
||||
}{
|
||||
{
|
||||
name: "too short",
|
||||
password: strings.Repeat("a", minPasswordLength-1),
|
||||
wantErr: fmt.Sprintf("password must be at least %d characters", minPasswordLength),
|
||||
},
|
||||
{
|
||||
name: "too long",
|
||||
password: strings.Repeat("a", maxPasswordLength+1),
|
||||
wantErr: fmt.Sprintf("password cannot be longer than %d characters", maxPasswordLength),
|
||||
},
|
||||
{
|
||||
name: "low entropy",
|
||||
password: strings.Repeat("a", minPasswordLength),
|
||||
wantErr: "insecure password", // Error produced by wagslane/go-password-validator
|
||||
},
|
||||
{
|
||||
name: "valid password",
|
||||
password: "SecurePassw0rd!123",
|
||||
wantErr: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Mock HTTP client if needed
|
||||
if tt.mockHTTP != nil {
|
||||
mockClient := new(MockHTTPClient)
|
||||
tt.mockHTTP(mockClient)
|
||||
}
|
||||
|
||||
err := validatePassword(tt.password)
|
||||
if tt.wantErr == "" {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.ErrorContains(t, err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsPasswordCompromised(t *testing.T) {
|
||||
t.Run("known compromised", func(t *testing.T) {
|
||||
compromised, err := isPasswordCompromised("password123456")
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, compromised)
|
||||
})
|
||||
|
||||
t.Run("randomly generated", func(t *testing.T) {
|
||||
randomStr := genRandomString(12)
|
||||
compromised, err := isPasswordCompromised(randomStr)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, compromised)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPasswordEntropyCalculation(t *testing.T) {
|
||||
tests := []struct {
|
||||
password string
|
||||
entropy float64
|
||||
}{
|
||||
{"password", 37.6},
|
||||
{"SecurePassw0rd!123", 103.12},
|
||||
{"aaaaaaaaaaaaaaaa", 9.5},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.password, func(t *testing.T) {
|
||||
entropy := passwordvalidator.GetEntropy(tt.password)
|
||||
assert.InDelta(t, tt.entropy, entropy, 1.0)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateUsername(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "too short",
|
||||
input: strings.Repeat("a", minUsernameLength-1),
|
||||
wantErr: fmt.Sprintf("username must be at least %d characters", minUsernameLength),
|
||||
},
|
||||
{
|
||||
name: "too long",
|
||||
input: strings.Repeat("a", maxUsernameLength+1),
|
||||
wantErr: fmt.Sprintf("username cannot be longer than %d characters", maxUsernameLength),
|
||||
},
|
||||
{
|
||||
name: "invalid characters",
|
||||
input: "user@name",
|
||||
wantErr: "username can only contain numbers, letters, and underscores",
|
||||
},
|
||||
{
|
||||
name: "valid username",
|
||||
input: "valid_user123",
|
||||
wantErr: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateUsername(tt.input)
|
||||
if tt.wantErr == "" {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.ErrorContains(t, err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeUsername(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "trim whitespace",
|
||||
input: " test_user ",
|
||||
want: "test_user",
|
||||
},
|
||||
{
|
||||
name: "lowercase",
|
||||
input: "TestUser",
|
||||
want: "testuser",
|
||||
},
|
||||
{
|
||||
name: "mixed case and spaces",
|
||||
input: " UserName123 ",
|
||||
want: "username123",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := normalizeUsername(tt.input)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func genRandomString(length int) string {
|
||||
const charset = "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
seededRand := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
b := make([]byte, length)
|
||||
for i := range b {
|
||||
b[i] = charset[seededRand.Intn(len(charset))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user