package service import ( "context" "fmt" "net/http" "time" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" "github.com/rs/zerolog" ) const ( panicRecoveryMsg = "panic recovered" defaultLogMsg = "incoming request" ) type userCtxKey struct{} // Get JWT bearer from request's authorization header, parse it with custom user claims, and // ensure its validity before attaching the claims to the request's context. func authMiddleware(jwtSecret string, expectedType string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tokenString, err := getTokenFromRequest(r) if err != nil { respondError(w, http.StatusUnauthorized, fmt.Sprintf("Unauthorized: %s", err)) } token, err := jwt.ParseWithClaims(tokenString, &userClaims{}, func(token *jwt.Token) (interface{}, error) { return []byte(jwtSecret), nil }) if err != nil || !token.Valid { respondError(w, http.StatusUnauthorized, "Invalid token") return } claims, ok := token.Claims.(*userClaims) if !ok || claims.TokenType != expectedType { respondError(w, http.StatusUnauthorized, "Invalid token type") return } ctx := context.WithValue(r.Context(), userCtxKey{}, claims) next.ServeHTTP(w, r.WithContext(ctx)) }) } } // JWT access token parsing, verification, and validation. func requireAccessToken(jwtSecret string) func(http.Handler) http.Handler { return authMiddleware(jwtSecret, "access") } // JWT refresh token parsing, verification, and validation. func requireRefreshToken(jwtSecret string) func(http.Handler) http.Handler { return authMiddleware(jwtSecret, "refresh") } // Ensure the current user is an administrator. func adminOnlyMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { user, ok := r.Context().Value(userCtxKey{}).(*userClaims) if !ok || !user.Admin { respondError(w, http.StatusForbidden, "Forbidden") return } next.ServeHTTP(w, r) }) } // Ensure the targeted resource is owned by the current user (i.e. current user's ID matches with // the one stored into the resource). func ownerOnlyMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { user, ok := r.Context().Value(userCtxKey{}).(*userClaims) requestedID := chi.URLParam(r, "id") if !ok || user.ID != requestedID { respondError(w, http.StatusForbidden, "Forbidden") return } next.ServeHTTP(w, r) }) } // Append user data into request's context based on user ID as a URL parameter. func userCtx(store UserStore) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { userIDStr := chi.URLParam(r, "id") userID, err := uuid.Parse(userIDStr) if err != nil { respondError(w, http.StatusNotFound, "Invalid user ID") return } user, err := store.GetUserByID(r.Context(), userID) if err != nil { respondError(w, http.StatusNotFound, "User not found") return } ctx := context.WithValue(r.Context(), userCtxKey{}, user) next.ServeHTTP(w, r.WithContext(ctx)) }) } } // Zerolog compatible logger middleware. Automatically logs and recovers from errors with HTTP 500 // response, by default logs to INFO level. func loggerMiddleware(log *zerolog.Logger) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { log := log.With().Logger() ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor) t1 := time.Now() defer func() { t2 := time.Now() // Recover automatically and respond with HTTP 500 if rec := recover(); rec != nil { log.Error(). Str("type", "error"). Timestamp(). Interface("recover_info", rec). Msg(panicRecoveryMsg) http.Error(ww, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) } // Log a regular HTTP request with some metadata log.Info(). Str("type", "access"). Timestamp(). Fields(map[string]interface{}{ "remote_ip": r.RemoteAddr, "url": r.URL.Path, "proto": r.Proto, "method": r.Method, "user_agent": r.Header.Get("User-Agent"), "status": ww.Status(), "latency_ms": float64(t2.Sub(t1).Nanoseconds()) / 1000000.0, "bytes_in": r.Header.Get("Content-Length"), "bytes_out": ww.BytesWritten(), }). Msg(defaultLogMsg) }() next.ServeHTTP(ww, r) } return http.HandlerFunc(fn) } }