notatest/server/pkg/service/middleware.go
2025-04-01 18:48:32 +03:00

197 lines
5.8 KiB
Go

package service
import (
"context"
"fmt"
"net/http"
"time"
"git.umbrella.haus/ae/notatest/pkg/data"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/rs/zerolog"
)
const (
panicRecoveryMsg = "panic recovered"
defaultLogMsg = "incoming request"
)
// Get JWT bearer from request's authorization header, parse it with custom user claims, and
// ensure its validity before attaching the claims to the request's context.
func authMiddleware(jwtSecret string, expectedType string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tokenString, err := getTokenFromRequest(r)
if err != nil {
respondError(w, http.StatusUnauthorized, fmt.Sprintf("Unauthorized: %s", err))
}
token, err := jwt.ParseWithClaims(tokenString, &userClaims{}, func(token *jwt.Token) (any, error) {
return []byte(jwtSecret), nil
})
if err != nil || !token.Valid {
respondError(w, http.StatusUnauthorized, "Invalid token")
return
}
claims, ok := token.Claims.(*userClaims)
if !ok || claims.TokenType != expectedType {
respondError(w, http.StatusUnauthorized, "Invalid token type")
return
}
ctx := context.WithValue(r.Context(), userCtxKey{}, claims)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// JWT access token parsing, verification, and validation.
func requireAccessToken(jwtSecret string) func(http.Handler) http.Handler {
return authMiddleware(jwtSecret, "access")
}
// JWT refresh token parsing, verification, and validation.
func requireRefreshToken(jwtSecret string) func(http.Handler) http.Handler {
return authMiddleware(jwtSecret, "refresh")
}
// Ensure the current user is an administrator.
func adminOnlyMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
if !ok || !user.Admin {
respondError(w, http.StatusForbidden, "Forbidden")
return
}
next.ServeHTTP(w, r)
})
}
// Ensure the targeted resource is owned by the current user (i.e. current user's ID matches with
// the one stored into the resource).
func ownerOnlyMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
requestedID := chi.URLParam(r, "id")
if !ok || user.Subject != requestedID {
respondError(w, http.StatusForbidden, "Forbidden")
return
}
next.ServeHTTP(w, r)
})
}
// Append user data into request's context based on user ID as a URL parameter.
func userCtx(store UserStore) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
userIDStr := chi.URLParam(r, "id")
userID, err := uuid.Parse(userIDStr)
if err != nil {
respondError(w, http.StatusNotFound, "Invalid user ID")
return
}
user, err := store.GetUserByID(r.Context(), userID)
if err != nil {
respondError(w, http.StatusNotFound, "User not found")
return
}
ctx := context.WithValue(r.Context(), userCtxKey{}, user)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// Append note data into request's context based on note ID as a URL parameter and user ID as
// context parameter.
func noteCtx(store NoteStore) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
noteIDStr := chi.URLParam(r, "id")
noteID, err := uuid.Parse(noteIDStr)
if err != nil {
respondError(w, http.StatusNotFound, "Invalid note ID")
return
}
// NOTE: user must already be in the context (e.g. via JWT middleware)
user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
if !ok {
respondError(w, http.StatusUnauthorized, "Unauthorized")
return
}
userID, err := uuid.Parse(user.Subject)
if err != nil {
respondError(w, http.StatusInternalServerError, "Invalid user ID")
return
}
note, err := store.GetNote(r.Context(), data.GetNoteParams{
ID: noteID,
UserID: userID,
})
if err != nil {
respondError(w, http.StatusNotFound, "Note not found")
return
}
ctx := context.WithValue(r.Context(), noteCtxKey{}, note)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// Zerolog compatible logger middleware. Automatically logs and recovers from errors with HTTP 500
// response, by default logs to INFO level.
func loggerMiddleware(log *zerolog.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
log := log.With().Logger()
ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
t1 := time.Now()
defer func() {
t2 := time.Now()
// Recover automatically and respond with HTTP 500
if rec := recover(); rec != nil {
log.Error().
Str("type", "error").
Timestamp().
Interface("recover_info", rec).
Msg(panicRecoveryMsg)
http.Error(ww, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}
// Log a regular HTTP request with some metadata
log.Info().
Str("type", "access").
Timestamp().
Fields(map[string]any{
"remote_ip": r.RemoteAddr,
"url": r.URL.Path,
"proto": r.Proto,
"method": r.Method,
"user_agent": r.Header.Get("User-Agent"),
"status": ww.Status(),
"latency_ms": float64(t2.Sub(t1).Nanoseconds()) / 1000000.0,
"bytes_in": r.Header.Get("Content-Length"),
"bytes_out": ww.BytesWritten(),
}).
Msg(defaultLogMsg)
}()
next.ServeHTTP(ww, r)
}
return http.HandlerFunc(fn)
}
}