package service import ( "context" "errors" "fmt" "net/http" "net/http/httptest" "testing" "time" "git.umbrella.haus/ae/notatest/internal/data" "github.com/go-chi/chi/v5" "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" "github.com/stretchr/testify/assert" ) type mockNoteStore struct { CreateNoteFunc func(context.Context, uuid.UUID) (data.Note, error) DeleteNoteFunc func(context.Context, data.DeleteNoteParams) error GetFullNoteFunc func(context.Context, uuid.UUID) (data.GetFullNoteRow, error) ListNotesFunc func(context.Context, data.ListNotesParams) ([]data.ListNotesRow, error) CreateNoteVersionFunc func(context.Context, data.CreateNoteVersionParams) error GetVersionFunc func(context.Context, data.GetVersionParams) (data.GetVersionRow, error) GetVersionHistoryFunc func(context.Context, data.GetVersionHistoryParams) ([]data.GetVersionHistoryRow, error) } func (m *mockNoteStore) CreateNote(ctx context.Context, id uuid.UUID) (data.Note, error) { return m.CreateNoteFunc(ctx, id) } func (m *mockNoteStore) DeleteNote(ctx context.Context, arg data.DeleteNoteParams) error { return m.DeleteNoteFunc(ctx, arg) } func (m *mockNoteStore) GetFullNote(ctx context.Context, id uuid.UUID) (data.GetFullNoteRow, error) { return m.GetFullNoteFunc(ctx, id) } func (m *mockNoteStore) ListNotes(ctx context.Context, arg data.ListNotesParams) ([]data.ListNotesRow, error) { return m.ListNotesFunc(ctx, arg) } func (m *mockNoteStore) CreateNoteVersion(ctx context.Context, arg data.CreateNoteVersionParams) error { return m.CreateNoteVersionFunc(ctx, arg) } func (m *mockNoteStore) GetVersion(ctx context.Context, arg data.GetVersionParams) (data.GetVersionRow, error) { return m.GetVersionFunc(ctx, arg) } func (m *mockNoteStore) GetVersionHistory(ctx context.Context, arg data.GetVersionHistoryParams) ([]data.GetVersionHistoryRow, error) { return m.GetVersionHistoryFunc(ctx, arg) } func TestRequireRTMiddleware(t *testing.T) { secret := "test-jwt-secret" testUserID := uuid.New().String() validRT := generateTestToken(t, secret, "refresh", testUserID, true) validAT := generateTestToken(t, secret, "access", testUserID, true) expiredRT := generateTestToken(t, secret, "refresh", testUserID, true, func(claims *userClaims) { claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(-1 * time.Hour)) }) tests := []struct { name string token string expectedErr string statusCode int }{ { "no token", "", "Unauthorized", http.StatusUnauthorized, }, { "invalid token", "invalid", "Invalid token", http.StatusUnauthorized, }, { "expired token", expiredRT, "Invalid token", http.StatusUnauthorized, }, { "wrong token type", validAT, "Invalid token type", http.StatusUnauthorized, }, { "valid token", validRT, "", http.StatusOK, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { rtAuthMiddleware := requireRefreshToken(secret) // Mock request with cookie req := httptest.NewRequest("GET", "/", nil) if tc.token != "" { req.AddCookie(&http.Cookie{ Name: "refresh_token", Value: tc.token, HttpOnly: true, }) } w := httptest.NewRecorder() called := false // Mock endpoint that the middleware protects handler := rtAuthMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { called = true claims, ok := r.Context().Value(userCtxKey{}).(*userClaims) assert.True(t, ok) assert.Equal(t, "refresh", claims.TokenType) assert.Equal(t, testUserID, claims.Subject) })) handler.ServeHTTP(w, req) assert.Equal(t, tc.statusCode, w.Code) assert.Equal(t, tc.statusCode == http.StatusOK, called) if tc.expectedErr != "" { assert.Contains(t, w.Body.String(), tc.expectedErr) } }) } } func TestRequireATMiddleware(t *testing.T) { secret := "test-jwt-secret" testUserID := uuid.New().String() validRT := generateTestToken(t, secret, "refresh", testUserID, true) validAT := generateTestToken(t, secret, "access", testUserID, true) expiredAT := generateTestToken(t, secret, "access", testUserID, true, func(claims *userClaims) { claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(-1 * time.Hour)) }) tests := []struct { name string token string expectedErr string statusCode int }{ { "no token", "", "Unauthorized", http.StatusUnauthorized, }, { "invalid token", "invalid", "Invalid token", http.StatusUnauthorized, }, { "expired token", expiredAT, "Invalid token", http.StatusUnauthorized, }, { "wrong token type", validRT, "Invalid token type", http.StatusUnauthorized, }, { "valid token", validAT, "", http.StatusOK, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { atAuthMiddleware := requireAccessToken(secret) // Mock request req := httptest.NewRequest("GET", "/", nil) if tc.token != "" { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", tc.token)) } w := httptest.NewRecorder() called := false // Mock endpoint that the middleware protects handler := atAuthMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { called = true _, ok := r.Context().Value(userCtxKey{}).(*userClaims) assert.True(t, ok) })) handler.ServeHTTP(w, req) assert.Equal(t, tc.statusCode, w.Code) assert.Equal(t, tc.statusCode == http.StatusOK, called) if tc.expectedErr != "" { assert.Contains(t, w.Body.String(), tc.expectedErr) } }) } } func TestAdminOnlyMiddleware(t *testing.T) { tests := []struct { name string user *userClaims statusCode int }{ { "no user", nil, http.StatusForbidden, }, { "non admin user", &userClaims{ Admin: false, }, http.StatusForbidden, }, { "admin user", &userClaims{ Admin: true, }, http.StatusOK, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // Mock request req := httptest.NewRequest("GET", "/", nil) if tc.user != nil { req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, tc.user)) } w := httptest.NewRecorder() called := false // Mock endpoint that the middleware protects handler := adminOnlyMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { called = true w.WriteHeader(http.StatusOK) })) handler.ServeHTTP(w, req) assert.Equal(t, tc.statusCode, w.Code) assert.Equal(t, tc.statusCode == http.StatusOK, called) }) } } func TestUUIDCtxMiddleware(t *testing.T) { testKeyName := "testKey" tests := []struct { name string parameter string expectedErr string statusCode int }{ { "missing uuid", "", "Invalid resource ID", http.StatusBadRequest, }, { "invalid uuid", "invalid", "Invalid resource ID", http.StatusBadRequest, }, { "valid uuid", uuid.New().String(), "", http.StatusOK, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { uuidCtxMiddleware := uuidCtx(testKeyName) req := httptest.NewRequest("GET", "/", nil) // We need to mock the URL parameter as we don't setup an actual router in this test env. rctx := chi.NewRouteContext() rctx.URLParams.Add(testKeyName, tc.parameter) req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) w := httptest.NewRecorder() called := false // Mock endpoint that the middleware protects handler := uuidCtxMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { called = true _, ok := r.Context().Value(uuidCtxKey{Name: testKeyName}).(uuid.UUID) assert.True(t, ok) })) handler.ServeHTTP(w, req) assert.Equal(t, tc.statusCode, w.Code) assert.Equal(t, tc.statusCode == http.StatusOK, called) if tc.expectedErr != "" { assert.Contains(t, w.Body.String(), tc.expectedErr) } }) } } func TestNoteCtxMiddleware(t *testing.T) { testTitle := "Test title" tesTContent := "## Test content\nData 123" testVersion := int32(3) noteID := uuid.New() ownerUserID := uuid.New() testOwnerClaims := userClaims{RegisteredClaims: jwt.RegisteredClaims{ Subject: ownerUserID.String(), }} otherUserID := uuid.New() testOtherClaims := userClaims{RegisteredClaims: jwt.RegisteredClaims{ Subject: otherUserID.String(), }} tests := []struct { name string resourceID *uuid.UUID user *userClaims mock func(*mockNoteStore) statusCode int expectedErr string }{ { "no resource id", nil, nil, func(m *mockNoteStore) {}, http.StatusBadRequest, "Resource ID missing", }, { "unauthorized", ¬eID, nil, func(m *mockNoteStore) {}, http.StatusUnauthorized, "Unauthorized", }, { "note not found", ¬eID, &testOwnerClaims, func(m *mockNoteStore) { m.GetFullNoteFunc = func(ctx context.Context, id uuid.UUID) (data.GetFullNoteRow, error) { return data.GetFullNoteRow{}, errors.New("not found") } }, http.StatusNotFound, "Note not found", }, { "not owner", ¬eID, &testOtherClaims, func(m *mockNoteStore) { m.GetFullNoteFunc = func(ctx context.Context, id uuid.UUID) (data.GetFullNoteRow, error) { assert.Equal(t, noteID, id) testTs := time.Now() return data.GetFullNoteRow{ NoteID: id, OwnerID: ownerUserID, Title: testTitle, Content: tesTContent, VersionNumber: testVersion, VersionCreatedAt: &testTs, NoteCreatedAt: &testTs, NoteUpdatedAt: &testTs, }, nil } }, http.StatusForbidden, "Forbidden", }, { "success", ¬eID, &testOwnerClaims, func(m *mockNoteStore) { m.GetFullNoteFunc = func(ctx context.Context, id uuid.UUID) (data.GetFullNoteRow, error) { assert.Equal(t, noteID, id) testTs := time.Now() return data.GetFullNoteRow{ NoteID: id, OwnerID: ownerUserID, Title: testTitle, Content: tesTContent, VersionNumber: testVersion, VersionCreatedAt: &testTs, NoteCreatedAt: &testTs, NoteUpdatedAt: &testTs, }, nil } }, http.StatusOK, "", }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // UUID ctx. (mock) -> note ctx. (tested here) mockStore := &mockNoteStore{} req := httptest.NewRequest("GET", "/", nil) tc.mock(mockStore) // Mock endpoint that the middleware protects (where the attached note data is actually utilized) handler := noteCtx(mockStore)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fullNote, ok := r.Context().Value(noteCtxKey{}).(*data.GetFullNoteRow) assert.True(t, ok) assert.Equal(t, noteID, fullNote.NoteID) assert.Equal(t, testTitle, fullNote.Title) assert.Equal(t, tesTContent, fullNote.Content) assert.Equal(t, testVersion, fullNote.VersionNumber) w.WriteHeader(http.StatusOK) })) // Request parameters don't need to be mocked, as parsing of them isn't handled // by this middleware, and thus that portion shouldn't be tested here. if tc.resourceID != nil { req = req.WithContext(context.WithValue(req.Context(), uuidCtxKey{Name: noteUUIDCtxParameter}, *tc.resourceID)) } if tc.user != nil { req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, tc.user)) } w := httptest.NewRecorder() handler.ServeHTTP(w, req) assert.Equal(t, tc.statusCode, w.Code) if tc.expectedErr != "" { assert.Contains(t, w.Body.String(), tc.expectedErr) } }) } } func TestVersionCtxMiddleware(t *testing.T) { testTitle := "Test title" tesTContent := "## Test content\nData 123" testVersion := int32(3) versionID := uuid.New() noteID := uuid.New() testNote := data.GetFullNoteRow{ NoteID: noteID, } tests := []struct { name string resourceID *uuid.UUID note *data.GetFullNoteRow mock func(*mockNoteStore) statusCode int expectedErr string }{ { "no note", nil, nil, func(m *mockNoteStore) {}, http.StatusNotFound, "Note not found", }, { "no resource id", nil, &testNote, func(m *mockNoteStore) {}, http.StatusBadRequest, "Resource ID missing", }, { "version not found", &versionID, &testNote, func(m *mockNoteStore) { m.GetVersionFunc = func(ctx context.Context, gvp data.GetVersionParams) (data.GetVersionRow, error) { return data.GetVersionRow{}, errors.New("not found") } }, http.StatusNotFound, "Version not found", }, { "success", &versionID, &testNote, func(m *mockNoteStore) { m.GetVersionFunc = func(ctx context.Context, gvp data.GetVersionParams) (data.GetVersionRow, error) { assert.Equal(t, versionID, gvp.ID) assert.Equal(t, noteID, gvp.NoteID) testTs := time.Now() return data.GetVersionRow{ VersionID: gvp.ID, Title: testTitle, Content: tesTContent, VersionNumber: testVersion, CreatedAt: &testTs, }, nil } }, http.StatusOK, "", }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // note ctx. (mock) -> UUID ctx. (mock) -> version ctx. (tested here) mockStore := &mockNoteStore{} req := httptest.NewRequest("GET", "/", nil) tc.mock(mockStore) // Mock endpoint that the middleware protects (where the attached note data is actually utilized) handler := versionCtx(mockStore)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fullVersion, ok := r.Context().Value(versionCtxKey{}).(*data.GetVersionRow) assert.True(t, ok) assert.Equal(t, versionID, fullVersion.VersionID) assert.Equal(t, testTitle, fullVersion.Title) assert.Equal(t, tesTContent, fullVersion.Content) assert.Equal(t, testVersion, fullVersion.VersionNumber) w.WriteHeader(http.StatusOK) })) // Request parameters don't need to be mocked, as parsing of them isn't handled // by this middleware, and thus that portion shouldn't be tested here. if tc.note != nil { req = req.WithContext(context.WithValue(req.Context(), noteCtxKey{}, tc.note)) } if tc.resourceID != nil { req = req.WithContext(context.WithValue(req.Context(), uuidCtxKey{Name: versionUUIDCtxParameter}, *tc.resourceID)) } w := httptest.NewRecorder() handler.ServeHTTP(w, req) assert.Equal(t, tc.statusCode, w.Code) if tc.expectedErr != "" { assert.Contains(t, w.Body.String(), tc.expectedErr) } }) } } func generateTestToken(t *testing.T, secret, tokenType, userID string, isAdmin bool, opts ...func(*userClaims)) string { t.Helper() claims := &userClaims{ Admin: isAdmin, TokenType: tokenType, RegisteredClaims: jwt.RegisteredClaims{ Subject: userID, ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), }, } for _, opt := range opts { opt(claims) } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) signedToken, err := token.SignedString([]byte(secret)) if err != nil { t.Fatalf("Failed to generate test token: %v", err) } return signedToken }