From 62b1a58e564e5df0b03735e2c382d146168a6f64 Mon Sep 17 00:00:00 2001 From: ae Date: Wed, 9 Apr 2025 01:58:38 +0300 Subject: [PATCH] feat!: trimming & logic/schema improvements - build: somewhat polished dockerization setup - build: io/fs migrations with `golang-migrate` - feat: automatic init. admin account creation (.env creds) - feat(routers): combined user & token routers into single auth router - feat(routers): improved route layouts (`Routes`) - feat(middlewares): removed redundant `userCtx` middleware - fix(schema): note <-> note_versions relation (versioning) - feat(queries): removed redundant rollback functionality - feat(queries): combined duplicate version check & insertion/creation - tests: decreased redundancy by removing 'unnecessary' unit tests - refactor: hid internal packages behind `server/internal` - docs: notes & auth handler comments --- server/Dockerfile | 19 + server/go.mod | 9 +- server/go.sum | 61 +- server/{pkg => internal}/data/db.go | 0 server/{pkg => internal}/data/models.go | 15 +- server/internal/data/note_versions.sql.go | 156 +++++ server/internal/data/notes.sql.go | 143 ++++ .../data/refresh_tokens.sql.go | 0 server/{pkg => internal}/data/users.sql.go | 57 ++ server/internal/service/auth.go | 658 ++++++++++++++++++ .../{pkg => internal}/service/middleware.go | 117 ++-- server/internal/service/middleware_test.go | 519 ++++++++++++++ server/internal/service/notes.go | 331 +++++++++ server/{pkg => internal}/service/service.go | 27 +- server/{pkg => internal}/service/util.go | 4 +- server/{pkg => internal}/service/util_test.go | 0 server/main.go | 78 ++- server/pkg/data/note_versions.sql.go | 132 ---- server/pkg/data/notes.sql.go | 105 --- server/pkg/migrate/migrate.go | 74 -- server/pkg/service/middleware_test.go | 386 ---------- server/pkg/service/notes.go | 262 ------- server/pkg/service/notes_test.go | 509 -------------- server/pkg/service/tokens.go | 249 ------- server/pkg/service/tokens_test.go | 211 ------ server/pkg/service/users.go | 356 ---------- server/pkg/service/users_test.go | 557 --------------- server/sql/migrations/0001_initial.up.sql | 20 +- server/sql/queries/note_versions.sql | 73 +- server/sql/queries/notes.sql | 32 +- server/sql/queries/users.sql | 9 + server/sql/sqlc.yaml | 2 +- 32 files changed, 2184 insertions(+), 2987 deletions(-) rename server/{pkg => internal}/data/db.go (100%) rename server/{pkg => internal}/data/models.go (78%) create mode 100644 server/internal/data/note_versions.sql.go create mode 100644 server/internal/data/notes.sql.go rename server/{pkg => internal}/data/refresh_tokens.sql.go (100%) rename server/{pkg => internal}/data/users.sql.go (69%) create mode 100644 server/internal/service/auth.go rename server/{pkg => internal}/service/middleware.go (59%) create mode 100644 server/internal/service/middleware_test.go create mode 100644 server/internal/service/notes.go rename server/{pkg => internal}/service/service.go (55%) rename server/{pkg => internal}/service/util.go (97%) rename server/{pkg => internal}/service/util_test.go (100%) delete mode 100644 server/pkg/data/note_versions.sql.go delete mode 100644 server/pkg/data/notes.sql.go delete mode 100644 server/pkg/migrate/migrate.go delete mode 100644 server/pkg/service/middleware_test.go delete mode 100644 server/pkg/service/notes.go delete mode 100644 server/pkg/service/notes_test.go delete mode 100644 server/pkg/service/tokens.go delete mode 100644 server/pkg/service/tokens_test.go delete mode 100644 server/pkg/service/users.go delete mode 100644 server/pkg/service/users_test.go diff --git a/server/Dockerfile b/server/Dockerfile index e69de29..cbc2f48 100644 --- a/server/Dockerfile +++ b/server/Dockerfile @@ -0,0 +1,19 @@ +# Build stage +FROM golang:1.24.2-alpine AS builder + +WORKDIR /app +COPY go.mod go.sum . +RUN go mod download + +COPY . . +# Optionally we could also strip debug symbols with `-ldflags '-s'` +RUN CGO_ENABLED=0 GOOS=linux go build -o /notatest . + +# Final stage (optimized image size) +FROM alpine:latest + +WORKDIR /app +COPY --from=builder /notatest /app/notatest +EXPOSE 8080 + +CMD ["/app/notatest"] diff --git a/server/go.mod b/server/go.mod index 38651b5..1b0e23a 100644 --- a/server/go.mod +++ b/server/go.mod @@ -3,9 +3,10 @@ module git.umbrella.haus/ae/notatest go 1.24.1 require ( - github.com/caarlos0/env v3.5.0+incompatible + github.com/caarlos0/env/v10 v10.0.0 github.com/go-chi/chi/v5 v5.2.1 github.com/golang-jwt/jwt/v5 v5.2.2 + github.com/golang-migrate/migrate/v4 v4.18.2 github.com/google/uuid v1.6.0 github.com/jackc/pgx/v5 v5.7.4 github.com/rs/zerolog v1.34.0 @@ -16,14 +17,16 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/hashicorp/errwrap v1.1.0 // indirect + github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect - github.com/kr/text v0.2.0 // indirect + github.com/lib/pq v1.10.9 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/rogpeppe/go-internal v1.14.1 // indirect github.com/stretchr/objx v0.5.2 // indirect + go.uber.org/atomic v1.11.0 // indirect golang.org/x/sys v0.31.0 // indirect golang.org/x/text v0.23.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/server/go.sum b/server/go.sum index 675d802..edabdcc 100644 --- a/server/go.sum +++ b/server/go.sum @@ -1,17 +1,45 @@ -github.com/caarlos0/env v3.5.0+incompatible h1:Yy0UN8o9Wtr/jGHZDpCBLpNrzcFLLM2yixi/rBrKyJs= -github.com/caarlos0/env v3.5.0+incompatible/go.mod h1:tdCsowwCzMLdkqRYDlHpZCp2UooDD3MspDBjZ2AD02Y= +github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0= +github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= +github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= +github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= +github.com/caarlos0/env/v10 v10.0.0 h1:yIHUBZGsyqCnpTkbjk8asUlx6RFhhEs+h7TOBdgdzXA= +github.com/caarlos0/env/v10 v10.0.0/go.mod h1:ZfulV76NvVPw3tm591U4SwL3Xx9ldzBP9aGxzeN7G18= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dhui/dktest v0.4.4 h1:+I4s6JRE1yGuqflzwqG+aIaMdgXIorCf5P98JnaAWa8= +github.com/dhui/dktest v0.4.4/go.mod h1:4+22R4lgsdAXrDyaH4Nqx2JEz2hLp49MqQmm9HLCQhM= +github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= +github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= +github.com/docker/docker v27.2.0+incompatible h1:Rk9nIVdfH3+Vz4cyI/uhbINhEZ/oLmc+CBXmH6fbNk4= +github.com/docker/docker v27.2.0+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= +github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c= +github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc= +github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= +github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= +github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= +github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/go-chi/chi/v5 v5.2.1 h1:KOIHODQj58PmL80G2Eak4WdvUzjSJSm0vG72crDCqb8= github.com/go-chi/chi/v5 v5.2.1/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops= +github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= +github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= +github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/golang-migrate/migrate/v4 v4.18.2 h1:2VSCMz7x7mjyTXx3m2zPokOY82LTRgxK1yQYKo6wWQ8= +github.com/golang-migrate/migrate/v4 v4.18.2/go.mod h1:2CM6tJvn2kqPXwnXO/d3rAQYiyoIm180VsO8PRX6Rpk= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= +github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= +github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= @@ -24,6 +52,8 @@ github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= @@ -31,11 +61,22 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= +github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo= +github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0= +github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y= +github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= +github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= +github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= +github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= +github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug= +github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= -github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= +github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= @@ -48,6 +89,16 @@ github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOf github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/wagslane/go-password-validator v0.3.0 h1:vfxOPzGHkz5S146HDpavl0cw1DSVP061Ry2PX0/ON6I= github.com/wagslane/go-password-validator v0.3.0/go.mod h1:TI1XJ6T5fRdRnHqHt14pvy1tNVnrwe7m3/f1f2fDphQ= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 h1:TT4fX+nBOA/+LUkobKGW1ydGcn+G3vRw9+g5HwCphpk= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0/go.mod h1:L7UH0GbB0p47T4Rri3uHjbpCFYrVrwc1I25QhNPiGK8= +go.opentelemetry.io/otel v1.29.0 h1:PdomN/Al4q/lN6iBJEN3AwPvUiHPMlt93c8bqTG5Llw= +go.opentelemetry.io/otel v1.29.0/go.mod h1:N/WtXPs1CNCUEx+Agz5uouwCba+i+bJGFicT8SR4NP8= +go.opentelemetry.io/otel/metric v1.29.0 h1:vPf/HFWTNkPu1aYeIsc98l4ktOQaL6LeSoeV2g+8YLc= +go.opentelemetry.io/otel/metric v1.29.0/go.mod h1:auu/QWieFVWx+DmQOUMgj0F8LHWdgalxXqvp7BII/W8= +go.opentelemetry.io/otel/trace v1.29.0 h1:J/8ZNK4XgR7a21DZUAsbF8pZ5Jcw1VhACmnYt39JTi4= +go.opentelemetry.io/otel/trace v1.29.0/go.mod h1:eHl3w0sp3paPkYstJOmAimxhiFXPg+MMTlEh3nsQgWQ= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= diff --git a/server/pkg/data/db.go b/server/internal/data/db.go similarity index 100% rename from server/pkg/data/db.go rename to server/internal/data/db.go diff --git a/server/pkg/data/models.go b/server/internal/data/models.go similarity index 78% rename from server/pkg/data/models.go rename to server/internal/data/models.go index 1cabac4..17f5b93 100644 --- a/server/pkg/data/models.go +++ b/server/internal/data/models.go @@ -11,10 +11,12 @@ import ( ) type Note struct { - ID uuid.UUID `json:"id"` - UserID uuid.UUID `json:"user_id"` - CreatedAt *time.Time `json:"created_at"` - UpdatedAt *time.Time `json:"updated_at"` + ID uuid.UUID `json:"id"` + UserID uuid.UUID `json:"user_id"` + CurrentVersion int32 `json:"current_version"` + LatestVersion int32 `json:"latest_version"` + CreatedAt *time.Time `json:"created_at"` + UpdatedAt *time.Time `json:"updated_at"` } type NoteVersion struct { @@ -36,11 +38,6 @@ type RefreshToken struct { Revoked bool `json:"revoked"` } -type SchemaMigration struct { - Version int64 `json:"version"` - AppliedAt *time.Time `json:"applied_at"` -} - type User struct { ID uuid.UUID `json:"id"` Username string `json:"username"` diff --git a/server/internal/data/note_versions.sql.go b/server/internal/data/note_versions.sql.go new file mode 100644 index 0000000..7f36ea8 --- /dev/null +++ b/server/internal/data/note_versions.sql.go @@ -0,0 +1,156 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.28.0 +// source: note_versions.sql + +package data + +import ( + "context" + "time" + + "github.com/google/uuid" +) + +const createNoteVersion = `-- name: CreateNoteVersion :exec +WITH potential_duplicate AS ( + SELECT version_number + FROM note_versions + WHERE + note_id = $1 + AND content_hash = $2 + ORDER BY version_number DESC + LIMIT 1 +), +note_update AS ( + UPDATE notes + SET + current_version = COALESCE( + (SELECT version_number FROM potential_duplicate), + latest_version + 1 -- increment only if we don't jump into a historical version + ), + latest_version = CASE + WHEN (SELECT version_number FROM potential_duplicate) IS NULL + THEN latest_version + 1 + ELSE latest_version + END, + updated_at = NOW() + WHERE id = $1 + RETURNING current_version, latest_version +) +INSERT INTO note_versions ( + note_id, title, content, version_number, content_hash +) +SELECT + $1, -- note_id + $3, -- title + $4, -- content + current_version, + $2 -- content_hash +FROM note_update +WHERE NOT EXISTS (SELECT 1 FROM potential_duplicate) +` + +type CreateNoteVersionParams struct { + NoteID uuid.UUID `json:"note_id"` + ContentHash string `json:"content_hash"` + Title string `json:"title"` + Content string `json:"content"` +} + +func (q *Queries) CreateNoteVersion(ctx context.Context, arg CreateNoteVersionParams) error { + _, err := q.db.Exec(ctx, createNoteVersion, + arg.NoteID, + arg.ContentHash, + arg.Title, + arg.Content, + ) + return err +} + +const getVersion = `-- name: GetVersion :one +SELECT + id AS version_id, + title, + content, + version_number, + created_at +FROM note_versions +WHERE note_id = $1 AND id = $2 +` + +type GetVersionParams struct { + NoteID uuid.UUID `json:"note_id"` + ID uuid.UUID `json:"id"` +} + +type GetVersionRow struct { + VersionID uuid.UUID `json:"version_id"` + Title string `json:"title"` + Content string `json:"content"` + VersionNumber int32 `json:"version_number"` + CreatedAt *time.Time `json:"created_at"` +} + +func (q *Queries) GetVersion(ctx context.Context, arg GetVersionParams) (GetVersionRow, error) { + row := q.db.QueryRow(ctx, getVersion, arg.NoteID, arg.ID) + var i GetVersionRow + err := row.Scan( + &i.VersionID, + &i.Title, + &i.Content, + &i.VersionNumber, + &i.CreatedAt, + ) + return i, err +} + +const getVersionHistory = `-- name: GetVersionHistory :many +SELECT + id AS version_id, + title, + version_number, + created_at +FROM note_versions +WHERE note_id = $1 +ORDER BY version_number DESC +LIMIT $2 OFFSET $3 +` + +type GetVersionHistoryParams struct { + NoteID uuid.UUID `json:"note_id"` + Limit int32 `json:"limit"` + Offset int32 `json:"offset"` +} + +type GetVersionHistoryRow struct { + VersionID uuid.UUID `json:"version_id"` + Title string `json:"title"` + VersionNumber int32 `json:"version_number"` + CreatedAt *time.Time `json:"created_at"` +} + +func (q *Queries) GetVersionHistory(ctx context.Context, arg GetVersionHistoryParams) ([]GetVersionHistoryRow, error) { + rows, err := q.db.Query(ctx, getVersionHistory, arg.NoteID, arg.Limit, arg.Offset) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetVersionHistoryRow + for rows.Next() { + var i GetVersionHistoryRow + if err := rows.Scan( + &i.VersionID, + &i.Title, + &i.VersionNumber, + &i.CreatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/server/internal/data/notes.sql.go b/server/internal/data/notes.sql.go new file mode 100644 index 0000000..c5a444b --- /dev/null +++ b/server/internal/data/notes.sql.go @@ -0,0 +1,143 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.28.0 +// source: notes.sql + +package data + +import ( + "context" + "time" + + "github.com/google/uuid" +) + +const createNote = `-- name: CreateNote :one +INSERT INTO notes (user_id) +VALUES ($1) +RETURNING id, user_id, current_version, latest_version, created_at, updated_at +` + +func (q *Queries) CreateNote(ctx context.Context, userID uuid.UUID) (Note, error) { + row := q.db.QueryRow(ctx, createNote, userID) + var i Note + err := row.Scan( + &i.ID, + &i.UserID, + &i.CurrentVersion, + &i.LatestVersion, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const deleteNote = `-- name: DeleteNote :exec +DELETE FROM notes +WHERE id = $1 AND user_id = $2 +` + +type DeleteNoteParams struct { + ID uuid.UUID `json:"id"` + UserID uuid.UUID `json:"user_id"` +} + +func (q *Queries) DeleteNote(ctx context.Context, arg DeleteNoteParams) error { + _, err := q.db.Exec(ctx, deleteNote, arg.ID, arg.UserID) + return err +} + +const getFullNote = `-- name: GetFullNote :one +SELECT + n.id AS note_id, + n.user_id AS owner_id, + nv.title, + nv.content, + nv.version_number, + nv.created_at AS version_created_at, + n.created_at AS note_created_at, + n.updated_at AS note_updated_at +FROM notes n +JOIN note_versions nv + ON n.id = nv.note_id AND n.current_version = nv.version_number +WHERE n.id = $1 +` + +type GetFullNoteRow struct { + NoteID uuid.UUID `json:"note_id"` + OwnerID uuid.UUID `json:"owner_id"` + Title string `json:"title"` + Content string `json:"content"` + VersionNumber int32 `json:"version_number"` + VersionCreatedAt *time.Time `json:"version_created_at"` + NoteCreatedAt *time.Time `json:"note_created_at"` + NoteUpdatedAt *time.Time `json:"note_updated_at"` +} + +func (q *Queries) GetFullNote(ctx context.Context, id uuid.UUID) (GetFullNoteRow, error) { + row := q.db.QueryRow(ctx, getFullNote, id) + var i GetFullNoteRow + err := row.Scan( + &i.NoteID, + &i.OwnerID, + &i.Title, + &i.Content, + &i.VersionNumber, + &i.VersionCreatedAt, + &i.NoteCreatedAt, + &i.NoteUpdatedAt, + ) + return i, err +} + +const listNotes = `-- name: ListNotes :many +SELECT + n.id AS note_id, + n.user_id AS owner_id, + nv.title, + n.updated_at +FROM notes n +JOIN note_versions nv + ON n.id = nv.note_id AND n.current_version = nv.version_number +WHERE n.user_id = $1 +ORDER BY n.updated_at DESC +LIMIT $2 OFFSET $3 +` + +type ListNotesParams struct { + UserID uuid.UUID `json:"user_id"` + Limit int32 `json:"limit"` + Offset int32 `json:"offset"` +} + +type ListNotesRow struct { + NoteID uuid.UUID `json:"note_id"` + OwnerID uuid.UUID `json:"owner_id"` + Title string `json:"title"` + UpdatedAt *time.Time `json:"updated_at"` +} + +func (q *Queries) ListNotes(ctx context.Context, arg ListNotesParams) ([]ListNotesRow, error) { + rows, err := q.db.Query(ctx, listNotes, arg.UserID, arg.Limit, arg.Offset) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListNotesRow + for rows.Next() { + var i ListNotesRow + if err := rows.Scan( + &i.NoteID, + &i.OwnerID, + &i.Title, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/server/pkg/data/refresh_tokens.sql.go b/server/internal/data/refresh_tokens.sql.go similarity index 100% rename from server/pkg/data/refresh_tokens.sql.go rename to server/internal/data/refresh_tokens.sql.go diff --git a/server/pkg/data/users.sql.go b/server/internal/data/users.sql.go similarity index 69% rename from server/pkg/data/users.sql.go rename to server/internal/data/users.sql.go index 6dbb0d3..2b2a3d5 100644 --- a/server/pkg/data/users.sql.go +++ b/server/internal/data/users.sql.go @@ -11,6 +11,31 @@ import ( "github.com/google/uuid" ) +const createAdmin = `-- name: CreateAdmin :one +INSERT INTO users (username, password_hash, is_admin) +VALUES ($1, $2, true) +RETURNING id, username, password_hash, is_admin, created_at, updated_at +` + +type CreateAdminParams struct { + Username string `json:"username"` + PasswordHash string `json:"password_hash"` +} + +func (q *Queries) CreateAdmin(ctx context.Context, arg CreateAdminParams) (User, error) { + row := q.db.QueryRow(ctx, createAdmin, arg.Username, arg.PasswordHash) + var i User + err := row.Scan( + &i.ID, + &i.Username, + &i.PasswordHash, + &i.IsAdmin, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + const createUser = `-- name: CreateUser :one INSERT INTO users (username, password_hash) VALUES ($1, $2) @@ -84,6 +109,38 @@ func (q *Queries) GetUserByUsername(ctx context.Context, username string) (User, return i, err } +const listAdmins = `-- name: ListAdmins :many +SELECT id, username, password_hash, is_admin, created_at, updated_at FROM users +WHERE is_admin = true +` + +func (q *Queries) ListAdmins(ctx context.Context) ([]User, error) { + rows, err := q.db.Query(ctx, listAdmins) + if err != nil { + return nil, err + } + defer rows.Close() + var items []User + for rows.Next() { + var i User + if err := rows.Scan( + &i.ID, + &i.Username, + &i.PasswordHash, + &i.IsAdmin, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const listUsers = `-- name: ListUsers :many SELECT id, username, password_hash, is_admin, created_at, updated_at FROM users ` diff --git a/server/internal/service/auth.go b/server/internal/service/auth.go new file mode 100644 index 0000000..d44246b --- /dev/null +++ b/server/internal/service/auth.go @@ -0,0 +1,658 @@ +package service + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "net/http" + "strconv" + "strings" + "time" + + "git.umbrella.haus/ae/notatest/internal/data" + "github.com/go-chi/chi/v5" + "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" + "github.com/jackc/pgx/v5/pgconn" + "github.com/rs/zerolog/log" + "golang.org/x/crypto/bcrypt" +) + +const ( + accessTokenDuration = 15 * time.Minute + refreshTokenDuration = 7 * 24 * time.Hour +) + +var ( + ErrInvalidToken = errors.New("invalid token") + ErrAuthHeaderInvalid = errors.New("token couldn't be parsed from authentication header") +) + +// User object context key for incoming requests (handled by middlewares). Only `*userClaims` type +// objects should be stored behind this key for consistency. +type userCtxKey struct{} + +// DTO without sensitive data fields such as user's password hash. +type userResponse struct { + ID uuid.UUID `json:"id"` + Username string `json:"username"` + IsAdmin bool `json:"is_admin"` + CreatedAt *time.Time `json:"created_at"` + UpdatedAt *time.Time `json:"updated_at"` +} + +// Custom JWT claims (should always be handled in the middleware layer). +type userClaims struct { + Admin bool `json:"admin"` + TokenType string `json:"type"` // "access" or "refresh" + jwt.RegisteredClaims // User's UUID should be stored in the subject claim +} + +type tokenPair struct { + AccessToken string + RefreshToken string +} + +// Mockable token related database operations interface. +type TokenStore interface { + CreateRefreshToken(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) + GetRefreshTokenByHash(ctx context.Context, tokenHash string) (data.RefreshToken, error) + RevokeRefreshToken(ctx context.Context, tokenHash string) error + RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error +} + +// Mockable user related database operations interface. +type UserStore interface { + CreateUser(ctx context.Context, arg data.CreateUserParams) (data.User, error) + ListUsers(ctx context.Context) ([]data.User, error) + GetUserByID(ctx context.Context, id uuid.UUID) (data.User, error) + GetUserByUsername(ctx context.Context, username string) (data.User, error) + UpdatePassword(ctx context.Context, arg data.UpdatePasswordParams) error + DeleteUser(ctx context.Context, id uuid.UUID) error + RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error +} + +// Chi HTTP router for authentication/authorization related actions (users/tokens). In theory +// (especially in production) the `UserStore` and `TokenStore` will point to the same database +// handler, but for code readability they should be kept in separate structs. +type authResource struct { + JWTSecret string + Users UserStore + Tokens TokenStore +} + +func (rs authResource) Routes() chi.Router { + r := chi.NewRouter() + + // Public routes + r.Post("/signup", rs.Create) // POST /auth/signup - registration + r.Post("/login", rs.Login) // POST /auth/login - login + + // Protected routes (access token required) + r.Group(func(r chi.Router) { + r.Use(requireAccessToken(rs.JWTSecret)) // JWT claims -> ctx. + r.Get("/me", rs.Get) // GET /auth/me - current user data + r.Post("/logout", rs.Logout) // POST /auth/logout - revoke all refresh cookies + + // Owner routes + r.Route("/owner", func(r chi.Router) { + r.Put("/", rs.UpdatePassword) // PUT /auth/owner - update user password + r.Delete("/", rs.OwnerDelete) // DELETE /auth/owner - delete user (owner) + }) + + // Administration routes (admin claim required) + r.Route("/admin", func(r chi.Router) { + r.Use(adminOnlyMiddleware) + r.Get("/all", rs.List) // GET /auth/admin/all - list all users + r.Route(fmt.Sprintf("/{%s}", targetUserUUIDCtxParameter), func(r chi.Router) { + r.Use(uuidCtx(targetUserUUIDCtxParameter)) + r.Delete("/", rs.AdminDelete) // DELETE /auth/admin/{id} - delete user (admin) + }) + }) + }) + + // Protected routes (refresh token required) + r.Group(func(r chi.Router) { + r.Use(requireRefreshToken(rs.JWTSecret)) + r.Post("/refresh", rs.RefreshAccessToken) // POST /auth/refresh - convert refresh token to new token pair + }) + + return r +} + +// Handler for new user creation. Will check the incoming JSON object's integrity, validate/normalize +// the username, and validate the password (check whether it's compromised via the HIBP API and +// calculate its entropy). +func (rs authResource) Create(w http.ResponseWriter, r *http.Request) { + type request struct { + Username *string `json:"username"` + Password *string `json:"password"` + } + + var req request + if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.Username == nil || req.Password == nil { + respondError(w, http.StatusBadRequest, "Invalid request body") + return + } + + // Username normalization (to lowercase) + normalizedUsername := normalizeUsername(*req.Username) + if err := validateUsername(normalizedUsername); err != nil { + respondError(w, http.StatusBadRequest, err.Error()) + return + } + + // Password validation (length, HIBP API, and entropy) + if err := validatePassword(*req.Password); err != nil { + respondError(w, http.StatusBadRequest, err.Error()) + return + } + + hashedPassword, err := bcrypt.GenerateFromPassword([]byte(*req.Password), bcrypt.DefaultCost) + if err != nil { + respondError(w, http.StatusInternalServerError, "Failed to create user") + return + } + + user, err := rs.Users.CreateUser(r.Context(), data.CreateUserParams{ + Username: normalizedUsername, + PasswordHash: string(hashedPassword), + }) + if err != nil { + if isDuplicateEntry(err) { + respondError(w, http.StatusConflict, "Username is already in use") + } else { + respondError(w, http.StatusInternalServerError, "Failed to create user") + } + return + } + + respondJSON(w, http.StatusCreated, map[string]string{ + "id": user.ID.String(), + "username": user.Username, + }) +} + +// Handler for logging in to an existing user account using a username-password credentials pair. +// Will check the incoming JSON object's integrity, use a normalized version of the username for a +// database lookup, and compare the given password's hash against the one stored in the database. +// By default only returns a fresh access tokens (and a refresh token as a httpOnly cookie), but +// if the `includeUser` parameter is set to `true` the user DTO will also be included. +func (rs authResource) Login(w http.ResponseWriter, r *http.Request) { + type request struct { + Username *string `json:"username"` + Password *string `json:"password"` + } + + var req request + if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.Username == nil || req.Password == nil { + respondError(w, http.StatusBadRequest, "Invalid request body") + return + } + + user, err := rs.Users.GetUserByUsername(r.Context(), normalizeUsername(*req.Username)) + if err != nil { + respondError(w, http.StatusUnauthorized, "Invalid credentials") + return + } + + if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(*req.Password)); err != nil { + respondError(w, http.StatusUnauthorized, "Invalid credentials") + return + } + + // Generate a new access/refresh token pair and store a SHA256 hash of the refresh token into + // the database for further token rotations. + tokenPair, err := rs.GenerateTokenPair(r.Context(), user.ID, user.IsAdmin) + if err != nil { + respondError(w, http.StatusInternalServerError, "Failed to generate tokens") + return + } + + // Set refresh token into a httpOnly cookie + http.SetCookie(w, &http.Cookie{ + Name: "refresh_token", + Value: tokenPair.RefreshToken, + Path: "/", + MaxAge: int(refreshTokenDuration.Seconds()), + HttpOnly: true, + Secure: true, + SameSite: http.SameSiteStrictMode, + }) + + // Build response + response := map[string]any{ + "access_token": tokenPair.AccessToken, + } + + // Include user data DTO into the response if the `includeUser` parameter was set to `true` + if includeUser, _ := strconv.ParseBool(r.URL.Query().Get("includeUser")); includeUser { + response["user"] = userResponse{ + ID: user.ID, + Username: user.Username, + IsAdmin: user.IsAdmin, + CreatedAt: user.CreatedAt, + UpdatedAt: user.UpdatedAt, + } + } + + respondJSON(w, http.StatusOK, response) +} + +// Handler for getting full data of the current user as a DTO (database lookup) based on the JWT +// claims set into the request's context by a middleware. +func (rs authResource) Get(w http.ResponseWriter, r *http.Request) { + user := rs.userFromCtxClaims(w, r) + respondJSON(w, http.StatusOK, userResponse{ + ID: user.ID, + Username: user.Username, + IsAdmin: user.IsAdmin, + CreatedAt: user.CreatedAt, + UpdatedAt: user.UpdatedAt, + }) +} + +// Handler for updating the current user's password. Performs the same password strength checks as +// the registration handler (`rs.Create`) and revokes any existing refresh tokens the user has +// stored in the database. +func (rs authResource) UpdatePassword(w http.ResponseWriter, r *http.Request) { + type request struct { + OldPassword string `json:"old_password"` + NewPassword string `json:"new_password"` + } + + var req request + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + respondError(w, http.StatusBadRequest, "Invalid request body") + return + } + + user := rs.userFromCtxClaims(w, r) + + // Verify the old password before proceeding with the update + if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.OldPassword)); err != nil { + respondError(w, http.StatusUnauthorized, "Invalid credentials") + return + } + + if err := validatePassword(req.NewPassword); err != nil { + respondError(w, http.StatusBadRequest, err.Error()) + return + } + + hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.NewPassword), bcrypt.DefaultCost) + if err != nil { + respondError(w, http.StatusInternalServerError, "Failed to update password") + return + } + + if err := rs.Users.UpdatePassword(r.Context(), data.UpdatePasswordParams{ + ID: user.ID, + PasswordHash: string(hashedPassword), + }); err != nil { + respondError(w, http.StatusInternalServerError, "Failed to update password") + return + } + + if err := rs.Users.RevokeAllUserRefreshTokens(r.Context(), user.ID); err != nil { + log.Error().Msgf("Failed to revoke refresh tokens: %s", err) + } + + w.WriteHeader(http.StatusNoContent) +} + +// Handler for hard deleting the current user. Requires the user's password as JSON input as a precaution. +func (rs authResource) OwnerDelete(w http.ResponseWriter, r *http.Request) { + type request struct { + Password string `json:"password"` + } + + var req request + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + respondError(w, http.StatusBadRequest, "Invalid request body") + return + } + + user := rs.userFromCtxClaims(w, r) + + // Verify the old password before allowing the deletion + if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password)); err != nil { + respondError(w, http.StatusUnauthorized, "Invalid credentials") + return + } + + err := rs.Users.DeleteUser(r.Context(), user.ID) + if err != nil { + respondError(w, http.StatusInternalServerError, "Failed to delete user") + return + } + + if err := rs.Users.RevokeAllUserRefreshTokens(r.Context(), user.ID); err != nil { + log.Error().Msgf("Failed to revoke refresh tokens: %s", err) + } + + w.WriteHeader(http.StatusNoContent) +} + +// Handler for listing all users stored in the database. Should only be allowed to be called by +// administrator level users. +func (rs authResource) List(w http.ResponseWriter, r *http.Request) { + users, err := rs.Users.ListUsers(r.Context()) + if err != nil { + respondError(w, http.StatusInternalServerError, "Failed to retrieve users") + return + } + + // Output sanitization to DTO + var output []userResponse + for _, user := range users { + output = append(output, userResponse{ + ID: user.ID, + Username: user.Username, + IsAdmin: user.IsAdmin, + CreatedAt: user.CreatedAt, + UpdatedAt: user.UpdatedAt, + }) + } + + respondJSON(w, http.StatusOK, output) +} + +// Handler for deleting another user account based on their ID. Will check the existence of the +// user based on the given ID and additionally revoke all the stored refresh tokens on successful +// deletion. Should only be allowed to be called by administrator level users. +func (rs authResource) AdminDelete(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + targetID, ok := ctx.Value(uuidCtxKey{Name: targetUserUUIDCtxParameter}).(uuid.UUID) + if !ok { + respondError(w, http.StatusBadRequest, "Resource ID missing") + return + } + + if err := rs.Users.DeleteUser(r.Context(), targetID); err != nil { + respondError(w, http.StatusInternalServerError, "Failed to delete user") + return + } + + if err := rs.Users.RevokeAllUserRefreshTokens(r.Context(), targetID); err != nil { + log.Error().Msgf("Failed to revoke refresh tokens: %s", err) + } + + w.WriteHeader(http.StatusNoContent) +} + +// Generate a new pair of access and refresh tokens (JWTs) with the user's UUID as the identifying +// `Subject` claim and custom claims for the user's administrator status (boolean) and token type +// ("refresh"/"access"). Stores a SHA256 hash of the refresh token into the database for further +// token rotations. +func (rs authResource) GenerateTokenPair(ctx context.Context, userID uuid.UUID, isAdmin bool) (*tokenPair, error) { + tokenPair, err := generateTokenPair(userID.String(), isAdmin, rs.JWTSecret) + if err != nil { + return nil, err + } + + hash := sha256.Sum256([]byte(tokenPair.RefreshToken)) + tokenHash := hex.EncodeToString(hash[:]) + + // Store the SHA256 hash of the refresh token with (almost) identical expiration timestamp + expiresAt := time.Now().Add(refreshTokenDuration) + _, err = rs.Tokens.CreateRefreshToken(ctx, data.CreateRefreshTokenParams{ + UserID: userID, + TokenHash: tokenHash, + ExpiresAt: expiresAt, + }) + if err != nil { + return nil, err + } + + return tokenPair, nil +} + +// Revoke the given token from the database by calculating its SHA256 hash. +func (rs authResource) RevokeRefreshToken(ctx context.Context, token string) error { + hash := sha256.Sum256([]byte(token)) + tokenHash := hex.EncodeToString(hash[:]) + return rs.Tokens.RevokeRefreshToken(ctx, tokenHash) +} + +// Validate the given refresh token by performing a database lookup with its SHA256 hash. Returns +// the refresh token database object on successful lookup. Fails if the token has been revoked +// (soft database operation) or expired (i.e. the corresponding user has to log in again). +func (rs authResource) ValidateRefreshToken(ctx context.Context, token string) (*data.RefreshToken, error) { + hash := sha256.Sum256([]byte(token)) + tokenHash := hex.EncodeToString(hash[:]) + + dbToken, err := rs.Tokens.GetRefreshTokenByHash(ctx, tokenHash) + if err != nil { + return nil, err + } + + // Check for soft revocation and/or expiration + if dbToken.Revoked || time.Now().After(dbToken.ExpiresAt) { + return nil, ErrInvalidToken + } + + return &dbToken, nil +} + +// Handler for performing a token rotation, i.e. invalidating the given refresh token (each refresh +// token is a single use utility) and exchanging it for a new pair of refresh and access tokens. +func (rs authResource) RefreshAccessToken(w http.ResponseWriter, r *http.Request) { + // Get claims from context + claims, ok := r.Context().Value(userCtxKey{}).(*userClaims) + if !ok || claims.TokenType != "refresh" { + respondError(w, http.StatusUnauthorized, "Invalid token") + return + } + + // Attempt to get the token from Authorization header (formatted as "Bearer ") + refreshToken, err := getTokenFromRequest(r) + if err != nil { + respondError(w, http.StatusUnauthorized, "Unauthorized") + return + } + + // Validate the refresh token in the database + if _, err := rs.ValidateRefreshToken(r.Context(), refreshToken); err != nil { + respondError(w, http.StatusUnauthorized, "Invalid refresh token") + return + } + + // Revoke the given (single use) refresh token + if err := rs.RevokeRefreshToken(r.Context(), refreshToken); err != nil { + respondError(w, http.StatusInternalServerError, "Failed to revoke token") + return + } + + userID, err := uuid.Parse(claims.Subject) + if err != nil { + respondError(w, http.StatusBadRequest, "Invalid user ID") + return + } + + // Generate a new pair (access & refresh tokens) + tokenPair, err := rs.GenerateTokenPair(r.Context(), userID, claims.Admin) + if err != nil { + respondError(w, http.StatusInternalServerError, "Failed to generate tokens") + return + } + + // Set refresh token into a httpOnly cookie + http.SetCookie(w, &http.Cookie{ + Name: "refresh_token", + Value: tokenPair.RefreshToken, + Path: "/", + MaxAge: int(refreshTokenDuration.Seconds()), + HttpOnly: true, + Secure: true, + SameSite: http.SameSiteStrictMode, + }) + + // Return the access token in the response body (it should be stored in browser's memory client-side) + respondJSON(w, http.StatusOK, map[string]string{ + "access_token": tokenPair.AccessToken, + }) +} + +// Handler for performing a logout process for the current user, i.e. replacing the current +// httpOnly `refresh_token` cookie with one that expires immediately. Theoretically the user +// will still be able to authenticate until the access token (stored client-side) expires, +// but that's up to the client to handle. +func (rs authResource) Logout(w http.ResponseWriter, r *http.Request) { + claims, ok := r.Context().Value(userCtxKey{}).(*userClaims) + if !ok { + respondError(w, http.StatusUnauthorized, "Not authenticated") + return + } + + userID, err := uuid.Parse(claims.Subject) + if err != nil { + respondError(w, http.StatusBadRequest, "Invalid user ID") + return + } + + // Clear the refresh token cookie + http.SetCookie(w, &http.Cookie{ + Name: "refresh_token", + Value: "", + Path: "/", + MaxAge: 0, // Expires immediately + HttpOnly: true, + Secure: true, + SameSite: http.SameSiteStrictMode, + }) + + if err := rs.Tokens.RevokeAllUserRefreshTokens(r.Context(), userID); err != nil { + respondError(w, http.StatusInternalServerError, "Failed to logout") + return + } + + w.WriteHeader(http.StatusNoContent) +} + +// Helper function for generating the initial administrator level account if one doesn't already +// exists in the database. +func CreateAdminIfNotExists(ctx context.Context, q *data.Queries, username, password string) error { + admins, err := q.ListAdmins(ctx) + if err != nil { + return err + } + if len(admins) > 0 { + log.Debug().Msg("Admin accounts already exist, skipping creation") + return nil + } + + // Username normalization (to lowercase) + normalizedUsername := normalizeUsername(username) + if err := validateUsername(normalizedUsername); err != nil { + return err + } + + // Password validation (length, HIBP API, and entropy) + if err := validatePassword(password); err != nil { + return err + } + + hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return err + } + + _, err = q.CreateAdmin(ctx, data.CreateAdminParams{ + Username: normalizedUsername, + PasswordHash: string(hashedPassword), + }) + if err != nil { + return err + } + + log.Info().Msgf("Initial admin user '%s' created successfully", username) + + return nil +} + +// Parse the JWT bearer token from the request's Authorization header. +func getTokenFromRequest(r *http.Request) (string, error) { + bearerToken := r.Header.Get("Authorization") + bearerFields := strings.Fields(bearerToken) + if len(bearerFields) == 2 && strings.ToLower(bearerFields[0]) == "bearer" { + return bearerFields[1], nil + } + return "", ErrAuthHeaderInvalid +} + +// Helper function for generating a new JWT token pair with the given specifications. +func generateTokenPair(userID string, isAdmin bool, jwtSecret string) (*tokenPair, error) { + atClaims := userClaims{ + Admin: isAdmin, + TokenType: "access", + RegisteredClaims: jwt.RegisteredClaims{ + Subject: userID, + ExpiresAt: jwt.NewNumericDate(time.Now().Add(accessTokenDuration)), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now()), + }, + } + accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, atClaims) + + t, err := accessToken.SignedString([]byte(jwtSecret)) + if err != nil { + return nil, err + } + + rtClaims := userClaims{ + Admin: isAdmin, + TokenType: "refresh", + RegisteredClaims: jwt.RegisteredClaims{ + Subject: userID, + ExpiresAt: jwt.NewNumericDate(time.Now().Add(refreshTokenDuration)), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now()), + }, + } + refreshToken := jwt.NewWithClaims(jwt.SigningMethodHS256, rtClaims) + + rt, err := refreshToken.SignedString([]byte(jwtSecret)) + if err != nil { + return nil, err + } + + return &tokenPair{AccessToken: t, RefreshToken: rt}, nil +} + +// Parse JWT claims (`userClaims`) from the request's context, and perform a database lookup based +// on `Subject` (after parsing it to `uuid.UUID`) to fetch the corresponding user's data. +func (rs authResource) userFromCtxClaims(w http.ResponseWriter, r *http.Request) *data.User { + claims, ok := r.Context().Value(userCtxKey{}).(*userClaims) + if !ok { + respondError(w, http.StatusUnauthorized, "Unauthorized") + return nil + } + + userID, err := uuid.Parse(claims.Subject) + if err != nil { + respondError(w, http.StatusBadRequest, "Invalid user ID") + return nil + } + + user, err := rs.Users.GetUserByID(r.Context(), userID) + if err != nil { + respondError(w, http.StatusNotFound, "User not found") + return nil + } + + return &user +} + +// Check if the given error is a PostgreSQL error for `unique_violation` (error code 23505), i.e. +// whether an entry with the given details already exists in the database table. +func isDuplicateEntry(err error) bool { + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) { + return pgErr.Code == "23505" + } + return false +} diff --git a/server/pkg/service/middleware.go b/server/internal/service/middleware.go similarity index 59% rename from server/pkg/service/middleware.go rename to server/internal/service/middleware.go index fd85d5c..c01fbe4 100644 --- a/server/pkg/service/middleware.go +++ b/server/internal/service/middleware.go @@ -2,11 +2,10 @@ package service import ( "context" - "fmt" "net/http" "time" - "git.umbrella.haus/ae/notatest/pkg/data" + "git.umbrella.haus/ae/notatest/internal/data" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/golang-jwt/jwt/v5" @@ -17,8 +16,17 @@ import ( const ( panicRecoveryMsg = "panic recovered" defaultLogMsg = "incoming request" + + noteUUIDCtxParameter = "noteID" + versionUUIDCtxParameter = "versionID" + targetUserUUIDCtxParameter = "targetID" ) +// General resource ID (UUID) context key. +type uuidCtxKey struct { + Name string +} + // Get JWT bearer from request's authorization header, parse it with custom user claims, and // ensure its validity before attaching the claims to the request's context. func authMiddleware(jwtSecret string, expectedType string) func(http.Handler) http.Handler { @@ -26,7 +34,8 @@ func authMiddleware(jwtSecret string, expectedType string) func(http.Handler) ht return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tokenString, err := getTokenFromRequest(r) if err != nil { - respondError(w, http.StatusUnauthorized, fmt.Sprintf("Unauthorized: %s", err)) + respondError(w, http.StatusUnauthorized, "Unauthorized") + return } token, err := jwt.ParseWithClaims(tokenString, &userClaims{}, func(token *jwt.Token) (any, error) { @@ -59,7 +68,8 @@ func requireRefreshToken(jwtSecret string) func(http.Handler) http.Handler { return authMiddleware(jwtSecret, "refresh") } -// Ensure the current user is an administrator. +// Ensure the current user is an administrator. Can be used to protect routes that can be utilized +// to view/modify/delete accounts that the current user isn't the owner of. func adminOnlyMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { user, ok := r.Context().Value(userCtxKey{}).(*userClaims) @@ -71,57 +81,38 @@ func adminOnlyMiddleware(next http.Handler) http.Handler { }) } -// Ensure the targeted resource is owned by the current user (i.e. current user's ID matches with -// the one stored into the resource). -func ownerOnlyMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - user, ok := r.Context().Value(userCtxKey{}).(*userClaims) - requestedID := chi.URLParam(r, "id") - if !ok || user.Subject != requestedID { - respondError(w, http.StatusForbidden, "Forbidden") - return - } - next.ServeHTTP(w, r) - }) -} - -// Append user data into request's context based on user ID as a URL parameter. -func userCtx(store UserStore) func(http.Handler) http.Handler { +// Append UUID from the given URL parameter to the request's context (`uuidCtxKey` with the +// parameter name as the "context identifier"). +func uuidCtx(parameter string) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - userIDStr := chi.URLParam(r, "id") - userID, err := uuid.Parse(userIDStr) + uuidParam := chi.URLParam(r, parameter) + resourceID, err := uuid.Parse(uuidParam) if err != nil { - respondError(w, http.StatusNotFound, "Invalid user ID") + respondError(w, http.StatusBadRequest, "Invalid resource ID") return } - user, err := store.GetUserByID(r.Context(), userID) - if err != nil { - respondError(w, http.StatusNotFound, "User not found") - return - } - - ctx := context.WithValue(r.Context(), userCtxKey{}, user) + ctx := context.WithValue(r.Context(), uuidCtxKey{Name: parameter}, resourceID) next.ServeHTTP(w, r.WithContext(ctx)) }) } } -// Append note data into request's context based on note ID as a URL parameter and user ID as -// context parameter. +// Append full note data (metadata + active version) into request's context based on note ID as a +// URL parameter and user ID as context parameter. Must be chained with `uuidCtx` to parse the +// resource ID into the request's context. func noteCtx(store NoteStore) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - noteIDStr := chi.URLParam(r, "id") - noteID, err := uuid.Parse(noteIDStr) - if err != nil { - respondError(w, http.StatusNotFound, "Invalid note ID") + ctx := r.Context() + noteID, ok := ctx.Value(uuidCtxKey{Name: noteUUIDCtxParameter}).(uuid.UUID) + if !ok { + respondError(w, http.StatusBadRequest, "Resource ID missing") return } - // NOTE: user must already be in the context (e.g. via JWT middleware) - user, ok := r.Context().Value(userCtxKey{}).(*userClaims) + user, ok := ctx.Value(userCtxKey{}).(*userClaims) if !ok { respondError(w, http.StatusUnauthorized, "Unauthorized") return @@ -129,20 +120,58 @@ func noteCtx(store NoteStore) func(http.Handler) http.Handler { userID, err := uuid.Parse(user.Subject) if err != nil { - respondError(w, http.StatusInternalServerError, "Invalid user ID") + respondError(w, http.StatusUnauthorized, "Invalid token") return } - note, err := store.GetNote(r.Context(), data.GetNoteParams{ - ID: noteID, - UserID: userID, - }) + // Get the "full note" (metadata + active version) with a single query + fullNote, err := store.GetFullNote(r.Context(), noteID) if err != nil { respondError(w, http.StatusNotFound, "Note not found") return } - ctx := context.WithValue(r.Context(), noteCtxKey{}, note) + // Validate note ownership + if userID != fullNote.OwnerID { + respondError(w, http.StatusForbidden, "Forbidden") + return + } + + ctx = context.WithValue(r.Context(), noteCtxKey{}, &fullNote) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + +// Append single version's data into request's context based on version ID as a URL parameter and +// note ID as context parameter. Must be chained with `noteCtx` and `uuidCtx` to parse the necessary +// resource IDs into request's context. +func versionCtx(store NoteStore) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + fullNote, ok := ctx.Value(noteCtxKey{}).(*data.GetFullNoteRow) + if !ok { + respondError(w, http.StatusNotFound, "Note not found") + return + } + + versionID, ok := ctx.Value(uuidCtxKey{Name: versionUUIDCtxParameter}).(uuid.UUID) + if !ok { + respondError(w, http.StatusBadRequest, "Resource ID missing") + return + } + + version, err := store.GetVersion(r.Context(), data.GetVersionParams{ + NoteID: fullNote.NoteID, + ID: versionID, + }) + if err != nil { + respondError(w, http.StatusNotFound, "Version not found") + return + } + + ctx = context.WithValue(r.Context(), versionCtxKey{}, &version) next.ServeHTTP(w, r.WithContext(ctx)) }) } diff --git a/server/internal/service/middleware_test.go b/server/internal/service/middleware_test.go new file mode 100644 index 0000000..ce11f82 --- /dev/null +++ b/server/internal/service/middleware_test.go @@ -0,0 +1,519 @@ +package service + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "git.umbrella.haus/ae/notatest/internal/data" + "github.com/go-chi/chi/v5" + "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" +) + +type mockNoteStore struct { + CreateNoteFunc func(context.Context, uuid.UUID) (data.Note, error) + DeleteNoteFunc func(context.Context, data.DeleteNoteParams) error + GetFullNoteFunc func(context.Context, uuid.UUID) (data.GetFullNoteRow, error) + ListNotesFunc func(context.Context, data.ListNotesParams) ([]data.ListNotesRow, error) + CreateNoteVersionFunc func(context.Context, data.CreateNoteVersionParams) error + GetVersionFunc func(context.Context, data.GetVersionParams) (data.GetVersionRow, error) + GetVersionHistoryFunc func(context.Context, data.GetVersionHistoryParams) ([]data.GetVersionHistoryRow, error) +} + +func (m *mockNoteStore) CreateNote(ctx context.Context, id uuid.UUID) (data.Note, error) { + return m.CreateNoteFunc(ctx, id) +} + +func (m *mockNoteStore) DeleteNote(ctx context.Context, arg data.DeleteNoteParams) error { + return m.DeleteNoteFunc(ctx, arg) +} + +func (m *mockNoteStore) GetFullNote(ctx context.Context, id uuid.UUID) (data.GetFullNoteRow, error) { + return m.GetFullNoteFunc(ctx, id) +} + +func (m *mockNoteStore) ListNotes(ctx context.Context, arg data.ListNotesParams) ([]data.ListNotesRow, error) { + return m.ListNotesFunc(ctx, arg) +} + +func (m *mockNoteStore) CreateNoteVersion(ctx context.Context, arg data.CreateNoteVersionParams) error { + return m.CreateNoteVersionFunc(ctx, arg) +} + +func (m *mockNoteStore) GetVersion(ctx context.Context, arg data.GetVersionParams) (data.GetVersionRow, error) { + return m.GetVersionFunc(ctx, arg) +} + +func (m *mockNoteStore) GetVersionHistory(ctx context.Context, arg data.GetVersionHistoryParams) ([]data.GetVersionHistoryRow, error) { + return m.GetVersionHistoryFunc(ctx, arg) +} + +func TestAuthMiddleware(t *testing.T) { + secret := "test-jwt-secret" + testUserID := uuid.New().String() + + validRT := generateTestToken(t, secret, "refresh", testUserID, true) + validAT := generateTestToken(t, secret, "access", testUserID, true) + expiredAT := generateTestToken(t, secret, "access", testUserID, true, func(claims *userClaims) { + claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(-1 * time.Hour)) + }) + + tests := []struct { + name string + token string + expectedErr string + statusCode int + }{ + { + "no token", + "", + "Unauthorized", + http.StatusUnauthorized, + }, + { + "invalid token", + "invalid", + "Invalid token", + http.StatusUnauthorized, + }, + { + "expired token", + expiredAT, + "Invalid token", + http.StatusUnauthorized, + }, + { + "wrong token type", + validRT, + "Invalid token type", + http.StatusUnauthorized, + }, + { + "valid token", + validAT, + "", + http.StatusOK, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + atAuthMiddleware := requireAccessToken(secret) + + // Mock request + req := httptest.NewRequest("GET", "/", nil) + if tc.token != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", tc.token)) + } + + w := httptest.NewRecorder() + called := false + + // Mock endpoint that the middleware protects + handler := atAuthMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + _, ok := r.Context().Value(userCtxKey{}).(*userClaims) + assert.True(t, ok) + })) + handler.ServeHTTP(w, req) + + assert.Equal(t, tc.statusCode, w.Code) + assert.Equal(t, tc.statusCode == http.StatusOK, called) + if tc.expectedErr != "" { + assert.Contains(t, w.Body.String(), tc.expectedErr) + } + }) + } +} + +func TestAdminOnlyMiddleware(t *testing.T) { + tests := []struct { + name string + user *userClaims + statusCode int + }{ + { + "no user", + nil, + http.StatusForbidden, + }, + { + "non admin user", + &userClaims{ + Admin: false, + }, + http.StatusForbidden, + }, + { + "admin user", + &userClaims{ + Admin: true, + }, + http.StatusOK, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Mock request + req := httptest.NewRequest("GET", "/", nil) + if tc.user != nil { + req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, tc.user)) + } + + w := httptest.NewRecorder() + called := false + + // Mock endpoint that the middleware protects + handler := adminOnlyMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + w.WriteHeader(http.StatusOK) + })) + handler.ServeHTTP(w, req) + + assert.Equal(t, tc.statusCode, w.Code) + assert.Equal(t, tc.statusCode == http.StatusOK, called) + }) + } +} + +func TestUUIDCtxMiddleware(t *testing.T) { + testKeyName := "testKey" + tests := []struct { + name string + parameter string + expectedErr string + statusCode int + }{ + { + "missing uuid", + "", + "Invalid resource ID", + http.StatusBadRequest, + }, + { + "invalid uuid", + "invalid", + "Invalid resource ID", + http.StatusBadRequest, + }, + { + "valid uuid", + uuid.New().String(), + "", + http.StatusOK, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + uuidCtxMiddleware := uuidCtx(testKeyName) + req := httptest.NewRequest("GET", "/", nil) + + // We need to mock the URL parameter as we don't setup an actual router in this test env. + rctx := chi.NewRouteContext() + rctx.URLParams.Add(testKeyName, tc.parameter) + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + w := httptest.NewRecorder() + called := false + + // Mock endpoint that the middleware protects + handler := uuidCtxMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + _, ok := r.Context().Value(uuidCtxKey{Name: testKeyName}).(uuid.UUID) + assert.True(t, ok) + })) + handler.ServeHTTP(w, req) + + assert.Equal(t, tc.statusCode, w.Code) + assert.Equal(t, tc.statusCode == http.StatusOK, called) + if tc.expectedErr != "" { + assert.Contains(t, w.Body.String(), tc.expectedErr) + } + }) + } +} + +func TestNoteCtxMiddleware(t *testing.T) { + testTitle := "Test title" + tesTContent := "## Test content\nData 123" + testVersion := int32(3) + noteID := uuid.New() + + ownerUserID := uuid.New() + testOwnerClaims := userClaims{RegisteredClaims: jwt.RegisteredClaims{ + Subject: ownerUserID.String(), + }} + + otherUserID := uuid.New() + testOtherClaims := userClaims{RegisteredClaims: jwt.RegisteredClaims{ + Subject: otherUserID.String(), + }} + + tests := []struct { + name string + resourceID *uuid.UUID + user *userClaims + mock func(*mockNoteStore) + statusCode int + expectedErr string + }{ + { + "no resource id", + nil, + nil, + func(m *mockNoteStore) {}, + http.StatusBadRequest, + "Resource ID missing", + }, + { + "unauthorized", + ¬eID, + nil, + func(m *mockNoteStore) {}, + http.StatusUnauthorized, + "Unauthorized", + }, + { + "note not found", + ¬eID, + &testOwnerClaims, + func(m *mockNoteStore) { + m.GetFullNoteFunc = func(ctx context.Context, id uuid.UUID) (data.GetFullNoteRow, error) { + return data.GetFullNoteRow{}, errors.New("not found") + } + }, + http.StatusNotFound, + "Note not found", + }, + { + "not owner", + ¬eID, + &testOtherClaims, + func(m *mockNoteStore) { + m.GetFullNoteFunc = func(ctx context.Context, id uuid.UUID) (data.GetFullNoteRow, error) { + assert.Equal(t, noteID, id) + testTs := time.Now() + return data.GetFullNoteRow{ + NoteID: id, + OwnerID: ownerUserID, + Title: testTitle, + Content: tesTContent, + VersionNumber: testVersion, + VersionCreatedAt: &testTs, + NoteCreatedAt: &testTs, + NoteUpdatedAt: &testTs, + }, nil + } + }, + http.StatusForbidden, + "Forbidden", + }, + { + "success", + ¬eID, + &testOwnerClaims, + func(m *mockNoteStore) { + m.GetFullNoteFunc = func(ctx context.Context, id uuid.UUID) (data.GetFullNoteRow, error) { + assert.Equal(t, noteID, id) + testTs := time.Now() + return data.GetFullNoteRow{ + NoteID: id, + OwnerID: ownerUserID, + Title: testTitle, + Content: tesTContent, + VersionNumber: testVersion, + VersionCreatedAt: &testTs, + NoteCreatedAt: &testTs, + NoteUpdatedAt: &testTs, + }, nil + } + }, + http.StatusOK, + "", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // UUID ctx. (mock) -> note ctx. (tested here) + mockStore := &mockNoteStore{} + req := httptest.NewRequest("GET", "/", nil) + tc.mock(mockStore) + + // Mock endpoint that the middleware protects (where the attached note data is actually utilized) + handler := noteCtx(mockStore)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fullNote, ok := r.Context().Value(noteCtxKey{}).(*data.GetFullNoteRow) + assert.True(t, ok) + assert.Equal(t, noteID, fullNote.NoteID) + assert.Equal(t, testTitle, fullNote.Title) + assert.Equal(t, tesTContent, fullNote.Content) + assert.Equal(t, testVersion, fullNote.VersionNumber) + + w.WriteHeader(http.StatusOK) + })) + + // Request parameters don't need to be mocked, as parsing of them isn't handled + // by this middleware, and thus that portion shouldn't be tested here. + + if tc.resourceID != nil { + req = req.WithContext(context.WithValue(req.Context(), uuidCtxKey{Name: noteUUIDCtxParameter}, *tc.resourceID)) + } + + if tc.user != nil { + req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, tc.user)) + } + + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + assert.Equal(t, tc.statusCode, w.Code) + if tc.expectedErr != "" { + assert.Contains(t, w.Body.String(), tc.expectedErr) + } + }) + } +} + +func TestVersionCtxMiddleware(t *testing.T) { + testTitle := "Test title" + tesTContent := "## Test content\nData 123" + testVersion := int32(3) + versionID := uuid.New() + + noteID := uuid.New() + testNote := data.GetFullNoteRow{ + NoteID: noteID, + } + + tests := []struct { + name string + resourceID *uuid.UUID + note *data.GetFullNoteRow + mock func(*mockNoteStore) + statusCode int + expectedErr string + }{ + { + "no note", + nil, + nil, + func(m *mockNoteStore) {}, + http.StatusNotFound, + "Note not found", + }, + { + "no resource id", + nil, + &testNote, + func(m *mockNoteStore) {}, + http.StatusBadRequest, + "Resource ID missing", + }, + { + "version not found", + &versionID, + &testNote, + func(m *mockNoteStore) { + m.GetVersionFunc = func(ctx context.Context, gvp data.GetVersionParams) (data.GetVersionRow, error) { + return data.GetVersionRow{}, errors.New("not found") + } + }, + http.StatusNotFound, + "Version not found", + }, + { + "success", + &versionID, + &testNote, + func(m *mockNoteStore) { + m.GetVersionFunc = func(ctx context.Context, gvp data.GetVersionParams) (data.GetVersionRow, error) { + assert.Equal(t, versionID, gvp.ID) + assert.Equal(t, noteID, gvp.NoteID) + testTs := time.Now() + return data.GetVersionRow{ + VersionID: gvp.ID, + Title: testTitle, + Content: tesTContent, + VersionNumber: testVersion, + CreatedAt: &testTs, + }, nil + } + }, + http.StatusOK, + "", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // note ctx. (mock) -> UUID ctx. (mock) -> version ctx. (tested here) + mockStore := &mockNoteStore{} + req := httptest.NewRequest("GET", "/", nil) + tc.mock(mockStore) + + // Mock endpoint that the middleware protects (where the attached note data is actually utilized) + handler := versionCtx(mockStore)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fullVersion, ok := r.Context().Value(versionCtxKey{}).(*data.GetVersionRow) + assert.True(t, ok) + assert.Equal(t, versionID, fullVersion.VersionID) + assert.Equal(t, testTitle, fullVersion.Title) + assert.Equal(t, tesTContent, fullVersion.Content) + assert.Equal(t, testVersion, fullVersion.VersionNumber) + + w.WriteHeader(http.StatusOK) + })) + + // Request parameters don't need to be mocked, as parsing of them isn't handled + // by this middleware, and thus that portion shouldn't be tested here. + + if tc.note != nil { + req = req.WithContext(context.WithValue(req.Context(), noteCtxKey{}, tc.note)) + } + + if tc.resourceID != nil { + req = req.WithContext(context.WithValue(req.Context(), uuidCtxKey{Name: versionUUIDCtxParameter}, *tc.resourceID)) + } + + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + assert.Equal(t, tc.statusCode, w.Code) + if tc.expectedErr != "" { + assert.Contains(t, w.Body.String(), tc.expectedErr) + } + }) + } +} + +func generateTestToken(t *testing.T, secret, tokenType, userID string, isAdmin bool, opts ...func(*userClaims)) string { + t.Helper() + + claims := &userClaims{ + Admin: isAdmin, + TokenType: tokenType, + RegisteredClaims: jwt.RegisteredClaims{ + Subject: userID, + ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), + }, + } + + for _, opt := range opts { + opt(claims) + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + signedToken, err := token.SignedString([]byte(secret)) + if err != nil { + t.Fatalf("Failed to generate test token: %v", err) + } + + return signedToken +} diff --git a/server/internal/service/notes.go b/server/internal/service/notes.go new file mode 100644 index 0000000..7b3aeb8 --- /dev/null +++ b/server/internal/service/notes.go @@ -0,0 +1,331 @@ +package service + +import ( + "context" + "crypto/sha1" + "encoding/hex" + "encoding/json" + "fmt" + "net/http" + "strconv" + "strings" + + "git.umbrella.haus/ae/notatest/internal/data" + "github.com/go-chi/chi/v5" + "github.com/google/uuid" +) + +const ( + titleMaxLength = 150 + initVersionTitle = "Untitled" + initVersionContent = "" +) + +// Note object context key for incoming requests (handled my middlewares). Only `*data.GetFullNoteRow` +// type objects should be stored behind this key for consistency. +type noteCtxKey struct{} + +// Note version object context key for incoming requests (handled by middlewares). Only +// `*data.GetVersionRow` type objects should be stored behind this key for consistency. +type versionCtxKey struct{} + +// Mockable database operations interface +type NoteStore interface { + CreateNote(ctx context.Context, userID uuid.UUID) (data.Note, error) + DeleteNote(ctx context.Context, arg data.DeleteNoteParams) error + GetFullNote(ctx context.Context, noteID uuid.UUID) (data.GetFullNoteRow, error) + ListNotes(ctx context.Context, arg data.ListNotesParams) ([]data.ListNotesRow, error) + CreateNoteVersion(ctx context.Context, arg data.CreateNoteVersionParams) error + GetVersion(ctx context.Context, arg data.GetVersionParams) (data.GetVersionRow, error) + GetVersionHistory(ctx context.Context, arg data.GetVersionHistoryParams) ([]data.GetVersionHistoryRow, error) +} + +// Chi HTTP router for notes related CRUD actions. +type notesResource struct { + JWTSecret string + Notes NoteStore +} + +func (rs notesResource) Routes() chi.Router { + r := chi.NewRouter() + + r.Group(func(r chi.Router) { + r.Use(requireAccessToken(rs.JWTSecret)) // JWT claims -> ctx. + + r.Post("/", rs.Create) // POST /notes - create new note + r.Get("/", rs.ListMetadata) // GET /notes - get all notes (metadata + titles) + + /* + Clients should utilize `rs.ListMetadata` to load index of user's available notes (e.g. + sidebar view), use `rs.GetFullNote` to get full notes individually, and if request the + versioning history with `rs.GetVersionHistory` if necessary (and similarly fetch each + version individually if the client wants to view them). + */ + + r.Route(fmt.Sprintf("/{%s}", noteUUIDCtxParameter), func(r chi.Router) { + r.Use(uuidCtx(noteUUIDCtxParameter)) + r.Use(noteCtx(rs.Notes)) // DB -> req. context (metadata + active version) + r.Get("/", rs.GetFullNote) // GET /notes/{id} - get note from context + r.Delete("/", rs.Delete) // DELETE /notes/{id} - delete note + r.Get("/versions", rs.GetVersionHistory) // GET /notes/{id}/versions - get full versioning history + r.Post("/versions", rs.CreateVersion) // POST /notes/{id}/versions - create new version + + r.Route(fmt.Sprintf("/{%s}", versionUUIDCtxParameter), func(r chi.Router) { + r.Use(uuidCtx(versionUUIDCtxParameter)) + r.Use(versionCtx(rs.Notes)) // DB -> req. context (scoped version) + r.Get("/", rs.GetFullVersion) // GET /notes/{id}/{id} - get + }) + }) + }) + + return r +} + +// Handler for new note creation. Creates the parent metadata object (`notes` table) and an initial +// placeholder content version (`note_versions` table), and returns the placeholder contents to the +// caller in the HTTP response. +func (rs *notesResource) Create(w http.ResponseWriter, r *http.Request) { + user, ok := r.Context().Value(userCtxKey{}).(*userClaims) + if !ok { + respondError(w, http.StatusUnauthorized, "Unauthorized") + return + } + + userID, err := uuid.Parse(user.Subject) + if err != nil { + respondError(w, http.StatusUnauthorized, "Invalid user ID") + return + } + + // Metadata object (parent) + note, err := rs.Notes.CreateNote(r.Context(), userID) + if err != nil { + respondError(w, http.StatusInternalServerError, "Failed to create note") + return + } + + // Initial (empty) placeholder version of the contents + err = rs.Notes.CreateNoteVersion(r.Context(), data.CreateNoteVersionParams{ + NoteID: note.ID, + Title: initVersionTitle, + Content: initVersionContent, + ContentHash: sha1ContentHash(initVersionTitle, initVersionContent), + }) + if err != nil { + respondError(w, http.StatusInternalServerError, "Failed to create initial version") + return + } + + // Placeholder contents are decided server-side, so we need to inform the client of them via a + // one-time-use DTO + type response struct { + Title string `json:"title"` + Content string `json:"content"` + } + + res := response{ + Title: initVersionTitle, + Content: initVersionContent, + } + + respondJSON(w, http.StatusCreated, res) +} + +func (rs *notesResource) ListMetadata(w http.ResponseWriter, r *http.Request) { + user, ok := r.Context().Value(userCtxKey{}).(*userClaims) + if !ok { + respondError(w, http.StatusUnauthorized, "Unauthorized") + return + } + + userID, err := uuid.Parse(user.Subject) + if err != nil { + respondError(w, http.StatusUnauthorized, "Invalid user ID") + return + } + + limit, offset := getPaginationParams(r) + notes, err := rs.Notes.ListNotes(r.Context(), data.ListNotesParams{ + UserID: userID, + Limit: limit, + Offset: offset, + }) + if err != nil { + respondError(w, http.StatusInternalServerError, "Failed to retrieve notes") + return + } + + for _, note := range notes { + if userID != note.OwnerID { + respondError(w, http.StatusForbidden, "Forbidden") + return + } + } + + respondJSON(w, http.StatusOK, notes) +} + +// Handler for returning the currently scoped (included to the request's context by a middleware) +// full note object. +func (rs *notesResource) GetFullNote(w http.ResponseWriter, r *http.Request) { + fullNote, ok := r.Context().Value(noteCtxKey{}).(*data.GetFullNoteRow) + if !ok { + respondError(w, http.StatusNotFound, "Note not found") + return + } + + respondJSON(w, http.StatusOK, fullNote) +} + +// Handler for hard deelting the currently scoped note (including its versions via database cascade). +func (rs *notesResource) Delete(w http.ResponseWriter, r *http.Request) { + fullNote, ok := r.Context().Value(noteCtxKey{}).(*data.GetFullNoteRow) + if !ok { + respondError(w, http.StatusNotFound, "Note not found") + return + } + + user, ok := r.Context().Value(userCtxKey{}).(*userClaims) + if !ok { + respondError(w, http.StatusUnauthorized, "Unauthorized") + return + } + + userID, err := uuid.Parse(user.Subject) + if err != nil { + respondError(w, http.StatusUnauthorized, "Invalid user ID") + return + } + + err = rs.Notes.DeleteNote(r.Context(), data.DeleteNoteParams{ + ID: fullNote.NoteID, + UserID: userID, // NOTE: using `fullNote.userID` here'd be insecure + }) + if err != nil { + respondError(w, http.StatusInternalServerError, "Failed to delete note") + return + } + + w.WriteHeader(http.StatusNoContent) +} + +// Handler for listing the currently scoped note's version history. If pagination parameters +// (`limit` and `offset`) aren't defined, limit of 50 versions (with offset 0) will be returned. +func (rs *notesResource) GetVersionHistory(w http.ResponseWriter, r *http.Request) { + fullNote, ok := r.Context().Value(noteCtxKey{}).(*data.GetFullNoteRow) + if !ok { + respondError(w, http.StatusNotFound, "Note not found") + return + } + + limit, offset := getPaginationParams(r) + versions, err := rs.Notes.GetVersionHistory(r.Context(), data.GetVersionHistoryParams{ + NoteID: fullNote.NoteID, + Limit: limit, + Offset: offset, + }) + if err != nil { + respondError(w, http.StatusInternalServerError, "Failed to get version history") + return + } + + respondJSON(w, http.StatusOK, versions) +} + +// Handler for creating a new content version for the currently scoped note. Will check the incoming +// JSON object's integrity and perform a de-duplication check for identical versions stored in the +// database (SHA-1 hash of version contents). If a duplicate version is found, it'll be placed as the +// active version by swapping its version number to HEAD+1. +func (rs *notesResource) CreateVersion(w http.ResponseWriter, r *http.Request) { + fullNote, ok := r.Context().Value(noteCtxKey{}).(*data.GetFullNoteRow) + if !ok { + respondError(w, http.StatusNotFound, "Note not found") + return + } + + var req struct { + Title *string `json:"title"` + Content *string `json:"content"` + } + + if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.Title == nil || req.Content == nil { + respondError(w, http.StatusBadRequest, "Invalid request body") + return + } + + // Extra check for frontend readability reasons (max. length isn't specifically limited in the database) + if len(*req.Title) > titleMaxLength { + respondError(w, http.StatusBadRequest, fmt.Sprintf("Title must be shorter than %d characters", titleMaxLength)) + return + } + + /* + The SQL query handles de-duplication checks and "intelligent" versioning increments, so we + don't have to worry about them here (`latest_version` = highest version number that exists + in this note's context; `current_version` = note's active content version): + + - New version's contents are a duplicate of a historical version: + - Don't increment `latest_version` + - Sync `current_version` with the `version_number` of the duplicate version + - New version's contents are unique: + - Increment `latest_version` + - Sync `current_version` with `latest_version` + */ + err := rs.Notes.CreateNoteVersion(r.Context(), data.CreateNoteVersionParams{ + NoteID: fullNote.NoteID, + Title: *req.Title, + Content: *req.Content, + ContentHash: sha1ContentHash(*req.Title, *req.Content), + }) + if err != nil { + respondError(w, http.StatusInternalServerError, "Failed to create note version") + return + } + + w.WriteHeader(http.StatusNoContent) +} + +// Handler for returning full data of the currently scoped note version. Identical to the beginning +// of the `RollbackNoteVersion` handler. +func (rs *notesResource) GetFullVersion(w http.ResponseWriter, r *http.Request) { + fullVersion, ok := r.Context().Value(noteCtxKey{}).(*data.GetVersionRow) + if !ok { + respondError(w, http.StatusNotFound, "Note not found") + return + } + + respondJSON(w, http.StatusOK, fullVersion) +} + +// Parse `limit` and `offset` 32-bit integer URL parameters from the given request. Defaults to +// limit of 50 and offset 0 if parameters are missing/invalid. +func getPaginationParams(r *http.Request) (limit int32, offset int32) { + defaultLimit := 50 + defaultOffset := 0 + + limitStr := r.URL.Query().Get("limit") + if limitStr != "" { + if l, err := strconv.Atoi(limitStr); err == nil && l > 0 { + defaultLimit = l + } + } + + offsetStr := r.URL.Query().Get("offset") + if offsetStr != "" { + if o, err := strconv.Atoi(offsetStr); err == nil && o >= 0 { + defaultOffset = o + } + } + + return int32(defaultLimit), int32(defaultOffset) +} + +// Concatenate the title and content strings, calculate a SHA-1 hash of the resulting string, and +// return the resulting hash as a string. +func sha1ContentHash(title, content string) string { + hashContent := title + content + hash := sha1.Sum([]byte(hashContent)) + hashStr := strings.ToUpper(hex.EncodeToString(hash[:])) + + return hashStr +} diff --git a/server/pkg/service/service.go b/server/internal/service/service.go similarity index 55% rename from server/pkg/service/service.go rename to server/internal/service/service.go index 78223b6..baa804c 100644 --- a/server/pkg/service/service.go +++ b/server/internal/service/service.go @@ -3,26 +3,25 @@ package service import ( "net/http" - "git.umbrella.haus/ae/notatest/pkg/data" + "git.umbrella.haus/ae/notatest/internal/data" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/jackc/pgx/v5" "github.com/rs/zerolog/log" ) -func Run(conn *pgx.Conn, jwtSecret string) error { - q := data.New(conn) +func Run(conn *pgx.Conn, q *data.Queries, jwtSecret string) error { r := chi.NewRouter() - tokensRouter := tokensResource{ - JWTSecret: jwtSecret, - Tokens: q, - } - usersRouter := usersResource{ + authRouter := authResource{ JWTSecret: jwtSecret, Users: q, + Tokens: q, + } + notesRouter := notesResource{ + JWTSecret: jwtSecret, + Notes: q, } - notesRouter := notesResource{} // Global middlewares r.Use(middleware.RequestID) @@ -31,10 +30,12 @@ func Run(conn *pgx.Conn, jwtSecret string) error { r.Use(middleware.Recoverer) r.Use(middleware.AllowContentType("application/json")) - // Routes grouped by functionality - r.Mount("/auth", tokensRouter.Routes()) - r.Mount("/users", usersRouter.Routes()) - r.Mount("/notes", notesRouter.Routes()) + // Routes grouped by functionality (we must prefix the API routes with `/api` + // as the domain will be the same for the front and back ends) + r.Route("/api", func(r chi.Router) { + r.Mount("/auth", authRouter.Routes()) + r.Mount("/notes", notesRouter.Routes()) + }) log.Info().Msg("Starting server on :8080") return http.ListenAndServe(":8080", r) diff --git a/server/pkg/service/util.go b/server/internal/service/util.go similarity index 97% rename from server/pkg/service/util.go rename to server/internal/service/util.go index 3a5a826..d238644 100644 --- a/server/pkg/service/util.go +++ b/server/internal/service/util.go @@ -41,7 +41,7 @@ func respondError(w http.ResponseWriter, status int, message string) { } /* - Client-side check: + Example client-side check: ``` function estimateEntropy(password: string): number { @@ -102,7 +102,7 @@ func normalizeUsername(username string) string { } /* - Client-side check (additionally input should automatically perform the normalization steps): + Example client-side check (without input normalization): ``` function validateUsername(username: string): string { diff --git a/server/pkg/service/util_test.go b/server/internal/service/util_test.go similarity index 100% rename from server/pkg/service/util_test.go rename to server/internal/service/util_test.go diff --git a/server/main.go b/server/main.go index 8a943fd..5c8af8d 100644 --- a/server/main.go +++ b/server/main.go @@ -3,11 +3,15 @@ package main import ( "context" "embed" + "fmt" "os" - "git.umbrella.haus/ae/notatest/pkg/migrate" - "git.umbrella.haus/ae/notatest/pkg/service" - "github.com/caarlos0/env" + "git.umbrella.haus/ae/notatest/internal/data" + "git.umbrella.haus/ae/notatest/internal/service" + "github.com/caarlos0/env/v10" + "github.com/golang-migrate/migrate/v4" + _ "github.com/golang-migrate/migrate/v4/database/postgres" + "github.com/golang-migrate/migrate/v4/source/iofs" "github.com/jackc/pgx/v5" "github.com/rs/zerolog" "github.com/rs/zerolog/log" @@ -17,60 +21,72 @@ import ( var migrationsFS embed.FS var ( - isDevelopment = false - config Config + config Config ) type Config struct { - JWTSecret string `env:"JWT_SECRET,notEmpty"` - DBURL string `env:"PG_URL,notEmpty"` - RunMode string `env:"GO_ENV" envDefault:"production"` + JWTSecret string `env:"JWT_SECRET,notEmpty"` + DatabaseURL string `env:"DB_URL,notEmpty"` + LogLevel string `env:"LOG_LEVEL" envDefault:"info"` + AdminUsername string `env:"ADMIN_USERNAME,notEmpty,unset"` + AdminPassword string `env:"ADMIN_PASSWORD,notEmpty,unset"` } func init() { - initLogger() config = Config{} - env.Parse(&config) - - if config.RunMode == "development" { - log.Info().Msg("Development mode enabled") - isDevelopment = true + if err := env.Parse(&config); err != nil { + log.Fatal().Err(err).Msg("Failed to parse environment variables") } - + initLogger() log.Debug().Msg("Initialization completed") } func main() { - conn, err := pgx.Connect(context.Background(), config.DBURL) + log.Debug().Msgf("Database URL: %s", config.DatabaseURL) + conn, err := pgx.Connect(context.Background(), config.DatabaseURL) if err != nil { - log.Fatal().Msgf("Failed connecting to database: %s", err) + log.Fatal().Err(err).Msg("Failed to connect to database") } log.Info().Msg("Successfully connected to the database") - log.Debug().Msg(config.DBURL) + log.Info().Msg("Applying migrations...") - if isDevelopment { - if err := migrate.Run(context.Background(), conn, migrationsFS); err != nil { - log.Fatal().Msgf("Failed running migrations: %s", err) - } + d, err := iofs.New(migrationsFS, "sql/migrations") + if err != nil { + log.Fatal().Err(err).Msg("Failed constructing io/fs driver") + } + migrator, err := migrate.NewWithSourceInstance("iofs", d, config.DatabaseURL) + if err != nil { + log.Fatal().Err(err).Msg("Failed to apply migrations") + } + defer migrator.Close() + + if err := migrator.Up(); err != nil && err != migrate.ErrNoChange { + log.Fatal().Err(err).Msg("Failed to apply migrations") } - service.Run(conn, config.JWTSecret) + q := data.New(conn) + err = service.CreateAdminIfNotExists(context.Background(), q, config.AdminUsername, config.AdminPassword) + if err != nil { + log.Fatal().Err(err).Msg("Failed initial admin account creation") + } + + log.Info().Msg("Migrations applied succesfully, proceeding to HTTP server startup") + service.Run(conn, q, config.JWTSecret) } func initLogger() { - logLevel := os.Getenv("LOG_LEVEL") - level, err := zerolog.ParseLevel(logLevel) + fmt.Println(config.LogLevel) + level, err := zerolog.ParseLevel(config.LogLevel) if err != nil { - // Default to INFO level = zerolog.InfoLevel } zerolog.SetGlobalLevel(level) - if isDevelopment { - log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) - } else { - log.Logger = log.Output(os.Stderr) // JSON to stdout/stderr + output := zerolog.ConsoleWriter{ + Out: os.Stdout, + TimeFormat: "2006-01-02 15:04:05", } + log.Logger = log.Output(output).With().Timestamp().Caller().Logger() - log.Debug().Msg("Logger initialized") + log.Info().Msgf("Logger initialized (log level: %s)", level) } diff --git a/server/pkg/data/note_versions.sql.go b/server/pkg/data/note_versions.sql.go deleted file mode 100644 index 73269bb..0000000 --- a/server/pkg/data/note_versions.sql.go +++ /dev/null @@ -1,132 +0,0 @@ -// Code generated by sqlc. DO NOT EDIT. -// versions: -// sqlc v1.28.0 -// source: note_versions.sql - -package data - -import ( - "context" - - "github.com/google/uuid" -) - -const createNoteVersion = `-- name: CreateNoteVersion :one -INSERT INTO note_versions (note_id, title, content, version_number, content_hash) -VALUES ( - $1, - $2, - $3, - (SELECT COALESCE(MAX(version_number), 0) + 1 FROM note_versions WHERE note_id = $1), - encode(sha256($2::bytea || '\n'::bytea || $3::bytea), 'hex') -) -RETURNING id, note_id, title, content, version_number, content_hash, created_at -` - -type CreateNoteVersionParams struct { - NoteID uuid.UUID `json:"note_id"` - Title string `json:"title"` - Content string `json:"content"` -} - -func (q *Queries) CreateNoteVersion(ctx context.Context, arg CreateNoteVersionParams) (NoteVersion, error) { - row := q.db.QueryRow(ctx, createNoteVersion, arg.NoteID, arg.Title, arg.Content) - var i NoteVersion - err := row.Scan( - &i.ID, - &i.NoteID, - &i.Title, - &i.Content, - &i.VersionNumber, - &i.ContentHash, - &i.CreatedAt, - ) - return i, err -} - -const findDuplicateContent = `-- name: FindDuplicateContent :one -SELECT EXISTS( - SELECT 1 FROM note_versions - WHERE note_id = $1 - AND content_hash = encode(sha256($2::bytea || '\n'::bytea || $3::bytea), 'hex') -) -` - -type FindDuplicateContentParams struct { - NoteID uuid.UUID `json:"note_id"` - Column2 []byte `json:"column_2"` - Column3 []byte `json:"column_3"` -} - -func (q *Queries) FindDuplicateContent(ctx context.Context, arg FindDuplicateContentParams) (bool, error) { - row := q.db.QueryRow(ctx, findDuplicateContent, arg.NoteID, arg.Column2, arg.Column3) - var exists bool - err := row.Scan(&exists) - return exists, err -} - -const getNoteVersion = `-- name: GetNoteVersion :one -SELECT id, note_id, title, content, version_number, content_hash, created_at FROM note_versions -WHERE note_id = $1 AND version_number = $2 LIMIT 1 -` - -type GetNoteVersionParams struct { - NoteID uuid.UUID `json:"note_id"` - VersionNumber int32 `json:"version_number"` -} - -func (q *Queries) GetNoteVersion(ctx context.Context, arg GetNoteVersionParams) (NoteVersion, error) { - row := q.db.QueryRow(ctx, getNoteVersion, arg.NoteID, arg.VersionNumber) - var i NoteVersion - err := row.Scan( - &i.ID, - &i.NoteID, - &i.Title, - &i.Content, - &i.VersionNumber, - &i.ContentHash, - &i.CreatedAt, - ) - return i, err -} - -const getNoteVersions = `-- name: GetNoteVersions :many -SELECT id, note_id, title, content, version_number, content_hash, created_at FROM note_versions -WHERE note_id = $1 -ORDER BY version_number DESC -LIMIT $2 OFFSET $3 -` - -type GetNoteVersionsParams struct { - NoteID uuid.UUID `json:"note_id"` - Limit int32 `json:"limit"` - Offset int32 `json:"offset"` -} - -func (q *Queries) GetNoteVersions(ctx context.Context, arg GetNoteVersionsParams) ([]NoteVersion, error) { - rows, err := q.db.Query(ctx, getNoteVersions, arg.NoteID, arg.Limit, arg.Offset) - if err != nil { - return nil, err - } - defer rows.Close() - var items []NoteVersion - for rows.Next() { - var i NoteVersion - if err := rows.Scan( - &i.ID, - &i.NoteID, - &i.Title, - &i.Content, - &i.VersionNumber, - &i.ContentHash, - &i.CreatedAt, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} diff --git a/server/pkg/data/notes.sql.go b/server/pkg/data/notes.sql.go deleted file mode 100644 index 8ba0fcf..0000000 --- a/server/pkg/data/notes.sql.go +++ /dev/null @@ -1,105 +0,0 @@ -// Code generated by sqlc. DO NOT EDIT. -// versions: -// sqlc v1.28.0 -// source: notes.sql - -package data - -import ( - "context" - - "github.com/google/uuid" -) - -const createNote = `-- name: CreateNote :one -INSERT INTO notes (user_id) -VALUES ($1) -RETURNING id, user_id, created_at, updated_at -` - -func (q *Queries) CreateNote(ctx context.Context, userID uuid.UUID) (Note, error) { - row := q.db.QueryRow(ctx, createNote, userID) - var i Note - err := row.Scan( - &i.ID, - &i.UserID, - &i.CreatedAt, - &i.UpdatedAt, - ) - return i, err -} - -const deleteNote = `-- name: DeleteNote :exec -DELETE FROM notes -WHERE id = $1 AND user_id = $2 -` - -type DeleteNoteParams struct { - ID uuid.UUID `json:"id"` - UserID uuid.UUID `json:"user_id"` -} - -func (q *Queries) DeleteNote(ctx context.Context, arg DeleteNoteParams) error { - _, err := q.db.Exec(ctx, deleteNote, arg.ID, arg.UserID) - return err -} - -const getNote = `-- name: GetNote :one -SELECT id, user_id, created_at, updated_at FROM notes -WHERE id = $1 AND user_id = $2 LIMIT 1 -` - -type GetNoteParams struct { - ID uuid.UUID `json:"id"` - UserID uuid.UUID `json:"user_id"` -} - -func (q *Queries) GetNote(ctx context.Context, arg GetNoteParams) (Note, error) { - row := q.db.QueryRow(ctx, getNote, arg.ID, arg.UserID) - var i Note - err := row.Scan( - &i.ID, - &i.UserID, - &i.CreatedAt, - &i.UpdatedAt, - ) - return i, err -} - -const listNotes = `-- name: ListNotes :many -SELECT id, user_id, created_at, updated_at FROM notes -WHERE user_id = $1 -ORDER BY created_at DESC -LIMIT $2 OFFSET $3 -` - -type ListNotesParams struct { - UserID uuid.UUID `json:"user_id"` - Limit int32 `json:"limit"` - Offset int32 `json:"offset"` -} - -func (q *Queries) ListNotes(ctx context.Context, arg ListNotesParams) ([]Note, error) { - rows, err := q.db.Query(ctx, listNotes, arg.UserID, arg.Limit, arg.Offset) - if err != nil { - return nil, err - } - defer rows.Close() - var items []Note - for rows.Next() { - var i Note - if err := rows.Scan( - &i.ID, - &i.UserID, - &i.CreatedAt, - &i.UpdatedAt, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} diff --git a/server/pkg/migrate/migrate.go b/server/pkg/migrate/migrate.go deleted file mode 100644 index 20aa35e..0000000 --- a/server/pkg/migrate/migrate.go +++ /dev/null @@ -1,74 +0,0 @@ -// pkg/migrate/migrate.go -package migrate - -import ( - "context" - "embed" - "fmt" - "io/fs" - "sort" - "strconv" - "strings" - - "github.com/jackc/pgx/v5" -) - -func Run(ctx context.Context, conn *pgx.Conn, migrationsFS embed.FS) error { - // Get already applied migrations - rows, _ := conn.Query(ctx, "SELECT version FROM schema_migrations") - defer rows.Close() - - applied := make(map[int64]bool) - for rows.Next() { - var version int64 - if err := rows.Scan(&version); err != nil { - return err - } - applied[version] = true - } - - files, err := migrationsFS.ReadDir("migrations") - if err != nil { - return err - } - - // Apply the migrations sequentially based on their ordinal number - for _, f := range sortMigrations(files) { - version, err := strconv.ParseInt(strings.Split(f.Name(), "_")[0], 10, 64) - if err != nil { - return fmt.Errorf("invalid migration name: %s", f.Name()) - } - - if applied[version] { - continue - } - - // Run migration - sql, err := migrationsFS.ReadFile("migrations/" + f.Name()) - if err != nil { - return err - } - - if _, err := conn.Exec(ctx, string(sql)); err != nil { - return fmt.Errorf("migration %d failed: %w", version, err) - } - - if _, err := conn.Exec(ctx, - "INSERT INTO schema_migrations (version) VALUES ($1)", version, - ); err != nil { - return err - } - } - - return nil -} - -// Sort the migration files based on their ordinal number prefix. -func sortMigrations(files []fs.DirEntry) []fs.DirEntry { - sort.Slice(files, func(i, j int) bool { - v1, _ := strconv.ParseInt(strings.Split(files[i].Name(), "_")[0], 10, 64) - v2, _ := strconv.ParseInt(strings.Split(files[j].Name(), "_")[0], 10, 64) - return v1 < v2 - }) - return files -} diff --git a/server/pkg/service/middleware_test.go b/server/pkg/service/middleware_test.go deleted file mode 100644 index 59c92c5..0000000 --- a/server/pkg/service/middleware_test.go +++ /dev/null @@ -1,386 +0,0 @@ -package service - -import ( - "context" - "errors" - "fmt" - "net/http" - "net/http/httptest" - "testing" - "time" - - "git.umbrella.haus/ae/notatest/pkg/data" - "github.com/go-chi/chi/v5" - "github.com/golang-jwt/jwt/v5" - "github.com/google/uuid" - "github.com/stretchr/testify/assert" -) - -func TestAuthMiddleware(t *testing.T) { - secret := "test-secret" - validToken := generateTestToken(t, secret, "access", uuid.New().String(), true) - expiredToken := generateTestToken(t, secret, "access", uuid.New().String(), true, func(claims *userClaims) { - claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(-1 * time.Hour)) - }) - - tests := []struct { - name string - token string - expectedErr string - statusCode int - }{ - { - "no token", - "", - "Unauthorized", - http.StatusUnauthorized, - }, - { - "invalid token", - "invalid", - "Invalid token", - http.StatusUnauthorized, - }, - { - "expired token", - expiredToken, - "Invalid token", - http.StatusUnauthorized, - }, - { - "wrong type", - generateTestToken( - t, - secret, - "refresh", - uuid.New().String(), - true, - ), - "Invalid token type", - http.StatusUnauthorized, - }, - { - "valid token", - validToken, - "", - http.StatusOK, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - mw := requireAccessToken(secret) - req := httptest.NewRequest("GET", "/", nil) - if tc.token != "" { - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", tc.token)) - } - - w := httptest.NewRecorder() - called := false - handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - called = true - _, ok := r.Context().Value(userCtxKey{}).(*userClaims) - assert.True(t, ok) - })) - - handler.ServeHTTP(w, req) - - assert.Equal(t, tc.statusCode, w.Code) - if tc.expectedErr != "" { - assert.Contains(t, w.Body.String(), tc.expectedErr) - } - assert.Equal(t, tc.statusCode == http.StatusOK, called) - }) - } -} - -func TestAdminOnlyMiddleware(t *testing.T) { - tests := []struct { - name string - user *userClaims - statusCode int - }{ - { - "no user", - nil, - http.StatusForbidden, - }, - { - "non admin user", - &userClaims{ - Admin: false, - }, - http.StatusForbidden, - }, - { - "admin user", - &userClaims{ - Admin: true, - }, - http.StatusOK, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - mw := adminOnlyMiddleware - req := httptest.NewRequest("GET", "/", nil) - if tc.user != nil { - req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, tc.user)) - } - - w := httptest.NewRecorder() - called := false - handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - called = true - w.WriteHeader(http.StatusOK) - })) - - handler.ServeHTTP(w, req) - - assert.Equal(t, tc.statusCode, w.Code) - assert.Equal(t, tc.statusCode == http.StatusOK, called) - }) - } -} - -func TestOwnerOnlyMiddleware(t *testing.T) { - userID := uuid.New().String() - tests := []struct { - name string - user *userClaims - urlID string - statusCode int - }{ - { - "no user", - nil, - userID, - http.StatusForbidden, - }, - { - "different ID", - &userClaims{ - RegisteredClaims: jwt.RegisteredClaims{ - Subject: uuid.New().String(), - }}, - userID, - http.StatusForbidden, - }, - { - "matching ID", - &userClaims{ - RegisteredClaims: jwt.RegisteredClaims{ - Subject: userID, - }, - }, - userID, - http.StatusOK, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - r := chi.NewRouter() - - handlerChain := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - }) - - r.With( - // Add user with the given claims to request's context - func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var ctx context.Context = r.Context() - if tc.user != nil { - ctx = context.WithValue(ctx, userCtxKey{}, tc.user) - } - next.ServeHTTP(w, r.WithContext(ctx)) - }) - }, - ownerOnlyMiddleware, - ).Get("/{id}", handlerChain) - - req := httptest.NewRequest("GET", fmt.Sprintf("/%s", tc.urlID), nil) - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - if tc.urlID == "invalid" { - assert.Equal(t, http.StatusNotFound, w.Code) - } else { - assert.Equal(t, tc.statusCode, w.Code) - } - }) - } -} - -func TestUserCtxMiddleware(t *testing.T) { - validUserID := uuid.New() - invalidUserID := "invalid" - - mockStore := &mockUserStore{ - GetUserByIDFunc: func(ctx context.Context, id uuid.UUID) (data.User, error) { - if id == validUserID { - return data.User{ID: validUserID}, nil - } - return data.User{}, errors.New("not found") - }, - } - - tests := []struct { - name string - urlID string - statusCode int - }{ - { - "valid ID", - validUserID.String(), - http.StatusOK, - }, - { - "invalid ID", - invalidUserID, - http.StatusNotFound, - }, - { - "non existent ID", - uuid.New().String(), - http.StatusNotFound, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - mw := userCtx(mockStore) - r := chi.NewRouter() - r.With(mw).Get("/{id}", func(w http.ResponseWriter, r *http.Request) { - user, ok := r.Context().Value(userCtxKey{}).(data.User) - assert.True(t, ok) - assert.Equal(t, validUserID, user.ID) - w.WriteHeader(http.StatusOK) - }) - - req := httptest.NewRequest("GET", fmt.Sprintf("/%s", tc.urlID), nil) - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - assert.Equal(t, tc.statusCode, w.Code) - }) - } -} - -func TestNoteCtxMiddleware(t *testing.T) { - userID := uuid.New() - noteID := uuid.New() - - tests := []struct { - name string - noteID string - user any - mock func(*mockNoteStore) - statusCode int - }{ - { - "invalid note ID", - "invalid", - &userClaims{RegisteredClaims: jwt.RegisteredClaims{ - Subject: userID.String(), - }}, - func(m *mockNoteStore) {}, - http.StatusNotFound, - }, - { - "unauthorized user", - noteID.String(), - nil, - func(m *mockNoteStore) {}, - http.StatusUnauthorized, - }, - { - "note not found", - noteID.String(), - &userClaims{RegisteredClaims: jwt.RegisteredClaims{ - Subject: userID.String(), - }}, - func(m *mockNoteStore) { - m.GetNoteFunc = func(ctx context.Context, arg data.GetNoteParams) (data.Note, error) { - return data.Note{}, errors.New("not found") - } - }, - http.StatusNotFound, - }, - { - "success", - noteID.String(), - &userClaims{RegisteredClaims: jwt.RegisteredClaims{ - Subject: userID.String(), - }}, - func(m *mockNoteStore) { - m.GetNoteFunc = func(ctx context.Context, arg data.GetNoteParams) (data.Note, error) { - assert.Equal(t, noteID, arg.ID) - assert.Equal(t, userID, arg.UserID) - return data.Note{ID: noteID}, nil - } - }, - http.StatusOK, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - mockStore := &mockNoteStore{} - tc.mock(mockStore) - - handler := noteCtx(mockStore)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, ok := r.Context().Value(noteCtxKey{}).(data.Note) - assert.True(t, ok) - w.WriteHeader(http.StatusOK) - })) - - req := httptest.NewRequest("GET", fmt.Sprintf("/notes/%s", tc.noteID), nil) - - // Chi router context mocks ID passed in a URL parameter - rctx := chi.NewRouteContext() - rctx.URLParams.Add("id", tc.noteID) - req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) - - if tc.user != nil { - req = req.WithContext(context.WithValue( - req.Context(), - userCtxKey{}, - tc.user, - )) - } - - w := httptest.NewRecorder() - handler.ServeHTTP(w, req) - - assert.Equal(t, tc.statusCode, w.Code) - }) - } -} - -func generateTestToken(t *testing.T, secret, tokenType, userID string, isAdmin bool, opts ...func(*userClaims)) string { - t.Helper() - - claims := &userClaims{ - Admin: isAdmin, - TokenType: tokenType, - RegisteredClaims: jwt.RegisteredClaims{ - Subject: userID, - ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), - }, - } - - for _, opt := range opts { - opt(claims) - } - - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - signedToken, err := token.SignedString([]byte(secret)) - if err != nil { - t.Fatalf("Failed to generate test token: %v", err) - } - - return signedToken -} diff --git a/server/pkg/service/notes.go b/server/pkg/service/notes.go deleted file mode 100644 index 26f4fad..0000000 --- a/server/pkg/service/notes.go +++ /dev/null @@ -1,262 +0,0 @@ -package service - -import ( - "context" - "encoding/json" - "net/http" - "strconv" - - "git.umbrella.haus/ae/notatest/pkg/data" - "github.com/go-chi/chi/v5" - "github.com/google/uuid" -) - -type noteCtxKey struct{} - -// Mockable database operations interface -type NoteStore interface { - CreateNote(ctx context.Context, userID uuid.UUID) (data.Note, error) - DeleteNote(ctx context.Context, arg data.DeleteNoteParams) error - GetNote(ctx context.Context, arg data.GetNoteParams) (data.Note, error) - ListNotes(ctx context.Context, arg data.ListNotesParams) ([]data.Note, error) - CreateNoteVersion(ctx context.Context, arg data.CreateNoteVersionParams) (data.NoteVersion, error) - FindDuplicateContent(ctx context.Context, arg data.FindDuplicateContentParams) (bool, error) - GetNoteVersion(ctx context.Context, arg data.GetNoteVersionParams) (data.NoteVersion, error) - GetNoteVersions(ctx context.Context, arg data.GetNoteVersionsParams) ([]data.NoteVersion, error) -} - -type notesResource struct { - JWTSecret string - Notes NoteStore -} - -func (rs notesResource) Routes() chi.Router { - r := chi.NewRouter() - - r.Group(func(r chi.Router) { - r.Use(requireAccessToken(rs.JWTSecret)) - - r.Post("/", rs.CreateNote) // POST /notes - note creation - r.Get("/", rs.ListNotes) // GET /notes - get all notes - - r.Route("/{id}", func(r chi.Router) { - r.Use(noteCtx(rs.Notes)) - - r.Get("/", rs.GetNote) // GET /notes/{id} - get specific note - r.Delete("/", rs.DeleteNote) // DELETE /notes/{id} - delete specific note - - r.Route("/versions", func(r chi.Router) { - r.Post("/", rs.CreateNoteVersion) // POST /notes/{id}/versions - create new version - r.Get("/", rs.ListNoteVersions) // GET /notes/{id}/versions - get all existing versions - r.Get("/{version}", rs.GetNoteVersion) // GET /notes/{id}/versions/{version} - get specific version - }) - }) - }) - - return r -} - -func (rs *notesResource) CreateNote(w http.ResponseWriter, r *http.Request) { - user, ok := r.Context().Value(userCtxKey{}).(*userClaims) - if !ok { - respondError(w, http.StatusUnauthorized, "Unauthorized") - return - } - - userID, err := uuid.Parse(user.Subject) - if err != nil { - respondError(w, http.StatusInternalServerError, "Invalid user ID") - return - } - - note, err := rs.Notes.CreateNote(r.Context(), userID) - if err != nil { - respondError(w, http.StatusInternalServerError, "Failed to create note") - return - } - - respondJSON(w, http.StatusCreated, note) -} - -func (rs *notesResource) ListNotes(w http.ResponseWriter, r *http.Request) { - user, ok := r.Context().Value(userCtxKey{}).(*userClaims) - if !ok { - respondError(w, http.StatusUnauthorized, "Unauthorized") - return - } - - userID, err := uuid.Parse(user.Subject) - if err != nil { - respondError(w, http.StatusInternalServerError, "Invalid user ID") - return - } - - limit, offset := getPaginationParams(r) - - notes, err := rs.Notes.ListNotes(r.Context(), data.ListNotesParams{ - UserID: userID, - Limit: limit, - Offset: offset, - }) - if err != nil { - respondError(w, http.StatusInternalServerError, "Failed to retrieve notes") - return - } - - respondJSON(w, http.StatusOK, notes) -} - -func (rs *notesResource) GetNote(w http.ResponseWriter, r *http.Request) { - note, ok := r.Context().Value(noteCtxKey{}).(data.Note) - if !ok { - respondError(w, http.StatusNotFound, "Note not found") - return - } - - respondJSON(w, http.StatusOK, note) -} - -func (rs *notesResource) DeleteNote(w http.ResponseWriter, r *http.Request) { - note, ok := r.Context().Value(noteCtxKey{}).(data.Note) - if !ok { - respondError(w, http.StatusNotFound, "Note not found") - return - } - - user, ok := r.Context().Value(userCtxKey{}).(*userClaims) - if !ok { - respondError(w, http.StatusUnauthorized, "Unauthorized") - return - } - - userID, err := uuid.Parse(user.Subject) - if err != nil { - respondError(w, http.StatusInternalServerError, "Invalid user ID") - return - } - - err = rs.Notes.DeleteNote(r.Context(), data.DeleteNoteParams{ - ID: note.ID, - UserID: userID, - }) - if err != nil { - respondError(w, http.StatusInternalServerError, "Failed to delete note") - return - } - - w.WriteHeader(http.StatusNoContent) -} - -func (rs *notesResource) CreateNoteVersion(w http.ResponseWriter, r *http.Request) { - note, ok := r.Context().Value(noteCtxKey{}).(data.Note) - if !ok { - respondError(w, http.StatusNotFound, "Note not found") - return - } - - var req struct { - Title string `json:"title"` - Content string `json:"content"` - } - - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - respondError(w, http.StatusBadRequest, "Invalid request body") - return - } - - // De-duplication check - duplicate, err := rs.Notes.FindDuplicateContent(r.Context(), data.FindDuplicateContentParams{ - NoteID: note.ID, - Column2: []byte(req.Title), - Column3: []byte(req.Content), - }) - if err != nil { - respondError(w, http.StatusInternalServerError, "Failed to check for duplicate content") - return - } - if duplicate { - respondError(w, http.StatusConflict, "Duplicate content detected") - return - } - - version, err := rs.Notes.CreateNoteVersion(r.Context(), data.CreateNoteVersionParams{ - NoteID: note.ID, - Title: req.Title, - Content: req.Content, - }) - if err != nil { - respondError(w, http.StatusInternalServerError, "Failed to create note version") - return - } - - respondJSON(w, http.StatusCreated, version) -} - -func (rs *notesResource) ListNoteVersions(w http.ResponseWriter, r *http.Request) { - note, ok := r.Context().Value(noteCtxKey{}).(data.Note) - if !ok { - respondError(w, http.StatusNotFound, "Note not found") - return - } - - limit, offset := getPaginationParams(r) - - versions, err := rs.Notes.GetNoteVersions(r.Context(), data.GetNoteVersionsParams{ - NoteID: note.ID, - Limit: limit, - Offset: offset, - }) - if err != nil { - respondError(w, http.StatusInternalServerError, "Failed to retrieve versions") - return - } - - respondJSON(w, http.StatusOK, versions) -} - -func (rs *notesResource) GetNoteVersion(w http.ResponseWriter, r *http.Request) { - note, ok := r.Context().Value(noteCtxKey{}).(data.Note) - if !ok { - respondError(w, http.StatusNotFound, "Note not found") - return - } - - versionStr := chi.URLParam(r, "version") - versionNumber, err := strconv.ParseInt(versionStr, 10, 32) - if err != nil { - respondError(w, http.StatusBadRequest, "Invalid version number") - return - } - - version, err := rs.Notes.GetNoteVersion(r.Context(), data.GetNoteVersionParams{ - NoteID: note.ID, - VersionNumber: int32(versionNumber), - }) - if err != nil { - respondError(w, http.StatusNotFound, "Version not found") - return - } - - respondJSON(w, http.StatusOK, version) -} - -func getPaginationParams(r *http.Request) (limit int32, offset int32) { - defaultLimit := 50 - defaultOffset := 0 - - limitStr := r.URL.Query().Get("limit") - if limitStr != "" { - if l, err := strconv.Atoi(limitStr); err == nil && l > 0 { - defaultLimit = l - } - } - - offsetStr := r.URL.Query().Get("offset") - if offsetStr != "" { - if o, err := strconv.Atoi(offsetStr); err == nil && o >= 0 { - defaultOffset = o - } - } - - return int32(defaultLimit), int32(defaultOffset) -} diff --git a/server/pkg/service/notes_test.go b/server/pkg/service/notes_test.go deleted file mode 100644 index 8a97d17..0000000 --- a/server/pkg/service/notes_test.go +++ /dev/null @@ -1,509 +0,0 @@ -package service - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "git.umbrella.haus/ae/notatest/pkg/data" - "github.com/go-chi/chi/v5" - "github.com/golang-jwt/jwt/v5" - "github.com/google/uuid" - "github.com/stretchr/testify/assert" -) - -type mockNoteStore struct { - CreateNoteFunc func(context.Context, uuid.UUID) (data.Note, error) - DeleteNoteFunc func(context.Context, data.DeleteNoteParams) error - GetNoteFunc func(context.Context, data.GetNoteParams) (data.Note, error) - ListNotesFunc func(context.Context, data.ListNotesParams) ([]data.Note, error) - CreateNoteVersionFunc func(context.Context, data.CreateNoteVersionParams) (data.NoteVersion, error) - FindDuplicateContentFunc func(context.Context, data.FindDuplicateContentParams) (bool, error) - GetNoteVersionFunc func(context.Context, data.GetNoteVersionParams) (data.NoteVersion, error) - GetNoteVersionsFunc func(context.Context, data.GetNoteVersionsParams) ([]data.NoteVersion, error) -} - -func (m *mockNoteStore) CreateNote(ctx context.Context, userID uuid.UUID) (data.Note, error) { - return m.CreateNoteFunc(ctx, userID) -} - -func (m *mockNoteStore) DeleteNote(ctx context.Context, arg data.DeleteNoteParams) error { - return m.DeleteNoteFunc(ctx, arg) -} - -func (m *mockNoteStore) GetNote(ctx context.Context, arg data.GetNoteParams) (data.Note, error) { - return m.GetNoteFunc(ctx, arg) -} - -func (m *mockNoteStore) ListNotes(ctx context.Context, arg data.ListNotesParams) ([]data.Note, error) { - return m.ListNotesFunc(ctx, arg) -} - -func (m *mockNoteStore) CreateNoteVersion(ctx context.Context, arg data.CreateNoteVersionParams) (data.NoteVersion, error) { - return m.CreateNoteVersionFunc(ctx, arg) -} - -func (m *mockNoteStore) FindDuplicateContent(ctx context.Context, arg data.FindDuplicateContentParams) (bool, error) { - return m.FindDuplicateContentFunc(ctx, arg) -} - -func (m *mockNoteStore) GetNoteVersion(ctx context.Context, arg data.GetNoteVersionParams) (data.NoteVersion, error) { - return m.GetNoteVersionFunc(ctx, arg) -} - -func (m *mockNoteStore) GetNoteVersions(ctx context.Context, arg data.GetNoteVersionsParams) ([]data.NoteVersion, error) { - return m.GetNoteVersionsFunc(ctx, arg) -} - -func TestNotes_CreateNote(t *testing.T) { - userID := uuid.New() - testNote := data.Note{ID: uuid.New(), UserID: userID} - - tests := []struct { - name string - mock func(*mockNoteStore) - statusCode int - }{ - { - "success", - func(m *mockNoteStore) { - m.CreateNoteFunc = func(_ context.Context, uid uuid.UUID) (data.Note, error) { - assert.Equal(t, userID, uid) - return testNote, nil - } - }, - http.StatusCreated, - }, - { - "database error", - func(m *mockNoteStore) { - m.CreateNoteFunc = func(context.Context, uuid.UUID) (data.Note, error) { - return data.Note{}, errors.New("db error") - } - }, - http.StatusInternalServerError, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - mockStore := &mockNoteStore{} - tc.mock(mockStore) - - rs := notesResource{Notes: mockStore} - req := httptest.NewRequest("POST", "/", nil) - req = req.WithContext(context.WithValue( - req.Context(), - userCtxKey{}, - &userClaims{RegisteredClaims: jwt.RegisteredClaims{ - Subject: userID.String(), - }}, - )) - - w := httptest.NewRecorder() - rs.CreateNote(w, req) - - assert.Equal(t, tc.statusCode, w.Code) - if tc.statusCode == http.StatusCreated { - var note data.Note - json.Unmarshal(w.Body.Bytes(), ¬e) - assert.Equal(t, testNote.ID, note.ID) - } - }) - } -} - -func TestNotes_ListNotes(t *testing.T) { - userID := uuid.New() - notes := []data.Note{ - {ID: uuid.New(), UserID: userID}, - {ID: uuid.New(), UserID: userID}, - } - - tests := []struct { - name string - query string - mock func(*mockNoteStore) - statusCode int - }{ - { - "success", - "", - func(m *mockNoteStore) { - m.ListNotesFunc = func(_ context.Context, arg data.ListNotesParams) ([]data.Note, error) { - assert.Equal(t, userID, arg.UserID) - return notes, nil - } - }, - http.StatusOK, - }, - { - "with pagination", - "?limit=10&offset=20", - func(m *mockNoteStore) { - m.ListNotesFunc = func(_ context.Context, arg data.ListNotesParams) ([]data.Note, error) { - assert.EqualValues(t, 10, arg.Limit) - assert.EqualValues(t, 20, arg.Offset) - return notes, nil - } - }, - http.StatusOK, - }, - { - "database error", - "", - func(m *mockNoteStore) { - m.ListNotesFunc = func(context.Context, data.ListNotesParams) ([]data.Note, error) { - return nil, errors.New("db error") - } - }, - http.StatusInternalServerError, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - mockStore := &mockNoteStore{} - tc.mock(mockStore) - - rs := notesResource{Notes: mockStore} - req := httptest.NewRequest("GET", fmt.Sprintf("/%s", tc.query), nil) - req = req.WithContext(context.WithValue( - req.Context(), - userCtxKey{}, - &userClaims{RegisteredClaims: jwt.RegisteredClaims{ - Subject: userID.String(), - }}, - )) - - w := httptest.NewRecorder() - rs.ListNotes(w, req) - - assert.Equal(t, tc.statusCode, w.Code) - }) - } -} - -func TestNotes_GetNote(t *testing.T) { - noteID := uuid.New() - userID := uuid.New() - validNote := data.Note{ID: noteID, UserID: userID} - - t.Run("success", func(t *testing.T) { - rs := notesResource{} - req := httptest.NewRequest("GET", "/", nil) - req = req.WithContext(context.WithValue( - req.Context(), - noteCtxKey{}, - validNote, - )) - - w := httptest.NewRecorder() - rs.GetNote(w, req) - - assert.Equal(t, http.StatusOK, w.Code) - var note data.Note - json.Unmarshal(w.Body.Bytes(), ¬e) - assert.Equal(t, validNote.ID, note.ID) - }) - - t.Run("not found", func(t *testing.T) { - rs := notesResource{} - req := httptest.NewRequest("GET", "/", nil) - w := httptest.NewRecorder() - rs.GetNote(w, req) - - assert.Equal(t, http.StatusNotFound, w.Code) - }) -} - -func TestNotes_DeleteNote(t *testing.T) { - noteID := uuid.New() - userID := uuid.New() - validNote := data.Note{ID: noteID, UserID: userID} - - t.Run("success", func(t *testing.T) { - mockStore := &mockNoteStore{ - DeleteNoteFunc: func(_ context.Context, arg data.DeleteNoteParams) error { - assert.Equal(t, noteID, arg.ID) - assert.Equal(t, userID, arg.UserID) - return nil - }, - } - - rs := notesResource{Notes: mockStore} - req := httptest.NewRequest("DELETE", "/", nil) - req = req.WithContext(context.WithValue( - req.Context(), - noteCtxKey{}, - validNote, - )) - req = req.WithContext(context.WithValue( - req.Context(), - userCtxKey{}, - &userClaims{RegisteredClaims: jwt.RegisteredClaims{ - Subject: userID.String(), - }}, - )) - - w := httptest.NewRecorder() - rs.DeleteNote(w, req) - - assert.Equal(t, http.StatusNoContent, w.Code) - }) - - t.Run("database error", func(t *testing.T) { - mockStore := &mockNoteStore{ - DeleteNoteFunc: func(context.Context, data.DeleteNoteParams) error { - return errors.New("db error") - }, - } - - rs := notesResource{Notes: mockStore} - req := httptest.NewRequest("DELETE", "/", nil) - req = req.WithContext(context.WithValue( - req.Context(), - noteCtxKey{}, - validNote, - )) - req = req.WithContext(context.WithValue( - req.Context(), - userCtxKey{}, - &userClaims{RegisteredClaims: jwt.RegisteredClaims{ - Subject: userID.String(), - }}, - )) - - w := httptest.NewRecorder() - rs.DeleteNote(w, req) - - assert.Equal(t, http.StatusInternalServerError, w.Code) - }) -} - -func TestNotes_CreateNoteVersion(t *testing.T) { - noteID := uuid.New() - validRequest := `{"title": "Test", "content": "Content"}` - - tests := []struct { - name string - body string - mock func(*mockNoteStore) - statusCode int - }{ - { - "success", - validRequest, - func(m *mockNoteStore) { - m.FindDuplicateContentFunc = func(context.Context, data.FindDuplicateContentParams) (bool, error) { - return false, nil - } - m.CreateNoteVersionFunc = func(_ context.Context, arg data.CreateNoteVersionParams) (data.NoteVersion, error) { - assert.Equal(t, noteID, arg.NoteID) - return data.NoteVersion{}, nil - } - }, - http.StatusCreated, - }, - { - "duplicate content", - validRequest, - func(m *mockNoteStore) { - m.FindDuplicateContentFunc = func(context.Context, data.FindDuplicateContentParams) (bool, error) { - return true, nil - } - }, - http.StatusConflict, - }, - { - "invalid request", - "{invalid}", - func(m *mockNoteStore) {}, - http.StatusBadRequest, - }, - { - "database error", - validRequest, - func(m *mockNoteStore) { - m.FindDuplicateContentFunc = func(context.Context, data.FindDuplicateContentParams) (bool, error) { - return false, nil - } - m.CreateNoteVersionFunc = func(context.Context, data.CreateNoteVersionParams) (data.NoteVersion, error) { - return data.NoteVersion{}, errors.New("db error") - } - }, - http.StatusInternalServerError, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - mockStore := &mockNoteStore{} - tc.mock(mockStore) - - rs := notesResource{Notes: mockStore} - req := httptest.NewRequest("POST", "/", strings.NewReader(tc.body)) - req = req.WithContext(context.WithValue( - req.Context(), - noteCtxKey{}, - data.Note{ID: noteID}, - )) - - w := httptest.NewRecorder() - rs.CreateNoteVersion(w, req) - - assert.Equal(t, tc.statusCode, w.Code) - }) - } -} - -func TestNotes_ListNoteVersions(t *testing.T) { - noteID := uuid.New() - versions := []data.NoteVersion{ - {ID: uuid.New(), NoteID: noteID, VersionNumber: 1}, - {ID: uuid.New(), NoteID: noteID, VersionNumber: 2}, - } - - tests := []struct { - name string - query string - mock func(*mockNoteStore) - statusCode int - }{ - { - "success", - "", - func(m *mockNoteStore) { - m.GetNoteVersionsFunc = func(_ context.Context, arg data.GetNoteVersionsParams) ([]data.NoteVersion, error) { - assert.Equal(t, noteID, arg.NoteID) - return versions, nil - } - }, - http.StatusOK, - }, - { - "with pagination", - "?limit=5&offset=10", - func(m *mockNoteStore) { - m.GetNoteVersionsFunc = func(_ context.Context, arg data.GetNoteVersionsParams) ([]data.NoteVersion, error) { - assert.EqualValues(t, 5, arg.Limit) - assert.EqualValues(t, 10, arg.Offset) - return versions, nil - } - }, - http.StatusOK, - }, - { - "database error", - "", - func(m *mockNoteStore) { - m.GetNoteVersionsFunc = func(context.Context, data.GetNoteVersionsParams) ([]data.NoteVersion, error) { - return nil, errors.New("db error") - } - }, - http.StatusInternalServerError, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - mockStore := &mockNoteStore{} - tc.mock(mockStore) - - rs := notesResource{Notes: mockStore} - req := httptest.NewRequest("GET", fmt.Sprintf("/versions/%s", tc.query), nil) - req = req.WithContext(context.WithValue( - req.Context(), - noteCtxKey{}, - data.Note{ID: noteID}, - )) - - w := httptest.NewRecorder() - rs.ListNoteVersions(w, req) - - assert.Equal(t, tc.statusCode, w.Code) - if tc.statusCode == http.StatusOK { - var result []data.NoteVersion - json.Unmarshal(w.Body.Bytes(), &result) - assert.Len(t, result, 2) - } - }) - } -} - -func TestNotes_GetNoteVersion(t *testing.T) { - noteID := uuid.New() - version := data.NoteVersion{ID: uuid.New(), NoteID: noteID, VersionNumber: 1} - - tests := []struct { - name string - version string - mock func(*mockNoteStore) - statusCode int - }{ - { - "invalid version", - "invalid", - func(m *mockNoteStore) {}, - http.StatusBadRequest, - }, - { - "version not found", - "1", - func(m *mockNoteStore) { - m.GetNoteVersionFunc = func(context.Context, data.GetNoteVersionParams) (data.NoteVersion, error) { - return data.NoteVersion{}, errors.New("not found") - } - }, - http.StatusNotFound, - }, - { - "success", - "1", - func(m *mockNoteStore) { - m.GetNoteVersionFunc = func(_ context.Context, arg data.GetNoteVersionParams) (data.NoteVersion, error) { - assert.Equal(t, noteID, arg.NoteID) - assert.EqualValues(t, 1, arg.VersionNumber) - return version, nil - } - }, - http.StatusOK, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - mockStore := &mockNoteStore{} - tc.mock(mockStore) - - rs := notesResource{Notes: mockStore} - - req := httptest.NewRequest("GET", fmt.Sprintf("/versions/%s", tc.version), nil) - - // Chi router context mocks ID (passed in a URL param.) and the note object (passed in req. ctx.) - rctx := chi.NewRouteContext() - rctx.URLParams.Add("version", tc.version) - req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) - req = req.WithContext(context.WithValue( - req.Context(), - noteCtxKey{}, - data.Note{ID: noteID}, - )) - - w := httptest.NewRecorder() - rs.GetNoteVersion(w, req) - - assert.Equal(t, tc.statusCode, w.Code) - if tc.statusCode == http.StatusOK { - var result data.NoteVersion - json.Unmarshal(w.Body.Bytes(), &result) - assert.Equal(t, version.VersionNumber, result.VersionNumber) - } - }) - } -} diff --git a/server/pkg/service/tokens.go b/server/pkg/service/tokens.go deleted file mode 100644 index 758badf..0000000 --- a/server/pkg/service/tokens.go +++ /dev/null @@ -1,249 +0,0 @@ -package service - -import ( - "context" - "crypto/sha256" - "encoding/hex" - "errors" - "fmt" - "net/http" - "strings" - "time" - - "git.umbrella.haus/ae/notatest/pkg/data" - "github.com/go-chi/chi/v5" - "github.com/golang-jwt/jwt/v5" - "github.com/google/uuid" -) - -const ( - accessTokenDuration = 15 * time.Minute - refreshTokenDuration = 7 * 24 * time.Hour -) - -var ( - ErrInvalidToken = errors.New("invalid token") - ErrAuthHeaderInvalid = errors.New("token couldn't be parsed from authentication header") -) - -type userClaims struct { - Admin bool `json:"admin"` - TokenType string `json:"type"` // "access" or "refresh" - jwt.RegisteredClaims // User's UUID should be stored in the subject claim -} - -type tokenPair struct { - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` -} - -// Mockable database operations interface -type TokenStore interface { - CreateRefreshToken(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) - GetRefreshTokenByHash(ctx context.Context, tokenHash string) (data.RefreshToken, error) - RevokeRefreshToken(ctx context.Context, tokenHash string) error - RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error -} - -type tokensResource struct { - JWTSecret string - Tokens TokenStore -} - -func (rs tokensResource) Routes() chi.Router { - r := chi.NewRouter() - - // Protected routes (access token required) - r.Group(func(r chi.Router) { - r.Use(requireAccessToken(rs.JWTSecret)) - r.Post("/logout", rs.HandleLogout) // POST /auth/logout - revoke all refresh cookies - }) - - // Protected routes (refresh token required) - r.Group(func(r chi.Router) { - r.Use(requireRefreshToken(rs.JWTSecret)) - r.Post("/refresh", rs.RefreshAccessToken) // POST /auth/refresh - convert refresh token to new token pair - }) - - return r -} - -func (rs tokensResource) GenerateTokenPair(ctx context.Context, userID uuid.UUID, isAdmin bool) (*tokenPair, error) { - tokenPair, err := generateTokenPair(userID.String(), isAdmin, rs.JWTSecret) - if err != nil { - return nil, err - } - - hash := sha256.Sum256([]byte(tokenPair.RefreshToken)) - tokenHash := hex.EncodeToString(hash[:]) - - // Store to DB with (almost) identical expiration timestamp - expiresAt := time.Now().Add(refreshTokenDuration) - _, err = rs.Tokens.CreateRefreshToken(ctx, data.CreateRefreshTokenParams{ - UserID: userID, - TokenHash: tokenHash, - ExpiresAt: expiresAt, - }) - if err != nil { - return nil, err - } - - return tokenPair, nil -} - -func (rs tokensResource) RevokeRefreshToken(ctx context.Context, token string) error { - hash := sha256.Sum256([]byte(token)) - tokenHash := hex.EncodeToString(hash[:]) - return rs.Tokens.RevokeRefreshToken(ctx, tokenHash) -} - -func (rs tokensResource) ValidateRefreshToken(ctx context.Context, token string) (*data.RefreshToken, error) { - hash := sha256.Sum256([]byte(token)) - tokenHash := hex.EncodeToString(hash[:]) - - dbToken, err := rs.Tokens.GetRefreshTokenByHash(ctx, tokenHash) - if err != nil { - return nil, err - } - - if dbToken.Revoked || time.Now().After(dbToken.ExpiresAt) { - return nil, ErrInvalidToken - } - - return &dbToken, nil -} - -func (rs tokensResource) RefreshAccessToken(w http.ResponseWriter, r *http.Request) { - // Get claims from context - claims, ok := r.Context().Value(userCtxKey{}).(*userClaims) - if !ok || claims.TokenType != "refresh" { - respondError(w, http.StatusUnauthorized, "Invalid token") - return - } - - // Attempt to get the token from Authentication header ("Bearer ") - refreshToken, err := getTokenFromRequest(r) - if err != nil { - respondError(w, http.StatusUnauthorized, fmt.Sprintf("Unauthorized: %s", err)) - } - - // Validate the refresh token in DB - if _, err := rs.ValidateRefreshToken(r.Context(), refreshToken); err != nil { - respondError(w, http.StatusUnauthorized, "Invalid refresh token") - return - } - - // Revoke the used refresh token - if err := rs.RevokeRefreshToken(r.Context(), refreshToken); err != nil { - respondError(w, http.StatusInternalServerError, "Failed to revoke token") - return - } - - // Generate a new pair (access & refresh tokens) - userID, err := uuid.Parse(claims.Subject) - if err != nil { - respondError(w, http.StatusInternalServerError, "Invalid user ID") - return - } - - tokenPair, err := rs.GenerateTokenPair(r.Context(), userID, claims.Admin) - if err != nil { - respondError(w, http.StatusInternalServerError, "Failed to generate tokens") - return - } - - // Set refresh token in HTTP-only cookie - http.SetCookie(w, &http.Cookie{ - Name: "refresh_token", - Value: tokenPair.RefreshToken, - Path: "/", - MaxAge: int(refreshTokenDuration.Seconds()), - HttpOnly: true, - Secure: true, - SameSite: http.SameSiteStrictMode, - }) - - // Return the access token in the response body (it should be stored in browser's memory client-side) - respondJSON(w, http.StatusOK, map[string]string{ - "access_token": tokenPair.AccessToken, - }) -} - -func (rs tokensResource) HandleLogout(w http.ResponseWriter, r *http.Request) { - claims, ok := r.Context().Value(userCtxKey{}).(*userClaims) - if !ok { - respondError(w, http.StatusUnauthorized, "Not authenticated") - return - } - - userID, err := uuid.Parse(claims.ID) - if err != nil { - respondError(w, http.StatusInternalServerError, "Invalid user ID") - return - } - - // Clear the refresh token cookie - http.SetCookie(w, &http.Cookie{ - Name: "refresh_token", - Value: "", - Path: "/", - MaxAge: -1, // Expires immediately - HttpOnly: true, - Secure: true, - SameSite: http.SameSiteStrictMode, - }) - - if err := rs.Tokens.RevokeAllUserRefreshTokens(r.Context(), userID); err != nil { - respondError(w, http.StatusInternalServerError, "Failed to logout") - return - } - - w.WriteHeader(http.StatusNoContent) -} - -func getTokenFromRequest(r *http.Request) (string, error) { - bearerToken := r.Header.Get("Authorization") - bearerFields := strings.Fields(bearerToken) - if len(bearerFields) == 2 && strings.ToLower(bearerFields[0]) == "bearer" { - return bearerFields[1], nil - } - return "", ErrAuthHeaderInvalid -} - -func generateTokenPair(userID string, isAdmin bool, jwtSecret string) (*tokenPair, error) { - atClaims := userClaims{ - Admin: isAdmin, - TokenType: "access", - RegisteredClaims: jwt.RegisteredClaims{ - Subject: userID, - ExpiresAt: jwt.NewNumericDate(time.Now().Add(accessTokenDuration)), - IssuedAt: jwt.NewNumericDate(time.Now()), - NotBefore: jwt.NewNumericDate(time.Now()), - }, - } - accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, atClaims) - - t, err := accessToken.SignedString([]byte(jwtSecret)) - if err != nil { - return nil, err - } - - rtClaims := userClaims{ - Admin: isAdmin, - TokenType: "refresh", - RegisteredClaims: jwt.RegisteredClaims{ - Subject: userID, - ExpiresAt: jwt.NewNumericDate(time.Now().Add(refreshTokenDuration)), - IssuedAt: jwt.NewNumericDate(time.Now()), - NotBefore: jwt.NewNumericDate(time.Now()), - }, - } - refreshToken := jwt.NewWithClaims(jwt.SigningMethodHS256, rtClaims) - - rt, err := refreshToken.SignedString([]byte(jwtSecret)) - if err != nil { - return nil, err - } - - return &tokenPair{AccessToken: t, RefreshToken: rt}, nil -} diff --git a/server/pkg/service/tokens_test.go b/server/pkg/service/tokens_test.go deleted file mode 100644 index c7db4a0..0000000 --- a/server/pkg/service/tokens_test.go +++ /dev/null @@ -1,211 +0,0 @@ -package service - -import ( - "context" - "crypto/sha256" - "encoding/hex" - "errors" - "fmt" - "net/http" - "net/http/httptest" - "testing" - "time" - - "git.umbrella.haus/ae/notatest/pkg/data" - "github.com/golang-jwt/jwt/v5" - "github.com/google/uuid" - "github.com/stretchr/testify/assert" -) - -type mockTokenStore struct { - CreateRefreshTokenFunc func(context.Context, data.CreateRefreshTokenParams) (data.RefreshToken, error) - GetRefreshTokenByHashFunc func(context.Context, string) (data.RefreshToken, error) - RevokeRefreshTokenFunc func(context.Context, string) error - RevokeAllUserRefreshTokensFunc func(context.Context, uuid.UUID) error -} - -func (m *mockTokenStore) CreateRefreshToken(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) { - return m.CreateRefreshTokenFunc(ctx, arg) -} - -func (m *mockTokenStore) GetRefreshTokenByHash(ctx context.Context, tokenHash string) (data.RefreshToken, error) { - return m.GetRefreshTokenByHashFunc(ctx, tokenHash) -} - -func (m *mockTokenStore) RevokeRefreshToken(ctx context.Context, token string) error { - return m.RevokeRefreshTokenFunc(ctx, token) -} - -func (m *mockTokenStore) RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error { - return m.RevokeAllUserRefreshTokensFunc(ctx, id) -} - -func TestGenerateTokenPair_Success(t *testing.T) { - userID := uuid.New() - called := false - - mockStore := &mockTokenStore{ - CreateRefreshTokenFunc: func(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) { - called = true - assert.Equal(t, userID, arg.UserID) - return data.RefreshToken{}, nil - }, - } - - rs := tokensResource{ - JWTSecret: "test-secret", - Tokens: mockStore, - } - - pair, err := rs.GenerateTokenPair(context.Background(), userID, false) - assert.NoError(t, err) - assert.True(t, called) - assert.NotEmpty(t, pair.AccessToken) - assert.NotEmpty(t, pair.RefreshToken) -} - -func TestGenerateTokenPair_DBError(t *testing.T) { - mockStore := &mockTokenStore{ - CreateRefreshTokenFunc: func(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) { - return data.RefreshToken{}, errors.New("db error") - }, - } - - rs := tokensResource{Tokens: mockStore} - _, err := rs.GenerateTokenPair(context.Background(), uuid.New(), false) - assert.ErrorContains(t, err, "db error") -} - -func TestValidateRefreshToken_Valid(t *testing.T) { - token := "valid-jwt-token" - hash := sha256.Sum256([]byte(token)) - expectedHash := hex.EncodeToString(hash[:]) - - mockStore := &mockTokenStore{ - GetRefreshTokenByHashFunc: func(ctx context.Context, tokenHash string) (data.RefreshToken, error) { - assert.Equal(t, expectedHash, tokenHash) - return data.RefreshToken{ExpiresAt: time.Now().Add(1 * time.Hour)}, nil - }, - } - - rs := tokensResource{Tokens: mockStore} - _, err := rs.ValidateRefreshToken(context.Background(), token) - assert.NoError(t, err) -} - -func TestRefreshAccessToken_Success(t *testing.T) { - userID := uuid.New() - refreshToken := "valid-jwt-token" - - // Expected hash of the test token - hash := sha256.Sum256([]byte(refreshToken)) - expectedHash := hex.EncodeToString(hash[:]) - - mockStore := &mockTokenStore{ - GetRefreshTokenByHashFunc: func(ctx context.Context, tokenHash string) (data.RefreshToken, error) { - assert.Equal(t, expectedHash, tokenHash) - - // Must return an unrevoked token with future expiration - return data.RefreshToken{ - TokenHash: tokenHash, - ExpiresAt: time.Now().Add(1 * time.Hour), - Revoked: false, - }, nil - }, - RevokeRefreshTokenFunc: func(ctx context.Context, tokenHash string) error { - assert.Equal(t, expectedHash, tokenHash) - return nil - }, - CreateRefreshTokenFunc: func(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) { - return data.RefreshToken{}, nil - }, - } - - rs := tokensResource{ - JWTSecret: "test-secret", - Tokens: mockStore, - } - - req := httptest.NewRequest("POST", "/", nil) - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", refreshToken)) - req = req.WithContext(context.WithValue( - req.Context(), - userCtxKey{}, - &userClaims{ - Admin: false, - TokenType: "refresh", - RegisteredClaims: jwt.RegisteredClaims{ - Subject: userID.String(), - }, - }, - )) - - w := httptest.NewRecorder() - rs.RefreshAccessToken(w, req) - - assert.Equal(t, http.StatusOK, w.Code) - assert.Contains(t, w.Body.String(), "access_token") - - cookies := w.Result().Cookies() - var refreshCookie *http.Cookie - for _, cookie := range cookies { - if cookie.Name == "refresh_token" { - refreshCookie = cookie - break - } - } - - if refreshCookie == nil { - t.Fatal("refresh token cookie not set") - } - - assert.True(t, refreshCookie.HttpOnly, "cookie should be HttpOnly") - assert.Equal(t, http.SameSiteStrictMode, refreshCookie.SameSite, "invalid SameSite mode") - assert.Equal(t, "/", refreshCookie.Path, "invalid cookie path") - assert.Greater(t, refreshCookie.MaxAge, 0, "cookie should have expiration") -} - -func TestHandleLogout_Success(t *testing.T) { - userID := uuid.New() - called := false - - mockStore := &mockTokenStore{ - RevokeAllUserRefreshTokensFunc: func(ctx context.Context, id uuid.UUID) error { - called = true - assert.Equal(t, userID, id) - return nil - }, - } - - rs := tokensResource{Tokens: mockStore} - - req := httptest.NewRequest("POST", "/", nil) - req = req.WithContext(context.WithValue( - req.Context(), - userCtxKey{}, - &userClaims{ - RegisteredClaims: jwt.RegisteredClaims{ - ID: userID.String(), - }, - }, - )) - - w := httptest.NewRecorder() - rs.HandleLogout(w, req) - - assert.True(t, called) - assert.Equal(t, http.StatusNoContent, w.Code) // 204 - - cookies := w.Result().Cookies() - var refreshCookie *http.Cookie - for _, cookie := range cookies { - if cookie.Name == "refresh_token" { - refreshCookie = cookie - break - } - } - - if refreshCookie != nil && refreshCookie.MaxAge != -1 { - t.Fatal("refresh token cookie not invalidated") - } -} diff --git a/server/pkg/service/users.go b/server/pkg/service/users.go deleted file mode 100644 index 4fbbd96..0000000 --- a/server/pkg/service/users.go +++ /dev/null @@ -1,356 +0,0 @@ -package service - -import ( - "context" - "encoding/json" - "errors" - "net/http" - "strconv" - "time" - - "git.umbrella.haus/ae/notatest/pkg/data" - "github.com/go-chi/chi/v5" - "github.com/google/uuid" - "github.com/jackc/pgx/v5/pgconn" - "github.com/rs/zerolog/log" - "golang.org/x/crypto/bcrypt" -) - -type userCtxKey struct{} - -// Stripped object that only contains non-critical data -type userResponse struct { - ID uuid.UUID `json:"id"` - Username string `json:"username"` - IsAdmin bool `json:"is_admin"` - CreatedAt *time.Time `json:"created_at"` - UpdatedAt *time.Time `json:"updated_at"` -} - -// Mockable database operations interface -type UserStore interface { - CreateUser(ctx context.Context, arg data.CreateUserParams) (data.User, error) - ListUsers(ctx context.Context) ([]data.User, error) - GetUserByID(ctx context.Context, id uuid.UUID) (data.User, error) - GetUserByUsername(ctx context.Context, username string) (data.User, error) - UpdatePassword(ctx context.Context, arg data.UpdatePasswordParams) error - DeleteUser(ctx context.Context, id uuid.UUID) error - RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error -} - -type usersResource struct { - JWTSecret string - Users UserStore -} - -func (rs usersResource) Routes() chi.Router { - r := chi.NewRouter() - - // Public routes (no tokens required) - r.Post("/", rs.Create) // POST /users - registration/signup - r.Post("/login", rs.Login) // POST /users/login - login as existing user - - // Protected routes (access token required) - r.Group(func(r chi.Router) { - r.Use(requireAccessToken(rs.JWTSecret)) - - r.Get("/me", rs.Get) // GET /users/me - get current user data - - // Admin only general routes - r.Group(func(r chi.Router) { - r.Use(adminOnlyMiddleware) - r.Get("/", rs.List) // GET /users - list all users - }) - - // User specific routes - r.Route("/{id}", func(r chi.Router) { - r.Use(userCtx(rs.Users)) // DB -> req. context - - // Admin routes - r.Route("/admin", func(r chi.Router) { - r.Use(adminOnlyMiddleware) - r.Delete("/", rs.AdminDelete) // DELETE /users/admin/{id} - delete user - }) - - // Owner routes - r.Route("/owner", func(r chi.Router) { - r.Use(ownerOnlyMiddleware) - r.Put("/", rs.UpdatePassword) // PUT /users/owner/{id} - update user password - r.Delete("/", rs.OwnerDelete) // DELETE /users/owner/{id} - delete user - }) - }) - }) - - return r -} - -func (rs usersResource) Create(w http.ResponseWriter, r *http.Request) { - type request struct { - Username string `json:"username"` - Password string `json:"password"` - } - - var req request - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - respondError(w, http.StatusBadRequest, "Invalid request body") - return - } - - normalizedUsername := normalizeUsername(req.Username) - if err := validateUsername(normalizedUsername); err != nil { - respondError(w, http.StatusBadRequest, err.Error()) - return - } - - if err := validatePassword(req.Password); err != nil { - respondError(w, http.StatusBadRequest, err.Error()) - return - } - - hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost) - if err != nil { - respondError(w, http.StatusInternalServerError, "Failed to create user") - return - } - - user, err := rs.Users.CreateUser(r.Context(), data.CreateUserParams{ - Username: normalizedUsername, - PasswordHash: string(hashedPassword), - }) - if err != nil { - if isDuplicateEntry(err) { - respondError(w, http.StatusConflict, "Username is already in use") - } else { - respondError(w, http.StatusInternalServerError, "Failed to create user") - } - return - } - - respondJSON(w, http.StatusCreated, map[string]string{ - "id": user.ID.String(), - "username": user.Username, - }) -} - -func (rs usersResource) Login(w http.ResponseWriter, r *http.Request) { - type request struct { - Username string `json:"username"` - Password string `json:"password"` - } - - var req request - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - respondError(w, http.StatusBadRequest, "Invalid request body") - return - } - - user, err := rs.Users.GetUserByUsername(r.Context(), normalizeUsername(req.Username)) - if err != nil { - respondError(w, http.StatusUnauthorized, "Invalid credentials") - return - } - - if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password)); err != nil { - respondError(w, http.StatusUnauthorized, "Invalid credentials") - return - } - - tokenPair, err := generateTokenPair(user.ID.String(), user.IsAdmin, rs.JWTSecret) - if err != nil { - respondError(w, http.StatusInternalServerError, "Failed to generate tokens") - return - } - - // Set refresh token in HTTP-only cookie - http.SetCookie(w, &http.Cookie{ - Name: "refresh_token", - Value: tokenPair.RefreshToken, - Path: "/", - MaxAge: int(refreshTokenDuration.Seconds()), - HttpOnly: true, - Secure: true, - SameSite: http.SameSiteStrictMode, - }) - - // Build response - response := map[string]any{ - "access_token": tokenPair.AccessToken, - } - - // Include user data if the client has requested it (`?includeUser=true`) - if includeUser, _ := strconv.ParseBool(r.URL.Query().Get("includeUser")); includeUser { - response["user"] = userResponse{ - ID: user.ID, - Username: user.Username, - IsAdmin: user.IsAdmin, - CreatedAt: user.CreatedAt, - UpdatedAt: user.UpdatedAt, - } - } - - respondJSON(w, http.StatusOK, response) -} - -func (rs usersResource) List(w http.ResponseWriter, r *http.Request) { - users, err := rs.Users.ListUsers(r.Context()) - if err != nil { - respondError(w, http.StatusInternalServerError, "Failed to retrieve users") - return - } - - // Output sanitization - var output []map[string]any - for _, user := range users { - output = append(output, map[string]any{ - "id": user.ID, - "username": user.Username, - "created_at": user.CreatedAt, - "updated_at": user.UpdatedAt, - }) - } - - respondJSON(w, http.StatusOK, output) -} - -func (rs usersResource) Get(w http.ResponseWriter, r *http.Request) { - claims, ok := r.Context().Value(userCtxKey{}).(*userClaims) - if !ok { - respondError(w, http.StatusUnauthorized, "Unauthorized") - return - } - - userID, err := uuid.Parse(claims.Subject) - if err != nil { - respondError(w, http.StatusInternalServerError, "Invalid user ID") - return - } - - user, err := rs.Users.GetUserByID(r.Context(), userID) - if err != nil { - respondError(w, http.StatusNotFound, "User not found") - return - } - - respondJSON(w, http.StatusOK, userResponse{ - ID: user.ID, - Username: user.Username, - IsAdmin: user.IsAdmin, - CreatedAt: user.CreatedAt, - UpdatedAt: user.UpdatedAt, - }) -} - -func (rs usersResource) UpdatePassword(w http.ResponseWriter, r *http.Request) { - type request struct { - OldPassword string `json:"old_password"` - NewPassword string `json:"new_password"` - } - - var req request - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - respondError(w, http.StatusBadRequest, "Invalid request body") - return - } - - user, ok := r.Context().Value(userCtxKey{}).(data.User) - if !ok { - respondError(w, http.StatusNotFound, "User not found") - return - } - - // Verify the old password before allowing the update - if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.OldPassword)); err != nil { - respondError(w, http.StatusUnauthorized, "Invalid credentials") - return - } - - if err := validatePassword(req.NewPassword); err != nil { - respondError(w, http.StatusBadRequest, err.Error()) - return - } - - hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.NewPassword), bcrypt.DefaultCost) - if err != nil { - respondError(w, http.StatusInternalServerError, "Failed to update password") - return - } - - if err := rs.Users.UpdatePassword(r.Context(), data.UpdatePasswordParams{ - ID: user.ID, - PasswordHash: string(hashedPassword), - }); err != nil { - respondError(w, http.StatusInternalServerError, "Failed to update password") - return - } - - if err := rs.Users.RevokeAllUserRefreshTokens(r.Context(), user.ID); err != nil { - log.Error().Msgf("Failed to revoke refresh tokens: %s", err) - } - - w.WriteHeader(http.StatusNoContent) -} - -func (rs usersResource) AdminDelete(w http.ResponseWriter, r *http.Request) { - user, ok := r.Context().Value(userCtxKey{}).(data.User) - if !ok { - respondError(w, http.StatusNotFound, "User not found") - return - } - - if err := rs.Users.DeleteUser(r.Context(), user.ID); err != nil { - respondError(w, http.StatusInternalServerError, "Failed to delete user") - return - } - - if err := rs.Users.RevokeAllUserRefreshTokens(r.Context(), user.ID); err != nil { - log.Error().Msgf("Failed to revoke refresh tokens: %s", err) - } - - w.WriteHeader(http.StatusNoContent) -} - -func (rs usersResource) OwnerDelete(w http.ResponseWriter, r *http.Request) { - type request struct { - Password string `json:"password"` - } - - var req request - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - respondError(w, http.StatusBadRequest, "Invalid request body") - return - } - - user, ok := r.Context().Value(userCtxKey{}).(data.User) - if !ok { - respondError(w, http.StatusNotFound, "User not found") - return - } - - // Verify the old password before allowing the deletion - if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password)); err != nil { - respondError(w, http.StatusUnauthorized, "Invalid credentials") - return - } - - err := rs.Users.DeleteUser(r.Context(), user.ID) - if err != nil { - respondError(w, http.StatusInternalServerError, "Failed to delete user") - return - } - - if err := rs.Users.RevokeAllUserRefreshTokens(r.Context(), user.ID); err != nil { - log.Error().Msgf("Failed to revoke refresh tokens: %s", err) - } - - w.WriteHeader(http.StatusNoContent) -} - -// Check if the given error is a PostgreSQL error for `unique_violation`, i.e. whether an entry -// with the given details already exists in the database table. -func isDuplicateEntry(err error) bool { - var pgErr *pgconn.PgError - if errors.As(err, &pgErr) { - return pgErr.Code == "23505" - } - return false -} diff --git a/server/pkg/service/users_test.go b/server/pkg/service/users_test.go deleted file mode 100644 index 2574cce..0000000 --- a/server/pkg/service/users_test.go +++ /dev/null @@ -1,557 +0,0 @@ -package service - -import ( - "bytes" - "context" - "encoding/json" - "errors" - "fmt" - "net/http" - "net/http/httptest" - "net/url" - "strings" - "testing" - "time" - - "git.umbrella.haus/ae/notatest/pkg/data" - "github.com/golang-jwt/jwt/v5" - "github.com/google/uuid" - "github.com/jackc/pgx/v5/pgconn" - "github.com/stretchr/testify/assert" - "golang.org/x/crypto/bcrypt" -) - -type mockUserStore struct { - CreateUserFunc func(context.Context, data.CreateUserParams) (data.User, error) - ListUsersFunc func(context.Context) ([]data.User, error) - GetUserByIDFunc func(context.Context, uuid.UUID) (data.User, error) - GetUserByUsernameFunc func(context.Context, string) (data.User, error) - UpdatePasswordFunc func(context.Context, data.UpdatePasswordParams) error - DeleteUserFunc func(context.Context, uuid.UUID) error - RevokeAllUserRefreshTokensFunc func(context.Context, uuid.UUID) error -} - -func (m *mockUserStore) CreateUser(ctx context.Context, arg data.CreateUserParams) (data.User, error) { - return m.CreateUserFunc(ctx, arg) -} - -func (m *mockUserStore) ListUsers(ctx context.Context) ([]data.User, error) { - return m.ListUsersFunc(ctx) -} - -func (m *mockUserStore) GetUserByID(ctx context.Context, id uuid.UUID) (data.User, error) { - return m.GetUserByIDFunc(ctx, id) -} - -func (m *mockUserStore) GetUserByUsername(ctx context.Context, username string) (data.User, error) { - return m.GetUserByUsernameFunc(ctx, username) -} - -func (m *mockUserStore) UpdatePassword(ctx context.Context, arg data.UpdatePasswordParams) error { - return m.UpdatePasswordFunc(ctx, arg) -} - -func (m *mockUserStore) DeleteUser(ctx context.Context, id uuid.UUID) error { - return m.DeleteUserFunc(ctx, id) -} - -func (m *mockUserStore) RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error { - return m.RevokeAllUserRefreshTokensFunc(ctx, id) -} - -func TestCreateUser_Duplicate(t *testing.T) { - mockStore := &mockUserStore{ - CreateUserFunc: func(ctx context.Context, arg data.CreateUserParams) (data.User, error) { - return data.User{}, &pgconn.PgError{Code: "23505"} - }, - } - - rs := usersResource{Users: mockStore, JWTSecret: "test-secret"} - - reqBody := `{"username": "existing", "password": "validPass123!"}` - req := httptest.NewRequest("POST", "/", strings.NewReader(reqBody)) - w := httptest.NewRecorder() - - rs.Create(w, req) - - assert.Equal(t, http.StatusConflict, w.Code) - assert.Contains(t, w.Body.String(), "Username is already in use") -} - -func TestCreateUser_InvalidUsername(t *testing.T) { - mockStore := &mockUserStore{} // No DB calls expected - rs := usersResource{Users: mockStore, JWTSecret: "test-secret"} - - // Test various invalid usernames - tests := []struct { - name string - body string - }{ - { - "too short", - `{"username": "a", "password": "validPass123!"}`, - }, - { - "invalid chars", - `{"username": "user@name", "password": "validPass123!"}`, - }, - { - "empty", - `{"username": "", "password": "validPass123!"}`, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - req := httptest.NewRequest("POST", "/", strings.NewReader(tc.body)) - w := httptest.NewRecorder() - rs.Create(w, req) - assert.Equal(t, http.StatusBadRequest, w.Code) - }) - } -} - -func TestCreateUser_InvalidPassword(t *testing.T) { - mockStore := &mockUserStore{} - rs := usersResource{Users: mockStore, JWTSecret: "test-secret"} - - tests := []struct { - name string - body string - }{ - { - "too short", - fmt.Sprintf(`{"username": "valid", "password": "%s"}`, strings.Repeat("a", minPasswordLength-1)), - }, - { - "too long", - fmt.Sprintf(`{"username": "valid", "password": "%s"}`, strings.Repeat("a", maxPasswordLength+1)), - }, - { - "low entropy", - fmt.Sprintf(`{"username": "valid", "password": "%s"}`, strings.Repeat("a", minPasswordLength)), - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - req := httptest.NewRequest("POST", "/", strings.NewReader(tc.body)) - w := httptest.NewRecorder() - rs.Create(w, req) - assert.Equal(t, http.StatusBadRequest, w.Code) - }) - } -} - -func TestListUsers_Success(t *testing.T) { - testUsers := []data.User{ - {ID: uuid.New(), Username: "user1"}, - {ID: uuid.New(), Username: "user2"}, - } - - mockStore := &mockUserStore{ - ListUsersFunc: func(ctx context.Context) ([]data.User, error) { - return testUsers, nil - }, - } - - rs := usersResource{Users: mockStore} - req := httptest.NewRequest("GET", "/", nil) - w := httptest.NewRecorder() - - rs.List(w, req) - - assert.Equal(t, http.StatusOK, w.Code) - - var response []map[string]any - if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil { - t.Fatal(err) - } - - assert.Len(t, response, 2) - assert.Equal(t, "user1", response[0]["username"]) - assert.NotContains(t, response[0], "password_hash") -} - -func TestUpdatePassword_InvalidOldPassword(t *testing.T) { - // User with password hash that won't match "wrongpassword" - user := data.User{ - ID: uuid.New(), - PasswordHash: "$2a$10$PHhno.bZBF8IEINdFRZAPujMxIN65msElATgJG6FIxZdeWYVLSfFi", // Hash of "correctpassword" - } - - mockStore := &mockUserStore{} - rs := usersResource{Users: mockStore} - - reqBody := `{"old_password": "wrongpassword", "new_password": "NewValidPass321!"}` - req := httptest.NewRequest("PUT", "/", strings.NewReader(reqBody)) - req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, user)) - w := httptest.NewRecorder() - - rs.UpdatePassword(w, req) - - assert.Equal(t, http.StatusUnauthorized, w.Code) -} - -func TestAdminDelete_Success(t *testing.T) { - user := data.User{ID: uuid.New()} - deleteCalled := false - revokeCalled := false - - mockStore := &mockUserStore{ - DeleteUserFunc: func(ctx context.Context, id uuid.UUID) error { - deleteCalled = true - assert.Equal(t, user.ID, id) - return nil - }, - RevokeAllUserRefreshTokensFunc: func(ctx context.Context, id uuid.UUID) error { - revokeCalled = true - assert.Equal(t, user.ID, id) - return nil - }, - } - - rs := usersResource{Users: mockStore} - req := httptest.NewRequest("DELETE", "/", nil) - req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, user)) - w := httptest.NewRecorder() - - rs.AdminDelete(w, req) - - assert.Equal(t, http.StatusNoContent, w.Code) - assert.True(t, deleteCalled) - assert.True(t, revokeCalled) -} - -func TestOwnerDelete_InvalidCredentials(t *testing.T) { - // Create user with known password hash - correctPassword := "CorrectPass123!" - hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(correctPassword), bcrypt.DefaultCost) - user := data.User{ - ID: uuid.New(), - PasswordHash: string(hashedPassword), - } - - mockStore := &mockUserStore{} - rs := usersResource{Users: mockStore} - - reqBody := `{"password": "wrongpassword"}` - req := httptest.NewRequest("DELETE", "/", strings.NewReader(reqBody)) - req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, user)) - w := httptest.NewRecorder() - - rs.OwnerDelete(w, req) - - assert.Equal(t, http.StatusUnauthorized, w.Code) -} - -func TestUsersGetCurrentUser(t *testing.T) { - validUserID := uuid.New() - testTime := time.Now().UTC().Truncate(time.Second) - testUser := data.User{ - ID: validUserID, - Username: "testuser", - CreatedAt: &testTime, - UpdatedAt: &testTime, - IsAdmin: false, - } - - tests := []struct { - name string - setupContext func(context.Context) context.Context - mockSetup func(*mockUserStore) - wantStatus int - wantResponse string - }{ - { - name: "success", - setupContext: func(ctx context.Context) context.Context { - return context.WithValue(ctx, userCtxKey{}, &userClaims{ - RegisteredClaims: jwt.RegisteredClaims{ - Subject: validUserID.String(), - }, - }) - }, - mockSetup: func(m *mockUserStore) { - m.GetUserByIDFunc = func(_ context.Context, id uuid.UUID) (data.User, error) { - assert.Equal(t, validUserID, id) - return testUser, nil - } - }, - wantStatus: http.StatusOK, - wantResponse: fmt.Sprintf( - `{"created_at":"%s","id":"%s","is_admin":false,"updated_at":"%s","username":"testuser"}`, - testUser.CreatedAt.Format(time.RFC3339Nano), - validUserID.String(), - testUser.UpdatedAt.Format(time.RFC3339Nano), - ), - }, - { - name: "user not found", - setupContext: func(ctx context.Context) context.Context { - return context.WithValue(ctx, userCtxKey{}, &userClaims{ - RegisteredClaims: jwt.RegisteredClaims{ - Subject: validUserID.String(), - }, - }) - }, - mockSetup: func(m *mockUserStore) { - m.GetUserByIDFunc = func(_ context.Context, id uuid.UUID) (data.User, error) { - return data.User{}, errors.New("not found") - } - }, - wantStatus: http.StatusNotFound, - wantResponse: `{"error":"User not found"}`, - }, - { - name: "unauthorized", - setupContext: func(ctx context.Context) context.Context { - return ctx // No user claims in context - }, - mockSetup: func(m *mockUserStore) {}, - wantStatus: http.StatusUnauthorized, - wantResponse: `{"error":"Unauthorized"}`, - }, - { - name: "invalid user ID", - setupContext: func(ctx context.Context) context.Context { - return context.WithValue(ctx, userCtxKey{}, &userClaims{ - RegisteredClaims: jwt.RegisteredClaims{ - Subject: "invalid", - }, - }) - }, - mockSetup: func(m *mockUserStore) {}, - wantStatus: http.StatusInternalServerError, - wantResponse: `{"error":"Invalid user ID"}`, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mockStore := &mockUserStore{} - tt.mockSetup(mockStore) - - rs := usersResource{Users: mockStore} - req := httptest.NewRequest("GET", "/me", nil) - req = req.WithContext(tt.setupContext(req.Context())) - - w := httptest.NewRecorder() - rs.Get(w, req) - - assert.Equal(t, tt.wantStatus, w.Code) - - if tt.wantResponse != "" { - actual := strings.TrimSpace(w.Body.String()) - assert.JSONEq(t, tt.wantResponse, actual) - } - - // Verify sensitive fields are never exposed - if w.Code == http.StatusOK { - var response map[string]any - json.Unmarshal(w.Body.Bytes(), &response) - _, exists := response["password_hash"] - assert.False(t, exists, "password_hash should not be exposed") - } - }) - } -} - -func TestUpdatePassword_DatabaseError(t *testing.T) { - // Add user with a valid password to the context - oldPassword := "OldValidPass321!" - hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(oldPassword), bcrypt.DefaultCost) - user := data.User{ - ID: uuid.New(), - PasswordHash: string(hashedPassword), - } - mockStore := &mockUserStore{ - UpdatePasswordFunc: func(ctx context.Context, arg data.UpdatePasswordParams) error { - return errors.New("database error") - }, - } - - rs := usersResource{Users: mockStore} - - reqBody := fmt.Sprintf(`{"old_password": "%s", "new_password": "NewValidPass123!"}`, oldPassword) - req := httptest.NewRequest("PUT", "/", strings.NewReader(reqBody)) - req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, user)) - w := httptest.NewRecorder() - - rs.UpdatePassword(w, req) - - assert.Equal(t, http.StatusInternalServerError, w.Code) - assert.Contains(t, w.Body.String(), "Failed to update password") -} - -func TestUsersLogin(t *testing.T) { - validPassword := "validPass123!" - hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(validPassword), bcrypt.DefaultCost) - testUser := data.User{ - ID: uuid.New(), - Username: "test_username", - PasswordHash: string(hashedPassword), - } - jwtSecret := "test-secret" - - tests := []struct { - name string - includeUser string - wantUserData bool - requestBody any - mockSetup func(*mockUserStore) - wantStatus int - wantResponse string - checkCookie bool - }{ - { - name: "invalid request body", - requestBody: "invalid", - mockSetup: func(m *mockUserStore) {}, - wantStatus: http.StatusBadRequest, - wantResponse: `{"error":"Invalid request body"}`, - }, - { - name: "user not found", - requestBody: map[string]string{ - "username": "nouser", - "password": validPassword, - }, - mockSetup: func(m *mockUserStore) { - m.GetUserByUsernameFunc = func(_ context.Context, username string) (data.User, error) { - return data.User{}, errors.New("not found") - } - }, - wantStatus: http.StatusUnauthorized, - wantResponse: `{"error":"Invalid credentials"}`, - }, - { - name: "invalid password", - requestBody: map[string]string{ - "username": testUser.Username, - "password": "wrongpassword", - }, - mockSetup: func(m *mockUserStore) { - m.GetUserByUsernameFunc = func(_ context.Context, username string) (data.User, error) { - return testUser, nil - } - }, - wantStatus: http.StatusUnauthorized, - wantResponse: `{"error":"Invalid credentials"}`, - }, - { - name: "successful login with user data", - includeUser: "true", - wantUserData: true, - requestBody: map[string]string{ - "username": testUser.Username, - "password": validPassword, - }, - mockSetup: func(m *mockUserStore) { - m.GetUserByUsernameFunc = func(_ context.Context, username string) (data.User, error) { - return testUser, nil - } - }, - wantStatus: http.StatusOK, - checkCookie: true, - }, - { - name: "successful login without user data", - includeUser: "false", - wantUserData: false, - requestBody: map[string]string{ - "username": testUser.Username, - "password": validPassword, - }, - mockSetup: func(m *mockUserStore) { - m.GetUserByUsernameFunc = func(_ context.Context, username string) (data.User, error) { - return testUser, nil - } - }, - wantStatus: http.StatusOK, - checkCookie: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mockStore := &mockUserStore{} - tt.mockSetup(mockStore) - - rs := usersResource{ - Users: mockStore, - JWTSecret: jwtSecret, - } - - body, _ := json.Marshal(tt.requestBody) - req := httptest.NewRequest("POST", "/login", bytes.NewReader(body)) - req.Header.Set("Content-Type", "application/json") - - // Add the necessary query parameters - q := url.Values{} - q.Add("includeUser", tt.includeUser) - req.URL.RawQuery = q.Encode() - - w := httptest.NewRecorder() - rs.Login(w, req) - - if w.Code != tt.wantStatus { - t.Errorf("expected status %d, got %d", tt.wantStatus, w.Code) - } - - if tt.wantResponse != "" && strings.TrimSpace(w.Body.String()) != tt.wantResponse { - t.Errorf("expected response %q, got %q", tt.wantResponse, w.Body.String()) - } - - if tt.wantUserData { - var response struct { - AccessToken string `json:"access_token"` - User data.User `json:"user"` // Cast to the "raw" type to allow checking for sensitive data fields - } - json.Unmarshal(w.Body.Bytes(), &response) - - assert.Equal(t, testUser.ID, response.User.ID) - assert.Equal(t, testUser.Username, response.User.Username) - assert.Empty(t, response.User.PasswordHash) // Ensure sensitive data excluded - } - - if tt.checkCookie { - cookies := w.Result().Cookies() - var refreshCookie *http.Cookie - for _, cookie := range cookies { - if cookie.Name == "refresh_token" { - refreshCookie = cookie - break - } - } - - if refreshCookie == nil { - t.Fatal("refresh token cookie not set") - } - - assert.True(t, refreshCookie.HttpOnly, "cookie should be HttpOnly") - assert.Equal(t, http.SameSiteStrictMode, refreshCookie.SameSite, "invalid SameSite mode") - assert.Equal(t, "/", refreshCookie.Path, "invalid cookie path") - assert.Greater(t, refreshCookie.MaxAge, 0, "cookie should have expiration") - - // Validate access token in response - var response map[string]string - json.Unmarshal(w.Body.Bytes(), &response) - if response["access_token"] == "" { - t.Error("access token not in response") - } - - // Verify JWT validity - token, err := jwt.ParseWithClaims( - response["access_token"], - &userClaims{}, - func(token *jwt.Token) (any, error) { - return []byte(jwtSecret), nil - }, - ) - assert.NoError(t, err, "invalid JWT") - assert.True(t, token.Valid, "invalid JWT") - } - }) - } -} diff --git a/server/sql/migrations/0001_initial.up.sql b/server/sql/migrations/0001_initial.up.sql index 78ce5fd..8a7eca3 100644 --- a/server/sql/migrations/0001_initial.up.sql +++ b/server/sql/migrations/0001_initial.up.sql @@ -1,10 +1,5 @@ CREATE EXTENSION IF NOT EXISTS "uuid-ossp"; -CREATE TABLE IF NOT EXISTS schema_migrations ( - version BIGINT PRIMARY KEY, - applied_at TIMESTAMPTZ DEFAULT NOW() -); - CREATE TABLE IF NOT EXISTS users ( id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), username TEXT UNIQUE NOT NULL, @@ -26,6 +21,8 @@ CREATE TABLE IF NOT EXISTS refresh_tokens ( CREATE TABLE IF NOT EXISTS notes ( id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + current_version INT NOT NULL DEFAULT 1, -- active version (can be historical) + latest_version INT NOT NULL DEFAULT 1, -- highest version number created_at TIMESTAMPTZ DEFAULT NOW(), updated_at TIMESTAMPTZ DEFAULT NOW() ); @@ -35,15 +32,18 @@ CREATE TABLE IF NOT EXISTS note_versions ( note_id UUID NOT NULL REFERENCES notes(id) ON DELETE CASCADE, title TEXT NOT NULL, content TEXT NOT NULL, - version_number INT NOT NULL, + version_number INT NOT NULL DEFAULT 1, content_hash TEXT NOT NULL, - created_at TIMESTAMPTZ DEFAULT NOW() + created_at TIMESTAMPTZ DEFAULT NOW(), + UNIQUE (note_id, version_number) ); CREATE UNIQUE INDEX IF NOT EXISTS idx_users_username ON users(username); + CREATE INDEX IF NOT EXISTS idx_refresh_tokens_user_id ON refresh_tokens(user_id); CREATE INDEX IF NOT EXISTS idx_refresh_tokens_expires_at ON refresh_tokens(expires_at); -CREATE UNIQUE INDEX IF NOT EXISTS idx_note_version_unique ON note_versions(note_id, version_number); -CREATE INDEX IF NOT EXISTS idx_note_versions_note ON note_versions(note_id); -CREATE INDEX IF NOT EXISTS idx_note_versions_number ON note_versions(version_number DESC); +CREATE INDEX IF NOT EXISTS idx_notes_user_updated ON notes(user_id, updated_at DESC); +CREATE INDEX IF NOT EXISTS idx_notes_current_version ON notes(current_version); + +CREATE INDEX IF NOT EXISTS idx_note_versions_content_hash ON note_versions(note_id, content_hash); diff --git a/server/sql/queries/note_versions.sql b/server/sql/queries/note_versions.sql index 493ded8..0dab6c3 100644 --- a/server/sql/queries/note_versions.sql +++ b/server/sql/queries/note_versions.sql @@ -1,27 +1,58 @@ --- name: CreateNoteVersion :one -INSERT INTO note_versions (note_id, title, content, version_number, content_hash) -VALUES ( - $1, - $2, - $3, - (SELECT COALESCE(MAX(version_number), 0) + 1 FROM note_versions WHERE note_id = $1), - encode(sha256($2::bytea || '\n'::bytea || $3::bytea), 'hex') +-- name: CreateNoteVersion :exec +WITH potential_duplicate AS ( + SELECT version_number + FROM note_versions + WHERE + note_id = $1 + AND content_hash = $2 + ORDER BY version_number DESC + LIMIT 1 +), +note_update AS ( + UPDATE notes + SET + current_version = COALESCE( + (SELECT version_number FROM potential_duplicate), + latest_version + 1 -- increment only if we don't jump into a historical version + ), + latest_version = CASE + WHEN (SELECT version_number FROM potential_duplicate) IS NULL + THEN latest_version + 1 + ELSE latest_version + END, + updated_at = NOW() + WHERE id = $1 + RETURNING current_version, latest_version ) -RETURNING *; +INSERT INTO note_versions ( + note_id, title, content, version_number, content_hash +) +SELECT + $1, -- note_id + $3, -- title + $4, -- content + current_version, + $2 -- content_hash +FROM note_update +WHERE NOT EXISTS (SELECT 1 FROM potential_duplicate); --- name: GetNoteVersions :many -SELECT * FROM note_versions +-- name: GetVersionHistory :many +SELECT + id AS version_id, + title, + version_number, + created_at +FROM note_versions WHERE note_id = $1 ORDER BY version_number DESC LIMIT $2 OFFSET $3; --- name: GetNoteVersion :one -SELECT * FROM note_versions -WHERE note_id = $1 AND version_number = $2 LIMIT 1; - --- name: FindDuplicateContent :one -SELECT EXISTS( - SELECT 1 FROM note_versions - WHERE note_id = $1 - AND content_hash = encode(sha256($2::bytea || '\n'::bytea || $3::bytea), 'hex') -); +-- name: GetVersion :one +SELECT + id AS version_id, + title, + content, + version_number, + created_at +FROM note_versions +WHERE note_id = $1 AND id = $2; diff --git a/server/sql/queries/notes.sql b/server/sql/queries/notes.sql index de58d34..15140a6 100644 --- a/server/sql/queries/notes.sql +++ b/server/sql/queries/notes.sql @@ -3,16 +3,34 @@ INSERT INTO notes (user_id) VALUES ($1) RETURNING *; --- name: GetNote :one -SELECT * FROM notes -WHERE id = $1 AND user_id = $2 LIMIT 1; - -- name: ListNotes :many -SELECT * FROM notes -WHERE user_id = $1 -ORDER BY created_at DESC +SELECT + n.id AS note_id, + n.user_id AS owner_id, + nv.title, + n.updated_at +FROM notes n +JOIN note_versions nv + ON n.id = nv.note_id AND n.current_version = nv.version_number +WHERE n.user_id = $1 +ORDER BY n.updated_at DESC LIMIT $2 OFFSET $3; +-- name: GetFullNote :one +SELECT + n.id AS note_id, + n.user_id AS owner_id, + nv.title, + nv.content, + nv.version_number, + nv.created_at AS version_created_at, + n.created_at AS note_created_at, + n.updated_at AS note_updated_at +FROM notes n +JOIN note_versions nv + ON n.id = nv.note_id AND n.current_version = nv.version_number +WHERE n.id = $1; + -- name: DeleteNote :exec DELETE FROM notes WHERE id = $1 AND user_id = $2; diff --git a/server/sql/queries/users.sql b/server/sql/queries/users.sql index 5f24219..a2b073b 100644 --- a/server/sql/queries/users.sql +++ b/server/sql/queries/users.sql @@ -3,9 +3,18 @@ INSERT INTO users (username, password_hash) VALUES ($1, $2) RETURNING *; +-- name: CreateAdmin :one +INSERT INTO users (username, password_hash, is_admin) +VALUES ($1, $2, true) +RETURNING *; + -- name: ListUsers :many SELECT * FROM users; +-- name: ListAdmins :many +SELECT * FROM users +WHERE is_admin = true; + -- name: GetUserByID :one SELECT * FROM users WHERE id = $1 LIMIT 1; diff --git a/server/sql/sqlc.yaml b/server/sql/sqlc.yaml index 572ccb1..42bd096 100644 --- a/server/sql/sqlc.yaml +++ b/server/sql/sqlc.yaml @@ -7,7 +7,7 @@ sql: gen: go: package: "data" - out: "../pkg/data" + out: "../internal/data" sql_package: "pgx/v5" emit_json_tags: true overrides: