diff --git a/server/go.mod b/server/go.mod index 1b0e23a..bd5c341 100644 --- a/server/go.mod +++ b/server/go.mod @@ -5,9 +5,11 @@ go 1.24.1 require ( github.com/caarlos0/env/v10 v10.0.0 github.com/go-chi/chi/v5 v5.2.1 + github.com/go-chi/cors v1.2.1 github.com/golang-jwt/jwt/v5 v5.2.2 github.com/golang-migrate/migrate/v4 v4.18.2 github.com/google/uuid v1.6.0 + github.com/gorilla/csrf v1.7.2 github.com/jackc/pgx/v5 v5.7.4 github.com/rs/zerolog v1.34.0 github.com/stretchr/testify v1.10.0 @@ -17,6 +19,7 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/gorilla/securecookie v1.1.2 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect diff --git a/server/go.sum b/server/go.sum index edabdcc..a491459 100644 --- a/server/go.sum +++ b/server/go.sum @@ -22,6 +22,8 @@ github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2 github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/go-chi/chi/v5 v5.2.1 h1:KOIHODQj58PmL80G2Eak4WdvUzjSJSm0vG72crDCqb8= github.com/go-chi/chi/v5 v5.2.1/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops= +github.com/go-chi/cors v1.2.1 h1:xEC8UT3Rlp2QuWNEr4Fs/c2EAGVKBwy/1vHx3bppil4= +github.com/go-chi/cors v1.2.1/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58= github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= @@ -33,8 +35,14 @@ github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeD github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang-migrate/migrate/v4 v4.18.2 h1:2VSCMz7x7mjyTXx3m2zPokOY82LTRgxK1yQYKo6wWQ8= github.com/golang-migrate/migrate/v4 v4.18.2/go.mod h1:2CM6tJvn2kqPXwnXO/d3rAQYiyoIm180VsO8PRX6Rpk= +github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= +github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/csrf v1.7.2 h1:oTUjx0vyf2T+wkrx09Trsev1TE+/EbDAeHtSTbtC2eI= +github.com/gorilla/csrf v1.7.2/go.mod h1:F1Fj3KG23WYHE6gozCmBAezKookxbIvUJT+121wTuLk= +github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= +github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= diff --git a/server/internal/service/auth.go b/server/internal/service/auth.go index 2577dec..e7d3846 100644 --- a/server/internal/service/auth.go +++ b/server/internal/service/auth.go @@ -16,6 +16,7 @@ import ( "github.com/go-chi/chi/v5" "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" + "github.com/gorilla/csrf" "github.com/jackc/pgx/v5/pgconn" "github.com/rs/zerolog/log" "golang.org/x/crypto/bcrypt" @@ -79,9 +80,9 @@ type UserStore interface { // (especially in production) the `UserStore` and `TokenStore` will point to the same database // handler, but for code readability they should be kept in separate structs. type authResource struct { - JWTSecret string - Users UserStore - Tokens TokenStore + Config SvcConfig + Users UserStore + Tokens TokenStore } func (rs authResource) Routes() chi.Router { @@ -93,9 +94,9 @@ func (rs authResource) Routes() chi.Router { // Protected routes (access token required) r.Group(func(r chi.Router) { - r.Use(requireAccessToken(rs.JWTSecret)) // JWT claims -> ctx. - r.Get("/me", rs.Get) // GET /auth/me - current user data - r.Post("/logout", rs.Logout) // POST /auth/logout - revoke all refresh cookies + r.Use(requireAccessToken(rs.Config.JWTSecret)) // JWT claims -> ctx. + r.Get("/me", rs.Get) // GET /auth/me - current user data + r.Post("/logout", rs.Logout) // POST /auth/logout - revoke all refresh cookies // Owner routes r.Route("/owner", func(r chi.Router) { @@ -115,9 +116,18 @@ func (rs authResource) Routes() chi.Router { }) // Protected routes (refresh token required) - r.Group(func(r chi.Router) { - r.Use(requireRefreshToken(rs.JWTSecret)) - r.Post("/refresh", rs.RefreshAccessToken) // POST /auth/refresh - convert refresh token to new token pair + r.Route("/cookie", func(r chi.Router) { + // The refresh token httpOnly cookie is restricted to `/api/auth/cookie`, which is why this + // is the only endpoint where CSRF must be taken into account (HTTPS requirement disabled + // for local development) + if rs.Config.IsProd { + r.Use(csrf.Protect([]byte(rs.Config.CSRFSecret))) + } else { + r.Use(csrf.Protect([]byte(rs.Config.CSRFSecret), csrf.Secure(false))) + } + r.Use(requireRefreshToken(rs.Config.JWTSecret)) + r.Get("/csrf", rs.GetCSRFToken) // GET /auth/cookie/csrf - get a new CSRF token in response headers + r.Post("/refresh", rs.RefreshAccessToken) // POST /auth/cookie/refresh - convert refresh token to new token pair }) return r @@ -214,12 +224,12 @@ func (rs authResource) Login(w http.ResponseWriter, r *http.Request) { // Set refresh token into a httpOnly cookie http.SetCookie(w, &http.Cookie{ - Name: "refresh_token", + Name: "notatest.refresh_token", Value: tokenPair.RefreshToken, - Path: "/", + Path: "/api/auth/cookie", MaxAge: int(refreshTokenDuration.Seconds()), HttpOnly: true, - Secure: true, + Secure: rs.Config.IsProd, SameSite: http.SameSiteStrictMode, }) @@ -393,7 +403,7 @@ func (rs authResource) AdminDelete(w http.ResponseWriter, r *http.Request) { // ("refresh"/"access"). Stores a SHA256 hash of the refresh token into the database for further // token rotations. func (rs authResource) GenerateTokenPair(ctx context.Context, userID uuid.UUID, isAdmin bool) (*tokenPair, error) { - tokenPair, err := generateTokenPair(userID.String(), isAdmin, rs.JWTSecret) + tokenPair, err := generateTokenPair(userID.String(), isAdmin, rs.Config.JWTSecret) if err != nil { return nil, err } @@ -442,6 +452,14 @@ func (rs authResource) ValidateRefreshToken(ctx context.Context, token string) ( return &dbToken, nil } +// Handler for returning the CSRF token in the `X-CSRF-Token` header. Notably this request doesn't +// need to contain a valid CSRF token as it uses "safe" (non-mutating) method GET, which means CSRF +// won't be enforced. +func (rs authResource) GetCSRFToken(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-CSRF-Token", csrf.Token(r)) + w.WriteHeader(http.StatusNoContent) +} + // Handler for performing a token rotation, i.e. invalidating the given refresh token (each refresh // token is a single use utility) and exchanging it for a new pair of refresh and access tokens. func (rs authResource) RefreshAccessToken(w http.ResponseWriter, r *http.Request) { @@ -486,12 +504,12 @@ func (rs authResource) RefreshAccessToken(w http.ResponseWriter, r *http.Request // Set refresh token into a httpOnly cookie http.SetCookie(w, &http.Cookie{ - Name: "refresh_token", + Name: "notatest.refresh_token", Value: tokenPair.RefreshToken, - Path: "/", + Path: "/api/auth/cookie", MaxAge: int(refreshTokenDuration.Seconds()), HttpOnly: true, - Secure: true, + Secure: rs.Config.IsProd, SameSite: http.SameSiteStrictMode, }) @@ -520,12 +538,12 @@ func (rs authResource) Logout(w http.ResponseWriter, r *http.Request) { // Clear the refresh token cookie http.SetCookie(w, &http.Cookie{ - Name: "refresh_token", + Name: "notatest.refresh_token", Value: "", - Path: "/", + Path: "/api/auth/cookie", MaxAge: 0, // Expires immediately HttpOnly: true, - Secure: true, + Secure: rs.Config.IsProd, SameSite: http.SameSiteStrictMode, }) @@ -537,6 +555,30 @@ func (rs authResource) Logout(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNoContent) } +// Parse JWT claims (`userClaims`) from the request's context, and perform a database lookup based +// on `Subject` (after parsing it to `uuid.UUID`) to fetch the corresponding user's data. +func (rs authResource) userFromCtxClaims(w http.ResponseWriter, r *http.Request) *data.User { + claims, ok := r.Context().Value(userCtxKey{}).(*userClaims) + if !ok { + respondError(w, http.StatusUnauthorized, "Unauthorized") + return nil + } + + userID, err := uuid.Parse(claims.Subject) + if err != nil { + respondError(w, http.StatusBadRequest, "Invalid user ID") + return nil + } + + user, err := rs.Users.GetUserByID(r.Context(), userID) + if err != nil { + respondError(w, http.StatusNotFound, "User not found") + return nil + } + + return &user +} + // Helper function for generating the initial administrator level account if one doesn't already // exists in the database. func CreateAdminIfNotExists(ctx context.Context, q *data.Queries, username, password string) error { @@ -590,7 +632,7 @@ func getTokenFromHeader(r *http.Request) (string, error) { // Parse the JWT token from the request's cookies (httpOnly cookie). func getTokenFromCookie(r *http.Request) (string, error) { - cookie, err := r.Cookie("refresh_token") + cookie, err := r.Cookie("notatest.refresh_token") if err != nil { return "", err } @@ -636,30 +678,6 @@ func generateTokenPair(userID string, isAdmin bool, jwtSecret string) (*tokenPai return &tokenPair{AccessToken: t, RefreshToken: rt}, nil } -// Parse JWT claims (`userClaims`) from the request's context, and perform a database lookup based -// on `Subject` (after parsing it to `uuid.UUID`) to fetch the corresponding user's data. -func (rs authResource) userFromCtxClaims(w http.ResponseWriter, r *http.Request) *data.User { - claims, ok := r.Context().Value(userCtxKey{}).(*userClaims) - if !ok { - respondError(w, http.StatusUnauthorized, "Unauthorized") - return nil - } - - userID, err := uuid.Parse(claims.Subject) - if err != nil { - respondError(w, http.StatusBadRequest, "Invalid user ID") - return nil - } - - user, err := rs.Users.GetUserByID(r.Context(), userID) - if err != nil { - respondError(w, http.StatusNotFound, "User not found") - return nil - } - - return &user -} - // Check if the given error is a PostgreSQL error for `unique_violation` (error code 23505), i.e. // whether an entry with the given details already exists in the database table. func isDuplicateEntry(err error) bool { diff --git a/server/internal/service/middleware.go b/server/internal/service/middleware.go index de6c177..6db7ba9 100644 --- a/server/internal/service/middleware.go +++ b/server/internal/service/middleware.go @@ -210,7 +210,7 @@ func loggerMiddleware(log *zerolog.Logger) func(http.Handler) http.Handler { } // Log a regular HTTP request with some metadata - log.Info(). + log.Debug(). Str("type", "access"). Timestamp(). Fields(map[string]any{ diff --git a/server/internal/service/middleware_test.go b/server/internal/service/middleware_test.go index 616ba3c..97d01a0 100644 --- a/server/internal/service/middleware_test.go +++ b/server/internal/service/middleware_test.go @@ -110,7 +110,7 @@ func TestRequireRTMiddleware(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) if tc.token != "" { req.AddCookie(&http.Cookie{ - Name: "refresh_token", + Name: "notatest.refresh_token", Value: tc.token, HttpOnly: true, }) diff --git a/server/internal/service/notes.go b/server/internal/service/notes.go index d1bb330..292c810 100644 --- a/server/internal/service/notes.go +++ b/server/internal/service/notes.go @@ -38,15 +38,15 @@ type NoteStore interface { // Chi HTTP router for notes related CRUD actions. type notesResource struct { - JWTSecret string - Notes NoteStore + Config SvcConfig + Notes NoteStore } func (rs notesResource) Routes() chi.Router { r := chi.NewRouter() r.Group(func(r chi.Router) { - r.Use(requireAccessToken(rs.JWTSecret)) // JWT claims -> ctx. + r.Use(requireAccessToken(rs.Config.JWTSecret)) // JWT claims -> ctx. r.Post("/", rs.Create) // POST /notes - create new note r.Get("/", rs.ListMetadata) // GET /notes - get all notes (metadata + titles) diff --git a/server/internal/service/service.go b/server/internal/service/service.go index baa804c..119bc7b 100644 --- a/server/internal/service/service.go +++ b/server/internal/service/service.go @@ -6,27 +6,61 @@ import ( "git.umbrella.haus/ae/notatest/internal/data" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" + "github.com/go-chi/cors" "github.com/jackc/pgx/v5" "github.com/rs/zerolog/log" ) -func Run(conn *pgx.Conn, q *data.Queries, jwtSecret string) error { +type SvcConfig struct { + JWTSecret string + CSRFSecret string + IsProd bool + Domain string + FrontendURL string +} + +func (sc *SvcConfig) allowedOrigins() []string { + var allowed []string + if sc.IsProd { + allowed = []string{sc.FrontendURL} + } else { + allowed = []string{"http://localhost:3000"} + } + + log.Debug().Msgf("CORS allowedOrigins: %v", allowed) + + return allowed +} + +func Run(conn *pgx.Conn, q *data.Queries, config SvcConfig) error { r := chi.NewRouter() + if !config.IsProd { + log.Warn().Msg("Running in *INSECURE* development mode") + } + authRouter := authResource{ - JWTSecret: jwtSecret, - Users: q, - Tokens: q, + Config: config, + Users: q, + Tokens: q, } notesRouter := notesResource{ - JWTSecret: jwtSecret, - Notes: q, + Config: config, + Notes: q, } // Global middlewares r.Use(middleware.RequestID) r.Use(middleware.RealIP) r.Use(loggerMiddleware(&log.Logger)) + r.Use(cors.Handler(cors.Options{ + AllowedOrigins: config.allowedOrigins(), + AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, + AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"}, + ExposedHeaders: []string{"List"}, + AllowCredentials: true, + MaxAge: 300, + })) r.Use(middleware.Recoverer) r.Use(middleware.AllowContentType("application/json")) @@ -35,8 +69,15 @@ func Run(conn *pgx.Conn, q *data.Queries, jwtSecret string) error { r.Route("/api", func(r chi.Router) { r.Mount("/auth", authRouter.Routes()) r.Mount("/notes", notesRouter.Routes()) + r.Get("/ping", ping) }) log.Info().Msg("Starting server on :8080") return http.ListenAndServe(":8080", r) } + +func ping(w http.ResponseWriter, r *http.Request) { + respondJSON(w, http.StatusOK, map[string]string{ + "message": "pong", + }) +} diff --git a/server/main.go b/server/main.go index 5c8af8d..b26fba9 100644 --- a/server/main.go +++ b/server/main.go @@ -21,15 +21,20 @@ import ( var migrationsFS embed.FS var ( - config Config + config Config + svcConfig service.SvcConfig ) type Config struct { JWTSecret string `env:"JWT_SECRET,notEmpty"` + CSRFSecret string `env:"CSRF_SECRET,notEmpty"` DatabaseURL string `env:"DB_URL,notEmpty"` - LogLevel string `env:"LOG_LEVEL" envDefault:"info"` AdminUsername string `env:"ADMIN_USERNAME,notEmpty,unset"` AdminPassword string `env:"ADMIN_PASSWORD,notEmpty,unset"` + Domain string `env:"DOMAIN" envDefault:"localhost"` + FrontendURL string `env:"FRONTEND_URL" envDefault:"http://localhost:5173"` + LogLevel string `env:"LOG_LEVEL" envDefault:"info"` + AppEnv string `env:"APP_ENV" envDefault:"production"` } func init() { @@ -38,6 +43,15 @@ func init() { log.Fatal().Err(err).Msg("Failed to parse environment variables") } initLogger() + + svcConfig = service.SvcConfig{ + JWTSecret: config.JWTSecret, + CSRFSecret: config.CSRFSecret, + IsProd: config.AppEnv == "production", + Domain: config.Domain, + FrontendURL: config.FrontendURL, + } + log.Debug().Msg("Initialization completed") } @@ -71,7 +85,7 @@ func main() { } log.Info().Msg("Migrations applied succesfully, proceeding to HTTP server startup") - service.Run(conn, q, config.JWTSecret) + service.Run(conn, q, svcConfig) } func initLogger() {