feat: cors, secure cookies, & csrf

This commit is contained in:
ae 2025-04-10 21:33:01 +03:00
parent a5a443a61e
commit a969629f2d
Signed by: ae
GPG Key ID: 995EFD5C1B532B3E
8 changed files with 142 additions and 58 deletions

View File

@ -5,9 +5,11 @@ go 1.24.1
require (
github.com/caarlos0/env/v10 v10.0.0
github.com/go-chi/chi/v5 v5.2.1
github.com/go-chi/cors v1.2.1
github.com/golang-jwt/jwt/v5 v5.2.2
github.com/golang-migrate/migrate/v4 v4.18.2
github.com/google/uuid v1.6.0
github.com/gorilla/csrf v1.7.2
github.com/jackc/pgx/v5 v5.7.4
github.com/rs/zerolog v1.34.0
github.com/stretchr/testify v1.10.0
@ -17,6 +19,7 @@ require (
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/gorilla/securecookie v1.1.2 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-multierror v1.1.1 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect

View File

@ -22,6 +22,8 @@ github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/go-chi/chi/v5 v5.2.1 h1:KOIHODQj58PmL80G2Eak4WdvUzjSJSm0vG72crDCqb8=
github.com/go-chi/chi/v5 v5.2.1/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops=
github.com/go-chi/cors v1.2.1 h1:xEC8UT3Rlp2QuWNEr4Fs/c2EAGVKBwy/1vHx3bppil4=
github.com/go-chi/cors v1.2.1/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58=
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
@ -33,8 +35,14 @@ github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeD
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang-migrate/migrate/v4 v4.18.2 h1:2VSCMz7x7mjyTXx3m2zPokOY82LTRgxK1yQYKo6wWQ8=
github.com/golang-migrate/migrate/v4 v4.18.2/go.mod h1:2CM6tJvn2kqPXwnXO/d3rAQYiyoIm180VsO8PRX6Rpk=
github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/csrf v1.7.2 h1:oTUjx0vyf2T+wkrx09Trsev1TE+/EbDAeHtSTbtC2eI=
github.com/gorilla/csrf v1.7.2/go.mod h1:F1Fj3KG23WYHE6gozCmBAezKookxbIvUJT+121wTuLk=
github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA=
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=

View File

@ -16,6 +16,7 @@ import (
"github.com/go-chi/chi/v5"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/gorilla/csrf"
"github.com/jackc/pgx/v5/pgconn"
"github.com/rs/zerolog/log"
"golang.org/x/crypto/bcrypt"
@ -79,9 +80,9 @@ type UserStore interface {
// (especially in production) the `UserStore` and `TokenStore` will point to the same database
// handler, but for code readability they should be kept in separate structs.
type authResource struct {
JWTSecret string
Users UserStore
Tokens TokenStore
Config SvcConfig
Users UserStore
Tokens TokenStore
}
func (rs authResource) Routes() chi.Router {
@ -93,9 +94,9 @@ func (rs authResource) Routes() chi.Router {
// Protected routes (access token required)
r.Group(func(r chi.Router) {
r.Use(requireAccessToken(rs.JWTSecret)) // JWT claims -> ctx.
r.Get("/me", rs.Get) // GET /auth/me - current user data
r.Post("/logout", rs.Logout) // POST /auth/logout - revoke all refresh cookies
r.Use(requireAccessToken(rs.Config.JWTSecret)) // JWT claims -> ctx.
r.Get("/me", rs.Get) // GET /auth/me - current user data
r.Post("/logout", rs.Logout) // POST /auth/logout - revoke all refresh cookies
// Owner routes
r.Route("/owner", func(r chi.Router) {
@ -115,9 +116,18 @@ func (rs authResource) Routes() chi.Router {
})
// Protected routes (refresh token required)
r.Group(func(r chi.Router) {
r.Use(requireRefreshToken(rs.JWTSecret))
r.Post("/refresh", rs.RefreshAccessToken) // POST /auth/refresh - convert refresh token to new token pair
r.Route("/cookie", func(r chi.Router) {
// The refresh token httpOnly cookie is restricted to `/api/auth/cookie`, which is why this
// is the only endpoint where CSRF must be taken into account (HTTPS requirement disabled
// for local development)
if rs.Config.IsProd {
r.Use(csrf.Protect([]byte(rs.Config.CSRFSecret)))
} else {
r.Use(csrf.Protect([]byte(rs.Config.CSRFSecret), csrf.Secure(false)))
}
r.Use(requireRefreshToken(rs.Config.JWTSecret))
r.Get("/csrf", rs.GetCSRFToken) // GET /auth/cookie/csrf - get a new CSRF token in response headers
r.Post("/refresh", rs.RefreshAccessToken) // POST /auth/cookie/refresh - convert refresh token to new token pair
})
return r
@ -214,12 +224,12 @@ func (rs authResource) Login(w http.ResponseWriter, r *http.Request) {
// Set refresh token into a httpOnly cookie
http.SetCookie(w, &http.Cookie{
Name: "refresh_token",
Name: "notatest.refresh_token",
Value: tokenPair.RefreshToken,
Path: "/",
Path: "/api/auth/cookie",
MaxAge: int(refreshTokenDuration.Seconds()),
HttpOnly: true,
Secure: true,
Secure: rs.Config.IsProd,
SameSite: http.SameSiteStrictMode,
})
@ -393,7 +403,7 @@ func (rs authResource) AdminDelete(w http.ResponseWriter, r *http.Request) {
// ("refresh"/"access"). Stores a SHA256 hash of the refresh token into the database for further
// token rotations.
func (rs authResource) GenerateTokenPair(ctx context.Context, userID uuid.UUID, isAdmin bool) (*tokenPair, error) {
tokenPair, err := generateTokenPair(userID.String(), isAdmin, rs.JWTSecret)
tokenPair, err := generateTokenPair(userID.String(), isAdmin, rs.Config.JWTSecret)
if err != nil {
return nil, err
}
@ -442,6 +452,14 @@ func (rs authResource) ValidateRefreshToken(ctx context.Context, token string) (
return &dbToken, nil
}
// Handler for returning the CSRF token in the `X-CSRF-Token` header. Notably this request doesn't
// need to contain a valid CSRF token as it uses "safe" (non-mutating) method GET, which means CSRF
// won't be enforced.
func (rs authResource) GetCSRFToken(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-CSRF-Token", csrf.Token(r))
w.WriteHeader(http.StatusNoContent)
}
// Handler for performing a token rotation, i.e. invalidating the given refresh token (each refresh
// token is a single use utility) and exchanging it for a new pair of refresh and access tokens.
func (rs authResource) RefreshAccessToken(w http.ResponseWriter, r *http.Request) {
@ -486,12 +504,12 @@ func (rs authResource) RefreshAccessToken(w http.ResponseWriter, r *http.Request
// Set refresh token into a httpOnly cookie
http.SetCookie(w, &http.Cookie{
Name: "refresh_token",
Name: "notatest.refresh_token",
Value: tokenPair.RefreshToken,
Path: "/",
Path: "/api/auth/cookie",
MaxAge: int(refreshTokenDuration.Seconds()),
HttpOnly: true,
Secure: true,
Secure: rs.Config.IsProd,
SameSite: http.SameSiteStrictMode,
})
@ -520,12 +538,12 @@ func (rs authResource) Logout(w http.ResponseWriter, r *http.Request) {
// Clear the refresh token cookie
http.SetCookie(w, &http.Cookie{
Name: "refresh_token",
Name: "notatest.refresh_token",
Value: "",
Path: "/",
Path: "/api/auth/cookie",
MaxAge: 0, // Expires immediately
HttpOnly: true,
Secure: true,
Secure: rs.Config.IsProd,
SameSite: http.SameSiteStrictMode,
})
@ -537,6 +555,30 @@ func (rs authResource) Logout(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}
// Parse JWT claims (`userClaims`) from the request's context, and perform a database lookup based
// on `Subject` (after parsing it to `uuid.UUID`) to fetch the corresponding user's data.
func (rs authResource) userFromCtxClaims(w http.ResponseWriter, r *http.Request) *data.User {
claims, ok := r.Context().Value(userCtxKey{}).(*userClaims)
if !ok {
respondError(w, http.StatusUnauthorized, "Unauthorized")
return nil
}
userID, err := uuid.Parse(claims.Subject)
if err != nil {
respondError(w, http.StatusBadRequest, "Invalid user ID")
return nil
}
user, err := rs.Users.GetUserByID(r.Context(), userID)
if err != nil {
respondError(w, http.StatusNotFound, "User not found")
return nil
}
return &user
}
// Helper function for generating the initial administrator level account if one doesn't already
// exists in the database.
func CreateAdminIfNotExists(ctx context.Context, q *data.Queries, username, password string) error {
@ -590,7 +632,7 @@ func getTokenFromHeader(r *http.Request) (string, error) {
// Parse the JWT token from the request's cookies (httpOnly cookie).
func getTokenFromCookie(r *http.Request) (string, error) {
cookie, err := r.Cookie("refresh_token")
cookie, err := r.Cookie("notatest.refresh_token")
if err != nil {
return "", err
}
@ -636,30 +678,6 @@ func generateTokenPair(userID string, isAdmin bool, jwtSecret string) (*tokenPai
return &tokenPair{AccessToken: t, RefreshToken: rt}, nil
}
// Parse JWT claims (`userClaims`) from the request's context, and perform a database lookup based
// on `Subject` (after parsing it to `uuid.UUID`) to fetch the corresponding user's data.
func (rs authResource) userFromCtxClaims(w http.ResponseWriter, r *http.Request) *data.User {
claims, ok := r.Context().Value(userCtxKey{}).(*userClaims)
if !ok {
respondError(w, http.StatusUnauthorized, "Unauthorized")
return nil
}
userID, err := uuid.Parse(claims.Subject)
if err != nil {
respondError(w, http.StatusBadRequest, "Invalid user ID")
return nil
}
user, err := rs.Users.GetUserByID(r.Context(), userID)
if err != nil {
respondError(w, http.StatusNotFound, "User not found")
return nil
}
return &user
}
// Check if the given error is a PostgreSQL error for `unique_violation` (error code 23505), i.e.
// whether an entry with the given details already exists in the database table.
func isDuplicateEntry(err error) bool {

View File

@ -210,7 +210,7 @@ func loggerMiddleware(log *zerolog.Logger) func(http.Handler) http.Handler {
}
// Log a regular HTTP request with some metadata
log.Info().
log.Debug().
Str("type", "access").
Timestamp().
Fields(map[string]any{

View File

@ -110,7 +110,7 @@ func TestRequireRTMiddleware(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
if tc.token != "" {
req.AddCookie(&http.Cookie{
Name: "refresh_token",
Name: "notatest.refresh_token",
Value: tc.token,
HttpOnly: true,
})

View File

@ -38,15 +38,15 @@ type NoteStore interface {
// Chi HTTP router for notes related CRUD actions.
type notesResource struct {
JWTSecret string
Notes NoteStore
Config SvcConfig
Notes NoteStore
}
func (rs notesResource) Routes() chi.Router {
r := chi.NewRouter()
r.Group(func(r chi.Router) {
r.Use(requireAccessToken(rs.JWTSecret)) // JWT claims -> ctx.
r.Use(requireAccessToken(rs.Config.JWTSecret)) // JWT claims -> ctx.
r.Post("/", rs.Create) // POST /notes - create new note
r.Get("/", rs.ListMetadata) // GET /notes - get all notes (metadata + titles)

View File

@ -6,27 +6,61 @@ import (
"git.umbrella.haus/ae/notatest/internal/data"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/go-chi/cors"
"github.com/jackc/pgx/v5"
"github.com/rs/zerolog/log"
)
func Run(conn *pgx.Conn, q *data.Queries, jwtSecret string) error {
type SvcConfig struct {
JWTSecret string
CSRFSecret string
IsProd bool
Domain string
FrontendURL string
}
func (sc *SvcConfig) allowedOrigins() []string {
var allowed []string
if sc.IsProd {
allowed = []string{sc.FrontendURL}
} else {
allowed = []string{"http://localhost:3000"}
}
log.Debug().Msgf("CORS allowedOrigins: %v", allowed)
return allowed
}
func Run(conn *pgx.Conn, q *data.Queries, config SvcConfig) error {
r := chi.NewRouter()
if !config.IsProd {
log.Warn().Msg("Running in *INSECURE* development mode")
}
authRouter := authResource{
JWTSecret: jwtSecret,
Users: q,
Tokens: q,
Config: config,
Users: q,
Tokens: q,
}
notesRouter := notesResource{
JWTSecret: jwtSecret,
Notes: q,
Config: config,
Notes: q,
}
// Global middlewares
r.Use(middleware.RequestID)
r.Use(middleware.RealIP)
r.Use(loggerMiddleware(&log.Logger))
r.Use(cors.Handler(cors.Options{
AllowedOrigins: config.allowedOrigins(),
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
ExposedHeaders: []string{"List"},
AllowCredentials: true,
MaxAge: 300,
}))
r.Use(middleware.Recoverer)
r.Use(middleware.AllowContentType("application/json"))
@ -35,8 +69,15 @@ func Run(conn *pgx.Conn, q *data.Queries, jwtSecret string) error {
r.Route("/api", func(r chi.Router) {
r.Mount("/auth", authRouter.Routes())
r.Mount("/notes", notesRouter.Routes())
r.Get("/ping", ping)
})
log.Info().Msg("Starting server on :8080")
return http.ListenAndServe(":8080", r)
}
func ping(w http.ResponseWriter, r *http.Request) {
respondJSON(w, http.StatusOK, map[string]string{
"message": "pong",
})
}

View File

@ -21,15 +21,20 @@ import (
var migrationsFS embed.FS
var (
config Config
config Config
svcConfig service.SvcConfig
)
type Config struct {
JWTSecret string `env:"JWT_SECRET,notEmpty"`
CSRFSecret string `env:"CSRF_SECRET,notEmpty"`
DatabaseURL string `env:"DB_URL,notEmpty"`
LogLevel string `env:"LOG_LEVEL" envDefault:"info"`
AdminUsername string `env:"ADMIN_USERNAME,notEmpty,unset"`
AdminPassword string `env:"ADMIN_PASSWORD,notEmpty,unset"`
Domain string `env:"DOMAIN" envDefault:"localhost"`
FrontendURL string `env:"FRONTEND_URL" envDefault:"http://localhost:5173"`
LogLevel string `env:"LOG_LEVEL" envDefault:"info"`
AppEnv string `env:"APP_ENV" envDefault:"production"`
}
func init() {
@ -38,6 +43,15 @@ func init() {
log.Fatal().Err(err).Msg("Failed to parse environment variables")
}
initLogger()
svcConfig = service.SvcConfig{
JWTSecret: config.JWTSecret,
CSRFSecret: config.CSRFSecret,
IsProd: config.AppEnv == "production",
Domain: config.Domain,
FrontendURL: config.FrontendURL,
}
log.Debug().Msg("Initialization completed")
}
@ -71,7 +85,7 @@ func main() {
}
log.Info().Msg("Migrations applied succesfully, proceeding to HTTP server startup")
service.Run(conn, q, config.JWTSecret)
service.Run(conn, q, svcConfig)
}
func initLogger() {