Compare commits
4 Commits
91daec42de
...
5de5c8c285
Author | SHA1 | Date | |
---|---|---|---|
5de5c8c285 | |||
18e650c898 | |||
a32bdef092 | |||
10bcdf88c7 |
@ -6,6 +6,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"git.umbrella.haus/ae/notatest/pkg/data"
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
"github.com/go-chi/chi/v5/middleware"
|
"github.com/go-chi/chi/v5/middleware"
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
@ -18,8 +19,6 @@ const (
|
|||||||
defaultLogMsg = "incoming request"
|
defaultLogMsg = "incoming request"
|
||||||
)
|
)
|
||||||
|
|
||||||
type userCtxKey struct{}
|
|
||||||
|
|
||||||
// Get JWT bearer from request's authorization header, parse it with custom user claims, and
|
// Get JWT bearer from request's authorization header, parse it with custom user claims, and
|
||||||
// ensure its validity before attaching the claims to the request's context.
|
// ensure its validity before attaching the claims to the request's context.
|
||||||
func authMiddleware(jwtSecret string, expectedType string) func(http.Handler) http.Handler {
|
func authMiddleware(jwtSecret string, expectedType string) func(http.Handler) http.Handler {
|
||||||
@ -30,7 +29,7 @@ func authMiddleware(jwtSecret string, expectedType string) func(http.Handler) ht
|
|||||||
respondError(w, http.StatusUnauthorized, fmt.Sprintf("Unauthorized: %s", err))
|
respondError(w, http.StatusUnauthorized, fmt.Sprintf("Unauthorized: %s", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := jwt.ParseWithClaims(tokenString, &userClaims{}, func(token *jwt.Token) (interface{}, error) {
|
token, err := jwt.ParseWithClaims(tokenString, &userClaims{}, func(token *jwt.Token) (any, error) {
|
||||||
return []byte(jwtSecret), nil
|
return []byte(jwtSecret), nil
|
||||||
})
|
})
|
||||||
if err != nil || !token.Valid {
|
if err != nil || !token.Valid {
|
||||||
@ -109,6 +108,46 @@ func userCtx(store UserStore) func(http.Handler) http.Handler {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Append note data into request's context based on note ID as a URL parameter and user ID as
|
||||||
|
// context parameter.
|
||||||
|
func noteCtx(store NoteStore) func(http.Handler) http.Handler {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
noteIDStr := chi.URLParam(r, "id")
|
||||||
|
noteID, err := uuid.Parse(noteIDStr)
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusNotFound, "Invalid note ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE: user must already be in the context (e.g. via JWT middleware)
|
||||||
|
user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
||||||
|
if !ok {
|
||||||
|
respondError(w, http.StatusUnauthorized, "Unauthorized")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userID, err := uuid.Parse(user.Subject)
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusInternalServerError, "Invalid user ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
note, err := store.GetNote(r.Context(), data.GetNoteParams{
|
||||||
|
ID: noteID,
|
||||||
|
UserID: userID,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusNotFound, "Note not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.WithValue(r.Context(), noteCtxKey{}, note)
|
||||||
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Zerolog compatible logger middleware. Automatically logs and recovers from errors with HTTP 500
|
// Zerolog compatible logger middleware. Automatically logs and recovers from errors with HTTP 500
|
||||||
// response, by default logs to INFO level.
|
// response, by default logs to INFO level.
|
||||||
func loggerMiddleware(log *zerolog.Logger) func(http.Handler) http.Handler {
|
func loggerMiddleware(log *zerolog.Logger) func(http.Handler) http.Handler {
|
||||||
@ -136,7 +175,7 @@ func loggerMiddleware(log *zerolog.Logger) func(http.Handler) http.Handler {
|
|||||||
log.Info().
|
log.Info().
|
||||||
Str("type", "access").
|
Str("type", "access").
|
||||||
Timestamp().
|
Timestamp().
|
||||||
Fields(map[string]interface{}{
|
Fields(map[string]any{
|
||||||
"remote_ip": r.RemoteAddr,
|
"remote_ip": r.RemoteAddr,
|
||||||
"url": r.URL.Path,
|
"url": r.URL.Path,
|
||||||
"proto": r.Proto,
|
"proto": r.Proto,
|
||||||
|
@ -232,9 +232,21 @@ func TestUserCtxMiddleware(t *testing.T) {
|
|||||||
urlID string
|
urlID string
|
||||||
statusCode int
|
statusCode int
|
||||||
}{
|
}{
|
||||||
{"valid ID", validUserID.String(), http.StatusOK},
|
{
|
||||||
{"invalid ID", invalidUserID, http.StatusNotFound},
|
"valid ID",
|
||||||
{"non existent ID", uuid.New().String(), http.StatusNotFound},
|
validUserID.String(),
|
||||||
|
http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"invalid ID",
|
||||||
|
invalidUserID,
|
||||||
|
http.StatusNotFound,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"non existent ID",
|
||||||
|
uuid.New().String(),
|
||||||
|
http.StatusNotFound,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
@ -257,6 +269,97 @@ func TestUserCtxMiddleware(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNoteCtxMiddleware(t *testing.T) {
|
||||||
|
userID := uuid.New()
|
||||||
|
noteID := uuid.New()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
noteID string
|
||||||
|
user any
|
||||||
|
mock func(*mockNoteStore)
|
||||||
|
statusCode int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"invalid note ID",
|
||||||
|
"invalid",
|
||||||
|
&userClaims{RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
Subject: userID.String(),
|
||||||
|
}},
|
||||||
|
func(m *mockNoteStore) {},
|
||||||
|
http.StatusNotFound,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"unauthorized user",
|
||||||
|
noteID.String(),
|
||||||
|
nil,
|
||||||
|
func(m *mockNoteStore) {},
|
||||||
|
http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"note not found",
|
||||||
|
noteID.String(),
|
||||||
|
&userClaims{RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
Subject: userID.String(),
|
||||||
|
}},
|
||||||
|
func(m *mockNoteStore) {
|
||||||
|
m.GetNoteFunc = func(ctx context.Context, arg data.GetNoteParams) (data.Note, error) {
|
||||||
|
return data.Note{}, errors.New("not found")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
http.StatusNotFound,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"success",
|
||||||
|
noteID.String(),
|
||||||
|
&userClaims{RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
Subject: userID.String(),
|
||||||
|
}},
|
||||||
|
func(m *mockNoteStore) {
|
||||||
|
m.GetNoteFunc = func(ctx context.Context, arg data.GetNoteParams) (data.Note, error) {
|
||||||
|
assert.Equal(t, noteID, arg.ID)
|
||||||
|
assert.Equal(t, userID, arg.UserID)
|
||||||
|
return data.Note{ID: noteID}, nil
|
||||||
|
}
|
||||||
|
},
|
||||||
|
http.StatusOK,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
mockStore := &mockNoteStore{}
|
||||||
|
tc.mock(mockStore)
|
||||||
|
|
||||||
|
handler := noteCtx(mockStore)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_, ok := r.Context().Value(noteCtxKey{}).(data.Note)
|
||||||
|
assert.True(t, ok)
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", fmt.Sprintf("/notes/%s", tc.noteID), nil)
|
||||||
|
|
||||||
|
// Chi router context mocks ID passed in a URL parameter
|
||||||
|
rctx := chi.NewRouteContext()
|
||||||
|
rctx.URLParams.Add("id", tc.noteID)
|
||||||
|
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
|
||||||
|
|
||||||
|
if tc.user != nil {
|
||||||
|
req = req.WithContext(context.WithValue(
|
||||||
|
req.Context(),
|
||||||
|
userCtxKey{},
|
||||||
|
tc.user,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.statusCode, w.Code)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func generateTestToken(t *testing.T, secret, tokenType, userID string, isAdmin bool, opts ...func(*userClaims)) string {
|
func generateTestToken(t *testing.T, secret, tokenType, userID string, isAdmin bool, opts ...func(*userClaims)) string {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
@ -1,19 +1,262 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"git.umbrella.haus/ae/notatest/pkg/data"
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type noteCtxKey struct{}
|
||||||
|
|
||||||
// Mockable database operations interface
|
// Mockable database operations interface
|
||||||
type NoteStore interface {
|
type NoteStore interface {
|
||||||
// TODO: implement
|
CreateNote(ctx context.Context, userID uuid.UUID) (data.Note, error)
|
||||||
|
DeleteNote(ctx context.Context, arg data.DeleteNoteParams) error
|
||||||
|
GetNote(ctx context.Context, arg data.GetNoteParams) (data.Note, error)
|
||||||
|
ListNotes(ctx context.Context, arg data.ListNotesParams) ([]data.Note, error)
|
||||||
|
CreateNoteVersion(ctx context.Context, arg data.CreateNoteVersionParams) (data.NoteVersion, error)
|
||||||
|
FindDuplicateContent(ctx context.Context, arg data.FindDuplicateContentParams) (bool, error)
|
||||||
|
GetNoteVersion(ctx context.Context, arg data.GetNoteVersionParams) (data.NoteVersion, error)
|
||||||
|
GetNoteVersions(ctx context.Context, arg data.GetNoteVersionsParams) ([]data.NoteVersion, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type notesResource struct {
|
type notesResource struct {
|
||||||
Notes NoteStore
|
JWTSecret string
|
||||||
|
Notes NoteStore
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rs notesResource) Routes() chi.Router {
|
func (rs notesResource) Routes() chi.Router {
|
||||||
r := chi.NewRouter()
|
r := chi.NewRouter()
|
||||||
|
|
||||||
|
r.Group(func(r chi.Router) {
|
||||||
|
r.Use(requireAccessToken(rs.JWTSecret))
|
||||||
|
|
||||||
|
r.Post("/", rs.CreateNote) // POST /notes - note creation
|
||||||
|
r.Get("/", rs.ListNotes) // GET /notes - get all notes
|
||||||
|
|
||||||
|
r.Route("/{id}", func(r chi.Router) {
|
||||||
|
r.Use(noteCtx(rs.Notes))
|
||||||
|
|
||||||
|
r.Get("/", rs.GetNote) // GET /notes/{id} - get specific note
|
||||||
|
r.Delete("/", rs.DeleteNote) // DELETE /notes/{id} - delete specific note
|
||||||
|
|
||||||
|
r.Route("/versions", func(r chi.Router) {
|
||||||
|
r.Post("/", rs.CreateNoteVersion) // POST /notes/{id}/versions - create new version
|
||||||
|
r.Get("/", rs.ListNoteVersions) // GET /notes/{id}/versions - get all existing versions
|
||||||
|
r.Get("/{version}", rs.GetNoteVersion) // GET /notes/{id}/versions/{version} - get specific version
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (rs *notesResource) CreateNote(w http.ResponseWriter, r *http.Request) {
|
||||||
|
user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
||||||
|
if !ok {
|
||||||
|
respondError(w, http.StatusUnauthorized, "Unauthorized")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userID, err := uuid.Parse(user.Subject)
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusInternalServerError, "Invalid user ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
note, err := rs.Notes.CreateNote(r.Context(), userID)
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusInternalServerError, "Failed to create note")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
respondJSON(w, http.StatusCreated, note)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rs *notesResource) ListNotes(w http.ResponseWriter, r *http.Request) {
|
||||||
|
user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
||||||
|
if !ok {
|
||||||
|
respondError(w, http.StatusUnauthorized, "Unauthorized")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userID, err := uuid.Parse(user.Subject)
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusInternalServerError, "Invalid user ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
limit, offset := getPaginationParams(r)
|
||||||
|
|
||||||
|
notes, err := rs.Notes.ListNotes(r.Context(), data.ListNotesParams{
|
||||||
|
UserID: userID,
|
||||||
|
Limit: limit,
|
||||||
|
Offset: offset,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusInternalServerError, "Failed to retrieve notes")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
respondJSON(w, http.StatusOK, notes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rs *notesResource) GetNote(w http.ResponseWriter, r *http.Request) {
|
||||||
|
note, ok := r.Context().Value(noteCtxKey{}).(data.Note)
|
||||||
|
if !ok {
|
||||||
|
respondError(w, http.StatusNotFound, "Note not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
respondJSON(w, http.StatusOK, note)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rs *notesResource) DeleteNote(w http.ResponseWriter, r *http.Request) {
|
||||||
|
note, ok := r.Context().Value(noteCtxKey{}).(data.Note)
|
||||||
|
if !ok {
|
||||||
|
respondError(w, http.StatusNotFound, "Note not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
||||||
|
if !ok {
|
||||||
|
respondError(w, http.StatusUnauthorized, "Unauthorized")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userID, err := uuid.Parse(user.Subject)
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusInternalServerError, "Invalid user ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = rs.Notes.DeleteNote(r.Context(), data.DeleteNoteParams{
|
||||||
|
ID: note.ID,
|
||||||
|
UserID: userID,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusInternalServerError, "Failed to delete note")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rs *notesResource) CreateNoteVersion(w http.ResponseWriter, r *http.Request) {
|
||||||
|
note, ok := r.Context().Value(noteCtxKey{}).(data.Note)
|
||||||
|
if !ok {
|
||||||
|
respondError(w, http.StatusNotFound, "Note not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req struct {
|
||||||
|
Title string `json:"title"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
respondError(w, http.StatusBadRequest, "Invalid request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// De-duplication check
|
||||||
|
duplicate, err := rs.Notes.FindDuplicateContent(r.Context(), data.FindDuplicateContentParams{
|
||||||
|
NoteID: note.ID,
|
||||||
|
Column2: []byte(req.Title),
|
||||||
|
Column3: []byte(req.Content),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusInternalServerError, "Failed to check for duplicate content")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if duplicate {
|
||||||
|
respondError(w, http.StatusConflict, "Duplicate content detected")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
version, err := rs.Notes.CreateNoteVersion(r.Context(), data.CreateNoteVersionParams{
|
||||||
|
NoteID: note.ID,
|
||||||
|
Title: req.Title,
|
||||||
|
Content: req.Content,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusInternalServerError, "Failed to create note version")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
respondJSON(w, http.StatusCreated, version)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rs *notesResource) ListNoteVersions(w http.ResponseWriter, r *http.Request) {
|
||||||
|
note, ok := r.Context().Value(noteCtxKey{}).(data.Note)
|
||||||
|
if !ok {
|
||||||
|
respondError(w, http.StatusNotFound, "Note not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
limit, offset := getPaginationParams(r)
|
||||||
|
|
||||||
|
versions, err := rs.Notes.GetNoteVersions(r.Context(), data.GetNoteVersionsParams{
|
||||||
|
NoteID: note.ID,
|
||||||
|
Limit: limit,
|
||||||
|
Offset: offset,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusInternalServerError, "Failed to retrieve versions")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
respondJSON(w, http.StatusOK, versions)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rs *notesResource) GetNoteVersion(w http.ResponseWriter, r *http.Request) {
|
||||||
|
note, ok := r.Context().Value(noteCtxKey{}).(data.Note)
|
||||||
|
if !ok {
|
||||||
|
respondError(w, http.StatusNotFound, "Note not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
versionStr := chi.URLParam(r, "version")
|
||||||
|
versionNumber, err := strconv.ParseInt(versionStr, 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusBadRequest, "Invalid version number")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
version, err := rs.Notes.GetNoteVersion(r.Context(), data.GetNoteVersionParams{
|
||||||
|
NoteID: note.ID,
|
||||||
|
VersionNumber: int32(versionNumber),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusNotFound, "Version not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
respondJSON(w, http.StatusOK, version)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getPaginationParams(r *http.Request) (limit int32, offset int32) {
|
||||||
|
defaultLimit := 50
|
||||||
|
defaultOffset := 0
|
||||||
|
|
||||||
|
limitStr := r.URL.Query().Get("limit")
|
||||||
|
if limitStr != "" {
|
||||||
|
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 {
|
||||||
|
defaultLimit = l
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
offsetStr := r.URL.Query().Get("offset")
|
||||||
|
if offsetStr != "" {
|
||||||
|
if o, err := strconv.Atoi(offsetStr); err == nil && o >= 0 {
|
||||||
|
defaultOffset = o
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return int32(defaultLimit), int32(defaultOffset)
|
||||||
|
}
|
||||||
|
364
server/pkg/service/notes_test.go
Normal file
364
server/pkg/service/notes_test.go
Normal file
@ -0,0 +1,364 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.umbrella.haus/ae/notatest/pkg/data"
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockNoteStore struct {
|
||||||
|
CreateNoteFunc func(context.Context, uuid.UUID) (data.Note, error)
|
||||||
|
DeleteNoteFunc func(context.Context, data.DeleteNoteParams) error
|
||||||
|
GetNoteFunc func(context.Context, data.GetNoteParams) (data.Note, error)
|
||||||
|
ListNotesFunc func(context.Context, data.ListNotesParams) ([]data.Note, error)
|
||||||
|
CreateNoteVersionFunc func(context.Context, data.CreateNoteVersionParams) (data.NoteVersion, error)
|
||||||
|
FindDuplicateContentFunc func(context.Context, data.FindDuplicateContentParams) (bool, error)
|
||||||
|
GetNoteVersionFunc func(context.Context, data.GetNoteVersionParams) (data.NoteVersion, error)
|
||||||
|
GetNoteVersionsFunc func(context.Context, data.GetNoteVersionsParams) ([]data.NoteVersion, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockNoteStore) CreateNote(ctx context.Context, userID uuid.UUID) (data.Note, error) {
|
||||||
|
return m.CreateNoteFunc(ctx, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockNoteStore) DeleteNote(ctx context.Context, arg data.DeleteNoteParams) error {
|
||||||
|
return m.DeleteNoteFunc(ctx, arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockNoteStore) GetNote(ctx context.Context, arg data.GetNoteParams) (data.Note, error) {
|
||||||
|
return m.GetNoteFunc(ctx, arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockNoteStore) ListNotes(ctx context.Context, arg data.ListNotesParams) ([]data.Note, error) {
|
||||||
|
return m.ListNotesFunc(ctx, arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockNoteStore) CreateNoteVersion(ctx context.Context, arg data.CreateNoteVersionParams) (data.NoteVersion, error) {
|
||||||
|
return m.CreateNoteVersionFunc(ctx, arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockNoteStore) FindDuplicateContent(ctx context.Context, arg data.FindDuplicateContentParams) (bool, error) {
|
||||||
|
return m.FindDuplicateContentFunc(ctx, arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockNoteStore) GetNoteVersion(ctx context.Context, arg data.GetNoteVersionParams) (data.NoteVersion, error) {
|
||||||
|
return m.GetNoteVersionFunc(ctx, arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockNoteStore) GetNoteVersions(ctx context.Context, arg data.GetNoteVersionsParams) ([]data.NoteVersion, error) {
|
||||||
|
return m.GetNoteVersionsFunc(ctx, arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNotes_CreateNote(t *testing.T) {
|
||||||
|
userID := uuid.New()
|
||||||
|
testNote := data.Note{ID: uuid.New(), UserID: userID}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
mock func(*mockNoteStore)
|
||||||
|
statusCode int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"success",
|
||||||
|
func(m *mockNoteStore) {
|
||||||
|
m.CreateNoteFunc = func(_ context.Context, uid uuid.UUID) (data.Note, error) {
|
||||||
|
assert.Equal(t, userID, uid)
|
||||||
|
return testNote, nil
|
||||||
|
}
|
||||||
|
},
|
||||||
|
http.StatusCreated,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"database error",
|
||||||
|
func(m *mockNoteStore) {
|
||||||
|
m.CreateNoteFunc = func(context.Context, uuid.UUID) (data.Note, error) {
|
||||||
|
return data.Note{}, errors.New("db error")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
http.StatusInternalServerError,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
mockStore := &mockNoteStore{}
|
||||||
|
tc.mock(mockStore)
|
||||||
|
|
||||||
|
rs := notesResource{Notes: mockStore}
|
||||||
|
req := httptest.NewRequest("POST", "/", nil)
|
||||||
|
req = req.WithContext(context.WithValue(
|
||||||
|
req.Context(),
|
||||||
|
userCtxKey{},
|
||||||
|
&userClaims{RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
Subject: userID.String(),
|
||||||
|
}},
|
||||||
|
))
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
rs.CreateNote(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.statusCode, w.Code)
|
||||||
|
if tc.statusCode == http.StatusCreated {
|
||||||
|
var note data.Note
|
||||||
|
json.Unmarshal(w.Body.Bytes(), ¬e)
|
||||||
|
assert.Equal(t, testNote.ID, note.ID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNotes_ListNotes(t *testing.T) {
|
||||||
|
userID := uuid.New()
|
||||||
|
notes := []data.Note{
|
||||||
|
{ID: uuid.New(), UserID: userID},
|
||||||
|
{ID: uuid.New(), UserID: userID},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
query string
|
||||||
|
mock func(*mockNoteStore)
|
||||||
|
statusCode int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"success",
|
||||||
|
"",
|
||||||
|
func(m *mockNoteStore) {
|
||||||
|
m.ListNotesFunc = func(_ context.Context, arg data.ListNotesParams) ([]data.Note, error) {
|
||||||
|
assert.Equal(t, userID, arg.UserID)
|
||||||
|
return notes, nil
|
||||||
|
}
|
||||||
|
},
|
||||||
|
http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"with pagination",
|
||||||
|
"?limit=10&offset=20",
|
||||||
|
func(m *mockNoteStore) {
|
||||||
|
m.ListNotesFunc = func(_ context.Context, arg data.ListNotesParams) ([]data.Note, error) {
|
||||||
|
assert.EqualValues(t, 10, arg.Limit)
|
||||||
|
assert.EqualValues(t, 20, arg.Offset)
|
||||||
|
return notes, nil
|
||||||
|
}
|
||||||
|
},
|
||||||
|
http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"database error",
|
||||||
|
"",
|
||||||
|
func(m *mockNoteStore) {
|
||||||
|
m.ListNotesFunc = func(context.Context, data.ListNotesParams) ([]data.Note, error) {
|
||||||
|
return nil, errors.New("db error")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
http.StatusInternalServerError,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
mockStore := &mockNoteStore{}
|
||||||
|
tc.mock(mockStore)
|
||||||
|
|
||||||
|
rs := notesResource{Notes: mockStore}
|
||||||
|
req := httptest.NewRequest("GET", fmt.Sprintf("/%s", tc.query), nil)
|
||||||
|
req = req.WithContext(context.WithValue(
|
||||||
|
req.Context(),
|
||||||
|
userCtxKey{},
|
||||||
|
&userClaims{RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
Subject: userID.String(),
|
||||||
|
}},
|
||||||
|
))
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
rs.ListNotes(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.statusCode, w.Code)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNotes_GetNote(t *testing.T) {
|
||||||
|
noteID := uuid.New()
|
||||||
|
userID := uuid.New()
|
||||||
|
validNote := data.Note{ID: noteID, UserID: userID}
|
||||||
|
|
||||||
|
t.Run("success", func(t *testing.T) {
|
||||||
|
rs := notesResource{}
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
req = req.WithContext(context.WithValue(
|
||||||
|
req.Context(),
|
||||||
|
noteCtxKey{},
|
||||||
|
validNote,
|
||||||
|
))
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
rs.GetNote(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
var note data.Note
|
||||||
|
json.Unmarshal(w.Body.Bytes(), ¬e)
|
||||||
|
assert.Equal(t, validNote.ID, note.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("not found", func(t *testing.T) {
|
||||||
|
rs := notesResource{}
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
rs.GetNote(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNotes_DeleteNote(t *testing.T) {
|
||||||
|
noteID := uuid.New()
|
||||||
|
userID := uuid.New()
|
||||||
|
validNote := data.Note{ID: noteID, UserID: userID}
|
||||||
|
|
||||||
|
t.Run("success", func(t *testing.T) {
|
||||||
|
mockStore := &mockNoteStore{
|
||||||
|
DeleteNoteFunc: func(_ context.Context, arg data.DeleteNoteParams) error {
|
||||||
|
assert.Equal(t, noteID, arg.ID)
|
||||||
|
assert.Equal(t, userID, arg.UserID)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
rs := notesResource{Notes: mockStore}
|
||||||
|
req := httptest.NewRequest("DELETE", "/", nil)
|
||||||
|
req = req.WithContext(context.WithValue(
|
||||||
|
req.Context(),
|
||||||
|
noteCtxKey{},
|
||||||
|
validNote,
|
||||||
|
))
|
||||||
|
req = req.WithContext(context.WithValue(
|
||||||
|
req.Context(),
|
||||||
|
userCtxKey{},
|
||||||
|
&userClaims{RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
Subject: userID.String(),
|
||||||
|
}},
|
||||||
|
))
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
rs.DeleteNote(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusNoContent, w.Code)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("database error", func(t *testing.T) {
|
||||||
|
mockStore := &mockNoteStore{
|
||||||
|
DeleteNoteFunc: func(context.Context, data.DeleteNoteParams) error {
|
||||||
|
return errors.New("db error")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
rs := notesResource{Notes: mockStore}
|
||||||
|
req := httptest.NewRequest("DELETE", "/", nil)
|
||||||
|
req = req.WithContext(context.WithValue(
|
||||||
|
req.Context(),
|
||||||
|
noteCtxKey{},
|
||||||
|
validNote,
|
||||||
|
))
|
||||||
|
req = req.WithContext(context.WithValue(
|
||||||
|
req.Context(),
|
||||||
|
userCtxKey{},
|
||||||
|
&userClaims{RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
Subject: userID.String(),
|
||||||
|
}},
|
||||||
|
))
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
rs.DeleteNote(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNotes_CreateNoteVersion(t *testing.T) {
|
||||||
|
noteID := uuid.New()
|
||||||
|
validRequest := `{"title": "Test", "content": "Content"}`
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
body string
|
||||||
|
mock func(*mockNoteStore)
|
||||||
|
statusCode int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"success",
|
||||||
|
validRequest,
|
||||||
|
func(m *mockNoteStore) {
|
||||||
|
m.FindDuplicateContentFunc = func(context.Context, data.FindDuplicateContentParams) (bool, error) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
m.CreateNoteVersionFunc = func(_ context.Context, arg data.CreateNoteVersionParams) (data.NoteVersion, error) {
|
||||||
|
assert.Equal(t, noteID, arg.NoteID)
|
||||||
|
return data.NoteVersion{}, nil
|
||||||
|
}
|
||||||
|
},
|
||||||
|
http.StatusCreated,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"duplicate content",
|
||||||
|
validRequest,
|
||||||
|
func(m *mockNoteStore) {
|
||||||
|
m.FindDuplicateContentFunc = func(context.Context, data.FindDuplicateContentParams) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
},
|
||||||
|
http.StatusConflict,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"invalid request",
|
||||||
|
"{invalid}",
|
||||||
|
func(m *mockNoteStore) {},
|
||||||
|
http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"database error",
|
||||||
|
validRequest,
|
||||||
|
func(m *mockNoteStore) {
|
||||||
|
m.FindDuplicateContentFunc = func(context.Context, data.FindDuplicateContentParams) (bool, error) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
m.CreateNoteVersionFunc = func(context.Context, data.CreateNoteVersionParams) (data.NoteVersion, error) {
|
||||||
|
return data.NoteVersion{}, errors.New("db error")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
http.StatusInternalServerError,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
mockStore := &mockNoteStore{}
|
||||||
|
tc.mock(mockStore)
|
||||||
|
|
||||||
|
rs := notesResource{Notes: mockStore}
|
||||||
|
req := httptest.NewRequest("POST", "/", strings.NewReader(tc.body))
|
||||||
|
req = req.WithContext(context.WithValue(
|
||||||
|
req.Context(),
|
||||||
|
noteCtxKey{},
|
||||||
|
data.Note{ID: noteID},
|
||||||
|
))
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
rs.CreateNoteVersion(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.statusCode, w.Code)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: add similar tests for `ListNoteVersions` and `GetNoteVersion`
|
@ -14,6 +14,8 @@ import (
|
|||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type userCtxKey struct{}
|
||||||
|
|
||||||
// Mockable database operations interface
|
// Mockable database operations interface
|
||||||
type UserStore interface {
|
type UserStore interface {
|
||||||
CreateUser(ctx context.Context, arg data.CreateUserParams) (data.User, error)
|
CreateUser(ctx context.Context, arg data.CreateUserParams) (data.User, error)
|
||||||
@ -124,9 +126,9 @@ func (rs usersResource) List(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Output sanitization
|
// Output sanitization
|
||||||
var output []map[string]interface{}
|
var output []map[string]any
|
||||||
for _, user := range users {
|
for _, user := range users {
|
||||||
output = append(output, map[string]interface{}{
|
output = append(output, map[string]any{
|
||||||
"id": user.ID,
|
"id": user.ID,
|
||||||
"username": user.Username,
|
"username": user.Username,
|
||||||
"created_at": user.CreatedAt,
|
"created_at": user.CreatedAt,
|
||||||
@ -144,7 +146,7 @@ func (rs usersResource) Get(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
respondJSON(w, http.StatusOK, map[string]interface{}{
|
respondJSON(w, http.StatusOK, map[string]any{
|
||||||
"id": user.ID,
|
"id": user.ID,
|
||||||
"username": user.Username,
|
"username": user.Username,
|
||||||
"created_at": user.CreatedAt,
|
"created_at": user.CreatedAt,
|
||||||
|
@ -154,7 +154,7 @@ func TestListUsers_Success(t *testing.T) {
|
|||||||
|
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
var response []map[string]interface{}
|
var response []map[string]any
|
||||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -30,7 +30,7 @@ var (
|
|||||||
usernameRegex = regexp.MustCompile("^[a-z0-9_]+$")
|
usernameRegex = regexp.MustCompile("^[a-z0-9_]+$")
|
||||||
)
|
)
|
||||||
|
|
||||||
func respondJSON(w http.ResponseWriter, status int, data interface{}) {
|
func respondJSON(w http.ResponseWriter, status int, data any) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(status)
|
w.WriteHeader(status)
|
||||||
json.NewEncoder(w).Encode(data)
|
json.NewEncoder(w).Encode(data)
|
||||||
|
@ -89,9 +89,18 @@ func TestPasswordEntropyCalculation(t *testing.T) {
|
|||||||
password string
|
password string
|
||||||
entropy float64
|
entropy float64
|
||||||
}{
|
}{
|
||||||
{"password", 37.6},
|
{
|
||||||
{"SecurePassw0rd!123", 103.12},
|
"password",
|
||||||
{"aaaaaaaaaaaaaaaa", 9.5},
|
37.6,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SecurePassw0rd!123",
|
||||||
|
103.12,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"aaaaaaaaaaaaaaaa",
|
||||||
|
9.5,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user