158 lines
4.7 KiB
Go
158 lines
4.7 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/go-chi/chi/v5/middleware"
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/google/uuid"
|
|
"github.com/rs/zerolog"
|
|
)
|
|
|
|
const (
|
|
panicRecoveryMsg = "panic recovered"
|
|
defaultLogMsg = "incoming request"
|
|
)
|
|
|
|
type userCtxKey struct{}
|
|
|
|
// Get JWT bearer from request's authorization header, parse it with custom user claims, and
|
|
// ensure its validity before attaching the claims to the request's context.
|
|
func authMiddleware(jwtSecret string, expectedType string) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
tokenString, err := getTokenFromRequest(r)
|
|
if err != nil {
|
|
respondError(w, http.StatusUnauthorized, fmt.Sprintf("Unauthorized: %s", err))
|
|
}
|
|
|
|
token, err := jwt.ParseWithClaims(tokenString, &userClaims{}, func(token *jwt.Token) (interface{}, error) {
|
|
return []byte(jwtSecret), nil
|
|
})
|
|
if err != nil || !token.Valid {
|
|
respondError(w, http.StatusUnauthorized, "Invalid token")
|
|
return
|
|
}
|
|
|
|
claims, ok := token.Claims.(*userClaims)
|
|
if !ok || claims.TokenType != expectedType {
|
|
respondError(w, http.StatusUnauthorized, "Invalid token type")
|
|
return
|
|
}
|
|
|
|
ctx := context.WithValue(r.Context(), userCtxKey{}, claims)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
}
|
|
|
|
// JWT access token parsing, verification, and validation.
|
|
func requireAccessToken(jwtSecret string) func(http.Handler) http.Handler {
|
|
return authMiddleware(jwtSecret, "access")
|
|
}
|
|
|
|
// JWT refresh token parsing, verification, and validation.
|
|
func requireRefreshToken(jwtSecret string) func(http.Handler) http.Handler {
|
|
return authMiddleware(jwtSecret, "refresh")
|
|
}
|
|
|
|
// Ensure the current user is an administrator.
|
|
func adminOnlyMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
|
if !ok || !user.Admin {
|
|
respondError(w, http.StatusForbidden, "Forbidden")
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// Ensure the targeted resource is owned by the current user (i.e. current user's ID matches with
|
|
// the one stored into the resource).
|
|
func ownerOnlyMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
|
requestedID := chi.URLParam(r, "id")
|
|
if !ok || user.ID != requestedID {
|
|
respondError(w, http.StatusForbidden, "Forbidden")
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// Append user data into request's context based on user ID as a URL parameter.
|
|
func userCtx(store UserStore) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
userIDStr := chi.URLParam(r, "id")
|
|
userID, err := uuid.Parse(userIDStr)
|
|
if err != nil {
|
|
respondError(w, http.StatusNotFound, "Invalid user ID")
|
|
return
|
|
}
|
|
|
|
user, err := store.GetUserByID(r.Context(), userID)
|
|
if err != nil {
|
|
respondError(w, http.StatusNotFound, "User not found")
|
|
return
|
|
}
|
|
|
|
ctx := context.WithValue(r.Context(), userCtxKey{}, user)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
}
|
|
|
|
// Zerolog compatible logger middleware. Automatically logs and recovers from errors with HTTP 500
|
|
// response, by default logs to INFO level.
|
|
func loggerMiddleware(log *zerolog.Logger) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
fn := func(w http.ResponseWriter, r *http.Request) {
|
|
log := log.With().Logger()
|
|
|
|
ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
|
|
|
|
t1 := time.Now()
|
|
defer func() {
|
|
t2 := time.Now()
|
|
|
|
// Recover automatically and respond with HTTP 500
|
|
if rec := recover(); rec != nil {
|
|
log.Error().
|
|
Str("type", "error").
|
|
Timestamp().
|
|
Interface("recover_info", rec).
|
|
Msg(panicRecoveryMsg)
|
|
http.Error(ww, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
|
}
|
|
|
|
// Log a regular HTTP request with some metadata
|
|
log.Info().
|
|
Str("type", "access").
|
|
Timestamp().
|
|
Fields(map[string]interface{}{
|
|
"remote_ip": r.RemoteAddr,
|
|
"url": r.URL.Path,
|
|
"proto": r.Proto,
|
|
"method": r.Method,
|
|
"user_agent": r.Header.Get("User-Agent"),
|
|
"status": ww.Status(),
|
|
"latency_ms": float64(t2.Sub(t1).Nanoseconds()) / 1000000.0,
|
|
"bytes_in": r.Header.Get("Content-Length"),
|
|
"bytes_out": ww.BytesWritten(),
|
|
}).
|
|
Msg(defaultLogMsg)
|
|
}()
|
|
|
|
next.ServeHTTP(ww, r)
|
|
}
|
|
return http.HandlerFunc(fn)
|
|
}
|
|
}
|