Compare commits
No commits in common. "b393f1a47c285c98476d87b6a3456a1189639ea1" and "998176c3f920053d169c98734f3c9676f7894f94" have entirely different histories.
b393f1a47c
...
998176c3f9
@ -11,7 +11,6 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"git.umbrella.haus/ae/notatest/pkg/data"
|
"git.umbrella.haus/ae/notatest/pkg/data"
|
||||||
"github.com/go-chi/chi/v5"
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@ -362,148 +361,4 @@ func TestNotes_CreateNoteVersion(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNotes_ListNoteVersions(t *testing.T) {
|
// TODO: add similar tests for `ListNoteVersions` and `GetNoteVersion`
|
||||||
noteID := uuid.New()
|
|
||||||
versions := []data.NoteVersion{
|
|
||||||
{ID: uuid.New(), NoteID: noteID, VersionNumber: 1},
|
|
||||||
{ID: uuid.New(), NoteID: noteID, VersionNumber: 2},
|
|
||||||
}
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
query string
|
|
||||||
mock func(*mockNoteStore)
|
|
||||||
statusCode int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
"success",
|
|
||||||
"",
|
|
||||||
func(m *mockNoteStore) {
|
|
||||||
m.GetNoteVersionsFunc = func(_ context.Context, arg data.GetNoteVersionsParams) ([]data.NoteVersion, error) {
|
|
||||||
assert.Equal(t, noteID, arg.NoteID)
|
|
||||||
return versions, nil
|
|
||||||
}
|
|
||||||
},
|
|
||||||
http.StatusOK,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"with pagination",
|
|
||||||
"?limit=5&offset=10",
|
|
||||||
func(m *mockNoteStore) {
|
|
||||||
m.GetNoteVersionsFunc = func(_ context.Context, arg data.GetNoteVersionsParams) ([]data.NoteVersion, error) {
|
|
||||||
assert.EqualValues(t, 5, arg.Limit)
|
|
||||||
assert.EqualValues(t, 10, arg.Offset)
|
|
||||||
return versions, nil
|
|
||||||
}
|
|
||||||
},
|
|
||||||
http.StatusOK,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"database error",
|
|
||||||
"",
|
|
||||||
func(m *mockNoteStore) {
|
|
||||||
m.GetNoteVersionsFunc = func(context.Context, data.GetNoteVersionsParams) ([]data.NoteVersion, error) {
|
|
||||||
return nil, errors.New("db error")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
http.StatusInternalServerError,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
mockStore := &mockNoteStore{}
|
|
||||||
tc.mock(mockStore)
|
|
||||||
|
|
||||||
rs := notesResource{Notes: mockStore}
|
|
||||||
req := httptest.NewRequest("GET", fmt.Sprintf("/versions/%s", tc.query), nil)
|
|
||||||
req = req.WithContext(context.WithValue(
|
|
||||||
req.Context(),
|
|
||||||
noteCtxKey{},
|
|
||||||
data.Note{ID: noteID},
|
|
||||||
))
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
rs.ListNoteVersions(w, req)
|
|
||||||
|
|
||||||
assert.Equal(t, tc.statusCode, w.Code)
|
|
||||||
if tc.statusCode == http.StatusOK {
|
|
||||||
var result []data.NoteVersion
|
|
||||||
json.Unmarshal(w.Body.Bytes(), &result)
|
|
||||||
assert.Len(t, result, 2)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNotes_GetNoteVersion(t *testing.T) {
|
|
||||||
noteID := uuid.New()
|
|
||||||
version := data.NoteVersion{ID: uuid.New(), NoteID: noteID, VersionNumber: 1}
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
version string
|
|
||||||
mock func(*mockNoteStore)
|
|
||||||
statusCode int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
"invalid version",
|
|
||||||
"invalid",
|
|
||||||
func(m *mockNoteStore) {},
|
|
||||||
http.StatusBadRequest,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"version not found",
|
|
||||||
"1",
|
|
||||||
func(m *mockNoteStore) {
|
|
||||||
m.GetNoteVersionFunc = func(context.Context, data.GetNoteVersionParams) (data.NoteVersion, error) {
|
|
||||||
return data.NoteVersion{}, errors.New("not found")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
http.StatusNotFound,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"success",
|
|
||||||
"1",
|
|
||||||
func(m *mockNoteStore) {
|
|
||||||
m.GetNoteVersionFunc = func(_ context.Context, arg data.GetNoteVersionParams) (data.NoteVersion, error) {
|
|
||||||
assert.Equal(t, noteID, arg.NoteID)
|
|
||||||
assert.EqualValues(t, 1, arg.VersionNumber)
|
|
||||||
return version, nil
|
|
||||||
}
|
|
||||||
},
|
|
||||||
http.StatusOK,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
mockStore := &mockNoteStore{}
|
|
||||||
tc.mock(mockStore)
|
|
||||||
|
|
||||||
rs := notesResource{Notes: mockStore}
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", fmt.Sprintf("/versions/%s", tc.version), nil)
|
|
||||||
|
|
||||||
// Chi router context mocks ID (passed in a URL param.) and the note object (passed in req. ctx.)
|
|
||||||
rctx := chi.NewRouteContext()
|
|
||||||
rctx.URLParams.Add("version", tc.version)
|
|
||||||
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
|
|
||||||
req = req.WithContext(context.WithValue(
|
|
||||||
req.Context(),
|
|
||||||
noteCtxKey{},
|
|
||||||
data.Note{ID: noteID},
|
|
||||||
))
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
rs.GetNoteVersion(w, req)
|
|
||||||
|
|
||||||
assert.Equal(t, tc.statusCode, w.Code)
|
|
||||||
if tc.statusCode == http.StatusOK {
|
|
||||||
var result data.NoteVersion
|
|
||||||
json.Unmarshal(w.Body.Bytes(), &result)
|
|
||||||
assert.Equal(t, version.VersionNumber, result.VersionNumber)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -53,13 +53,6 @@ type tokensResource struct {
|
|||||||
func (rs tokensResource) Routes() chi.Router {
|
func (rs tokensResource) Routes() chi.Router {
|
||||||
r := chi.NewRouter()
|
r := chi.NewRouter()
|
||||||
|
|
||||||
// Protected routes (access token required)
|
|
||||||
r.Group(func(r chi.Router) {
|
|
||||||
r.Use(requireAccessToken(rs.JWTSecret))
|
|
||||||
r.Post("/logout", rs.HandleLogout) // POST /auth/logout - revoke all refresh cookies
|
|
||||||
})
|
|
||||||
|
|
||||||
// Protected routes (refresh token required)
|
|
||||||
r.Group(func(r chi.Router) {
|
r.Group(func(r chi.Router) {
|
||||||
r.Use(requireRefreshToken(rs.JWTSecret))
|
r.Use(requireRefreshToken(rs.JWTSecret))
|
||||||
r.Post("/refresh", rs.RefreshAccessToken) // POST /auth/refresh - convert refresh token to new token pair
|
r.Post("/refresh", rs.RefreshAccessToken) // POST /auth/refresh - convert refresh token to new token pair
|
||||||
@ -182,23 +175,12 @@ func (rs tokensResource) HandleLogout(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clear the refresh token cookie
|
|
||||||
http.SetCookie(w, &http.Cookie{
|
|
||||||
Name: "refresh_token",
|
|
||||||
Value: "",
|
|
||||||
Path: "/",
|
|
||||||
MaxAge: -1, // Expires immediately
|
|
||||||
HttpOnly: true,
|
|
||||||
Secure: true,
|
|
||||||
SameSite: http.SameSiteStrictMode,
|
|
||||||
})
|
|
||||||
|
|
||||||
if err := rs.Tokens.RevokeAllUserRefreshTokens(r.Context(), userID); err != nil {
|
if err := rs.Tokens.RevokeAllUserRefreshTokens(r.Context(), userID); err != nil {
|
||||||
respondError(w, http.StatusInternalServerError, "Failed to logout")
|
respondError(w, http.StatusInternalServerError, "Failed to logout")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
w.WriteHeader(http.StatusNoContent)
|
respondJSON(w, http.StatusOK, map[string]string{"status": "logged out"})
|
||||||
}
|
}
|
||||||
|
|
||||||
func getTokenFromRequest(r *http.Request) (string, error) {
|
func getTokenFromRequest(r *http.Request) (string, error) {
|
||||||
|
@ -145,24 +145,7 @@ func TestRefreshAccessToken_Success(t *testing.T) {
|
|||||||
|
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
assert.Contains(t, w.Body.String(), "access_token")
|
assert.Contains(t, w.Body.String(), "access_token")
|
||||||
|
assert.Contains(t, w.Result().Cookies()[0].Name, "refresh_token")
|
||||||
cookies := w.Result().Cookies()
|
|
||||||
var refreshCookie *http.Cookie
|
|
||||||
for _, cookie := range cookies {
|
|
||||||
if cookie.Name == "refresh_token" {
|
|
||||||
refreshCookie = cookie
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if refreshCookie == nil {
|
|
||||||
t.Fatal("refresh token cookie not set")
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.True(t, refreshCookie.HttpOnly, "cookie should be HttpOnly")
|
|
||||||
assert.Equal(t, http.SameSiteStrictMode, refreshCookie.SameSite, "invalid SameSite mode")
|
|
||||||
assert.Equal(t, "/", refreshCookie.Path, "invalid cookie path")
|
|
||||||
assert.Greater(t, refreshCookie.MaxAge, 0, "cookie should have expiration")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandleLogout_Success(t *testing.T) {
|
func TestHandleLogout_Success(t *testing.T) {
|
||||||
@ -194,18 +177,5 @@ func TestHandleLogout_Success(t *testing.T) {
|
|||||||
rs.HandleLogout(w, req)
|
rs.HandleLogout(w, req)
|
||||||
|
|
||||||
assert.True(t, called)
|
assert.True(t, called)
|
||||||
assert.Equal(t, http.StatusNoContent, w.Code) // 204
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
cookies := w.Result().Cookies()
|
|
||||||
var refreshCookie *http.Cookie
|
|
||||||
for _, cookie := range cookies {
|
|
||||||
if cookie.Name == "refresh_token" {
|
|
||||||
refreshCookie = cookie
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if refreshCookie != nil && refreshCookie.MaxAge != -1 {
|
|
||||||
t.Fatal("refresh token cookie not invalidated")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
@ -12,7 +11,6 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"git.umbrella.haus/ae/notatest/pkg/data"
|
"git.umbrella.haus/ae/notatest/pkg/data"
|
||||||
"github.com/golang-jwt/jwt/v5"
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/jackc/pgx/v5/pgconn"
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@ -282,138 +280,3 @@ func TestUpdatePassword_DatabaseError(t *testing.T) {
|
|||||||
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
||||||
assert.Contains(t, w.Body.String(), "Failed to update password")
|
assert.Contains(t, w.Body.String(), "Failed to update password")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUsersLogin(t *testing.T) {
|
|
||||||
validPassword := "validPass123!"
|
|
||||||
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(validPassword), bcrypt.DefaultCost)
|
|
||||||
testUser := data.User{
|
|
||||||
ID: uuid.New(),
|
|
||||||
Username: "test_username",
|
|
||||||
PasswordHash: string(hashedPassword),
|
|
||||||
}
|
|
||||||
jwtSecret := "test-secret"
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
requestBody interface{}
|
|
||||||
mockSetup func(*mockUserStore)
|
|
||||||
wantStatus int
|
|
||||||
wantResponse string
|
|
||||||
checkCookie bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "invalid request body",
|
|
||||||
requestBody: "invalid",
|
|
||||||
mockSetup: func(m *mockUserStore) {},
|
|
||||||
wantStatus: http.StatusBadRequest,
|
|
||||||
wantResponse: `{"error":"Invalid request body"}`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "user not found",
|
|
||||||
requestBody: map[string]string{
|
|
||||||
"username": "nouser",
|
|
||||||
"password": validPassword,
|
|
||||||
},
|
|
||||||
mockSetup: func(m *mockUserStore) {
|
|
||||||
m.GetUserByUsernameFunc = func(_ context.Context, username string) (data.User, error) {
|
|
||||||
return data.User{}, errors.New("not found")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
wantStatus: http.StatusUnauthorized,
|
|
||||||
wantResponse: `{"error":"Invalid credentials"}`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid password",
|
|
||||||
requestBody: map[string]string{
|
|
||||||
"username": testUser.Username,
|
|
||||||
"password": "wrongpassword",
|
|
||||||
},
|
|
||||||
mockSetup: func(m *mockUserStore) {
|
|
||||||
m.GetUserByUsernameFunc = func(_ context.Context, username string) (data.User, error) {
|
|
||||||
return testUser, nil
|
|
||||||
}
|
|
||||||
},
|
|
||||||
wantStatus: http.StatusUnauthorized,
|
|
||||||
wantResponse: `{"error":"Invalid credentials"}`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "successful login",
|
|
||||||
requestBody: map[string]string{
|
|
||||||
"username": testUser.Username,
|
|
||||||
"password": validPassword,
|
|
||||||
},
|
|
||||||
mockSetup: func(m *mockUserStore) {
|
|
||||||
m.GetUserByUsernameFunc = func(_ context.Context, username string) (data.User, error) {
|
|
||||||
return testUser, nil
|
|
||||||
}
|
|
||||||
},
|
|
||||||
wantStatus: http.StatusOK,
|
|
||||||
checkCookie: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
mockStore := &mockUserStore{}
|
|
||||||
tt.mockSetup(mockStore)
|
|
||||||
|
|
||||||
rs := usersResource{
|
|
||||||
Users: mockStore,
|
|
||||||
JWTSecret: jwtSecret,
|
|
||||||
}
|
|
||||||
|
|
||||||
body, _ := json.Marshal(tt.requestBody)
|
|
||||||
req := httptest.NewRequest("POST", "/login", bytes.NewReader(body))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
rs.Login(w, req)
|
|
||||||
|
|
||||||
if w.Code != tt.wantStatus {
|
|
||||||
t.Errorf("expected status %d, got %d", tt.wantStatus, w.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
if tt.wantResponse != "" && strings.TrimSpace(w.Body.String()) != tt.wantResponse {
|
|
||||||
t.Errorf("expected response %q, got %q", tt.wantResponse, w.Body.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
if tt.checkCookie {
|
|
||||||
cookies := w.Result().Cookies()
|
|
||||||
var refreshCookie *http.Cookie
|
|
||||||
for _, cookie := range cookies {
|
|
||||||
if cookie.Name == "refresh_token" {
|
|
||||||
refreshCookie = cookie
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if refreshCookie == nil {
|
|
||||||
t.Fatal("refresh token cookie not set")
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.True(t, refreshCookie.HttpOnly, "cookie should be HttpOnly")
|
|
||||||
assert.Equal(t, http.SameSiteStrictMode, refreshCookie.SameSite, "invalid SameSite mode")
|
|
||||||
assert.Equal(t, "/", refreshCookie.Path, "invalid cookie path")
|
|
||||||
assert.Greater(t, refreshCookie.MaxAge, 0, "cookie should have expiration")
|
|
||||||
|
|
||||||
// Validate access token in response
|
|
||||||
var response map[string]string
|
|
||||||
json.Unmarshal(w.Body.Bytes(), &response)
|
|
||||||
if response["access_token"] == "" {
|
|
||||||
t.Error("access token not in response")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify JWT validity
|
|
||||||
token, err := jwt.ParseWithClaims(
|
|
||||||
response["access_token"],
|
|
||||||
&userClaims{},
|
|
||||||
func(token *jwt.Token) (interface{}, error) {
|
|
||||||
return []byte(jwtSecret), nil
|
|
||||||
},
|
|
||||||
)
|
|
||||||
assert.NoError(t, err, "invalid JWT")
|
|
||||||
assert.True(t, token.Valid, "invalid JWT")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user