From b393f1a47c285c98476d87b6a3456a1189639ea1 Mon Sep 17 00:00:00 2001 From: ae Date: Wed, 2 Apr 2025 12:44:59 +0300 Subject: [PATCH] test: user login, rt cookies, & note versioning --- server/pkg/service/notes_test.go | 147 +++++++++++++++++++++++++++++- server/pkg/service/tokens_test.go | 34 ++++++- server/pkg/service/users_test.go | 137 ++++++++++++++++++++++++++++ 3 files changed, 315 insertions(+), 3 deletions(-) diff --git a/server/pkg/service/notes_test.go b/server/pkg/service/notes_test.go index 4fb8b57..8a97d17 100644 --- a/server/pkg/service/notes_test.go +++ b/server/pkg/service/notes_test.go @@ -11,6 +11,7 @@ import ( "testing" "git.umbrella.haus/ae/notatest/pkg/data" + "github.com/go-chi/chi/v5" "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" "github.com/stretchr/testify/assert" @@ -361,4 +362,148 @@ func TestNotes_CreateNoteVersion(t *testing.T) { } } -// TODO: add similar tests for `ListNoteVersions` and `GetNoteVersion` +func TestNotes_ListNoteVersions(t *testing.T) { + noteID := uuid.New() + versions := []data.NoteVersion{ + {ID: uuid.New(), NoteID: noteID, VersionNumber: 1}, + {ID: uuid.New(), NoteID: noteID, VersionNumber: 2}, + } + + tests := []struct { + name string + query string + mock func(*mockNoteStore) + statusCode int + }{ + { + "success", + "", + func(m *mockNoteStore) { + m.GetNoteVersionsFunc = func(_ context.Context, arg data.GetNoteVersionsParams) ([]data.NoteVersion, error) { + assert.Equal(t, noteID, arg.NoteID) + return versions, nil + } + }, + http.StatusOK, + }, + { + "with pagination", + "?limit=5&offset=10", + func(m *mockNoteStore) { + m.GetNoteVersionsFunc = func(_ context.Context, arg data.GetNoteVersionsParams) ([]data.NoteVersion, error) { + assert.EqualValues(t, 5, arg.Limit) + assert.EqualValues(t, 10, arg.Offset) + return versions, nil + } + }, + http.StatusOK, + }, + { + "database error", + "", + func(m *mockNoteStore) { + m.GetNoteVersionsFunc = func(context.Context, data.GetNoteVersionsParams) ([]data.NoteVersion, error) { + return nil, errors.New("db error") + } + }, + http.StatusInternalServerError, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + mockStore := &mockNoteStore{} + tc.mock(mockStore) + + rs := notesResource{Notes: mockStore} + req := httptest.NewRequest("GET", fmt.Sprintf("/versions/%s", tc.query), nil) + req = req.WithContext(context.WithValue( + req.Context(), + noteCtxKey{}, + data.Note{ID: noteID}, + )) + + w := httptest.NewRecorder() + rs.ListNoteVersions(w, req) + + assert.Equal(t, tc.statusCode, w.Code) + if tc.statusCode == http.StatusOK { + var result []data.NoteVersion + json.Unmarshal(w.Body.Bytes(), &result) + assert.Len(t, result, 2) + } + }) + } +} + +func TestNotes_GetNoteVersion(t *testing.T) { + noteID := uuid.New() + version := data.NoteVersion{ID: uuid.New(), NoteID: noteID, VersionNumber: 1} + + tests := []struct { + name string + version string + mock func(*mockNoteStore) + statusCode int + }{ + { + "invalid version", + "invalid", + func(m *mockNoteStore) {}, + http.StatusBadRequest, + }, + { + "version not found", + "1", + func(m *mockNoteStore) { + m.GetNoteVersionFunc = func(context.Context, data.GetNoteVersionParams) (data.NoteVersion, error) { + return data.NoteVersion{}, errors.New("not found") + } + }, + http.StatusNotFound, + }, + { + "success", + "1", + func(m *mockNoteStore) { + m.GetNoteVersionFunc = func(_ context.Context, arg data.GetNoteVersionParams) (data.NoteVersion, error) { + assert.Equal(t, noteID, arg.NoteID) + assert.EqualValues(t, 1, arg.VersionNumber) + return version, nil + } + }, + http.StatusOK, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + mockStore := &mockNoteStore{} + tc.mock(mockStore) + + rs := notesResource{Notes: mockStore} + + req := httptest.NewRequest("GET", fmt.Sprintf("/versions/%s", tc.version), nil) + + // Chi router context mocks ID (passed in a URL param.) and the note object (passed in req. ctx.) + rctx := chi.NewRouteContext() + rctx.URLParams.Add("version", tc.version) + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + req = req.WithContext(context.WithValue( + req.Context(), + noteCtxKey{}, + data.Note{ID: noteID}, + )) + + w := httptest.NewRecorder() + rs.GetNoteVersion(w, req) + + assert.Equal(t, tc.statusCode, w.Code) + if tc.statusCode == http.StatusOK { + var result data.NoteVersion + json.Unmarshal(w.Body.Bytes(), &result) + assert.Equal(t, version.VersionNumber, result.VersionNumber) + } + }) + } +} diff --git a/server/pkg/service/tokens_test.go b/server/pkg/service/tokens_test.go index d27f42a..c7db4a0 100644 --- a/server/pkg/service/tokens_test.go +++ b/server/pkg/service/tokens_test.go @@ -145,7 +145,24 @@ func TestRefreshAccessToken_Success(t *testing.T) { assert.Equal(t, http.StatusOK, w.Code) assert.Contains(t, w.Body.String(), "access_token") - assert.Contains(t, w.Result().Cookies()[0].Name, "refresh_token") + + cookies := w.Result().Cookies() + var refreshCookie *http.Cookie + for _, cookie := range cookies { + if cookie.Name == "refresh_token" { + refreshCookie = cookie + break + } + } + + if refreshCookie == nil { + t.Fatal("refresh token cookie not set") + } + + assert.True(t, refreshCookie.HttpOnly, "cookie should be HttpOnly") + assert.Equal(t, http.SameSiteStrictMode, refreshCookie.SameSite, "invalid SameSite mode") + assert.Equal(t, "/", refreshCookie.Path, "invalid cookie path") + assert.Greater(t, refreshCookie.MaxAge, 0, "cookie should have expiration") } func TestHandleLogout_Success(t *testing.T) { @@ -177,5 +194,18 @@ func TestHandleLogout_Success(t *testing.T) { rs.HandleLogout(w, req) assert.True(t, called) - assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, http.StatusNoContent, w.Code) // 204 + + cookies := w.Result().Cookies() + var refreshCookie *http.Cookie + for _, cookie := range cookies { + if cookie.Name == "refresh_token" { + refreshCookie = cookie + break + } + } + + if refreshCookie != nil && refreshCookie.MaxAge != -1 { + t.Fatal("refresh token cookie not invalidated") + } } diff --git a/server/pkg/service/users_test.go b/server/pkg/service/users_test.go index 4d99910..e12e9dc 100644 --- a/server/pkg/service/users_test.go +++ b/server/pkg/service/users_test.go @@ -1,6 +1,7 @@ package service import ( + "bytes" "context" "encoding/json" "errors" @@ -11,6 +12,7 @@ import ( "testing" "git.umbrella.haus/ae/notatest/pkg/data" + "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" "github.com/jackc/pgx/v5/pgconn" "github.com/stretchr/testify/assert" @@ -280,3 +282,138 @@ func TestUpdatePassword_DatabaseError(t *testing.T) { assert.Equal(t, http.StatusInternalServerError, w.Code) assert.Contains(t, w.Body.String(), "Failed to update password") } + +func TestUsersLogin(t *testing.T) { + validPassword := "validPass123!" + hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(validPassword), bcrypt.DefaultCost) + testUser := data.User{ + ID: uuid.New(), + Username: "test_username", + PasswordHash: string(hashedPassword), + } + jwtSecret := "test-secret" + + tests := []struct { + name string + requestBody interface{} + mockSetup func(*mockUserStore) + wantStatus int + wantResponse string + checkCookie bool + }{ + { + name: "invalid request body", + requestBody: "invalid", + mockSetup: func(m *mockUserStore) {}, + wantStatus: http.StatusBadRequest, + wantResponse: `{"error":"Invalid request body"}`, + }, + { + name: "user not found", + requestBody: map[string]string{ + "username": "nouser", + "password": validPassword, + }, + mockSetup: func(m *mockUserStore) { + m.GetUserByUsernameFunc = func(_ context.Context, username string) (data.User, error) { + return data.User{}, errors.New("not found") + } + }, + wantStatus: http.StatusUnauthorized, + wantResponse: `{"error":"Invalid credentials"}`, + }, + { + name: "invalid password", + requestBody: map[string]string{ + "username": testUser.Username, + "password": "wrongpassword", + }, + mockSetup: func(m *mockUserStore) { + m.GetUserByUsernameFunc = func(_ context.Context, username string) (data.User, error) { + return testUser, nil + } + }, + wantStatus: http.StatusUnauthorized, + wantResponse: `{"error":"Invalid credentials"}`, + }, + { + name: "successful login", + requestBody: map[string]string{ + "username": testUser.Username, + "password": validPassword, + }, + mockSetup: func(m *mockUserStore) { + m.GetUserByUsernameFunc = func(_ context.Context, username string) (data.User, error) { + return testUser, nil + } + }, + wantStatus: http.StatusOK, + checkCookie: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockStore := &mockUserStore{} + tt.mockSetup(mockStore) + + rs := usersResource{ + Users: mockStore, + JWTSecret: jwtSecret, + } + + body, _ := json.Marshal(tt.requestBody) + req := httptest.NewRequest("POST", "/login", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + rs.Login(w, req) + + if w.Code != tt.wantStatus { + t.Errorf("expected status %d, got %d", tt.wantStatus, w.Code) + } + + if tt.wantResponse != "" && strings.TrimSpace(w.Body.String()) != tt.wantResponse { + t.Errorf("expected response %q, got %q", tt.wantResponse, w.Body.String()) + } + + if tt.checkCookie { + cookies := w.Result().Cookies() + var refreshCookie *http.Cookie + for _, cookie := range cookies { + if cookie.Name == "refresh_token" { + refreshCookie = cookie + break + } + } + + if refreshCookie == nil { + t.Fatal("refresh token cookie not set") + } + + assert.True(t, refreshCookie.HttpOnly, "cookie should be HttpOnly") + assert.Equal(t, http.SameSiteStrictMode, refreshCookie.SameSite, "invalid SameSite mode") + assert.Equal(t, "/", refreshCookie.Path, "invalid cookie path") + assert.Greater(t, refreshCookie.MaxAge, 0, "cookie should have expiration") + + // Validate access token in response + var response map[string]string + json.Unmarshal(w.Body.Bytes(), &response) + if response["access_token"] == "" { + t.Error("access token not in response") + } + + // Verify JWT validity + token, err := jwt.ParseWithClaims( + response["access_token"], + &userClaims{}, + func(token *jwt.Token) (interface{}, error) { + return []byte(jwtSecret), nil + }, + ) + assert.NoError(t, err, "invalid JWT") + assert.True(t, token.Valid, "invalid JWT") + } + }) + } +}