Compare commits
4 Commits
91daec42de
...
5de5c8c285
Author | SHA1 | Date | |
---|---|---|---|
5de5c8c285 | |||
18e650c898 | |||
a32bdef092 | |||
10bcdf88c7 |
@ -6,6 +6,7 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.umbrella.haus/ae/notatest/pkg/data"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
@ -18,8 +19,6 @@ const (
|
||||
defaultLogMsg = "incoming request"
|
||||
)
|
||||
|
||||
type userCtxKey struct{}
|
||||
|
||||
// Get JWT bearer from request's authorization header, parse it with custom user claims, and
|
||||
// ensure its validity before attaching the claims to the request's context.
|
||||
func authMiddleware(jwtSecret string, expectedType string) func(http.Handler) http.Handler {
|
||||
@ -30,7 +29,7 @@ func authMiddleware(jwtSecret string, expectedType string) func(http.Handler) ht
|
||||
respondError(w, http.StatusUnauthorized, fmt.Sprintf("Unauthorized: %s", err))
|
||||
}
|
||||
|
||||
token, err := jwt.ParseWithClaims(tokenString, &userClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &userClaims{}, func(token *jwt.Token) (any, error) {
|
||||
return []byte(jwtSecret), nil
|
||||
})
|
||||
if err != nil || !token.Valid {
|
||||
@ -109,6 +108,46 @@ func userCtx(store UserStore) func(http.Handler) http.Handler {
|
||||
}
|
||||
}
|
||||
|
||||
// Append note data into request's context based on note ID as a URL parameter and user ID as
|
||||
// context parameter.
|
||||
func noteCtx(store NoteStore) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
noteIDStr := chi.URLParam(r, "id")
|
||||
noteID, err := uuid.Parse(noteIDStr)
|
||||
if err != nil {
|
||||
respondError(w, http.StatusNotFound, "Invalid note ID")
|
||||
return
|
||||
}
|
||||
|
||||
// NOTE: user must already be in the context (e.g. via JWT middleware)
|
||||
user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
||||
if !ok {
|
||||
respondError(w, http.StatusUnauthorized, "Unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
userID, err := uuid.Parse(user.Subject)
|
||||
if err != nil {
|
||||
respondError(w, http.StatusInternalServerError, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
note, err := store.GetNote(r.Context(), data.GetNoteParams{
|
||||
ID: noteID,
|
||||
UserID: userID,
|
||||
})
|
||||
if err != nil {
|
||||
respondError(w, http.StatusNotFound, "Note not found")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), noteCtxKey{}, note)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Zerolog compatible logger middleware. Automatically logs and recovers from errors with HTTP 500
|
||||
// response, by default logs to INFO level.
|
||||
func loggerMiddleware(log *zerolog.Logger) func(http.Handler) http.Handler {
|
||||
@ -136,7 +175,7 @@ func loggerMiddleware(log *zerolog.Logger) func(http.Handler) http.Handler {
|
||||
log.Info().
|
||||
Str("type", "access").
|
||||
Timestamp().
|
||||
Fields(map[string]interface{}{
|
||||
Fields(map[string]any{
|
||||
"remote_ip": r.RemoteAddr,
|
||||
"url": r.URL.Path,
|
||||
"proto": r.Proto,
|
||||
|
@ -232,9 +232,21 @@ func TestUserCtxMiddleware(t *testing.T) {
|
||||
urlID string
|
||||
statusCode int
|
||||
}{
|
||||
{"valid ID", validUserID.String(), http.StatusOK},
|
||||
{"invalid ID", invalidUserID, http.StatusNotFound},
|
||||
{"non existent ID", uuid.New().String(), http.StatusNotFound},
|
||||
{
|
||||
"valid ID",
|
||||
validUserID.String(),
|
||||
http.StatusOK,
|
||||
},
|
||||
{
|
||||
"invalid ID",
|
||||
invalidUserID,
|
||||
http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
"non existent ID",
|
||||
uuid.New().String(),
|
||||
http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
@ -257,6 +269,97 @@ func TestUserCtxMiddleware(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNoteCtxMiddleware(t *testing.T) {
|
||||
userID := uuid.New()
|
||||
noteID := uuid.New()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
noteID string
|
||||
user any
|
||||
mock func(*mockNoteStore)
|
||||
statusCode int
|
||||
}{
|
||||
{
|
||||
"invalid note ID",
|
||||
"invalid",
|
||||
&userClaims{RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: userID.String(),
|
||||
}},
|
||||
func(m *mockNoteStore) {},
|
||||
http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
"unauthorized user",
|
||||
noteID.String(),
|
||||
nil,
|
||||
func(m *mockNoteStore) {},
|
||||
http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
"note not found",
|
||||
noteID.String(),
|
||||
&userClaims{RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: userID.String(),
|
||||
}},
|
||||
func(m *mockNoteStore) {
|
||||
m.GetNoteFunc = func(ctx context.Context, arg data.GetNoteParams) (data.Note, error) {
|
||||
return data.Note{}, errors.New("not found")
|
||||
}
|
||||
},
|
||||
http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
"success",
|
||||
noteID.String(),
|
||||
&userClaims{RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: userID.String(),
|
||||
}},
|
||||
func(m *mockNoteStore) {
|
||||
m.GetNoteFunc = func(ctx context.Context, arg data.GetNoteParams) (data.Note, error) {
|
||||
assert.Equal(t, noteID, arg.ID)
|
||||
assert.Equal(t, userID, arg.UserID)
|
||||
return data.Note{ID: noteID}, nil
|
||||
}
|
||||
},
|
||||
http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
mockStore := &mockNoteStore{}
|
||||
tc.mock(mockStore)
|
||||
|
||||
handler := noteCtx(mockStore)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, ok := r.Context().Value(noteCtxKey{}).(data.Note)
|
||||
assert.True(t, ok)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/notes/%s", tc.noteID), nil)
|
||||
|
||||
// Chi router context mocks ID passed in a URL parameter
|
||||
rctx := chi.NewRouteContext()
|
||||
rctx.URLParams.Add("id", tc.noteID)
|
||||
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
|
||||
|
||||
if tc.user != nil {
|
||||
req = req.WithContext(context.WithValue(
|
||||
req.Context(),
|
||||
userCtxKey{},
|
||||
tc.user,
|
||||
))
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, tc.statusCode, w.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func generateTestToken(t *testing.T, secret, tokenType, userID string, isAdmin bool, opts ...func(*userClaims)) string {
|
||||
t.Helper()
|
||||
|
||||
|
@ -1,19 +1,262 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"git.umbrella.haus/ae/notatest/pkg/data"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type noteCtxKey struct{}
|
||||
|
||||
// Mockable database operations interface
|
||||
type NoteStore interface {
|
||||
// TODO: implement
|
||||
CreateNote(ctx context.Context, userID uuid.UUID) (data.Note, error)
|
||||
DeleteNote(ctx context.Context, arg data.DeleteNoteParams) error
|
||||
GetNote(ctx context.Context, arg data.GetNoteParams) (data.Note, error)
|
||||
ListNotes(ctx context.Context, arg data.ListNotesParams) ([]data.Note, error)
|
||||
CreateNoteVersion(ctx context.Context, arg data.CreateNoteVersionParams) (data.NoteVersion, error)
|
||||
FindDuplicateContent(ctx context.Context, arg data.FindDuplicateContentParams) (bool, error)
|
||||
GetNoteVersion(ctx context.Context, arg data.GetNoteVersionParams) (data.NoteVersion, error)
|
||||
GetNoteVersions(ctx context.Context, arg data.GetNoteVersionsParams) ([]data.NoteVersion, error)
|
||||
}
|
||||
|
||||
type notesResource struct {
|
||||
Notes NoteStore
|
||||
JWTSecret string
|
||||
Notes NoteStore
|
||||
}
|
||||
|
||||
func (rs notesResource) Routes() chi.Router {
|
||||
r := chi.NewRouter()
|
||||
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(requireAccessToken(rs.JWTSecret))
|
||||
|
||||
r.Post("/", rs.CreateNote) // POST /notes - note creation
|
||||
r.Get("/", rs.ListNotes) // GET /notes - get all notes
|
||||
|
||||
r.Route("/{id}", func(r chi.Router) {
|
||||
r.Use(noteCtx(rs.Notes))
|
||||
|
||||
r.Get("/", rs.GetNote) // GET /notes/{id} - get specific note
|
||||
r.Delete("/", rs.DeleteNote) // DELETE /notes/{id} - delete specific note
|
||||
|
||||
r.Route("/versions", func(r chi.Router) {
|
||||
r.Post("/", rs.CreateNoteVersion) // POST /notes/{id}/versions - create new version
|
||||
r.Get("/", rs.ListNoteVersions) // GET /notes/{id}/versions - get all existing versions
|
||||
r.Get("/{version}", rs.GetNoteVersion) // GET /notes/{id}/versions/{version} - get specific version
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func (rs *notesResource) CreateNote(w http.ResponseWriter, r *http.Request) {
|
||||
user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
||||
if !ok {
|
||||
respondError(w, http.StatusUnauthorized, "Unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
userID, err := uuid.Parse(user.Subject)
|
||||
if err != nil {
|
||||
respondError(w, http.StatusInternalServerError, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
note, err := rs.Notes.CreateNote(r.Context(), userID)
|
||||
if err != nil {
|
||||
respondError(w, http.StatusInternalServerError, "Failed to create note")
|
||||
return
|
||||
}
|
||||
|
||||
respondJSON(w, http.StatusCreated, note)
|
||||
}
|
||||
|
||||
func (rs *notesResource) ListNotes(w http.ResponseWriter, r *http.Request) {
|
||||
user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
||||
if !ok {
|
||||
respondError(w, http.StatusUnauthorized, "Unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
userID, err := uuid.Parse(user.Subject)
|
||||
if err != nil {
|
||||
respondError(w, http.StatusInternalServerError, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
limit, offset := getPaginationParams(r)
|
||||
|
||||
notes, err := rs.Notes.ListNotes(r.Context(), data.ListNotesParams{
|
||||
UserID: userID,
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
})
|
||||
if err != nil {
|
||||
respondError(w, http.StatusInternalServerError, "Failed to retrieve notes")
|
||||
return
|
||||
}
|
||||
|
||||
respondJSON(w, http.StatusOK, notes)
|
||||
}
|
||||
|
||||
func (rs *notesResource) GetNote(w http.ResponseWriter, r *http.Request) {
|
||||
note, ok := r.Context().Value(noteCtxKey{}).(data.Note)
|
||||
if !ok {
|
||||
respondError(w, http.StatusNotFound, "Note not found")
|
||||
return
|
||||
}
|
||||
|
||||
respondJSON(w, http.StatusOK, note)
|
||||
}
|
||||
|
||||
func (rs *notesResource) DeleteNote(w http.ResponseWriter, r *http.Request) {
|
||||
note, ok := r.Context().Value(noteCtxKey{}).(data.Note)
|
||||
if !ok {
|
||||
respondError(w, http.StatusNotFound, "Note not found")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
||||
if !ok {
|
||||
respondError(w, http.StatusUnauthorized, "Unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
userID, err := uuid.Parse(user.Subject)
|
||||
if err != nil {
|
||||
respondError(w, http.StatusInternalServerError, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
err = rs.Notes.DeleteNote(r.Context(), data.DeleteNoteParams{
|
||||
ID: note.ID,
|
||||
UserID: userID,
|
||||
})
|
||||
if err != nil {
|
||||
respondError(w, http.StatusInternalServerError, "Failed to delete note")
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (rs *notesResource) CreateNoteVersion(w http.ResponseWriter, r *http.Request) {
|
||||
note, ok := r.Context().Value(noteCtxKey{}).(data.Note)
|
||||
if !ok {
|
||||
respondError(w, http.StatusNotFound, "Note not found")
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
respondError(w, http.StatusBadRequest, "Invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
// De-duplication check
|
||||
duplicate, err := rs.Notes.FindDuplicateContent(r.Context(), data.FindDuplicateContentParams{
|
||||
NoteID: note.ID,
|
||||
Column2: []byte(req.Title),
|
||||
Column3: []byte(req.Content),
|
||||
})
|
||||
if err != nil {
|
||||
respondError(w, http.StatusInternalServerError, "Failed to check for duplicate content")
|
||||
return
|
||||
}
|
||||
if duplicate {
|
||||
respondError(w, http.StatusConflict, "Duplicate content detected")
|
||||
return
|
||||
}
|
||||
|
||||
version, err := rs.Notes.CreateNoteVersion(r.Context(), data.CreateNoteVersionParams{
|
||||
NoteID: note.ID,
|
||||
Title: req.Title,
|
||||
Content: req.Content,
|
||||
})
|
||||
if err != nil {
|
||||
respondError(w, http.StatusInternalServerError, "Failed to create note version")
|
||||
return
|
||||
}
|
||||
|
||||
respondJSON(w, http.StatusCreated, version)
|
||||
}
|
||||
|
||||
func (rs *notesResource) ListNoteVersions(w http.ResponseWriter, r *http.Request) {
|
||||
note, ok := r.Context().Value(noteCtxKey{}).(data.Note)
|
||||
if !ok {
|
||||
respondError(w, http.StatusNotFound, "Note not found")
|
||||
return
|
||||
}
|
||||
|
||||
limit, offset := getPaginationParams(r)
|
||||
|
||||
versions, err := rs.Notes.GetNoteVersions(r.Context(), data.GetNoteVersionsParams{
|
||||
NoteID: note.ID,
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
})
|
||||
if err != nil {
|
||||
respondError(w, http.StatusInternalServerError, "Failed to retrieve versions")
|
||||
return
|
||||
}
|
||||
|
||||
respondJSON(w, http.StatusOK, versions)
|
||||
}
|
||||
|
||||
func (rs *notesResource) GetNoteVersion(w http.ResponseWriter, r *http.Request) {
|
||||
note, ok := r.Context().Value(noteCtxKey{}).(data.Note)
|
||||
if !ok {
|
||||
respondError(w, http.StatusNotFound, "Note not found")
|
||||
return
|
||||
}
|
||||
|
||||
versionStr := chi.URLParam(r, "version")
|
||||
versionNumber, err := strconv.ParseInt(versionStr, 10, 32)
|
||||
if err != nil {
|
||||
respondError(w, http.StatusBadRequest, "Invalid version number")
|
||||
return
|
||||
}
|
||||
|
||||
version, err := rs.Notes.GetNoteVersion(r.Context(), data.GetNoteVersionParams{
|
||||
NoteID: note.ID,
|
||||
VersionNumber: int32(versionNumber),
|
||||
})
|
||||
if err != nil {
|
||||
respondError(w, http.StatusNotFound, "Version not found")
|
||||
return
|
||||
}
|
||||
|
||||
respondJSON(w, http.StatusOK, version)
|
||||
}
|
||||
|
||||
func getPaginationParams(r *http.Request) (limit int32, offset int32) {
|
||||
defaultLimit := 50
|
||||
defaultOffset := 0
|
||||
|
||||
limitStr := r.URL.Query().Get("limit")
|
||||
if limitStr != "" {
|
||||
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 {
|
||||
defaultLimit = l
|
||||
}
|
||||
}
|
||||
|
||||
offsetStr := r.URL.Query().Get("offset")
|
||||
if offsetStr != "" {
|
||||
if o, err := strconv.Atoi(offsetStr); err == nil && o >= 0 {
|
||||
defaultOffset = o
|
||||
}
|
||||
}
|
||||
|
||||
return int32(defaultLimit), int32(defaultOffset)
|
||||
}
|
||||
|
364
server/pkg/service/notes_test.go
Normal file
364
server/pkg/service/notes_test.go
Normal file
@ -0,0 +1,364 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"git.umbrella.haus/ae/notatest/pkg/data"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type mockNoteStore struct {
|
||||
CreateNoteFunc func(context.Context, uuid.UUID) (data.Note, error)
|
||||
DeleteNoteFunc func(context.Context, data.DeleteNoteParams) error
|
||||
GetNoteFunc func(context.Context, data.GetNoteParams) (data.Note, error)
|
||||
ListNotesFunc func(context.Context, data.ListNotesParams) ([]data.Note, error)
|
||||
CreateNoteVersionFunc func(context.Context, data.CreateNoteVersionParams) (data.NoteVersion, error)
|
||||
FindDuplicateContentFunc func(context.Context, data.FindDuplicateContentParams) (bool, error)
|
||||
GetNoteVersionFunc func(context.Context, data.GetNoteVersionParams) (data.NoteVersion, error)
|
||||
GetNoteVersionsFunc func(context.Context, data.GetNoteVersionsParams) ([]data.NoteVersion, error)
|
||||
}
|
||||
|
||||
func (m *mockNoteStore) CreateNote(ctx context.Context, userID uuid.UUID) (data.Note, error) {
|
||||
return m.CreateNoteFunc(ctx, userID)
|
||||
}
|
||||
|
||||
func (m *mockNoteStore) DeleteNote(ctx context.Context, arg data.DeleteNoteParams) error {
|
||||
return m.DeleteNoteFunc(ctx, arg)
|
||||
}
|
||||
|
||||
func (m *mockNoteStore) GetNote(ctx context.Context, arg data.GetNoteParams) (data.Note, error) {
|
||||
return m.GetNoteFunc(ctx, arg)
|
||||
}
|
||||
|
||||
func (m *mockNoteStore) ListNotes(ctx context.Context, arg data.ListNotesParams) ([]data.Note, error) {
|
||||
return m.ListNotesFunc(ctx, arg)
|
||||
}
|
||||
|
||||
func (m *mockNoteStore) CreateNoteVersion(ctx context.Context, arg data.CreateNoteVersionParams) (data.NoteVersion, error) {
|
||||
return m.CreateNoteVersionFunc(ctx, arg)
|
||||
}
|
||||
|
||||
func (m *mockNoteStore) FindDuplicateContent(ctx context.Context, arg data.FindDuplicateContentParams) (bool, error) {
|
||||
return m.FindDuplicateContentFunc(ctx, arg)
|
||||
}
|
||||
|
||||
func (m *mockNoteStore) GetNoteVersion(ctx context.Context, arg data.GetNoteVersionParams) (data.NoteVersion, error) {
|
||||
return m.GetNoteVersionFunc(ctx, arg)
|
||||
}
|
||||
|
||||
func (m *mockNoteStore) GetNoteVersions(ctx context.Context, arg data.GetNoteVersionsParams) ([]data.NoteVersion, error) {
|
||||
return m.GetNoteVersionsFunc(ctx, arg)
|
||||
}
|
||||
|
||||
func TestNotes_CreateNote(t *testing.T) {
|
||||
userID := uuid.New()
|
||||
testNote := data.Note{ID: uuid.New(), UserID: userID}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mock func(*mockNoteStore)
|
||||
statusCode int
|
||||
}{
|
||||
{
|
||||
"success",
|
||||
func(m *mockNoteStore) {
|
||||
m.CreateNoteFunc = func(_ context.Context, uid uuid.UUID) (data.Note, error) {
|
||||
assert.Equal(t, userID, uid)
|
||||
return testNote, nil
|
||||
}
|
||||
},
|
||||
http.StatusCreated,
|
||||
},
|
||||
{
|
||||
"database error",
|
||||
func(m *mockNoteStore) {
|
||||
m.CreateNoteFunc = func(context.Context, uuid.UUID) (data.Note, error) {
|
||||
return data.Note{}, errors.New("db error")
|
||||
}
|
||||
},
|
||||
http.StatusInternalServerError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
mockStore := &mockNoteStore{}
|
||||
tc.mock(mockStore)
|
||||
|
||||
rs := notesResource{Notes: mockStore}
|
||||
req := httptest.NewRequest("POST", "/", nil)
|
||||
req = req.WithContext(context.WithValue(
|
||||
req.Context(),
|
||||
userCtxKey{},
|
||||
&userClaims{RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: userID.String(),
|
||||
}},
|
||||
))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
rs.CreateNote(w, req)
|
||||
|
||||
assert.Equal(t, tc.statusCode, w.Code)
|
||||
if tc.statusCode == http.StatusCreated {
|
||||
var note data.Note
|
||||
json.Unmarshal(w.Body.Bytes(), ¬e)
|
||||
assert.Equal(t, testNote.ID, note.ID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotes_ListNotes(t *testing.T) {
|
||||
userID := uuid.New()
|
||||
notes := []data.Note{
|
||||
{ID: uuid.New(), UserID: userID},
|
||||
{ID: uuid.New(), UserID: userID},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
mock func(*mockNoteStore)
|
||||
statusCode int
|
||||
}{
|
||||
{
|
||||
"success",
|
||||
"",
|
||||
func(m *mockNoteStore) {
|
||||
m.ListNotesFunc = func(_ context.Context, arg data.ListNotesParams) ([]data.Note, error) {
|
||||
assert.Equal(t, userID, arg.UserID)
|
||||
return notes, nil
|
||||
}
|
||||
},
|
||||
http.StatusOK,
|
||||
},
|
||||
{
|
||||
"with pagination",
|
||||
"?limit=10&offset=20",
|
||||
func(m *mockNoteStore) {
|
||||
m.ListNotesFunc = func(_ context.Context, arg data.ListNotesParams) ([]data.Note, error) {
|
||||
assert.EqualValues(t, 10, arg.Limit)
|
||||
assert.EqualValues(t, 20, arg.Offset)
|
||||
return notes, nil
|
||||
}
|
||||
},
|
||||
http.StatusOK,
|
||||
},
|
||||
{
|
||||
"database error",
|
||||
"",
|
||||
func(m *mockNoteStore) {
|
||||
m.ListNotesFunc = func(context.Context, data.ListNotesParams) ([]data.Note, error) {
|
||||
return nil, errors.New("db error")
|
||||
}
|
||||
},
|
||||
http.StatusInternalServerError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
mockStore := &mockNoteStore{}
|
||||
tc.mock(mockStore)
|
||||
|
||||
rs := notesResource{Notes: mockStore}
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/%s", tc.query), nil)
|
||||
req = req.WithContext(context.WithValue(
|
||||
req.Context(),
|
||||
userCtxKey{},
|
||||
&userClaims{RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: userID.String(),
|
||||
}},
|
||||
))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
rs.ListNotes(w, req)
|
||||
|
||||
assert.Equal(t, tc.statusCode, w.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotes_GetNote(t *testing.T) {
|
||||
noteID := uuid.New()
|
||||
userID := uuid.New()
|
||||
validNote := data.Note{ID: noteID, UserID: userID}
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
rs := notesResource{}
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req = req.WithContext(context.WithValue(
|
||||
req.Context(),
|
||||
noteCtxKey{},
|
||||
validNote,
|
||||
))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
rs.GetNote(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
var note data.Note
|
||||
json.Unmarshal(w.Body.Bytes(), ¬e)
|
||||
assert.Equal(t, validNote.ID, note.ID)
|
||||
})
|
||||
|
||||
t.Run("not found", func(t *testing.T) {
|
||||
rs := notesResource{}
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
rs.GetNote(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNotes_DeleteNote(t *testing.T) {
|
||||
noteID := uuid.New()
|
||||
userID := uuid.New()
|
||||
validNote := data.Note{ID: noteID, UserID: userID}
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
mockStore := &mockNoteStore{
|
||||
DeleteNoteFunc: func(_ context.Context, arg data.DeleteNoteParams) error {
|
||||
assert.Equal(t, noteID, arg.ID)
|
||||
assert.Equal(t, userID, arg.UserID)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
rs := notesResource{Notes: mockStore}
|
||||
req := httptest.NewRequest("DELETE", "/", nil)
|
||||
req = req.WithContext(context.WithValue(
|
||||
req.Context(),
|
||||
noteCtxKey{},
|
||||
validNote,
|
||||
))
|
||||
req = req.WithContext(context.WithValue(
|
||||
req.Context(),
|
||||
userCtxKey{},
|
||||
&userClaims{RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: userID.String(),
|
||||
}},
|
||||
))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
rs.DeleteNote(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusNoContent, w.Code)
|
||||
})
|
||||
|
||||
t.Run("database error", func(t *testing.T) {
|
||||
mockStore := &mockNoteStore{
|
||||
DeleteNoteFunc: func(context.Context, data.DeleteNoteParams) error {
|
||||
return errors.New("db error")
|
||||
},
|
||||
}
|
||||
|
||||
rs := notesResource{Notes: mockStore}
|
||||
req := httptest.NewRequest("DELETE", "/", nil)
|
||||
req = req.WithContext(context.WithValue(
|
||||
req.Context(),
|
||||
noteCtxKey{},
|
||||
validNote,
|
||||
))
|
||||
req = req.WithContext(context.WithValue(
|
||||
req.Context(),
|
||||
userCtxKey{},
|
||||
&userClaims{RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: userID.String(),
|
||||
}},
|
||||
))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
rs.DeleteNote(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNotes_CreateNoteVersion(t *testing.T) {
|
||||
noteID := uuid.New()
|
||||
validRequest := `{"title": "Test", "content": "Content"}`
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
mock func(*mockNoteStore)
|
||||
statusCode int
|
||||
}{
|
||||
{
|
||||
"success",
|
||||
validRequest,
|
||||
func(m *mockNoteStore) {
|
||||
m.FindDuplicateContentFunc = func(context.Context, data.FindDuplicateContentParams) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
m.CreateNoteVersionFunc = func(_ context.Context, arg data.CreateNoteVersionParams) (data.NoteVersion, error) {
|
||||
assert.Equal(t, noteID, arg.NoteID)
|
||||
return data.NoteVersion{}, nil
|
||||
}
|
||||
},
|
||||
http.StatusCreated,
|
||||
},
|
||||
{
|
||||
"duplicate content",
|
||||
validRequest,
|
||||
func(m *mockNoteStore) {
|
||||
m.FindDuplicateContentFunc = func(context.Context, data.FindDuplicateContentParams) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
},
|
||||
http.StatusConflict,
|
||||
},
|
||||
{
|
||||
"invalid request",
|
||||
"{invalid}",
|
||||
func(m *mockNoteStore) {},
|
||||
http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
"database error",
|
||||
validRequest,
|
||||
func(m *mockNoteStore) {
|
||||
m.FindDuplicateContentFunc = func(context.Context, data.FindDuplicateContentParams) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
m.CreateNoteVersionFunc = func(context.Context, data.CreateNoteVersionParams) (data.NoteVersion, error) {
|
||||
return data.NoteVersion{}, errors.New("db error")
|
||||
}
|
||||
},
|
||||
http.StatusInternalServerError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
mockStore := &mockNoteStore{}
|
||||
tc.mock(mockStore)
|
||||
|
||||
rs := notesResource{Notes: mockStore}
|
||||
req := httptest.NewRequest("POST", "/", strings.NewReader(tc.body))
|
||||
req = req.WithContext(context.WithValue(
|
||||
req.Context(),
|
||||
noteCtxKey{},
|
||||
data.Note{ID: noteID},
|
||||
))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
rs.CreateNoteVersion(w, req)
|
||||
|
||||
assert.Equal(t, tc.statusCode, w.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: add similar tests for `ListNoteVersions` and `GetNoteVersion`
|
@ -14,6 +14,8 @@ import (
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
type userCtxKey struct{}
|
||||
|
||||
// Mockable database operations interface
|
||||
type UserStore interface {
|
||||
CreateUser(ctx context.Context, arg data.CreateUserParams) (data.User, error)
|
||||
@ -124,9 +126,9 @@ func (rs usersResource) List(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// Output sanitization
|
||||
var output []map[string]interface{}
|
||||
var output []map[string]any
|
||||
for _, user := range users {
|
||||
output = append(output, map[string]interface{}{
|
||||
output = append(output, map[string]any{
|
||||
"id": user.ID,
|
||||
"username": user.Username,
|
||||
"created_at": user.CreatedAt,
|
||||
@ -144,7 +146,7 @@ func (rs usersResource) Get(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
respondJSON(w, http.StatusOK, map[string]interface{}{
|
||||
respondJSON(w, http.StatusOK, map[string]any{
|
||||
"id": user.ID,
|
||||
"username": user.Username,
|
||||
"created_at": user.CreatedAt,
|
||||
|
@ -154,7 +154,7 @@ func TestListUsers_Success(t *testing.T) {
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var response []map[string]interface{}
|
||||
var response []map[string]any
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -30,7 +30,7 @@ var (
|
||||
usernameRegex = regexp.MustCompile("^[a-z0-9_]+$")
|
||||
)
|
||||
|
||||
func respondJSON(w http.ResponseWriter, status int, data interface{}) {
|
||||
func respondJSON(w http.ResponseWriter, status int, data any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
json.NewEncoder(w).Encode(data)
|
||||
|
@ -89,9 +89,18 @@ func TestPasswordEntropyCalculation(t *testing.T) {
|
||||
password string
|
||||
entropy float64
|
||||
}{
|
||||
{"password", 37.6},
|
||||
{"SecurePassw0rd!123", 103.12},
|
||||
{"aaaaaaaaaaaaaaaa", 9.5},
|
||||
{
|
||||
"password",
|
||||
37.6,
|
||||
},
|
||||
{
|
||||
"SecurePassw0rd!123",
|
||||
103.12,
|
||||
},
|
||||
{
|
||||
"aaaaaaaaaaaaaaaa",
|
||||
9.5,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
Loading…
x
Reference in New Issue
Block a user