From 998176c3f920053d169c98734f3c9676f7894f94 Mon Sep 17 00:00:00 2001
From: ae <git@golfed.xyz>
Date: Tue, 1 Apr 2025 22:54:39 +0300
Subject: [PATCH] feat: rt as httponly cookie & add login handler

---
 server/main.go                               |  3 +-
 server/pkg/service/{server.go => service.go} |  8 ++--
 server/pkg/service/tokens.go                 | 16 ++++++-
 server/pkg/service/tokens_test.go            |  3 +-
 server/pkg/service/users.go                  | 50 +++++++++++++++++++-
 server/pkg/service/users_test.go             |  5 ++
 6 files changed, 75 insertions(+), 10 deletions(-)
 rename server/pkg/service/{server.go => service.go} (79%)

diff --git a/server/main.go b/server/main.go
index 96d7798..8a943fd 100644
--- a/server/main.go
+++ b/server/main.go
@@ -23,7 +23,6 @@ var (
 
 type Config struct {
 	JWTSecret string `env:"JWT_SECRET,notEmpty"`
-	HTTPPort  string `env:"HTTP_PORT" envDefault:"8080"`
 	DBURL     string `env:"PG_URL,notEmpty"`
 	RunMode   string `env:"GO_ENV" envDefault:"production"`
 }
@@ -55,7 +54,7 @@ func main() {
 		}
 	}
 
