package service import ( "context" "encoding/json" "errors" "fmt" "net/http" "net/http/httptest" "strings" "testing" "git.umbrella.haus/ae/notatest/pkg/data" "github.com/go-chi/chi/v5" "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" "github.com/stretchr/testify/assert" ) type mockNoteStore struct { CreateNoteFunc func(context.Context, uuid.UUID) (data.Note, error) DeleteNoteFunc func(context.Context, data.DeleteNoteParams) error GetNoteFunc func(context.Context, data.GetNoteParams) (data.Note, error) ListNotesFunc func(context.Context, data.ListNotesParams) ([]data.Note, error) CreateNoteVersionFunc func(context.Context, data.CreateNoteVersionParams) (data.NoteVersion, error) FindDuplicateContentFunc func(context.Context, data.FindDuplicateContentParams) (bool, error) GetNoteVersionFunc func(context.Context, data.GetNoteVersionParams) (data.NoteVersion, error) GetNoteVersionsFunc func(context.Context, data.GetNoteVersionsParams) ([]data.NoteVersion, error) } func (m *mockNoteStore) CreateNote(ctx context.Context, userID uuid.UUID) (data.Note, error) { return m.CreateNoteFunc(ctx, userID) } func (m *mockNoteStore) DeleteNote(ctx context.Context, arg data.DeleteNoteParams) error { return m.DeleteNoteFunc(ctx, arg) } func (m *mockNoteStore) GetNote(ctx context.Context, arg data.GetNoteParams) (data.Note, error) { return m.GetNoteFunc(ctx, arg) } func (m *mockNoteStore) ListNotes(ctx context.Context, arg data.ListNotesParams) ([]data.Note, error) { return m.ListNotesFunc(ctx, arg) } func (m *mockNoteStore) CreateNoteVersion(ctx context.Context, arg data.CreateNoteVersionParams) (data.NoteVersion, error) { return m.CreateNoteVersionFunc(ctx, arg) } func (m *mockNoteStore) FindDuplicateContent(ctx context.Context, arg data.FindDuplicateContentParams) (bool, error) { return m.FindDuplicateContentFunc(ctx, arg) } func (m *mockNoteStore) GetNoteVersion(ctx context.Context, arg data.GetNoteVersionParams) (data.NoteVersion, error) { return m.GetNoteVersionFunc(ctx, arg) } func (m *mockNoteStore) GetNoteVersions(ctx context.Context, arg data.GetNoteVersionsParams) ([]data.NoteVersion, error) { return m.GetNoteVersionsFunc(ctx, arg) } func TestNotes_CreateNote(t *testing.T) { userID := uuid.New() testNote := data.Note{ID: uuid.New(), UserID: userID} tests := []struct { name string mock func(*mockNoteStore) statusCode int }{ { "success", func(m *mockNoteStore) { m.CreateNoteFunc = func(_ context.Context, uid uuid.UUID) (data.Note, error) { assert.Equal(t, userID, uid) return testNote, nil } }, http.StatusCreated, }, { "database error", func(m *mockNoteStore) { m.CreateNoteFunc = func(context.Context, uuid.UUID) (data.Note, error) { return data.Note{}, errors.New("db error") } }, http.StatusInternalServerError, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { mockStore := &mockNoteStore{} tc.mock(mockStore) rs := notesResource{Notes: mockStore} req := httptest.NewRequest("POST", "/", nil) req = req.WithContext(context.WithValue( req.Context(), userCtxKey{}, &userClaims{RegisteredClaims: jwt.RegisteredClaims{ Subject: userID.String(), }}, )) w := httptest.NewRecorder() rs.CreateNote(w, req) assert.Equal(t, tc.statusCode, w.Code) if tc.statusCode == http.StatusCreated { var note data.Note json.Unmarshal(w.Body.Bytes(), ¬e) assert.Equal(t, testNote.ID, note.ID) } }) } } func TestNotes_ListNotes(t *testing.T) { userID := uuid.New() notes := []data.Note{ {ID: uuid.New(), UserID: userID}, {ID: uuid.New(), UserID: userID}, } tests := []struct { name string query string mock func(*mockNoteStore) statusCode int }{ { "success", "", func(m *mockNoteStore) { m.ListNotesFunc = func(_ context.Context, arg data.ListNotesParams) ([]data.Note, error) { assert.Equal(t, userID, arg.UserID) return notes, nil } }, http.StatusOK, }, { "with pagination", "?limit=10&offset=20", func(m *mockNoteStore) { m.ListNotesFunc = func(_ context.Context, arg data.ListNotesParams) ([]data.Note, error) { assert.EqualValues(t, 10, arg.Limit) assert.EqualValues(t, 20, arg.Offset) return notes, nil } }, http.StatusOK, }, { "database error", "", func(m *mockNoteStore) { m.ListNotesFunc = func(context.Context, data.ListNotesParams) ([]data.Note, error) { return nil, errors.New("db error") } }, http.StatusInternalServerError, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { mockStore := &mockNoteStore{} tc.mock(mockStore) rs := notesResource{Notes: mockStore} req := httptest.NewRequest("GET", fmt.Sprintf("/%s", tc.query), nil) req = req.WithContext(context.WithValue( req.Context(), userCtxKey{}, &userClaims{RegisteredClaims: jwt.RegisteredClaims{ Subject: userID.String(), }}, )) w := httptest.NewRecorder() rs.ListNotes(w, req) assert.Equal(t, tc.statusCode, w.Code) }) } } func TestNotes_GetNote(t *testing.T) { noteID := uuid.New() userID := uuid.New() validNote := data.Note{ID: noteID, UserID: userID} t.Run("success", func(t *testing.T) { rs := notesResource{} req := httptest.NewRequest("GET", "/", nil) req = req.WithContext(context.WithValue( req.Context(), noteCtxKey{}, validNote, )) w := httptest.NewRecorder() rs.GetNote(w, req) assert.Equal(t, http.StatusOK, w.Code) var note data.Note json.Unmarshal(w.Body.Bytes(), ¬e) assert.Equal(t, validNote.ID, note.ID) }) t.Run("not found", func(t *testing.T) { rs := notesResource{} req := httptest.NewRequest("GET", "/", nil) w := httptest.NewRecorder() rs.GetNote(w, req) assert.Equal(t, http.StatusNotFound, w.Code) }) } func TestNotes_DeleteNote(t *testing.T) { noteID := uuid.New() userID := uuid.New() validNote := data.Note{ID: noteID, UserID: userID} t.Run("success", func(t *testing.T) { mockStore := &mockNoteStore{ DeleteNoteFunc: func(_ context.Context, arg data.DeleteNoteParams) error { assert.Equal(t, noteID, arg.ID) assert.Equal(t, userID, arg.UserID) return nil }, } rs := notesResource{Notes: mockStore} req := httptest.NewRequest("DELETE", "/", nil) req = req.WithContext(context.WithValue( req.Context(), noteCtxKey{}, validNote, )) req = req.WithContext(context.WithValue( req.Context(), userCtxKey{}, &userClaims{RegisteredClaims: jwt.RegisteredClaims{ Subject: userID.String(), }}, )) w := httptest.NewRecorder() rs.DeleteNote(w, req) assert.Equal(t, http.StatusNoContent, w.Code) }) t.Run("database error", func(t *testing.T) { mockStore := &mockNoteStore{ DeleteNoteFunc: func(context.Context, data.DeleteNoteParams) error { return errors.New("db error") }, } rs := notesResource{Notes: mockStore} req := httptest.NewRequest("DELETE", "/", nil) req = req.WithContext(context.WithValue( req.Context(), noteCtxKey{}, validNote, )) req = req.WithContext(context.WithValue( req.Context(), userCtxKey{}, &userClaims{RegisteredClaims: jwt.RegisteredClaims{ Subject: userID.String(), }}, )) w := httptest.NewRecorder() rs.DeleteNote(w, req) assert.Equal(t, http.StatusInternalServerError, w.Code) }) } func TestNotes_CreateNoteVersion(t *testing.T) { noteID := uuid.New() validRequest := `{"title": "Test", "content": "Content"}` tests := []struct { name string body string mock func(*mockNoteStore) statusCode int }{ { "success", validRequest, func(m *mockNoteStore) { m.FindDuplicateContentFunc = func(context.Context, data.FindDuplicateContentParams) (bool, error) { return false, nil } m.CreateNoteVersionFunc = func(_ context.Context, arg data.CreateNoteVersionParams) (data.NoteVersion, error) { assert.Equal(t, noteID, arg.NoteID) return data.NoteVersion{}, nil } }, http.StatusCreated, }, { "duplicate content", validRequest, func(m *mockNoteStore) { m.FindDuplicateContentFunc = func(context.Context, data.FindDuplicateContentParams) (bool, error) { return true, nil } }, http.StatusConflict, }, { "invalid request", "{invalid}", func(m *mockNoteStore) {}, http.StatusBadRequest, }, { "database error", validRequest, func(m *mockNoteStore) { m.FindDuplicateContentFunc = func(context.Context, data.FindDuplicateContentParams) (bool, error) { return false, nil } m.CreateNoteVersionFunc = func(context.Context, data.CreateNoteVersionParams) (data.NoteVersion, error) { return data.NoteVersion{}, errors.New("db error") } }, http.StatusInternalServerError, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { mockStore := &mockNoteStore{} tc.mock(mockStore) rs := notesResource{Notes: mockStore} req := httptest.NewRequest("POST", "/", strings.NewReader(tc.body)) req = req.WithContext(context.WithValue( req.Context(), noteCtxKey{}, data.Note{ID: noteID}, )) w := httptest.NewRecorder() rs.CreateNoteVersion(w, req) assert.Equal(t, tc.statusCode, w.Code) }) } } func TestNotes_ListNoteVersions(t *testing.T) { noteID := uuid.New() versions := []data.NoteVersion{ {ID: uuid.New(), NoteID: noteID, VersionNumber: 1}, {ID: uuid.New(), NoteID: noteID, VersionNumber: 2}, } tests := []struct { name string query string mock func(*mockNoteStore) statusCode int }{ { "success", "", func(m *mockNoteStore) { m.GetNoteVersionsFunc = func(_ context.Context, arg data.GetNoteVersionsParams) ([]data.NoteVersion, error) { assert.Equal(t, noteID, arg.NoteID) return versions, nil } }, http.StatusOK, }, { "with pagination", "?limit=5&offset=10", func(m *mockNoteStore) { m.GetNoteVersionsFunc = func(_ context.Context, arg data.GetNoteVersionsParams) ([]data.NoteVersion, error) { assert.EqualValues(t, 5, arg.Limit) assert.EqualValues(t, 10, arg.Offset) return versions, nil } }, http.StatusOK, }, { "database error", "", func(m *mockNoteStore) { m.GetNoteVersionsFunc = func(context.Context, data.GetNoteVersionsParams) ([]data.NoteVersion, error) { return nil, errors.New("db error") } }, http.StatusInternalServerError, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { mockStore := &mockNoteStore{} tc.mock(mockStore) rs := notesResource{Notes: mockStore} req := httptest.NewRequest("GET", fmt.Sprintf("/versions/%s", tc.query), nil) req = req.WithContext(context.WithValue( req.Context(), noteCtxKey{}, data.Note{ID: noteID}, )) w := httptest.NewRecorder() rs.ListNoteVersions(w, req) assert.Equal(t, tc.statusCode, w.Code) if tc.statusCode == http.StatusOK { var result []data.NoteVersion json.Unmarshal(w.Body.Bytes(), &result) assert.Len(t, result, 2) } }) } } func TestNotes_GetNoteVersion(t *testing.T) { noteID := uuid.New() version := data.NoteVersion{ID: uuid.New(), NoteID: noteID, VersionNumber: 1} tests := []struct { name string version string mock func(*mockNoteStore) statusCode int }{ { "invalid version", "invalid", func(m *mockNoteStore) {}, http.StatusBadRequest, }, { "version not found", "1", func(m *mockNoteStore) { m.GetNoteVersionFunc = func(context.Context, data.GetNoteVersionParams) (data.NoteVersion, error) { return data.NoteVersion{}, errors.New("not found") } }, http.StatusNotFound, }, { "success", "1", func(m *mockNoteStore) { m.GetNoteVersionFunc = func(_ context.Context, arg data.GetNoteVersionParams) (data.NoteVersion, error) { assert.Equal(t, noteID, arg.NoteID) assert.EqualValues(t, 1, arg.VersionNumber) return version, nil } }, http.StatusOK, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { mockStore := &mockNoteStore{} tc.mock(mockStore) rs := notesResource{Notes: mockStore} req := httptest.NewRequest("GET", fmt.Sprintf("/versions/%s", tc.version), nil) // Chi router context mocks ID (passed in a URL param.) and the note object (passed in req. ctx.) rctx := chi.NewRouteContext() rctx.URLParams.Add("version", tc.version) req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) req = req.WithContext(context.WithValue( req.Context(), noteCtxKey{}, data.Note{ID: noteID}, )) w := httptest.NewRecorder() rs.GetNoteVersion(w, req) assert.Equal(t, tc.statusCode, w.Code) if tc.statusCode == http.StatusOK { var result data.NoteVersion json.Unmarshal(w.Body.Bytes(), &result) assert.Equal(t, version.VersionNumber, result.VersionNumber) } }) } }