package service import ( "context" "net/http" "time" "git.umbrella.haus/ae/notatest/internal/data" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" "github.com/rs/zerolog" ) const ( panicRecoveryMsg = "panic recovered" defaultLogMsg = "incoming request" noteUUIDCtxParameter = "noteID" versionUUIDCtxParameter = "versionID" targetUserUUIDCtxParameter = "targetID" ) // General resource ID (UUID) context key. type uuidCtxKey struct { Name string } // Get JWT bearer from request's authorization header, parse it with custom user claims, and // ensure its validity before attaching the claims to the request's context. func authMiddleware(jwtSecret string, expectedType string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tokenString, err := getTokenFromRequest(r) if err != nil { respondError(w, http.StatusUnauthorized, "Unauthorized") return } token, err := jwt.ParseWithClaims(tokenString, &userClaims{}, func(token *jwt.Token) (any, error) { return []byte(jwtSecret), nil }) if err != nil || !token.Valid { respondError(w, http.StatusUnauthorized, "Invalid token") return } claims, ok := token.Claims.(*userClaims) if !ok || claims.TokenType != expectedType { respondError(w, http.StatusUnauthorized, "Invalid token type") return } ctx := context.WithValue(r.Context(), userCtxKey{}, claims) next.ServeHTTP(w, r.WithContext(ctx)) }) } } // JWT access token parsing, verification, and validation. func requireAccessToken(jwtSecret string) func(http.Handler) http.Handler { return authMiddleware(jwtSecret, "access") } // JWT refresh token parsing, verification, and validation. func requireRefreshToken(jwtSecret string) func(http.Handler) http.Handler { return authMiddleware(jwtSecret, "refresh") } // Ensure the current user is an administrator. Can be used to protect routes that can be utilized // to view/modify/delete accounts that the current user isn't the owner of. func adminOnlyMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { user, ok := r.Context().Value(userCtxKey{}).(*userClaims) if !ok || !user.Admin { respondError(w, http.StatusForbidden, "Forbidden") return } next.ServeHTTP(w, r) }) } // Append UUID from the given URL parameter to the request's context (`uuidCtxKey` with the // parameter name as the "context identifier"). func uuidCtx(parameter string) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { uuidParam := chi.URLParam(r, parameter) resourceID, err := uuid.Parse(uuidParam) if err != nil { respondError(w, http.StatusBadRequest, "Invalid resource ID") return } ctx := context.WithValue(r.Context(), uuidCtxKey{Name: parameter}, resourceID) next.ServeHTTP(w, r.WithContext(ctx)) }) } } // Append full note data (metadata + active version) into request's context based on note ID as a // URL parameter and user ID as context parameter. Must be chained with `uuidCtx` to parse the // resource ID into the request's context. func noteCtx(store NoteStore) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() noteID, ok := ctx.Value(uuidCtxKey{Name: noteUUIDCtxParameter}).(uuid.UUID) if !ok { respondError(w, http.StatusBadRequest, "Resource ID missing") return } user, ok := ctx.Value(userCtxKey{}).(*userClaims) if !ok { respondError(w, http.StatusUnauthorized, "Unauthorized") return } userID, err := uuid.Parse(user.Subject) if err != nil { respondError(w, http.StatusUnauthorized, "Invalid token") return } // Get the "full note" (metadata + active version) with a single query fullNote, err := store.GetFullNote(r.Context(), noteID) if err != nil { respondError(w, http.StatusNotFound, "Note not found") return } // Validate note ownership if userID != fullNote.OwnerID { respondError(w, http.StatusForbidden, "Forbidden") return } ctx = context.WithValue(r.Context(), noteCtxKey{}, &fullNote) next.ServeHTTP(w, r.WithContext(ctx)) }) } } // Append single version's data into request's context based on version ID as a URL parameter and // note ID as context parameter. Must be chained with `noteCtx` and `uuidCtx` to parse the necessary // resource IDs into request's context. func versionCtx(store NoteStore) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() fullNote, ok := ctx.Value(noteCtxKey{}).(*data.GetFullNoteRow) if !ok { respondError(w, http.StatusNotFound, "Note not found") return } versionID, ok := ctx.Value(uuidCtxKey{Name: versionUUIDCtxParameter}).(uuid.UUID) if !ok { respondError(w, http.StatusBadRequest, "Resource ID missing") return } version, err := store.GetVersion(r.Context(), data.GetVersionParams{ NoteID: fullNote.NoteID, ID: versionID, }) if err != nil { respondError(w, http.StatusNotFound, "Version not found") return } ctx = context.WithValue(r.Context(), versionCtxKey{}, &version) next.ServeHTTP(w, r.WithContext(ctx)) }) } } // Zerolog compatible logger middleware. Automatically logs and recovers from errors with HTTP 500 // response, by default logs to INFO level. func loggerMiddleware(log *zerolog.Logger) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { log := log.With().Logger() ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor) t1 := time.Now() defer func() { t2 := time.Now() // Recover automatically and respond with HTTP 500 if rec := recover(); rec != nil { log.Error(). Str("type", "error"). Timestamp(). Interface("recover_info", rec). Msg(panicRecoveryMsg) http.Error(ww, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) } // Log a regular HTTP request with some metadata log.Info(). Str("type", "access"). Timestamp(). Fields(map[string]any{ "remote_ip": r.RemoteAddr, "url": r.URL.Path, "proto": r.Proto, "method": r.Method, "user_agent": r.Header.Get("User-Agent"), "status": ww.Status(), "latency_ms": float64(t2.Sub(t1).Nanoseconds()) / 1000000.0, "bytes_in": r.Header.Get("Content-Length"), "bytes_out": ww.BytesWritten(), }). Msg(defaultLogMsg) }() next.ServeHTTP(ww, r) } return http.HandlerFunc(fn) } }