-	service.Run(conn, config.JWTSecret, config.HTTPPort)
+	service.Run(conn, config.JWTSecret)
 }
 
 func initLogger() {
diff --git a/server/pkg/service/server.go b/server/pkg/service/service.go
similarity index 79%
rename from server/pkg/service/server.go
rename to server/pkg/service/service.go
index 011a6ad..78223b6 100644
--- a/server/pkg/service/server.go
+++ b/server/pkg/service/service.go
@@ -1,7 +1,6 @@
 package service
 
 import (
-	"fmt"
 	"net/http"
 
 	"git.umbrella.haus/ae/notatest/pkg/data"
@@ -11,7 +10,7 @@ import (
 	"github.com/rs/zerolog/log"
 )
 
-func Run(conn *pgx.Conn, jwtSecret string, httpPort string) error {
+func Run(conn *pgx.Conn, jwtSecret string) error {
 	q := data.New(conn)
 	r := chi.NewRouter()
 
@@ -37,7 +36,6 @@ func Run(conn *pgx.Conn, jwtSecret string, httpPort string) error {
 	r.Mount("/users", usersRouter.Routes())
 	r.Mount("/notes", notesRouter.Routes())
 
-	portStr := fmt.Sprintf(":%s", httpPort)
-	log.Info().Msgf("Starting server on %s", portStr)
-	return http.ListenAndServe(portStr, r)
+	log.Info().Msg("Starting server on :8080")
+	return http.ListenAndServe(":8080", r)
 }
diff --git a/server/pkg/service/tokens.go b/server/pkg/service/tokens.go
index 03d0c90..8316b80 100644
--- a/server/pkg/service/tokens.go
+++ b/server/pkg/service/tokens.go
@@ -145,7 +145,21 @@ func (rs tokensResource) RefreshAccessToken(w http.ResponseWriter, r *http.Reque
 		return
 	}
 
-	respondJSON(w, http.StatusOK, tokenPair)
+	// Set refresh token in HTTP-only cookie
+	http.SetCookie(w, &http.Cookie{
+		Name:     "refresh_token",
+		Value:    tokenPair.RefreshToken,
+		Path:     "/",
+		MaxAge:   int(refreshTokenDuration.Seconds()),
+		HttpOnly: true,
+		Secure:   true,
+		SameSite: http.SameSiteStrictMode,
+	})
+
+	// Return the access token in the response body (it should be stored in browser's memory client-side)
+	respondJSON(w, http.StatusOK, map[string]string{
+		"access_token": tokenPair.AccessToken,
+	})
 }
 
 func (rs tokensResource) HandleLogout(w http.ResponseWriter, r *http.Request) {
diff --git a/server/pkg/service/tokens_test.go b/server/pkg/service/tokens_test.go
index f1b546d..d27f42a 100644
--- a/server/pkg/service/tokens_test.go
+++ b/server/pkg/service/tokens_test.go
@@ -144,7 +144,8 @@ func TestRefreshAccessToken_Success(t *testing.T) {
 	rs.RefreshAccessToken(w, req)
 
 	assert.Equal(t, http.StatusOK, w.Code)
-	assert.Contains(t, w.Body.String(), "access_token", "refresh_token")
+	assert.Contains(t, w.Body.String(), "access_token")
+	assert.Contains(t, w.Result().Cookies()[0].Name, "refresh_token")
 }
 
 func TestHandleLogout_Success(t *testing.T) {
diff --git a/server/pkg/service/users.go b/server/pkg/service/users.go
index 914cbbe..1a178b7 100644
--- a/server/pkg/service/users.go
+++ b/server/pkg/service/users.go
@@ -21,6 +21,7 @@ type UserStore interface {
 	CreateUser(ctx context.Context, arg data.CreateUserParams) (data.User, error)
 	ListUsers(ctx context.Context) ([]data.User, error)
 	GetUserByID(ctx context.Context, id uuid.UUID) (data.User, error)
+	GetUserByUsername(ctx context.Context, username string) (data.User, error)
 	UpdatePassword(ctx context.Context, arg data.UpdatePasswordParams) error
 	DeleteUser(ctx context.Context, id uuid.UUID) error
 	RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error
@@ -35,7 +36,8 @@ func (rs usersResource) Routes() chi.Router {
 	r := chi.NewRouter()
 
 	// Public routes (no tokens required)
-	r.Post("/", rs.Create) // POST /users - registration/signup
+	r.Post("/", rs.Create)     // POST /users - registration/signup
+	r.Post("/login", rs.Login) // POST /users/login - login as existing user
 
 	// Protected routes (access token required)
 	r.Group(func(r chi.Router) {
@@ -118,6 +120,52 @@ func (rs usersResource) Create(w http.ResponseWriter, r *http.Request) {
 	})
 }
 
+func (rs usersResource) Login(w http.ResponseWriter, r *http.Request) {
+	type request struct {
+		Username string `json:"username"`
+		Password string `json:"password"`
+	}
+
+	var req request
+	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+		respondError(w, http.StatusBadRequest, "Invalid request body")
+		return
+	}
+
+	user, err := rs.Users.GetUserByUsername(r.Context(), normalizeUsername(req.Username))
+	if err != nil {
+		respondError(w, http.StatusUnauthorized, "Invalid credentials")
+		return
+	}
+
+	if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password)); err != nil {
+		respondError(w, http.StatusUnauthorized, "Invalid credentials")
+		return
+	}
+
+	tokenPair, err := generateTokenPair(user.ID.String(), user.IsAdmin, rs.JWTSecret)
+	if err != nil {
+		respondError(w, http.StatusInternalServerError, "Failed to generate tokens")
+		return
+	}
+
+	// Set refresh token in HTTP-only cookie
+	http.SetCookie(w, &http.Cookie{
+		Name:     "refresh_token",
+		Value:    tokenPair.RefreshToken,
+		Path:     "/",
+		MaxAge:   int(refreshTokenDuration.Seconds()),
+		HttpOnly: true,
+		Secure:   true,
+		SameSite: http.SameSiteStrictMode,
+	})
+
+	// Return the access token in the response body (it should be stored in browser's memory client-side)
+	respondJSON(w, http.StatusOK, map[string]string{
+		"access_token": tokenPair.AccessToken,
+	})
+}
+
 func (rs usersResource) List(w http.ResponseWriter, r *http.Request) {
 	users, err := rs.Users.ListUsers(r.Context())
 	if err != nil {
diff --git a/server/pkg/service/users_test.go b/server/pkg/service/users_test.go
index 148313f..4d99910 100644
--- a/server/pkg/service/users_test.go
+++ b/server/pkg/service/users_test.go
@@ -21,6 +21,7 @@ type mockUserStore struct {
 	CreateUserFunc                 func(context.Context, data.CreateUserParams) (data.User, error)
 	ListUsersFunc                  func(context.Context) ([]data.User, error)
 	GetUserByIDFunc                func(context.Context, uuid.UUID) (data.User, error)
+	GetUserByUsernameFunc          func(context.Context, string) (data.User, error)
 	UpdatePasswordFunc             func(context.Context, data.UpdatePasswordParams) error
 	DeleteUserFunc                 func(context.Context, uuid.UUID) error
 	RevokeAllUserRefreshTokensFunc func(context.Context, uuid.UUID) error
@@ -38,6 +39,10 @@ func (m *mockUserStore) GetUserByID(ctx context.Context, id uuid.UUID) (data.Use
 	return m.GetUserByIDFunc(ctx, id)
 }
 
+func (m *mockUserStore) GetUserByUsername(ctx context.Context, username string) (data.User, error) {
+	return m.GetUserByUsernameFunc(ctx, username)
+}
+
 func (m *mockUserStore) UpdatePassword(ctx context.Context, arg data.UpdatePasswordParams) error {
 	return m.UpdatePasswordFunc(ctx, arg)
 }