feat!: trimming & logic/schema improvements

- build: somewhat polished dockerization setup
- build: io/fs migrations with `golang-migrate`
- feat: automatic init. admin account creation (.env creds)
- feat(routers): combined user & token routers into single auth router
- feat(routers): improved route layouts (`Routes`)
- feat(middlewares): removed redundant `userCtx` middleware
- fix(schema): note <-> note_versions relation (versioning)
- feat(queries): removed redundant rollback functionality
- feat(queries): combined duplicate version check & insertion/creation
- tests: decreased redundancy by removing 'unnecessary' unit tests
- refactor: hid internal packages behind `server/internal`
- docs: notes & auth handler comments
This commit is contained in:
ae 2025-04-09 01:58:38 +03:00
parent b1edbeb0a3
commit 62b1a58e56
Signed by: ae
GPG Key ID: 995EFD5C1B532B3E
32 changed files with 2184 additions and 2987 deletions

View File

@ -0,0 +1,19 @@
# Build stage
FROM golang:1.24.2-alpine AS builder
WORKDIR /app
COPY go.mod go.sum .
RUN go mod download
COPY . .
# Optionally we could also strip debug symbols with `-ldflags '-s'`
RUN CGO_ENABLED=0 GOOS=linux go build -o /notatest .
# Final stage (optimized image size)
FROM alpine:latest
WORKDIR /app
COPY --from=builder /notatest /app/notatest
EXPOSE 8080
CMD ["/app/notatest"]

View File

@ -3,9 +3,10 @@ module git.umbrella.haus/ae/notatest
go 1.24.1 go 1.24.1
require ( require (
github.com/caarlos0/env v3.5.0+incompatible github.com/caarlos0/env/v10 v10.0.0
github.com/go-chi/chi/v5 v5.2.1 github.com/go-chi/chi/v5 v5.2.1
github.com/golang-jwt/jwt/v5 v5.2.2 github.com/golang-jwt/jwt/v5 v5.2.2
github.com/golang-migrate/migrate/v4 v4.18.2
github.com/google/uuid v1.6.0 github.com/google/uuid v1.6.0
github.com/jackc/pgx/v5 v5.7.4 github.com/jackc/pgx/v5 v5.7.4
github.com/rs/zerolog v1.34.0 github.com/rs/zerolog v1.34.0
@ -16,14 +17,16 @@ require (
require ( require (
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-multierror v1.1.1 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/kr/text v0.2.0 // indirect github.com/lib/pq v1.10.9 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.14.1 // indirect
github.com/stretchr/objx v0.5.2 // indirect github.com/stretchr/objx v0.5.2 // indirect
go.uber.org/atomic v1.11.0 // indirect
golang.org/x/sys v0.31.0 // indirect golang.org/x/sys v0.31.0 // indirect
golang.org/x/text v0.23.0 // indirect golang.org/x/text v0.23.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect

View File

@ -1,17 +1,45 @@
github.com/caarlos0/env v3.5.0+incompatible h1:Yy0UN8o9Wtr/jGHZDpCBLpNrzcFLLM2yixi/rBrKyJs= github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0=
github.com/caarlos0/env v3.5.0+incompatible/go.mod h1:tdCsowwCzMLdkqRYDlHpZCp2UooDD3MspDBjZ2AD02Y= github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/caarlos0/env/v10 v10.0.0 h1:yIHUBZGsyqCnpTkbjk8asUlx6RFhhEs+h7TOBdgdzXA=
github.com/caarlos0/env/v10 v10.0.0/go.mod h1:ZfulV76NvVPw3tm591U4SwL3Xx9ldzBP9aGxzeN7G18=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dhui/dktest v0.4.4 h1:+I4s6JRE1yGuqflzwqG+aIaMdgXIorCf5P98JnaAWa8=
github.com/dhui/dktest v0.4.4/go.mod h1:4+22R4lgsdAXrDyaH4Nqx2JEz2hLp49MqQmm9HLCQhM=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
github.com/docker/docker v27.2.0+incompatible h1:Rk9nIVdfH3+Vz4cyI/uhbINhEZ/oLmc+CBXmH6fbNk4=
github.com/docker/docker v27.2.0+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c=
github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc=
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/go-chi/chi/v5 v5.2.1 h1:KOIHODQj58PmL80G2Eak4WdvUzjSJSm0vG72crDCqb8= github.com/go-chi/chi/v5 v5.2.1 h1:KOIHODQj58PmL80G2Eak4WdvUzjSJSm0vG72crDCqb8=
github.com/go-chi/chi/v5 v5.2.1/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops= github.com/go-chi/chi/v5 v5.2.1/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops=
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang-migrate/migrate/v4 v4.18.2 h1:2VSCMz7x7mjyTXx3m2zPokOY82LTRgxK1yQYKo6wWQ8=
github.com/golang-migrate/migrate/v4 v4.18.2/go.mod h1:2CM6tJvn2kqPXwnXO/d3rAQYiyoIm180VsO8PRX6Rpk=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
@ -24,6 +52,8 @@ github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
@ -31,11 +61,22 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0=
github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y=
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug=
github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
@ -48,6 +89,16 @@ github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOf
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/wagslane/go-password-validator v0.3.0 h1:vfxOPzGHkz5S146HDpavl0cw1DSVP061Ry2PX0/ON6I= github.com/wagslane/go-password-validator v0.3.0 h1:vfxOPzGHkz5S146HDpavl0cw1DSVP061Ry2PX0/ON6I=
github.com/wagslane/go-password-validator v0.3.0/go.mod h1:TI1XJ6T5fRdRnHqHt14pvy1tNVnrwe7m3/f1f2fDphQ= github.com/wagslane/go-password-validator v0.3.0/go.mod h1:TI1XJ6T5fRdRnHqHt14pvy1tNVnrwe7m3/f1f2fDphQ=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 h1:TT4fX+nBOA/+LUkobKGW1ydGcn+G3vRw9+g5HwCphpk=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0/go.mod h1:L7UH0GbB0p47T4Rri3uHjbpCFYrVrwc1I25QhNPiGK8=
go.opentelemetry.io/otel v1.29.0 h1:PdomN/Al4q/lN6iBJEN3AwPvUiHPMlt93c8bqTG5Llw=
go.opentelemetry.io/otel v1.29.0/go.mod h1:N/WtXPs1CNCUEx+Agz5uouwCba+i+bJGFicT8SR4NP8=
go.opentelemetry.io/otel/metric v1.29.0 h1:vPf/HFWTNkPu1aYeIsc98l4ktOQaL6LeSoeV2g+8YLc=
go.opentelemetry.io/otel/metric v1.29.0/go.mod h1:auu/QWieFVWx+DmQOUMgj0F8LHWdgalxXqvp7BII/W8=
go.opentelemetry.io/otel/trace v1.29.0 h1:J/8ZNK4XgR7a21DZUAsbF8pZ5Jcw1VhACmnYt39JTi4=
go.opentelemetry.io/otel/trace v1.29.0/go.mod h1:eHl3w0sp3paPkYstJOmAimxhiFXPg+MMTlEh3nsQgWQ=
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=

View File

@ -11,10 +11,12 @@ import (
) )
type Note struct { type Note struct {
ID uuid.UUID `json:"id"` ID uuid.UUID `json:"id"`
UserID uuid.UUID `json:"user_id"` UserID uuid.UUID `json:"user_id"`
CreatedAt *time.Time `json:"created_at"` CurrentVersion int32 `json:"current_version"`
UpdatedAt *time.Time `json:"updated_at"` LatestVersion int32 `json:"latest_version"`
CreatedAt *time.Time `json:"created_at"`
UpdatedAt *time.Time `json:"updated_at"`
} }
type NoteVersion struct { type NoteVersion struct {
@ -36,11 +38,6 @@ type RefreshToken struct {
Revoked bool `json:"revoked"` Revoked bool `json:"revoked"`
} }
type SchemaMigration struct {
Version int64 `json:"version"`
AppliedAt *time.Time `json:"applied_at"`
}
type User struct { type User struct {
ID uuid.UUID `json:"id"` ID uuid.UUID `json:"id"`
Username string `json:"username"` Username string `json:"username"`

View File

@ -0,0 +1,156 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.28.0
// source: note_versions.sql
package data
import (
"context"
"time"
"github.com/google/uuid"
)
const createNoteVersion = `-- name: CreateNoteVersion :exec
WITH potential_duplicate AS (
SELECT version_number
FROM note_versions
WHERE
note_id = $1
AND content_hash = $2
ORDER BY version_number DESC
LIMIT 1
),
note_update AS (
UPDATE notes
SET
current_version = COALESCE(
(SELECT version_number FROM potential_duplicate),
latest_version + 1 -- increment only if we don't jump into a historical version
),
latest_version = CASE
WHEN (SELECT version_number FROM potential_duplicate) IS NULL
THEN latest_version + 1
ELSE latest_version
END,
updated_at = NOW()
WHERE id = $1
RETURNING current_version, latest_version
)
INSERT INTO note_versions (
note_id, title, content, version_number, content_hash
)
SELECT
$1, -- note_id
$3, -- title
$4, -- content
current_version,
$2 -- content_hash
FROM note_update
WHERE NOT EXISTS (SELECT 1 FROM potential_duplicate)
`
type CreateNoteVersionParams struct {
NoteID uuid.UUID `json:"note_id"`
ContentHash string `json:"content_hash"`
Title string `json:"title"`
Content string `json:"content"`
}
func (q *Queries) CreateNoteVersion(ctx context.Context, arg CreateNoteVersionParams) error {
_, err := q.db.Exec(ctx, createNoteVersion,
arg.NoteID,
arg.ContentHash,
arg.Title,
arg.Content,
)
return err
}
const getVersion = `-- name: GetVersion :one
SELECT
id AS version_id,
title,
content,
version_number,
created_at
FROM note_versions
WHERE note_id = $1 AND id = $2
`
type GetVersionParams struct {
NoteID uuid.UUID `json:"note_id"`
ID uuid.UUID `json:"id"`
}
type GetVersionRow struct {
VersionID uuid.UUID `json:"version_id"`
Title string `json:"title"`
Content string `json:"content"`
VersionNumber int32 `json:"version_number"`
CreatedAt *time.Time `json:"created_at"`
}
func (q *Queries) GetVersion(ctx context.Context, arg GetVersionParams) (GetVersionRow, error) {
row := q.db.QueryRow(ctx, getVersion, arg.NoteID, arg.ID)
var i GetVersionRow
err := row.Scan(
&i.VersionID,
&i.Title,
&i.Content,
&i.VersionNumber,
&i.CreatedAt,
)
return i, err
}
const getVersionHistory = `-- name: GetVersionHistory :many
SELECT
id AS version_id,
title,
version_number,
created_at
FROM note_versions
WHERE note_id = $1
ORDER BY version_number DESC
LIMIT $2 OFFSET $3
`
type GetVersionHistoryParams struct {
NoteID uuid.UUID `json:"note_id"`
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
}
type GetVersionHistoryRow struct {
VersionID uuid.UUID `json:"version_id"`
Title string `json:"title"`
VersionNumber int32 `json:"version_number"`
CreatedAt *time.Time `json:"created_at"`
}
func (q *Queries) GetVersionHistory(ctx context.Context, arg GetVersionHistoryParams) ([]GetVersionHistoryRow, error) {
rows, err := q.db.Query(ctx, getVersionHistory, arg.NoteID, arg.Limit, arg.Offset)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetVersionHistoryRow
for rows.Next() {
var i GetVersionHistoryRow
if err := rows.Scan(
&i.VersionID,
&i.Title,
&i.VersionNumber,
&i.CreatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}

View File

@ -0,0 +1,143 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.28.0
// source: notes.sql
package data
import (
"context"
"time"
"github.com/google/uuid"
)
const createNote = `-- name: CreateNote :one
INSERT INTO notes (user_id)
VALUES ($1)
RETURNING id, user_id, current_version, latest_version, created_at, updated_at
`
func (q *Queries) CreateNote(ctx context.Context, userID uuid.UUID) (Note, error) {
row := q.db.QueryRow(ctx, createNote, userID)
var i Note
err := row.Scan(
&i.ID,
&i.UserID,
&i.CurrentVersion,
&i.LatestVersion,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const deleteNote = `-- name: DeleteNote :exec
DELETE FROM notes
WHERE id = $1 AND user_id = $2
`
type DeleteNoteParams struct {
ID uuid.UUID `json:"id"`
UserID uuid.UUID `json:"user_id"`
}
func (q *Queries) DeleteNote(ctx context.Context, arg DeleteNoteParams) error {
_, err := q.db.Exec(ctx, deleteNote, arg.ID, arg.UserID)
return err
}
const getFullNote = `-- name: GetFullNote :one
SELECT
n.id AS note_id,
n.user_id AS owner_id,
nv.title,
nv.content,
nv.version_number,
nv.created_at AS version_created_at,
n.created_at AS note_created_at,
n.updated_at AS note_updated_at
FROM notes n
JOIN note_versions nv
ON n.id = nv.note_id AND n.current_version = nv.version_number
WHERE n.id = $1
`
type GetFullNoteRow struct {
NoteID uuid.UUID `json:"note_id"`
OwnerID uuid.UUID `json:"owner_id"`
Title string `json:"title"`
Content string `json:"content"`
VersionNumber int32 `json:"version_number"`
VersionCreatedAt *time.Time `json:"version_created_at"`
NoteCreatedAt *time.Time `json:"note_created_at"`
NoteUpdatedAt *time.Time `json:"note_updated_at"`
}
func (q *Queries) GetFullNote(ctx context.Context, id uuid.UUID) (GetFullNoteRow, error) {
row := q.db.QueryRow(ctx, getFullNote, id)
var i GetFullNoteRow
err := row.Scan(
&i.NoteID,
&i.OwnerID,
&i.Title,
&i.Content,
&i.VersionNumber,
&i.VersionCreatedAt,
&i.NoteCreatedAt,
&i.NoteUpdatedAt,
)
return i, err
}
const listNotes = `-- name: ListNotes :many
SELECT
n.id AS note_id,
n.user_id AS owner_id,
nv.title,
n.updated_at
FROM notes n
JOIN note_versions nv
ON n.id = nv.note_id AND n.current_version = nv.version_number
WHERE n.user_id = $1
ORDER BY n.updated_at DESC
LIMIT $2 OFFSET $3
`
type ListNotesParams struct {
UserID uuid.UUID `json:"user_id"`
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
}
type ListNotesRow struct {
NoteID uuid.UUID `json:"note_id"`
OwnerID uuid.UUID `json:"owner_id"`
Title string `json:"title"`
UpdatedAt *time.Time `json:"updated_at"`
}
func (q *Queries) ListNotes(ctx context.Context, arg ListNotesParams) ([]ListNotesRow, error) {
rows, err := q.db.Query(ctx, listNotes, arg.UserID, arg.Limit, arg.Offset)
if err != nil {
return nil, err
}
defer rows.Close()
var items []ListNotesRow
for rows.Next() {
var i ListNotesRow
if err := rows.Scan(
&i.NoteID,
&i.OwnerID,
&i.Title,
&i.UpdatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}

View File

@ -11,6 +11,31 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
) )
const createAdmin = `-- name: CreateAdmin :one
INSERT INTO users (username, password_hash, is_admin)
VALUES ($1, $2, true)
RETURNING id, username, password_hash, is_admin, created_at, updated_at
`
type CreateAdminParams struct {
Username string `json:"username"`
PasswordHash string `json:"password_hash"`
}
func (q *Queries) CreateAdmin(ctx context.Context, arg CreateAdminParams) (User, error) {
row := q.db.QueryRow(ctx, createAdmin, arg.Username, arg.PasswordHash)
var i User
err := row.Scan(
&i.ID,
&i.Username,
&i.PasswordHash,
&i.IsAdmin,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const createUser = `-- name: CreateUser :one const createUser = `-- name: CreateUser :one
INSERT INTO users (username, password_hash) INSERT INTO users (username, password_hash)
VALUES ($1, $2) VALUES ($1, $2)
@ -84,6 +109,38 @@ func (q *Queries) GetUserByUsername(ctx context.Context, username string) (User,
return i, err return i, err
} }
const listAdmins = `-- name: ListAdmins :many
SELECT id, username, password_hash, is_admin, created_at, updated_at FROM users
WHERE is_admin = true
`
func (q *Queries) ListAdmins(ctx context.Context) ([]User, error) {
rows, err := q.db.Query(ctx, listAdmins)
if err != nil {
return nil, err
}
defer rows.Close()
var items []User
for rows.Next() {
var i User
if err := rows.Scan(
&i.ID,
&i.Username,
&i.PasswordHash,
&i.IsAdmin,
&i.CreatedAt,
&i.UpdatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listUsers = `-- name: ListUsers :many const listUsers = `-- name: ListUsers :many
SELECT id, username, password_hash, is_admin, created_at, updated_at FROM users SELECT id, username, password_hash, is_admin, created_at, updated_at FROM users
` `

View File

@ -0,0 +1,658 @@
package service
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"net/http"
"strconv"
"strings"
"time"
"git.umbrella.haus/ae/notatest/internal/data"
"github.com/go-chi/chi/v5"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgconn"
"github.com/rs/zerolog/log"
"golang.org/x/crypto/bcrypt"
)
const (
accessTokenDuration = 15 * time.Minute
refreshTokenDuration = 7 * 24 * time.Hour
)
var (
ErrInvalidToken = errors.New("invalid token")
ErrAuthHeaderInvalid = errors.New("token couldn't be parsed from authentication header")
)
// User object context key for incoming requests (handled by middlewares). Only `*userClaims` type
// objects should be stored behind this key for consistency.
type userCtxKey struct{}
// DTO without sensitive data fields such as user's password hash.
type userResponse struct {
ID uuid.UUID `json:"id"`
Username string `json:"username"`
IsAdmin bool `json:"is_admin"`
CreatedAt *time.Time `json:"created_at"`
UpdatedAt *time.Time `json:"updated_at"`
}
// Custom JWT claims (should always be handled in the middleware layer).
type userClaims struct {
Admin bool `json:"admin"`
TokenType string `json:"type"` // "access" or "refresh"
jwt.RegisteredClaims // User's UUID should be stored in the subject claim
}
type tokenPair struct {
AccessToken string
RefreshToken string
}
// Mockable token related database operations interface.
type TokenStore interface {
CreateRefreshToken(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error)
GetRefreshTokenByHash(ctx context.Context, tokenHash string) (data.RefreshToken, error)
RevokeRefreshToken(ctx context.Context, tokenHash string) error
RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error
}
// Mockable user related database operations interface.
type UserStore interface {
CreateUser(ctx context.Context, arg data.CreateUserParams) (data.User, error)
ListUsers(ctx context.Context) ([]data.User, error)
GetUserByID(ctx context.Context, id uuid.UUID) (data.User, error)
GetUserByUsername(ctx context.Context, username string) (data.User, error)
UpdatePassword(ctx context.Context, arg data.UpdatePasswordParams) error
DeleteUser(ctx context.Context, id uuid.UUID) error
RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error
}
// Chi HTTP router for authentication/authorization related actions (users/tokens). In theory
// (especially in production) the `UserStore` and `TokenStore` will point to the same database
// handler, but for code readability they should be kept in separate structs.
type authResource struct {
JWTSecret string
Users UserStore
Tokens TokenStore
}
func (rs authResource) Routes() chi.Router {
r := chi.NewRouter()
// Public routes
r.Post("/signup", rs.Create) // POST /auth/signup - registration
r.Post("/login", rs.Login) // POST /auth/login - login
// Protected routes (access token required)
r.Group(func(r chi.Router) {
r.Use(requireAccessToken(rs.JWTSecret)) // JWT claims -> ctx.
r.Get("/me", rs.Get) // GET /auth/me - current user data
r.Post("/logout", rs.Logout) // POST /auth/logout - revoke all refresh cookies
// Owner routes
r.Route("/owner", func(r chi.Router) {
r.Put("/", rs.UpdatePassword) // PUT /auth/owner - update user password
r.Delete("/", rs.OwnerDelete) // DELETE /auth/owner - delete user (owner)
})
// Administration routes (admin claim required)
r.Route("/admin", func(r chi.Router) {
r.Use(adminOnlyMiddleware)
r.Get("/all", rs.List) // GET /auth/admin/all - list all users
r.Route(fmt.Sprintf("/{%s}", targetUserUUIDCtxParameter), func(r chi.Router) {
r.Use(uuidCtx(targetUserUUIDCtxParameter))
r.Delete("/", rs.AdminDelete) // DELETE /auth/admin/{id} - delete user (admin)
})
})
})
// Protected routes (refresh token required)
r.Group(func(r chi.Router) {
r.Use(requireRefreshToken(rs.JWTSecret))
r.Post("/refresh", rs.RefreshAccessToken) // POST /auth/refresh - convert refresh token to new token pair
})
return r
}
// Handler for new user creation. Will check the incoming JSON object's integrity, validate/normalize
// the username, and validate the password (check whether it's compromised via the HIBP API and
// calculate its entropy).
func (rs authResource) Create(w http.ResponseWriter, r *http.Request) {
type request struct {
Username *string `json:"username"`
Password *string `json:"password"`
}
var req request
if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.Username == nil || req.Password == nil {
respondError(w, http.StatusBadRequest, "Invalid request body")
return
}
// Username normalization (to lowercase)
normalizedUsername := normalizeUsername(*req.Username)
if err := validateUsername(normalizedUsername); err != nil {
respondError(w, http.StatusBadRequest, err.Error())
return
}
// Password validation (length, HIBP API, and entropy)
if err := validatePassword(*req.Password); err != nil {
respondError(w, http.StatusBadRequest, err.Error())
return
}
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(*req.Password), bcrypt.DefaultCost)
if err != nil {
respondError(w, http.StatusInternalServerError, "Failed to create user")
return
}
user, err := rs.Users.CreateUser(r.Context(), data.CreateUserParams{
Username: normalizedUsername,
PasswordHash: string(hashedPassword),
})
if err != nil {
if isDuplicateEntry(err) {
respondError(w, http.StatusConflict, "Username is already in use")
} else {
respondError(w, http.StatusInternalServerError, "Failed to create user")
}
return
}
respondJSON(w, http.StatusCreated, map[string]string{
"id": user.ID.String(),
"username": user.Username,
})
}
// Handler for logging in to an existing user account using a username-password credentials pair.
// Will check the incoming JSON object's integrity, use a normalized version of the username for a
// database lookup, and compare the given password's hash against the one stored in the database.
// By default only returns a fresh access tokens (and a refresh token as a httpOnly cookie), but
// if the `includeUser` parameter is set to `true` the user DTO will also be included.
func (rs authResource) Login(w http.ResponseWriter, r *http.Request) {
type request struct {
Username *string `json:"username"`
Password *string `json:"password"`
}
var req request
if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.Username == nil || req.Password == nil {
respondError(w, http.StatusBadRequest, "Invalid request body")
return
}
user, err := rs.Users.GetUserByUsername(r.Context(), normalizeUsername(*req.Username))
if err != nil {
respondError(w, http.StatusUnauthorized, "Invalid credentials")
return
}
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(*req.Password)); err != nil {
respondError(w, http.StatusUnauthorized, "Invalid credentials")
return
}
// Generate a new access/refresh token pair and store a SHA256 hash of the refresh token into
// the database for further token rotations.
tokenPair, err := rs.GenerateTokenPair(r.Context(), user.ID, user.IsAdmin)
if err != nil {
respondError(w, http.StatusInternalServerError, "Failed to generate tokens")
return
}
// Set refresh token into a httpOnly cookie
http.SetCookie(w, &http.Cookie{
Name: "refresh_token",
Value: tokenPair.RefreshToken,
Path: "/",
MaxAge: int(refreshTokenDuration.Seconds()),
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteStrictMode,
})
// Build response
response := map[string]any{
"access_token": tokenPair.AccessToken,
}
// Include user data DTO into the response if the `includeUser` parameter was set to `true`
if includeUser, _ := strconv.ParseBool(r.URL.Query().Get("includeUser")); includeUser {
response["user"] = userResponse{
ID: user.ID,
Username: user.Username,
IsAdmin: user.IsAdmin,
CreatedAt: user.CreatedAt,
UpdatedAt: user.UpdatedAt,
}
}
respondJSON(w, http.StatusOK, response)
}
// Handler for getting full data of the current user as a DTO (database lookup) based on the JWT
// claims set into the request's context by a middleware.
func (rs authResource) Get(w http.ResponseWriter, r *http.Request) {
user := rs.userFromCtxClaims(w, r)
respondJSON(w, http.StatusOK, userResponse{
ID: user.ID,
Username: user.Username,
IsAdmin: user.IsAdmin,
CreatedAt: user.CreatedAt,
UpdatedAt: user.UpdatedAt,
})
}
// Handler for updating the current user's password. Performs the same password strength checks as
// the registration handler (`rs.Create`) and revokes any existing refresh tokens the user has
// stored in the database.
func (rs authResource) UpdatePassword(w http.ResponseWriter, r *http.Request) {
type request struct {
OldPassword string `json:"old_password"`
NewPassword string `json:"new_password"`
}
var req request
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
respondError(w, http.StatusBadRequest, "Invalid request body")
return
}
user := rs.userFromCtxClaims(w, r)
// Verify the old password before proceeding with the update
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.OldPassword)); err != nil {
respondError(w, http.StatusUnauthorized, "Invalid credentials")
return
}
if err := validatePassword(req.NewPassword); err != nil {
respondError(w, http.StatusBadRequest, err.Error())
return
}
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.NewPassword), bcrypt.DefaultCost)
if err != nil {
respondError(w, http.StatusInternalServerError, "Failed to update password")
return
}
if err := rs.Users.UpdatePassword(r.Context(), data.UpdatePasswordParams{
ID: user.ID,
PasswordHash: string(hashedPassword),
}); err != nil {
respondError(w, http.StatusInternalServerError, "Failed to update password")
return
}
if err := rs.Users.RevokeAllUserRefreshTokens(r.Context(), user.ID); err != nil {
log.Error().Msgf("Failed to revoke refresh tokens: %s", err)
}
w.WriteHeader(http.StatusNoContent)
}
// Handler for hard deleting the current user. Requires the user's password as JSON input as a precaution.
func (rs authResource) OwnerDelete(w http.ResponseWriter, r *http.Request) {
type request struct {
Password string `json:"password"`
}
var req request
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
respondError(w, http.StatusBadRequest, "Invalid request body")
return
}
user := rs.userFromCtxClaims(w, r)
// Verify the old password before allowing the deletion
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password)); err != nil {
respondError(w, http.StatusUnauthorized, "Invalid credentials")
return
}
err := rs.Users.DeleteUser(r.Context(), user.ID)
if err != nil {
respondError(w, http.StatusInternalServerError, "Failed to delete user")
return
}
if err := rs.Users.RevokeAllUserRefreshTokens(r.Context(), user.ID); err != nil {
log.Error().Msgf("Failed to revoke refresh tokens: %s", err)
}
w.WriteHeader(http.StatusNoContent)
}
// Handler for listing all users stored in the database. Should only be allowed to be called by
// administrator level users.
func (rs authResource) List(w http.ResponseWriter, r *http.Request) {
users, err := rs.Users.ListUsers(r.Context())
if err != nil {
respondError(w, http.StatusInternalServerError, "Failed to retrieve users")
return
}
// Output sanitization to DTO
var output []userResponse
for _, user := range users {
output = append(output, userResponse{
ID: user.ID,
Username: user.Username,
IsAdmin: user.IsAdmin,
CreatedAt: user.CreatedAt,
UpdatedAt: user.UpdatedAt,
})
}
respondJSON(w, http.StatusOK, output)
}
// Handler for deleting another user account based on their ID. Will check the existence of the
// user based on the given ID and additionally revoke all the stored refresh tokens on successful
// deletion. Should only be allowed to be called by administrator level users.
func (rs authResource) AdminDelete(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
targetID, ok := ctx.Value(uuidCtxKey{Name: targetUserUUIDCtxParameter}).(uuid.UUID)
if !ok {
respondError(w, http.StatusBadRequest, "Resource ID missing")
return
}
if err := rs.Users.DeleteUser(r.Context(), targetID); err != nil {
respondError(w, http.StatusInternalServerError, "Failed to delete user")
return
}
if err := rs.Users.RevokeAllUserRefreshTokens(r.Context(), targetID); err != nil {
log.Error().Msgf("Failed to revoke refresh tokens: %s", err)
}
w.WriteHeader(http.StatusNoContent)
}
// Generate a new pair of access and refresh tokens (JWTs) with the user's UUID as the identifying
// `Subject` claim and custom claims for the user's administrator status (boolean) and token type
// ("refresh"/"access"). Stores a SHA256 hash of the refresh token into the database for further
// token rotations.
func (rs authResource) GenerateTokenPair(ctx context.Context, userID uuid.UUID, isAdmin bool) (*tokenPair, error) {
tokenPair, err := generateTokenPair(userID.String(), isAdmin, rs.JWTSecret)
if err != nil {
return nil, err
}
hash := sha256.Sum256([]byte(tokenPair.RefreshToken))
tokenHash := hex.EncodeToString(hash[:])
// Store the SHA256 hash of the refresh token with (almost) identical expiration timestamp
expiresAt := time.Now().Add(refreshTokenDuration)
_, err = rs.Tokens.CreateRefreshToken(ctx, data.CreateRefreshTokenParams{
UserID: userID,
TokenHash: tokenHash,
ExpiresAt: expiresAt,
})
if err != nil {
return nil, err
}
return tokenPair, nil
}
// Revoke the given token from the database by calculating its SHA256 hash.
func (rs authResource) RevokeRefreshToken(ctx context.Context, token string) error {
hash := sha256.Sum256([]byte(token))
tokenHash := hex.EncodeToString(hash[:])
return rs.Tokens.RevokeRefreshToken(ctx, tokenHash)
}
// Validate the given refresh token by performing a database lookup with its SHA256 hash. Returns
// the refresh token database object on successful lookup. Fails if the token has been revoked
// (soft database operation) or expired (i.e. the corresponding user has to log in again).
func (rs authResource) ValidateRefreshToken(ctx context.Context, token string) (*data.RefreshToken, error) {
hash := sha256.Sum256([]byte(token))
tokenHash := hex.EncodeToString(hash[:])
dbToken, err := rs.Tokens.GetRefreshTokenByHash(ctx, tokenHash)
if err != nil {
return nil, err
}
// Check for soft revocation and/or expiration
if dbToken.Revoked || time.Now().After(dbToken.ExpiresAt) {
return nil, ErrInvalidToken
}
return &dbToken, nil
}
// Handler for performing a token rotation, i.e. invalidating the given refresh token (each refresh
// token is a single use utility) and exchanging it for a new pair of refresh and access tokens.
func (rs authResource) RefreshAccessToken(w http.ResponseWriter, r *http.Request) {
// Get claims from context
claims, ok := r.Context().Value(userCtxKey{}).(*userClaims)
if !ok || claims.TokenType != "refresh" {
respondError(w, http.StatusUnauthorized, "Invalid token")
return
}
// Attempt to get the token from Authorization header (formatted as "Bearer <token>")
refreshToken, err := getTokenFromRequest(r)
if err != nil {
respondError(w, http.StatusUnauthorized, "Unauthorized")
return
}
// Validate the refresh token in the database
if _, err := rs.ValidateRefreshToken(r.Context(), refreshToken); err != nil {
respondError(w, http.StatusUnauthorized, "Invalid refresh token")
return
}
// Revoke the given (single use) refresh token
if err := rs.RevokeRefreshToken(r.Context(), refreshToken); err != nil {
respondError(w, http.StatusInternalServerError, "Failed to revoke token")
return
}
userID, err := uuid.Parse(claims.Subject)
if err != nil {
respondError(w, http.StatusBadRequest, "Invalid user ID")
return
}
// Generate a new pair (access & refresh tokens)
tokenPair, err := rs.GenerateTokenPair(r.Context(), userID, claims.Admin)
if err != nil {
respondError(w, http.StatusInternalServerError, "Failed to generate tokens")
return
}
// Set refresh token into a httpOnly cookie
http.SetCookie(w, &http.Cookie{
Name: "refresh_token",
Value: tokenPair.RefreshToken,
Path: "/",
MaxAge: int(refreshTokenDuration.Seconds()),
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteStrictMode,
})
// Return the access token in the response body (it should be stored in browser's memory client-side)
respondJSON(w, http.StatusOK, map[string]string{
"access_token": tokenPair.AccessToken,
})
}
// Handler for performing a logout process for the current user, i.e. replacing the current
// httpOnly `refresh_token` cookie with one that expires immediately. Theoretically the user
// will still be able to authenticate until the access token (stored client-side) expires,
// but that's up to the client to handle.
func (rs authResource) Logout(w http.ResponseWriter, r *http.Request) {
claims, ok := r.Context().Value(userCtxKey{}).(*userClaims)
if !ok {
respondError(w, http.StatusUnauthorized, "Not authenticated")
return
}
userID, err := uuid.Parse(claims.Subject)
if err != nil {
respondError(w, http.StatusBadRequest, "Invalid user ID")
return
}
// Clear the refresh token cookie
http.SetCookie(w, &http.Cookie{
Name: "refresh_token",
Value: "",
Path: "/",
MaxAge: 0, // Expires immediately
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteStrictMode,
})
if err := rs.Tokens.RevokeAllUserRefreshTokens(r.Context(), userID); err != nil {
respondError(w, http.StatusInternalServerError, "Failed to logout")
return
}
w.WriteHeader(http.StatusNoContent)
}
// Helper function for generating the initial administrator level account if one doesn't already
// exists in the database.
func CreateAdminIfNotExists(ctx context.Context, q *data.Queries, username, password string) error {
admins, err := q.ListAdmins(ctx)
if err != nil {
return err
}
if len(admins) > 0 {
log.Debug().Msg("Admin accounts already exist, skipping creation")
return nil
}
// Username normalization (to lowercase)
normalizedUsername := normalizeUsername(username)
if err := validateUsername(normalizedUsername); err != nil {
return err
}
// Password validation (length, HIBP API, and entropy)
if err := validatePassword(password); err != nil {
return err
}
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return err
}
_, err = q.CreateAdmin(ctx, data.CreateAdminParams{
Username: normalizedUsername,
PasswordHash: string(hashedPassword),
})
if err != nil {
return err
}
log.Info().Msgf("Initial admin user '%s' created successfully", username)
return nil
}
// Parse the JWT bearer token from the request's Authorization header.
func getTokenFromRequest(r *http.Request) (string, error) {
bearerToken := r.Header.Get("Authorization")
bearerFields := strings.Fields(bearerToken)
if len(bearerFields) == 2 && strings.ToLower(bearerFields[0]) == "bearer" {
return bearerFields[1], nil
}
return "", ErrAuthHeaderInvalid
}
// Helper function for generating a new JWT token pair with the given specifications.
func generateTokenPair(userID string, isAdmin bool, jwtSecret string) (*tokenPair, error) {
atClaims := userClaims{
Admin: isAdmin,
TokenType: "access",
RegisteredClaims: jwt.RegisteredClaims{
Subject: userID,
ExpiresAt: jwt.NewNumericDate(time.Now().Add(accessTokenDuration)),
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
},
}
accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, atClaims)
t, err := accessToken.SignedString([]byte(jwtSecret))
if err != nil {
return nil, err
}
rtClaims := userClaims{
Admin: isAdmin,
TokenType: "refresh",
RegisteredClaims: jwt.RegisteredClaims{
Subject: userID,
ExpiresAt: jwt.NewNumericDate(time.Now().Add(refreshTokenDuration)),
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
},
}
refreshToken := jwt.NewWithClaims(jwt.SigningMethodHS256, rtClaims)
rt, err := refreshToken.SignedString([]byte(jwtSecret))
if err != nil {
return nil, err
}
return &tokenPair{AccessToken: t, RefreshToken: rt}, nil
}
// Parse JWT claims (`userClaims`) from the request's context, and perform a database lookup based
// on `Subject` (after parsing it to `uuid.UUID`) to fetch the corresponding user's data.
func (rs authResource) userFromCtxClaims(w http.ResponseWriter, r *http.Request) *data.User {
claims, ok := r.Context().Value(userCtxKey{}).(*userClaims)
if !ok {
respondError(w, http.StatusUnauthorized, "Unauthorized")
return nil
}
userID, err := uuid.Parse(claims.Subject)
if err != nil {
respondError(w, http.StatusBadRequest, "Invalid user ID")
return nil
}
user, err := rs.Users.GetUserByID(r.Context(), userID)
if err != nil {
respondError(w, http.StatusNotFound, "User not found")
return nil
}
return &user
}
// Check if the given error is a PostgreSQL error for `unique_violation` (error code 23505), i.e.
// whether an entry with the given details already exists in the database table.
func isDuplicateEntry(err error) bool {
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) {
return pgErr.Code == "23505"
}
return false
}

View File

@ -2,11 +2,10 @@ package service
import ( import (
"context" "context"
"fmt"
"net/http" "net/http"
"time" "time"
"git.umbrella.haus/ae/notatest/pkg/data" "git.umbrella.haus/ae/notatest/internal/data"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware" "github.com/go-chi/chi/v5/middleware"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
@ -17,8 +16,17 @@ import (
const ( const (
panicRecoveryMsg = "panic recovered" panicRecoveryMsg = "panic recovered"
defaultLogMsg = "incoming request" defaultLogMsg = "incoming request"
noteUUIDCtxParameter = "noteID"
versionUUIDCtxParameter = "versionID"
targetUserUUIDCtxParameter = "targetID"
) )
// General resource ID (UUID) context key.
type uuidCtxKey struct {
Name string
}
// Get JWT bearer from request's authorization header, parse it with custom user claims, and // Get JWT bearer from request's authorization header, parse it with custom user claims, and
// ensure its validity before attaching the claims to the request's context. // ensure its validity before attaching the claims to the request's context.
func authMiddleware(jwtSecret string, expectedType string) func(http.Handler) http.Handler { func authMiddleware(jwtSecret string, expectedType string) func(http.Handler) http.Handler {
@ -26,7 +34,8 @@ func authMiddleware(jwtSecret string, expectedType string) func(http.Handler) ht
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tokenString, err := getTokenFromRequest(r) tokenString, err := getTokenFromRequest(r)
if err != nil { if err != nil {
respondError(w, http.StatusUnauthorized, fmt.Sprintf("Unauthorized: %s", err)) respondError(w, http.StatusUnauthorized, "Unauthorized")
return
} }
token, err := jwt.ParseWithClaims(tokenString, &userClaims{}, func(token *jwt.Token) (any, error) { token, err := jwt.ParseWithClaims(tokenString, &userClaims{}, func(token *jwt.Token) (any, error) {
@ -59,7 +68,8 @@ func requireRefreshToken(jwtSecret string) func(http.Handler) http.Handler {
return authMiddleware(jwtSecret, "refresh") return authMiddleware(jwtSecret, "refresh")
} }
// Ensure the current user is an administrator. // Ensure the current user is an administrator. Can be used to protect routes that can be utilized
// to view/modify/delete accounts that the current user isn't the owner of.
func adminOnlyMiddleware(next http.Handler) http.Handler { func adminOnlyMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user, ok := r.Context().Value(userCtxKey{}).(*userClaims) user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
@ -71,57 +81,38 @@ func adminOnlyMiddleware(next http.Handler) http.Handler {
}) })
} }
// Ensure the targeted resource is owned by the current user (i.e. current user's ID matches with // Append UUID from the given URL parameter to the request's context (`uuidCtxKey` with the
// the one stored into the resource). // parameter name as the "context identifier").
func ownerOnlyMiddleware(next http.Handler) http.Handler { func uuidCtx(parameter string) func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
requestedID := chi.URLParam(r, "id")
if !ok || user.Subject != requestedID {
respondError(w, http.StatusForbidden, "Forbidden")
return
}
next.ServeHTTP(w, r)
})
}
// Append user data into request's context based on user ID as a URL parameter.
func userCtx(store UserStore) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
userIDStr := chi.URLParam(r, "id") uuidParam := chi.URLParam(r, parameter)
userID, err := uuid.Parse(userIDStr) resourceID, err := uuid.Parse(uuidParam)
if err != nil { if err != nil {
respondError(w, http.StatusNotFound, "Invalid user ID") respondError(w, http.StatusBadRequest, "Invalid resource ID")
return return
} }
user, err := store.GetUserByID(r.Context(), userID) ctx := context.WithValue(r.Context(), uuidCtxKey{Name: parameter}, resourceID)
if err != nil {
respondError(w, http.StatusNotFound, "User not found")
return
}
ctx := context.WithValue(r.Context(), userCtxKey{}, user)
next.ServeHTTP(w, r.WithContext(ctx)) next.ServeHTTP(w, r.WithContext(ctx))
}) })
} }
} }
// Append note data into request's context based on note ID as a URL parameter and user ID as // Append full note data (metadata + active version) into request's context based on note ID as a
// context parameter. // URL parameter and user ID as context parameter. Must be chained with `uuidCtx` to parse the
// resource ID into the request's context.
func noteCtx(store NoteStore) func(http.Handler) http.Handler { func noteCtx(store NoteStore) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
noteIDStr := chi.URLParam(r, "id") ctx := r.Context()
noteID, err := uuid.Parse(noteIDStr) noteID, ok := ctx.Value(uuidCtxKey{Name: noteUUIDCtxParameter}).(uuid.UUID)
if err != nil { if !ok {
respondError(w, http.StatusNotFound, "Invalid note ID") respondError(w, http.StatusBadRequest, "Resource ID missing")
return return
} }
// NOTE: user must already be in the context (e.g. via JWT middleware) user, ok := ctx.Value(userCtxKey{}).(*userClaims)
user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
if !ok { if !ok {
respondError(w, http.StatusUnauthorized, "Unauthorized") respondError(w, http.StatusUnauthorized, "Unauthorized")
return return
@ -129,20 +120,58 @@ func noteCtx(store NoteStore) func(http.Handler) http.Handler {
userID, err := uuid.Parse(user.Subject) userID, err := uuid.Parse(user.Subject)
if err != nil { if err != nil {
respondError(w, http.StatusInternalServerError, "Invalid user ID") respondError(w, http.StatusUnauthorized, "Invalid token")
return return
} }
note, err := store.GetNote(r.Context(), data.GetNoteParams{ // Get the "full note" (metadata + active version) with a single query
ID: noteID, fullNote, err := store.GetFullNote(r.Context(), noteID)
UserID: userID,
})
if err != nil { if err != nil {
respondError(w, http.StatusNotFound, "Note not found") respondError(w, http.StatusNotFound, "Note not found")
return return
} }
ctx := context.WithValue(r.Context(), noteCtxKey{}, note) // Validate note ownership
if userID != fullNote.OwnerID {
respondError(w, http.StatusForbidden, "Forbidden")
return
}
ctx = context.WithValue(r.Context(), noteCtxKey{}, &fullNote)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// Append single version's data into request's context based on version ID as a URL parameter and
// note ID as context parameter. Must be chained with `noteCtx` and `uuidCtx` to parse the necessary
// resource IDs into request's context.
func versionCtx(store NoteStore) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
fullNote, ok := ctx.Value(noteCtxKey{}).(*data.GetFullNoteRow)
if !ok {
respondError(w, http.StatusNotFound, "Note not found")
return
}
versionID, ok := ctx.Value(uuidCtxKey{Name: versionUUIDCtxParameter}).(uuid.UUID)
if !ok {
respondError(w, http.StatusBadRequest, "Resource ID missing")
return
}
version, err := store.GetVersion(r.Context(), data.GetVersionParams{
NoteID: fullNote.NoteID,
ID: versionID,
})
if err != nil {
respondError(w, http.StatusNotFound, "Version not found")
return
}
ctx = context.WithValue(r.Context(), versionCtxKey{}, &version)
next.ServeHTTP(w, r.WithContext(ctx)) next.ServeHTTP(w, r.WithContext(ctx))
}) })
} }

View File

@ -0,0 +1,519 @@
package service
import (
"context"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"git.umbrella.haus/ae/notatest/internal/data"
"github.com/go-chi/chi/v5"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
type mockNoteStore struct {
CreateNoteFunc func(context.Context, uuid.UUID) (data.Note, error)
DeleteNoteFunc func(context.Context, data.DeleteNoteParams) error
GetFullNoteFunc func(context.Context, uuid.UUID) (data.GetFullNoteRow, error)
ListNotesFunc func(context.Context, data.ListNotesParams) ([]data.ListNotesRow, error)
CreateNoteVersionFunc func(context.Context, data.CreateNoteVersionParams) error
GetVersionFunc func(context.Context, data.GetVersionParams) (data.GetVersionRow, error)
GetVersionHistoryFunc func(context.Context, data.GetVersionHistoryParams) ([]data.GetVersionHistoryRow, error)
}
func (m *mockNoteStore) CreateNote(ctx context.Context, id uuid.UUID) (data.Note, error) {
return m.CreateNoteFunc(ctx, id)
}
func (m *mockNoteStore) DeleteNote(ctx context.Context, arg data.DeleteNoteParams) error {
return m.DeleteNoteFunc(ctx, arg)
}
func (m *mockNoteStore) GetFullNote(ctx context.Context, id uuid.UUID) (data.GetFullNoteRow, error) {
return m.GetFullNoteFunc(ctx, id)
}
func (m *mockNoteStore) ListNotes(ctx context.Context, arg data.ListNotesParams) ([]data.ListNotesRow, error) {
return m.ListNotesFunc(ctx, arg)
}
func (m *mockNoteStore) CreateNoteVersion(ctx context.Context, arg data.CreateNoteVersionParams) error {
return m.CreateNoteVersionFunc(ctx, arg)
}
func (m *mockNoteStore) GetVersion(ctx context.Context, arg data.GetVersionParams) (data.GetVersionRow, error) {
return m.GetVersionFunc(ctx, arg)
}
func (m *mockNoteStore) GetVersionHistory(ctx context.Context, arg data.GetVersionHistoryParams) ([]data.GetVersionHistoryRow, error) {
return m.GetVersionHistoryFunc(ctx, arg)
}
func TestAuthMiddleware(t *testing.T) {
secret := "test-jwt-secret"
testUserID := uuid.New().String()
validRT := generateTestToken(t, secret, "refresh", testUserID, true)
validAT := generateTestToken(t, secret, "access", testUserID, true)
expiredAT := generateTestToken(t, secret, "access", testUserID, true, func(claims *userClaims) {
claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(-1 * time.Hour))
})
tests := []struct {
name string
token string
expectedErr string
statusCode int
}{
{
"no token",
"",
"Unauthorized",
http.StatusUnauthorized,
},
{
"invalid token",
"invalid",
"Invalid token",
http.StatusUnauthorized,
},
{
"expired token",
expiredAT,
"Invalid token",
http.StatusUnauthorized,
},
{
"wrong token type",
validRT,
"Invalid token type",
http.StatusUnauthorized,
},
{
"valid token",
validAT,
"",
http.StatusOK,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
atAuthMiddleware := requireAccessToken(secret)
// Mock request
req := httptest.NewRequest("GET", "/", nil)
if tc.token != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", tc.token))
}
w := httptest.NewRecorder()
called := false
// Mock endpoint that the middleware protects
handler := atAuthMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
_, ok := r.Context().Value(userCtxKey{}).(*userClaims)
assert.True(t, ok)
}))
handler.ServeHTTP(w, req)
assert.Equal(t, tc.statusCode, w.Code)
assert.Equal(t, tc.statusCode == http.StatusOK, called)
if tc.expectedErr != "" {
assert.Contains(t, w.Body.String(), tc.expectedErr)
}
})
}
}
func TestAdminOnlyMiddleware(t *testing.T) {
tests := []struct {
name string
user *userClaims
statusCode int
}{
{
"no user",
nil,
http.StatusForbidden,
},
{
"non admin user",
&userClaims{
Admin: false,
},
http.StatusForbidden,
},
{
"admin user",
&userClaims{
Admin: true,
},
http.StatusOK,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Mock request
req := httptest.NewRequest("GET", "/", nil)
if tc.user != nil {
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, tc.user))
}
w := httptest.NewRecorder()
called := false
// Mock endpoint that the middleware protects
handler := adminOnlyMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
w.WriteHeader(http.StatusOK)
}))
handler.ServeHTTP(w, req)
assert.Equal(t, tc.statusCode, w.Code)
assert.Equal(t, tc.statusCode == http.StatusOK, called)
})
}
}
func TestUUIDCtxMiddleware(t *testing.T) {
testKeyName := "testKey"
tests := []struct {
name string
parameter string
expectedErr string
statusCode int
}{
{
"missing uuid",
"",
"Invalid resource ID",
http.StatusBadRequest,
},
{
"invalid uuid",
"invalid",
"Invalid resource ID",
http.StatusBadRequest,
},
{
"valid uuid",
uuid.New().String(),
"",
http.StatusOK,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
uuidCtxMiddleware := uuidCtx(testKeyName)
req := httptest.NewRequest("GET", "/", nil)
// We need to mock the URL parameter as we don't setup an actual router in this test env.
rctx := chi.NewRouteContext()
rctx.URLParams.Add(testKeyName, tc.parameter)
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
w := httptest.NewRecorder()
called := false
// Mock endpoint that the middleware protects
handler := uuidCtxMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
_, ok := r.Context().Value(uuidCtxKey{Name: testKeyName}).(uuid.UUID)
assert.True(t, ok)
}))
handler.ServeHTTP(w, req)
assert.Equal(t, tc.statusCode, w.Code)
assert.Equal(t, tc.statusCode == http.StatusOK, called)
if tc.expectedErr != "" {
assert.Contains(t, w.Body.String(), tc.expectedErr)
}
})
}
}
func TestNoteCtxMiddleware(t *testing.T) {
testTitle := "Test title"
tesTContent := "## Test content\nData 123"
testVersion := int32(3)
noteID := uuid.New()
ownerUserID := uuid.New()
testOwnerClaims := userClaims{RegisteredClaims: jwt.RegisteredClaims{
Subject: ownerUserID.String(),
}}
otherUserID := uuid.New()
testOtherClaims := userClaims{RegisteredClaims: jwt.RegisteredClaims{
Subject: otherUserID.String(),
}}
tests := []struct {
name string
resourceID *uuid.UUID
user *userClaims
mock func(*mockNoteStore)
statusCode int
expectedErr string
}{
{
"no resource id",
nil,
nil,
func(m *mockNoteStore) {},
http.StatusBadRequest,
"Resource ID missing",
},
{
"unauthorized",
&noteID,
nil,
func(m *mockNoteStore) {},
http.StatusUnauthorized,
"Unauthorized",
},
{
"note not found",
&noteID,
&testOwnerClaims,
func(m *mockNoteStore) {
m.GetFullNoteFunc = func(ctx context.Context, id uuid.UUID) (data.GetFullNoteRow, error) {
return data.GetFullNoteRow{}, errors.New("not found")
}
},
http.StatusNotFound,
"Note not found",
},
{
"not owner",
&noteID,
&testOtherClaims,
func(m *mockNoteStore) {
m.GetFullNoteFunc = func(ctx context.Context, id uuid.UUID) (data.GetFullNoteRow, error) {
assert.Equal(t, noteID, id)
testTs := time.Now()
return data.GetFullNoteRow{
NoteID: id,
OwnerID: ownerUserID,
Title: testTitle,
Content: tesTContent,
VersionNumber: testVersion,
VersionCreatedAt: &testTs,
NoteCreatedAt: &testTs,
NoteUpdatedAt: &testTs,
}, nil
}
},
http.StatusForbidden,
"Forbidden",
},
{
"success",
&noteID,
&testOwnerClaims,
func(m *mockNoteStore) {
m.GetFullNoteFunc = func(ctx context.Context, id uuid.UUID) (data.GetFullNoteRow, error) {
assert.Equal(t, noteID, id)
testTs := time.Now()
return data.GetFullNoteRow{
NoteID: id,
OwnerID: ownerUserID,
Title: testTitle,
Content: tesTContent,
VersionNumber: testVersion,
VersionCreatedAt: &testTs,
NoteCreatedAt: &testTs,
NoteUpdatedAt: &testTs,
}, nil
}
},
http.StatusOK,
"",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// UUID ctx. (mock) -> note ctx. (tested here)
mockStore := &mockNoteStore{}
req := httptest.NewRequest("GET", "/", nil)
tc.mock(mockStore)
// Mock endpoint that the middleware protects (where the attached note data is actually utilized)
handler := noteCtx(mockStore)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fullNote, ok := r.Context().Value(noteCtxKey{}).(*data.GetFullNoteRow)
assert.True(t, ok)
assert.Equal(t, noteID, fullNote.NoteID)
assert.Equal(t, testTitle, fullNote.Title)
assert.Equal(t, tesTContent, fullNote.Content)
assert.Equal(t, testVersion, fullNote.VersionNumber)
w.WriteHeader(http.StatusOK)
}))
// Request parameters don't need to be mocked, as parsing of them isn't handled
// by this middleware, and thus that portion shouldn't be tested here.
if tc.resourceID != nil {
req = req.WithContext(context.WithValue(req.Context(), uuidCtxKey{Name: noteUUIDCtxParameter}, *tc.resourceID))
}
if tc.user != nil {
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, tc.user))
}
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, tc.statusCode, w.Code)
if tc.expectedErr != "" {
assert.Contains(t, w.Body.String(), tc.expectedErr)
}
})
}
}
func TestVersionCtxMiddleware(t *testing.T) {
testTitle := "Test title"
tesTContent := "## Test content\nData 123"
testVersion := int32(3)
versionID := uuid.New()
noteID := uuid.New()
testNote := data.GetFullNoteRow{
NoteID: noteID,
}
tests := []struct {
name string
resourceID *uuid.UUID
note *data.GetFullNoteRow
mock func(*mockNoteStore)
statusCode int
expectedErr string
}{
{
"no note",
nil,
nil,
func(m *mockNoteStore) {},
http.StatusNotFound,
"Note not found",
},
{
"no resource id",
nil,
&testNote,
func(m *mockNoteStore) {},
http.StatusBadRequest,
"Resource ID missing",
},
{
"version not found",
&versionID,
&testNote,
func(m *mockNoteStore) {
m.GetVersionFunc = func(ctx context.Context, gvp data.GetVersionParams) (data.GetVersionRow, error) {
return data.GetVersionRow{}, errors.New("not found")
}
},
http.StatusNotFound,
"Version not found",
},
{
"success",
&versionID,
&testNote,
func(m *mockNoteStore) {
m.GetVersionFunc = func(ctx context.Context, gvp data.GetVersionParams) (data.GetVersionRow, error) {
assert.Equal(t, versionID, gvp.ID)
assert.Equal(t, noteID, gvp.NoteID)
testTs := time.Now()
return data.GetVersionRow{
VersionID: gvp.ID,
Title: testTitle,
Content: tesTContent,
VersionNumber: testVersion,
CreatedAt: &testTs,
}, nil
}
},
http.StatusOK,
"",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// note ctx. (mock) -> UUID ctx. (mock) -> version ctx. (tested here)
mockStore := &mockNoteStore{}
req := httptest.NewRequest("GET", "/", nil)
tc.mock(mockStore)
// Mock endpoint that the middleware protects (where the attached note data is actually utilized)
handler := versionCtx(mockStore)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fullVersion, ok := r.Context().Value(versionCtxKey{}).(*data.GetVersionRow)
assert.True(t, ok)
assert.Equal(t, versionID, fullVersion.VersionID)
assert.Equal(t, testTitle, fullVersion.Title)
assert.Equal(t, tesTContent, fullVersion.Content)
assert.Equal(t, testVersion, fullVersion.VersionNumber)
w.WriteHeader(http.StatusOK)
}))
// Request parameters don't need to be mocked, as parsing of them isn't handled
// by this middleware, and thus that portion shouldn't be tested here.
if tc.note != nil {
req = req.WithContext(context.WithValue(req.Context(), noteCtxKey{}, tc.note))
}
if tc.resourceID != nil {
req = req.WithContext(context.WithValue(req.Context(), uuidCtxKey{Name: versionUUIDCtxParameter}, *tc.resourceID))
}
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, tc.statusCode, w.Code)
if tc.expectedErr != "" {
assert.Contains(t, w.Body.String(), tc.expectedErr)
}
})
}
}
func generateTestToken(t *testing.T, secret, tokenType, userID string, isAdmin bool, opts ...func(*userClaims)) string {
t.Helper()
claims := &userClaims{
Admin: isAdmin,
TokenType: tokenType,
RegisteredClaims: jwt.RegisteredClaims{
Subject: userID,
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
},
}
for _, opt := range opts {
opt(claims)
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
signedToken, err := token.SignedString([]byte(secret))
if err != nil {
t.Fatalf("Failed to generate test token: %v", err)
}
return signedToken
}

View File

@ -0,0 +1,331 @@
package service
import (
"context"
"crypto/sha1"
"encoding/hex"
"encoding/json"
"fmt"
"net/http"
"strconv"
"strings"
"git.umbrella.haus/ae/notatest/internal/data"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
)
const (
titleMaxLength = 150
initVersionTitle = "Untitled"
initVersionContent = ""
)
// Note object context key for incoming requests (handled my middlewares). Only `*data.GetFullNoteRow`
// type objects should be stored behind this key for consistency.
type noteCtxKey struct{}
// Note version object context key for incoming requests (handled by middlewares). Only
// `*data.GetVersionRow` type objects should be stored behind this key for consistency.
type versionCtxKey struct{}
// Mockable database operations interface
type NoteStore interface {
CreateNote(ctx context.Context, userID uuid.UUID) (data.Note, error)
DeleteNote(ctx context.Context, arg data.DeleteNoteParams) error
GetFullNote(ctx context.Context, noteID uuid.UUID) (data.GetFullNoteRow, error)
ListNotes(ctx context.Context, arg data.ListNotesParams) ([]data.ListNotesRow, error)
CreateNoteVersion(ctx context.Context, arg data.CreateNoteVersionParams) error
GetVersion(ctx context.Context, arg data.GetVersionParams) (data.GetVersionRow, error)
GetVersionHistory(ctx context.Context, arg data.GetVersionHistoryParams) ([]data.GetVersionHistoryRow, error)
}
// Chi HTTP router for notes related CRUD actions.
type notesResource struct {
JWTSecret string
Notes NoteStore
}
func (rs notesResource) Routes() chi.Router {
r := chi.NewRouter()
r.Group(func(r chi.Router) {
r.Use(requireAccessToken(rs.JWTSecret)) // JWT claims -> ctx.
r.Post("/", rs.Create) // POST /notes - create new note
r.Get("/", rs.ListMetadata) // GET /notes - get all notes (metadata + titles)
/*
Clients should utilize `rs.ListMetadata` to load index of user's available notes (e.g.
sidebar view), use `rs.GetFullNote` to get full notes individually, and if request the
versioning history with `rs.GetVersionHistory` if necessary (and similarly fetch each
version individually if the client wants to view them).
*/
r.Route(fmt.Sprintf("/{%s}", noteUUIDCtxParameter), func(r chi.Router) {
r.Use(uuidCtx(noteUUIDCtxParameter))
r.Use(noteCtx(rs.Notes)) // DB -> req. context (metadata + active version)
r.Get("/", rs.GetFullNote) // GET /notes/{id} - get note from context
r.Delete("/", rs.Delete) // DELETE /notes/{id} - delete note
r.Get("/versions", rs.GetVersionHistory) // GET /notes/{id}/versions - get full versioning history
r.Post("/versions", rs.CreateVersion) // POST /notes/{id}/versions - create new version
r.Route(fmt.Sprintf("/{%s}", versionUUIDCtxParameter), func(r chi.Router) {
r.Use(uuidCtx(versionUUIDCtxParameter))
r.Use(versionCtx(rs.Notes)) // DB -> req. context (scoped version)
r.Get("/", rs.GetFullVersion) // GET /notes/{id}/{id} - get
})
})
})
return r
}
// Handler for new note creation. Creates the parent metadata object (`notes` table) and an initial
// placeholder content version (`note_versions` table), and returns the placeholder contents to the
// caller in the HTTP response.
func (rs *notesResource) Create(w http.ResponseWriter, r *http.Request) {
user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
if !ok {
respondError(w, http.StatusUnauthorized, "Unauthorized")
return
}
userID, err := uuid.Parse(user.Subject)
if err != nil {
respondError(w, http.StatusUnauthorized, "Invalid user ID")
return
}
// Metadata object (parent)
note, err := rs.Notes.CreateNote(r.Context(), userID)
if err != nil {
respondError(w, http.StatusInternalServerError, "Failed to create note")
return
}
// Initial (empty) placeholder version of the contents
err = rs.Notes.CreateNoteVersion(r.Context(), data.CreateNoteVersionParams{
NoteID: note.ID,
Title: initVersionTitle,
Content: initVersionContent,
ContentHash: sha1ContentHash(initVersionTitle, initVersionContent),
})
if err != nil {
respondError(w, http.StatusInternalServerError, "Failed to create initial version")
return
}
// Placeholder contents are decided server-side, so we need to inform the client of them via a
// one-time-use DTO
type response struct {
Title string `json:"title"`
Content string `json:"content"`
}
res := response{
Title: initVersionTitle,
Content: initVersionContent,
}
respondJSON(w, http.StatusCreated, res)
}
func (rs *notesResource) ListMetadata(w http.ResponseWriter, r *http.Request) {
user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
if !ok {
respondError(w, http.StatusUnauthorized, "Unauthorized")
return
}
userID, err := uuid.Parse(user.Subject)
if err != nil {
respondError(w, http.StatusUnauthorized, "Invalid user ID")
return
}
limit, offset := getPaginationParams(r)
notes, err := rs.Notes.ListNotes(r.Context(), data.ListNotesParams{
UserID: userID,
Limit: limit,
Offset: offset,
})
if err != nil {
respondError(w, http.StatusInternalServerError, "Failed to retrieve notes")
return
}
for _, note := range notes {
if userID != note.OwnerID {
respondError(w, http.StatusForbidden, "Forbidden")
return
}
}
respondJSON(w, http.StatusOK, notes)
}
// Handler for returning the currently scoped (included to the request's context by a middleware)
// full note object.
func (rs *notesResource) GetFullNote(w http.ResponseWriter, r *http.Request) {
fullNote, ok := r.Context().Value(noteCtxKey{}).(*data.GetFullNoteRow)
if !ok {
respondError(w, http.StatusNotFound, "Note not found")
return
}
respondJSON(w, http.StatusOK, fullNote)
}
// Handler for hard deelting the currently scoped note (including its versions via database cascade).
func (rs *notesResource) Delete(w http.ResponseWriter, r *http.Request) {
fullNote, ok := r.Context().Value(noteCtxKey{}).(*data.GetFullNoteRow)
if !ok {
respondError(w, http.StatusNotFound, "Note not found")
return
}
user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
if !ok {
respondError(w, http.StatusUnauthorized, "Unauthorized")
return
}
userID, err := uuid.Parse(user.Subject)
if err != nil {
respondError(w, http.StatusUnauthorized, "Invalid user ID")
return
}
err = rs.Notes.DeleteNote(r.Context(), data.DeleteNoteParams{
ID: fullNote.NoteID,
UserID: userID, // NOTE: using `fullNote.userID` here'd be insecure
})
if err != nil {
respondError(w, http.StatusInternalServerError, "Failed to delete note")
return
}
w.WriteHeader(http.StatusNoContent)
}
// Handler for listing the currently scoped note's version history. If pagination parameters
// (`limit` and `offset`) aren't defined, limit of 50 versions (with offset 0) will be returned.
func (rs *notesResource) GetVersionHistory(w http.ResponseWriter, r *http.Request) {
fullNote, ok := r.Context().Value(noteCtxKey{}).(*data.GetFullNoteRow)
if !ok {
respondError(w, http.StatusNotFound, "Note not found")
return
}
limit, offset := getPaginationParams(r)
versions, err := rs.Notes.GetVersionHistory(r.Context(), data.GetVersionHistoryParams{
NoteID: fullNote.NoteID,
Limit: limit,
Offset: offset,
})
if err != nil {
respondError(w, http.StatusInternalServerError, "Failed to get version history")
return
}
respondJSON(w, http.StatusOK, versions)
}
// Handler for creating a new content version for the currently scoped note. Will check the incoming
// JSON object's integrity and perform a de-duplication check for identical versions stored in the
// database (SHA-1 hash of version contents). If a duplicate version is found, it'll be placed as the
// active version by swapping its version number to HEAD+1.
func (rs *notesResource) CreateVersion(w http.ResponseWriter, r *http.Request) {
fullNote, ok := r.Context().Value(noteCtxKey{}).(*data.GetFullNoteRow)
if !ok {
respondError(w, http.StatusNotFound, "Note not found")
return
}
var req struct {
Title *string `json:"title"`
Content *string `json:"content"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.Title == nil || req.Content == nil {
respondError(w, http.StatusBadRequest, "Invalid request body")
return
}
// Extra check for frontend readability reasons (max. length isn't specifically limited in the database)
if len(*req.Title) > titleMaxLength {
respondError(w, http.StatusBadRequest, fmt.Sprintf("Title must be shorter than %d characters", titleMaxLength))
return
}
/*
The SQL query handles de-duplication checks and "intelligent" versioning increments, so we
don't have to worry about them here (`latest_version` = highest version number that exists
in this note's context; `current_version` = note's active content version):
- New version's contents are a duplicate of a historical version:
- Don't increment `latest_version`
- Sync `current_version` with the `version_number` of the duplicate version
- New version's contents are unique:
- Increment `latest_version`
- Sync `current_version` with `latest_version`
*/
err := rs.Notes.CreateNoteVersion(r.Context(), data.CreateNoteVersionParams{
NoteID: fullNote.NoteID,
Title: *req.Title,
Content: *req.Content,
ContentHash: sha1ContentHash(*req.Title, *req.Content),
})
if err != nil {
respondError(w, http.StatusInternalServerError, "Failed to create note version")
return
}
w.WriteHeader(http.StatusNoContent)
}
// Handler for returning full data of the currently scoped note version. Identical to the beginning
// of the `RollbackNoteVersion` handler.
func (rs *notesResource) GetFullVersion(w http.ResponseWriter, r *http.Request) {
fullVersion, ok := r.Context().Value(noteCtxKey{}).(*data.GetVersionRow)
if !ok {
respondError(w, http.StatusNotFound, "Note not found")
return
}
respondJSON(w, http.StatusOK, fullVersion)
}
// Parse `limit` and `offset` 32-bit integer URL parameters from the given request. Defaults to
// limit of 50 and offset 0 if parameters are missing/invalid.
func getPaginationParams(r *http.Request) (limit int32, offset int32) {
defaultLimit := 50
defaultOffset := 0
limitStr := r.URL.Query().Get("limit")
if limitStr != "" {
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 {
defaultLimit = l
}
}
offsetStr := r.URL.Query().Get("offset")
if offsetStr != "" {
if o, err := strconv.Atoi(offsetStr); err == nil && o >= 0 {
defaultOffset = o
}
}
return int32(defaultLimit), int32(defaultOffset)
}
// Concatenate the title and content strings, calculate a SHA-1 hash of the resulting string, and
// return the resulting hash as a string.
func sha1ContentHash(title, content string) string {
hashContent := title + content
hash := sha1.Sum([]byte(hashContent))
hashStr := strings.ToUpper(hex.EncodeToString(hash[:]))
return hashStr
}

View File

@ -3,26 +3,25 @@ package service
import ( import (
"net/http" "net/http"
"git.umbrella.haus/ae/notatest/pkg/data" "git.umbrella.haus/ae/notatest/internal/data"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware" "github.com/go-chi/chi/v5/middleware"
"github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
func Run(conn *pgx.Conn, jwtSecret string) error { func Run(conn *pgx.Conn, q *data.Queries, jwtSecret string) error {
q := data.New(conn)
r := chi.NewRouter() r := chi.NewRouter()
tokensRouter := tokensResource{ authRouter := authResource{
JWTSecret: jwtSecret,
Tokens: q,
}
usersRouter := usersResource{
JWTSecret: jwtSecret, JWTSecret: jwtSecret,
Users: q, Users: q,
Tokens: q,
}
notesRouter := notesResource{
JWTSecret: jwtSecret,
Notes: q,
} }
notesRouter := notesResource{}
// Global middlewares // Global middlewares
r.Use(middleware.RequestID) r.Use(middleware.RequestID)
@ -31,10 +30,12 @@ func Run(conn *pgx.Conn, jwtSecret string) error {
r.Use(middleware.Recoverer) r.Use(middleware.Recoverer)
r.Use(middleware.AllowContentType("application/json")) r.Use(middleware.AllowContentType("application/json"))
// Routes grouped by functionality // Routes grouped by functionality (we must prefix the API routes with `/api`
r.Mount("/auth", tokensRouter.Routes()) // as the domain will be the same for the front and back ends)
r.Mount("/users", usersRouter.Routes()) r.Route("/api", func(r chi.Router) {
r.Mount("/notes", notesRouter.Routes()) r.Mount("/auth", authRouter.Routes())
r.Mount("/notes", notesRouter.Routes())
})
log.Info().Msg("Starting server on :8080") log.Info().Msg("Starting server on :8080")
return http.ListenAndServe(":8080", r) return http.ListenAndServe(":8080", r)

View File

@ -41,7 +41,7 @@ func respondError(w http.ResponseWriter, status int, message string) {
} }
/* /*
Client-side check: Example client-side check:
``` ```
function estimateEntropy(password: string): number { function estimateEntropy(password: string): number {
@ -102,7 +102,7 @@ func normalizeUsername(username string) string {
} }
/* /*
Client-side check (additionally input should automatically perform the normalization steps): Example client-side check (without input normalization):
``` ```
function validateUsername(username: string): string { function validateUsername(username: string): string {

View File

@ -3,11 +3,15 @@ package main
import ( import (
"context" "context"
"embed" "embed"
"fmt"
"os" "os"
"git.umbrella.haus/ae/notatest/pkg/migrate" "git.umbrella.haus/ae/notatest/internal/data"
"git.umbrella.haus/ae/notatest/pkg/service" "git.umbrella.haus/ae/notatest/internal/service"
"github.com/caarlos0/env" "github.com/caarlos0/env/v10"
"github.com/golang-migrate/migrate/v4"
_ "github.com/golang-migrate/migrate/v4/database/postgres"
"github.com/golang-migrate/migrate/v4/source/iofs"
"github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -17,60 +21,72 @@ import (
var migrationsFS embed.FS var migrationsFS embed.FS
var ( var (
isDevelopment = false config Config
config Config
) )
type Config struct { type Config struct {
JWTSecret string `env:"JWT_SECRET,notEmpty"` JWTSecret string `env:"JWT_SECRET,notEmpty"`
DBURL string `env:"PG_URL,notEmpty"` DatabaseURL string `env:"DB_URL,notEmpty"`
RunMode string `env:"GO_ENV" envDefault:"production"` LogLevel string `env:"LOG_LEVEL" envDefault:"info"`
AdminUsername string `env:"ADMIN_USERNAME,notEmpty,unset"`
AdminPassword string `env:"ADMIN_PASSWORD,notEmpty,unset"`
} }
func init() { func init() {
initLogger()
config = Config{} config = Config{}
env.Parse(&config) if err := env.Parse(&config); err != nil {
log.Fatal().Err(err).Msg("Failed to parse environment variables")
if config.RunMode == "development" {
log.Info().Msg("Development mode enabled")
isDevelopment = true
} }
initLogger()
log.Debug().Msg("Initialization completed") log.Debug().Msg("Initialization completed")
} }
func main() { func main() {
conn, err := pgx.Connect(context.Background(), config.DBURL) log.Debug().Msgf("Database URL: %s", config.DatabaseURL)
conn, err := pgx.Connect(context.Background(), config.DatabaseURL)
if err != nil { if err != nil {
log.Fatal().Msgf("Failed connecting to database: %s", err) log.Fatal().Err(err).Msg("Failed to connect to database")
} }
log.Info().Msg("Successfully connected to the database") log.Info().Msg("Successfully connected to the database")
log.Debug().Msg(config.DBURL) log.Info().Msg("Applying migrations...")
if isDevelopment { d, err := iofs.New(migrationsFS, "sql/migrations")
if err := migrate.Run(context.Background(), conn, migrationsFS); err != nil { if err != nil {
log.Fatal().Msgf("Failed running migrations: %s", err) log.Fatal().Err(err).Msg("Failed constructing io/fs driver")
} }
migrator, err := migrate.NewWithSourceInstance("iofs", d, config.DatabaseURL)
if err != nil {
log.Fatal().Err(err).Msg("Failed to apply migrations")
}
defer migrator.Close()
if err := migrator.Up(); err != nil && err != migrate.ErrNoChange {
log.Fatal().Err(err).Msg("Failed to apply migrations")
} }
service.Run(conn, config.JWTSecret) q := data.New(conn)
err = service.CreateAdminIfNotExists(context.Background(), q, config.AdminUsername, config.AdminPassword)
if err != nil {
log.Fatal().Err(err).Msg("Failed initial admin account creation")
}
log.Info().Msg("Migrations applied succesfully, proceeding to HTTP server startup")
service.Run(conn, q, config.JWTSecret)
} }
func initLogger() { func initLogger() {
logLevel := os.Getenv("LOG_LEVEL") fmt.Println(config.LogLevel)
level, err := zerolog.ParseLevel(logLevel) level, err := zerolog.ParseLevel(config.LogLevel)
if err != nil { if err != nil {
// Default to INFO
level = zerolog.InfoLevel level = zerolog.InfoLevel
} }
zerolog.SetGlobalLevel(level) zerolog.SetGlobalLevel(level)
if isDevelopment { output := zerolog.ConsoleWriter{
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) Out: os.Stdout,
} else { TimeFormat: "2006-01-02 15:04:05",
log.Logger = log.Output(os.Stderr) // JSON to stdout/stderr
} }
log.Logger = log.Output(output).With().Timestamp().Caller().Logger()
log.Debug().Msg("Logger initialized") log.Info().Msgf("Logger initialized (log level: %s)", level)
} }

View File

@ -1,132 +0,0 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.28.0
// source: note_versions.sql
package data
import (
"context"
"github.com/google/uuid"
)
const createNoteVersion = `-- name: CreateNoteVersion :one
INSERT INTO note_versions (note_id, title, content, version_number, content_hash)
VALUES (
$1,
$2,
$3,
(SELECT COALESCE(MAX(version_number), 0) + 1 FROM note_versions WHERE note_id = $1),
encode(sha256($2::bytea || '\n'::bytea || $3::bytea), 'hex')
)
RETURNING id, note_id, title, content, version_number, content_hash, created_at
`
type CreateNoteVersionParams struct {
NoteID uuid.UUID `json:"note_id"`
Title string `json:"title"`
Content string `json:"content"`
}
func (q *Queries) CreateNoteVersion(ctx context.Context, arg CreateNoteVersionParams) (NoteVersion, error) {
row := q.db.QueryRow(ctx, createNoteVersion, arg.NoteID, arg.Title, arg.Content)
var i NoteVersion
err := row.Scan(
&i.ID,
&i.NoteID,
&i.Title,
&i.Content,
&i.VersionNumber,
&i.ContentHash,
&i.CreatedAt,
)
return i, err
}
const findDuplicateContent = `-- name: FindDuplicateContent :one
SELECT EXISTS(
SELECT 1 FROM note_versions
WHERE note_id = $1
AND content_hash = encode(sha256($2::bytea || '\n'::bytea || $3::bytea), 'hex')
)
`
type FindDuplicateContentParams struct {
NoteID uuid.UUID `json:"note_id"`
Column2 []byte `json:"column_2"`
Column3 []byte `json:"column_3"`
}
func (q *Queries) FindDuplicateContent(ctx context.Context, arg FindDuplicateContentParams) (bool, error) {
row := q.db.QueryRow(ctx, findDuplicateContent, arg.NoteID, arg.Column2, arg.Column3)
var exists bool
err := row.Scan(&exists)
return exists, err
}
const getNoteVersion = `-- name: GetNoteVersion :one
SELECT id, note_id, title, content, version_number, content_hash, created_at FROM note_versions
WHERE note_id = $1 AND version_number = $2 LIMIT 1
`
type GetNoteVersionParams struct {
NoteID uuid.UUID `json:"note_id"`
VersionNumber int32 `json:"version_number"`
}
func (q *Queries) GetNoteVersion(ctx context.Context, arg GetNoteVersionParams) (NoteVersion, error) {
row := q.db.QueryRow(ctx, getNoteVersion, arg.NoteID, arg.VersionNumber)
var i NoteVersion
err := row.Scan(
&i.ID,
&i.NoteID,
&i.Title,
&i.Content,
&i.VersionNumber,
&i.ContentHash,
&i.CreatedAt,
)
return i, err
}
const getNoteVersions = `-- name: GetNoteVersions :many
SELECT id, note_id, title, content, version_number, content_hash, created_at FROM note_versions
WHERE note_id = $1
ORDER BY version_number DESC
LIMIT $2 OFFSET $3
`
type GetNoteVersionsParams struct {
NoteID uuid.UUID `json:"note_id"`
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
}
func (q *Queries) GetNoteVersions(ctx context.Context, arg GetNoteVersionsParams) ([]NoteVersion, error) {
rows, err := q.db.Query(ctx, getNoteVersions, arg.NoteID, arg.Limit, arg.Offset)
if err != nil {
return nil, err
}
defer rows.Close()
var items []NoteVersion
for rows.Next() {
var i NoteVersion
if err := rows.Scan(
&i.ID,
&i.NoteID,
&i.Title,
&i.Content,
&i.VersionNumber,
&i.ContentHash,
&i.CreatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}

View File

@ -1,105 +0,0 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.28.0
// source: notes.sql
package data
import (
"context"
"github.com/google/uuid"
)
const createNote = `-- name: CreateNote :one
INSERT INTO notes (user_id)
VALUES ($1)
RETURNING id, user_id, created_at, updated_at
`
func (q *Queries) CreateNote(ctx context.Context, userID uuid.UUID) (Note, error) {
row := q.db.QueryRow(ctx, createNote, userID)
var i Note
err := row.Scan(
&i.ID,
&i.UserID,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const deleteNote = `-- name: DeleteNote :exec
DELETE FROM notes
WHERE id = $1 AND user_id = $2
`
type DeleteNoteParams struct {
ID uuid.UUID `json:"id"`
UserID uuid.UUID `json:"user_id"`
}
func (q *Queries) DeleteNote(ctx context.Context, arg DeleteNoteParams) error {
_, err := q.db.Exec(ctx, deleteNote, arg.ID, arg.UserID)
return err
}
const getNote = `-- name: GetNote :one
SELECT id, user_id, created_at, updated_at FROM notes
WHERE id = $1 AND user_id = $2 LIMIT 1
`
type GetNoteParams struct {
ID uuid.UUID `json:"id"`
UserID uuid.UUID `json:"user_id"`
}
func (q *Queries) GetNote(ctx context.Context, arg GetNoteParams) (Note, error) {
row := q.db.QueryRow(ctx, getNote, arg.ID, arg.UserID)
var i Note
err := row.Scan(
&i.ID,
&i.UserID,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const listNotes = `-- name: ListNotes :many
SELECT id, user_id, created_at, updated_at FROM notes
WHERE user_id = $1
ORDER BY created_at DESC
LIMIT $2 OFFSET $3
`
type ListNotesParams struct {
UserID uuid.UUID `json:"user_id"`
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
}
func (q *Queries) ListNotes(ctx context.Context, arg ListNotesParams) ([]Note, error) {
rows, err := q.db.Query(ctx, listNotes, arg.UserID, arg.Limit, arg.Offset)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Note
for rows.Next() {
var i Note
if err := rows.Scan(
&i.ID,
&i.UserID,
&i.CreatedAt,
&i.UpdatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}

View File

@ -1,74 +0,0 @@
// pkg/migrate/migrate.go
package migrate
import (
"context"
"embed"
"fmt"
"io/fs"
"sort"
"strconv"
"strings"
"github.com/jackc/pgx/v5"
)
func Run(ctx context.Context, conn *pgx.Conn, migrationsFS embed.FS) error {
// Get already applied migrations
rows, _ := conn.Query(ctx, "SELECT version FROM schema_migrations")
defer rows.Close()
applied := make(map[int64]bool)
for rows.Next() {
var version int64
if err := rows.Scan(&version); err != nil {
return err
}
applied[version] = true
}
files, err := migrationsFS.ReadDir("migrations")
if err != nil {
return err
}
// Apply the migrations sequentially based on their ordinal number
for _, f := range sortMigrations(files) {
version, err := strconv.ParseInt(strings.Split(f.Name(), "_")[0], 10, 64)
if err != nil {
return fmt.Errorf("invalid migration name: %s", f.Name())
}
if applied[version] {
continue
}
// Run migration
sql, err := migrationsFS.ReadFile("migrations/" + f.Name())
if err != nil {
return err
}
if _, err := conn.Exec(ctx, string(sql)); err != nil {
return fmt.Errorf("migration %d failed: %w", version, err)
}
if _, err := conn.Exec(ctx,
"INSERT INTO schema_migrations (version) VALUES ($1)", version,
); err != nil {
return err
}
}
return nil
}
// Sort the migration files based on their ordinal number prefix.
func sortMigrations(files []fs.DirEntry) []fs.DirEntry {
sort.Slice(files, func(i, j int) bool {
v1, _ := strconv.ParseInt(strings.Split(files[i].Name(), "_")[0], 10, 64)
v2, _ := strconv.ParseInt(strings.Split(files[j].Name(), "_")[0], 10, 64)
return v1 < v2
})
return files
}

View File

@ -1,386 +0,0 @@
package service
import (
"context"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"git.umbrella.haus/ae/notatest/pkg/data"
"github.com/go-chi/chi/v5"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func TestAuthMiddleware(t *testing.T) {
secret := "test-secret"
validToken := generateTestToken(t, secret, "access", uuid.New().String(), true)
expiredToken := generateTestToken(t, secret, "access", uuid.New().String(), true, func(claims *userClaims) {
claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(-1 * time.Hour))
})
tests := []struct {
name string
token string
expectedErr string
statusCode int
}{
{
"no token",
"",
"Unauthorized",
http.StatusUnauthorized,
},
{
"invalid token",
"invalid",
"Invalid token",
http.StatusUnauthorized,
},
{
"expired token",
expiredToken,
"Invalid token",
http.StatusUnauthorized,
},
{
"wrong type",
generateTestToken(
t,
secret,
"refresh",
uuid.New().String(),
true,
),
"Invalid token type",
http.StatusUnauthorized,
},
{
"valid token",
validToken,
"",
http.StatusOK,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
mw := requireAccessToken(secret)
req := httptest.NewRequest("GET", "/", nil)
if tc.token != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", tc.token))
}
w := httptest.NewRecorder()
called := false
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
_, ok := r.Context().Value(userCtxKey{}).(*userClaims)
assert.True(t, ok)
}))
handler.ServeHTTP(w, req)
assert.Equal(t, tc.statusCode, w.Code)
if tc.expectedErr != "" {
assert.Contains(t, w.Body.String(), tc.expectedErr)
}
assert.Equal(t, tc.statusCode == http.StatusOK, called)
})
}
}
func TestAdminOnlyMiddleware(t *testing.T) {
tests := []struct {
name string
user *userClaims
statusCode int
}{
{
"no user",
nil,
http.StatusForbidden,
},
{
"non admin user",
&userClaims{
Admin: false,
},
http.StatusForbidden,
},
{
"admin user",
&userClaims{
Admin: true,
},
http.StatusOK,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
mw := adminOnlyMiddleware
req := httptest.NewRequest("GET", "/", nil)
if tc.user != nil {
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, tc.user))
}
w := httptest.NewRecorder()
called := false
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
w.WriteHeader(http.StatusOK)
}))
handler.ServeHTTP(w, req)
assert.Equal(t, tc.statusCode, w.Code)
assert.Equal(t, tc.statusCode == http.StatusOK, called)
})
}
}
func TestOwnerOnlyMiddleware(t *testing.T) {
userID := uuid.New().String()
tests := []struct {
name string
user *userClaims
urlID string
statusCode int
}{
{
"no user",
nil,
userID,
http.StatusForbidden,
},
{
"different ID",
&userClaims{
RegisteredClaims: jwt.RegisteredClaims{
Subject: uuid.New().String(),
}},
userID,
http.StatusForbidden,
},
{
"matching ID",
&userClaims{
RegisteredClaims: jwt.RegisteredClaims{
Subject: userID,
},
},
userID,
http.StatusOK,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
r := chi.NewRouter()
handlerChain := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
r.With(
// Add user with the given claims to request's context
func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var ctx context.Context = r.Context()
if tc.user != nil {
ctx = context.WithValue(ctx, userCtxKey{}, tc.user)
}
next.ServeHTTP(w, r.WithContext(ctx))
})
},
ownerOnlyMiddleware,
).Get("/{id}", handlerChain)
req := httptest.NewRequest("GET", fmt.Sprintf("/%s", tc.urlID), nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if tc.urlID == "invalid" {
assert.Equal(t, http.StatusNotFound, w.Code)
} else {
assert.Equal(t, tc.statusCode, w.Code)
}
})
}
}
func TestUserCtxMiddleware(t *testing.T) {
validUserID := uuid.New()
invalidUserID := "invalid"
mockStore := &mockUserStore{
GetUserByIDFunc: func(ctx context.Context, id uuid.UUID) (data.User, error) {
if id == validUserID {
return data.User{ID: validUserID}, nil
}
return data.User{}, errors.New("not found")
},
}
tests := []struct {
name string
urlID string
statusCode int
}{
{
"valid ID",
validUserID.String(),
http.StatusOK,
},
{
"invalid ID",
invalidUserID,
http.StatusNotFound,
},
{
"non existent ID",
uuid.New().String(),
http.StatusNotFound,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
mw := userCtx(mockStore)
r := chi.NewRouter()
r.With(mw).Get("/{id}", func(w http.ResponseWriter, r *http.Request) {
user, ok := r.Context().Value(userCtxKey{}).(data.User)
assert.True(t, ok)
assert.Equal(t, validUserID, user.ID)
w.WriteHeader(http.StatusOK)
})
req := httptest.NewRequest("GET", fmt.Sprintf("/%s", tc.urlID), nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, tc.statusCode, w.Code)
})
}
}
func TestNoteCtxMiddleware(t *testing.T) {
userID := uuid.New()
noteID := uuid.New()
tests := []struct {
name string
noteID string
user any
mock func(*mockNoteStore)
statusCode int
}{
{
"invalid note ID",
"invalid",
&userClaims{RegisteredClaims: jwt.RegisteredClaims{
Subject: userID.String(),
}},
func(m *mockNoteStore) {},
http.StatusNotFound,
},
{
"unauthorized user",
noteID.String(),
nil,
func(m *mockNoteStore) {},
http.StatusUnauthorized,
},
{
"note not found",
noteID.String(),
&userClaims{RegisteredClaims: jwt.RegisteredClaims{
Subject: userID.String(),
}},
func(m *mockNoteStore) {
m.GetNoteFunc = func(ctx context.Context, arg data.GetNoteParams) (data.Note, error) {
return data.Note{}, errors.New("not found")
}
},
http.StatusNotFound,
},
{
"success",
noteID.String(),
&userClaims{RegisteredClaims: jwt.RegisteredClaims{
Subject: userID.String(),
}},
func(m *mockNoteStore) {
m.GetNoteFunc = func(ctx context.Context, arg data.GetNoteParams) (data.Note, error) {
assert.Equal(t, noteID, arg.ID)
assert.Equal(t, userID, arg.UserID)
return data.Note{ID: noteID}, nil
}
},
http.StatusOK,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
mockStore := &mockNoteStore{}
tc.mock(mockStore)
handler := noteCtx(mockStore)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, ok := r.Context().Value(noteCtxKey{}).(data.Note)
assert.True(t, ok)
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", fmt.Sprintf("/notes/%s", tc.noteID), nil)
// Chi router context mocks ID passed in a URL parameter
rctx := chi.NewRouteContext()
rctx.URLParams.Add("id", tc.noteID)
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
if tc.user != nil {
req = req.WithContext(context.WithValue(
req.Context(),
userCtxKey{},
tc.user,
))
}
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, tc.statusCode, w.Code)
})
}
}
func generateTestToken(t *testing.T, secret, tokenType, userID string, isAdmin bool, opts ...func(*userClaims)) string {
t.Helper()
claims := &userClaims{
Admin: isAdmin,
TokenType: tokenType,
RegisteredClaims: jwt.RegisteredClaims{
Subject: userID,
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
},
}
for _, opt := range opts {
opt(claims)
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
signedToken, err := token.SignedString([]byte(secret))
if err != nil {
t.Fatalf("Failed to generate test token: %v", err)
}
return signedToken
}

View File

@ -1,262 +0,0 @@
package service
import (
"context"
"encoding/json"
"net/http"
"strconv"
"git.umbrella.haus/ae/notatest/pkg/data"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
)
type noteCtxKey struct{}
// Mockable database operations interface
type NoteStore interface {
CreateNote(ctx context.Context, userID uuid.UUID) (data.Note, error)
DeleteNote(ctx context.Context, arg data.DeleteNoteParams) error
GetNote(ctx context.Context, arg data.GetNoteParams) (data.Note, error)
ListNotes(ctx context.Context, arg data.ListNotesParams) ([]data.Note, error)
CreateNoteVersion(ctx context.Context, arg data.CreateNoteVersionParams) (data.NoteVersion, error)
FindDuplicateContent(ctx context.Context, arg data.FindDuplicateContentParams) (bool, error)
GetNoteVersion(ctx context.Context, arg data.GetNoteVersionParams) (data.NoteVersion, error)
GetNoteVersions(ctx context.Context, arg data.GetNoteVersionsParams) ([]data.NoteVersion, error)
}
type notesResource struct {
JWTSecret string
Notes NoteStore
}
func (rs notesResource) Routes() chi.Router {
r := chi.NewRouter()
r.Group(func(r chi.Router) {
r.Use(requireAccessToken(rs.JWTSecret))
r.Post("/", rs.CreateNote) // POST /notes - note creation
r.Get("/", rs.ListNotes) // GET /notes - get all notes
r.Route("/{id}", func(r chi.Router) {
r.Use(noteCtx(rs.Notes))
r.Get("/", rs.GetNote) // GET /notes/{id} - get specific note
r.Delete("/", rs.DeleteNote) // DELETE /notes/{id} - delete specific note
r.Route("/versions", func(r chi.Router) {
r.Post("/", rs.CreateNoteVersion) // POST /notes/{id}/versions - create new version
r.Get("/", rs.ListNoteVersions) // GET /notes/{id}/versions - get all existing versions
r.Get("/{version}", rs.GetNoteVersion) // GET /notes/{id}/versions/{version} - get specific version
})
})
})
return r
}
func (rs *notesResource) CreateNote(w http.ResponseWriter, r *http.Request) {
user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
if !ok {
respondError(w, http.StatusUnauthorized, "Unauthorized")
return
}
userID, err := uuid.Parse(user.Subject)
if err != nil {
respondError(w, http.StatusInternalServerError, "Invalid user ID")
return
}
note, err := rs.Notes.CreateNote(r.Context(), userID)
if err != nil {
respondError(w, http.StatusInternalServerError, "Failed to create note")
return
}
respondJSON(w, http.StatusCreated, note)
}
func (rs *notesResource) ListNotes(w http.ResponseWriter, r *http.Request) {
user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
if !ok {
respondError(w, http.StatusUnauthorized, "Unauthorized")
return
}
userID, err := uuid.Parse(user.Subject)
if err != nil {
respondError(w, http.StatusInternalServerError, "Invalid user ID")
return
}
limit, offset := getPaginationParams(r)
notes, err := rs.Notes.ListNotes(r.Context(), data.ListNotesParams{
UserID: userID,
Limit: limit,
Offset: offset,
})
if err != nil {
respondError(w, http.StatusInternalServerError, "Failed to retrieve notes")
return
}
respondJSON(w, http.StatusOK, notes)
}
func (rs *notesResource) GetNote(w http.ResponseWriter, r *http.Request) {
note, ok := r.Context().Value(noteCtxKey{}).(data.Note)
if !ok {
respondError(w, http.StatusNotFound, "Note not found")
return
}
respondJSON(w, http.StatusOK, note)
}
func (rs *notesResource) DeleteNote(w http.ResponseWriter, r *http.Request) {
note, ok := r.Context().Value(noteCtxKey{}).(data.Note)
if !ok {
respondError(w, http.StatusNotFound, "Note not found")
return
}
user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
if !ok {
respondError(w, http.StatusUnauthorized, "Unauthorized")
return
}
userID, err := uuid.Parse(user.Subject)
if err != nil {
respondError(w, http.StatusInternalServerError, "Invalid user ID")
return
}
err = rs.Notes.DeleteNote(r.Context(), data.DeleteNoteParams{
ID: note.ID,
UserID: userID,
})
if err != nil {
respondError(w, http.StatusInternalServerError, "Failed to delete note")
return
}
w.WriteHeader(http.StatusNoContent)
}
func (rs *notesResource) CreateNoteVersion(w http.ResponseWriter, r *http.Request) {
note, ok := r.Context().Value(noteCtxKey{}).(data.Note)
if !ok {
respondError(w, http.StatusNotFound, "Note not found")
return
}
var req struct {
Title string `json:"title"`
Content string `json:"content"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
respondError(w, http.StatusBadRequest, "Invalid request body")
return
}
// De-duplication check
duplicate, err := rs.Notes.FindDuplicateContent(r.Context(), data.FindDuplicateContentParams{
NoteID: note.ID,
Column2: []byte(req.Title),
Column3: []byte(req.Content),
})
if err != nil {
respondError(w, http.StatusInternalServerError, "Failed to check for duplicate content")
return
}
if duplicate {
respondError(w, http.StatusConflict, "Duplicate content detected")
return
}
version, err := rs.Notes.CreateNoteVersion(r.Context(), data.CreateNoteVersionParams{
NoteID: note.ID,
Title: req.Title,
Content: req.Content,
})
if err != nil {
respondError(w, http.StatusInternalServerError, "Failed to create note version")
return
}
respondJSON(w, http.StatusCreated, version)
}
func (rs *notesResource) ListNoteVersions(w http.ResponseWriter, r *http.Request) {
note, ok := r.Context().Value(noteCtxKey{}).(data.Note)
if !ok {
respondError(w, http.StatusNotFound, "Note not found")
return
}
limit, offset := getPaginationParams(r)
versions, err := rs.Notes.GetNoteVersions(r.Context(), data.GetNoteVersionsParams{
NoteID: note.ID,
Limit: limit,
Offset: offset,
})
if err != nil {
respondError(w, http.StatusInternalServerError, "Failed to retrieve versions")
return
}
respondJSON(w, http.StatusOK, versions)
}
func (rs *notesResource) GetNoteVersion(w http.ResponseWriter, r *http.Request) {
note, ok := r.Context().Value(noteCtxKey{}).(data.Note)
if !ok {
respondError(w, http.StatusNotFound, "Note not found")
return
}
versionStr := chi.URLParam(r, "version")
versionNumber, err := strconv.ParseInt(versionStr, 10, 32)
if err != nil {
respondError(w, http.StatusBadRequest, "Invalid version number")
return
}
version, err := rs.Notes.GetNoteVersion(r.Context(), data.GetNoteVersionParams{
NoteID: note.ID,
VersionNumber: int32(versionNumber),
})
if err != nil {
respondError(w, http.StatusNotFound, "Version not found")
return
}
respondJSON(w, http.StatusOK, version)
}
func getPaginationParams(r *http.Request) (limit int32, offset int32) {
defaultLimit := 50
defaultOffset := 0
limitStr := r.URL.Query().Get("limit")
if limitStr != "" {
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 {
defaultLimit = l
}
}
offsetStr := r.URL.Query().Get("offset")
if offsetStr != "" {
if o, err := strconv.Atoi(offsetStr); err == nil && o >= 0 {
defaultOffset = o
}
}
return int32(defaultLimit), int32(defaultOffset)
}

View File

@ -1,509 +0,0 @@
package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"git.umbrella.haus/ae/notatest/pkg/data"
"github.com/go-chi/chi/v5"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
type mockNoteStore struct {
CreateNoteFunc func(context.Context, uuid.UUID) (data.Note, error)
DeleteNoteFunc func(context.Context, data.DeleteNoteParams) error
GetNoteFunc func(context.Context, data.GetNoteParams) (data.Note, error)
ListNotesFunc func(context.Context, data.ListNotesParams) ([]data.Note, error)
CreateNoteVersionFunc func(context.Context, data.CreateNoteVersionParams) (data.NoteVersion, error)
FindDuplicateContentFunc func(context.Context, data.FindDuplicateContentParams) (bool, error)
GetNoteVersionFunc func(context.Context, data.GetNoteVersionParams) (data.NoteVersion, error)
GetNoteVersionsFunc func(context.Context, data.GetNoteVersionsParams) ([]data.NoteVersion, error)
}
func (m *mockNoteStore) CreateNote(ctx context.Context, userID uuid.UUID) (data.Note, error) {
return m.CreateNoteFunc(ctx, userID)
}
func (m *mockNoteStore) DeleteNote(ctx context.Context, arg data.DeleteNoteParams) error {
return m.DeleteNoteFunc(ctx, arg)
}
func (m *mockNoteStore) GetNote(ctx context.Context, arg data.GetNoteParams) (data.Note, error) {
return m.GetNoteFunc(ctx, arg)
}
func (m *mockNoteStore) ListNotes(ctx context.Context, arg data.ListNotesParams) ([]data.Note, error) {
return m.ListNotesFunc(ctx, arg)
}
func (m *mockNoteStore) CreateNoteVersion(ctx context.Context, arg data.CreateNoteVersionParams) (data.NoteVersion, error) {
return m.CreateNoteVersionFunc(ctx, arg)
}
func (m *mockNoteStore) FindDuplicateContent(ctx context.Context, arg data.FindDuplicateContentParams) (bool, error) {
return m.FindDuplicateContentFunc(ctx, arg)
}
func (m *mockNoteStore) GetNoteVersion(ctx context.Context, arg data.GetNoteVersionParams) (data.NoteVersion, error) {
return m.GetNoteVersionFunc(ctx, arg)
}
func (m *mockNoteStore) GetNoteVersions(ctx context.Context, arg data.GetNoteVersionsParams) ([]data.NoteVersion, error) {
return m.GetNoteVersionsFunc(ctx, arg)
}
func TestNotes_CreateNote(t *testing.T) {
userID := uuid.New()
testNote := data.Note{ID: uuid.New(), UserID: userID}
tests := []struct {
name string
mock func(*mockNoteStore)
statusCode int
}{
{
"success",
func(m *mockNoteStore) {
m.CreateNoteFunc = func(_ context.Context, uid uuid.UUID) (data.Note, error) {
assert.Equal(t, userID, uid)
return testNote, nil
}
},
http.StatusCreated,
},
{
"database error",
func(m *mockNoteStore) {
m.CreateNoteFunc = func(context.Context, uuid.UUID) (data.Note, error) {
return data.Note{}, errors.New("db error")
}
},
http.StatusInternalServerError,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
mockStore := &mockNoteStore{}
tc.mock(mockStore)
rs := notesResource{Notes: mockStore}
req := httptest.NewRequest("POST", "/", nil)
req = req.WithContext(context.WithValue(
req.Context(),
userCtxKey{},
&userClaims{RegisteredClaims: jwt.RegisteredClaims{
Subject: userID.String(),
}},
))
w := httptest.NewRecorder()
rs.CreateNote(w, req)
assert.Equal(t, tc.statusCode, w.Code)
if tc.statusCode == http.StatusCreated {
var note data.Note
json.Unmarshal(w.Body.Bytes(), &note)
assert.Equal(t, testNote.ID, note.ID)
}
})
}
}
func TestNotes_ListNotes(t *testing.T) {
userID := uuid.New()
notes := []data.Note{
{ID: uuid.New(), UserID: userID},
{ID: uuid.New(), UserID: userID},
}
tests := []struct {
name string
query string
mock func(*mockNoteStore)
statusCode int
}{
{
"success",
"",
func(m *mockNoteStore) {
m.ListNotesFunc = func(_ context.Context, arg data.ListNotesParams) ([]data.Note, error) {
assert.Equal(t, userID, arg.UserID)
return notes, nil
}
},
http.StatusOK,
},
{
"with pagination",
"?limit=10&offset=20",
func(m *mockNoteStore) {
m.ListNotesFunc = func(_ context.Context, arg data.ListNotesParams) ([]data.Note, error) {
assert.EqualValues(t, 10, arg.Limit)
assert.EqualValues(t, 20, arg.Offset)
return notes, nil
}
},
http.StatusOK,
},
{
"database error",
"",
func(m *mockNoteStore) {
m.ListNotesFunc = func(context.Context, data.ListNotesParams) ([]data.Note, error) {
return nil, errors.New("db error")
}
},
http.StatusInternalServerError,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
mockStore := &mockNoteStore{}
tc.mock(mockStore)
rs := notesResource{Notes: mockStore}
req := httptest.NewRequest("GET", fmt.Sprintf("/%s", tc.query), nil)
req = req.WithContext(context.WithValue(
req.Context(),
userCtxKey{},
&userClaims{RegisteredClaims: jwt.RegisteredClaims{
Subject: userID.String(),
}},
))
w := httptest.NewRecorder()
rs.ListNotes(w, req)
assert.Equal(t, tc.statusCode, w.Code)
})
}
}
func TestNotes_GetNote(t *testing.T) {
noteID := uuid.New()
userID := uuid.New()
validNote := data.Note{ID: noteID, UserID: userID}
t.Run("success", func(t *testing.T) {
rs := notesResource{}
req := httptest.NewRequest("GET", "/", nil)
req = req.WithContext(context.WithValue(
req.Context(),
noteCtxKey{},
validNote,
))
w := httptest.NewRecorder()
rs.GetNote(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var note data.Note
json.Unmarshal(w.Body.Bytes(), &note)
assert.Equal(t, validNote.ID, note.ID)
})
t.Run("not found", func(t *testing.T) {
rs := notesResource{}
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
rs.GetNote(w, req)
assert.Equal(t, http.StatusNotFound, w.Code)
})
}
func TestNotes_DeleteNote(t *testing.T) {
noteID := uuid.New()
userID := uuid.New()
validNote := data.Note{ID: noteID, UserID: userID}
t.Run("success", func(t *testing.T) {
mockStore := &mockNoteStore{
DeleteNoteFunc: func(_ context.Context, arg data.DeleteNoteParams) error {
assert.Equal(t, noteID, arg.ID)
assert.Equal(t, userID, arg.UserID)
return nil
},
}
rs := notesResource{Notes: mockStore}
req := httptest.NewRequest("DELETE", "/", nil)
req = req.WithContext(context.WithValue(
req.Context(),
noteCtxKey{},
validNote,
))
req = req.WithContext(context.WithValue(
req.Context(),
userCtxKey{},
&userClaims{RegisteredClaims: jwt.RegisteredClaims{
Subject: userID.String(),
}},
))
w := httptest.NewRecorder()
rs.DeleteNote(w, req)
assert.Equal(t, http.StatusNoContent, w.Code)
})
t.Run("database error", func(t *testing.T) {
mockStore := &mockNoteStore{
DeleteNoteFunc: func(context.Context, data.DeleteNoteParams) error {
return errors.New("db error")
},
}
rs := notesResource{Notes: mockStore}
req := httptest.NewRequest("DELETE", "/", nil)
req = req.WithContext(context.WithValue(
req.Context(),
noteCtxKey{},
validNote,
))
req = req.WithContext(context.WithValue(
req.Context(),
userCtxKey{},
&userClaims{RegisteredClaims: jwt.RegisteredClaims{
Subject: userID.String(),
}},
))
w := httptest.NewRecorder()
rs.DeleteNote(w, req)
assert.Equal(t, http.StatusInternalServerError, w.Code)
})
}
func TestNotes_CreateNoteVersion(t *testing.T) {
noteID := uuid.New()
validRequest := `{"title": "Test", "content": "Content"}`
tests := []struct {
name string
body string
mock func(*mockNoteStore)
statusCode int
}{
{
"success",
validRequest,
func(m *mockNoteStore) {
m.FindDuplicateContentFunc = func(context.Context, data.FindDuplicateContentParams) (bool, error) {
return false, nil
}
m.CreateNoteVersionFunc = func(_ context.Context, arg data.CreateNoteVersionParams) (data.NoteVersion, error) {
assert.Equal(t, noteID, arg.NoteID)
return data.NoteVersion{}, nil
}
},
http.StatusCreated,
},
{
"duplicate content",
validRequest,
func(m *mockNoteStore) {
m.FindDuplicateContentFunc = func(context.Context, data.FindDuplicateContentParams) (bool, error) {
return true, nil
}
},
http.StatusConflict,
},
{
"invalid request",
"{invalid}",
func(m *mockNoteStore) {},
http.StatusBadRequest,
},
{
"database error",
validRequest,
func(m *mockNoteStore) {
m.FindDuplicateContentFunc = func(context.Context, data.FindDuplicateContentParams) (bool, error) {
return false, nil
}
m.CreateNoteVersionFunc = func(context.Context, data.CreateNoteVersionParams) (data.NoteVersion, error) {
return data.NoteVersion{}, errors.New("db error")
}
},
http.StatusInternalServerError,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
mockStore := &mockNoteStore{}
tc.mock(mockStore)
rs := notesResource{Notes: mockStore}
req := httptest.NewRequest("POST", "/", strings.NewReader(tc.body))
req = req.WithContext(context.WithValue(
req.Context(),
noteCtxKey{},
data.Note{ID: noteID},
))
w := httptest.NewRecorder()
rs.CreateNoteVersion(w, req)
assert.Equal(t, tc.statusCode, w.Code)
})
}
}
func TestNotes_ListNoteVersions(t *testing.T) {
noteID := uuid.New()
versions := []data.NoteVersion{
{ID: uuid.New(), NoteID: noteID, VersionNumber: 1},
{ID: uuid.New(), NoteID: noteID, VersionNumber: 2},
}
tests := []struct {
name string
query string
mock func(*mockNoteStore)
statusCode int
}{
{
"success",
"",
func(m *mockNoteStore) {
m.GetNoteVersionsFunc = func(_ context.Context, arg data.GetNoteVersionsParams) ([]data.NoteVersion, error) {
assert.Equal(t, noteID, arg.NoteID)
return versions, nil
}
},
http.StatusOK,
},
{
"with pagination",
"?limit=5&offset=10",
func(m *mockNoteStore) {
m.GetNoteVersionsFunc = func(_ context.Context, arg data.GetNoteVersionsParams) ([]data.NoteVersion, error) {
assert.EqualValues(t, 5, arg.Limit)
assert.EqualValues(t, 10, arg.Offset)
return versions, nil
}
},
http.StatusOK,
},
{
"database error",
"",
func(m *mockNoteStore) {
m.GetNoteVersionsFunc = func(context.Context, data.GetNoteVersionsParams) ([]data.NoteVersion, error) {
return nil, errors.New("db error")
}
},
http.StatusInternalServerError,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
mockStore := &mockNoteStore{}
tc.mock(mockStore)
rs := notesResource{Notes: mockStore}
req := httptest.NewRequest("GET", fmt.Sprintf("/versions/%s", tc.query), nil)
req = req.WithContext(context.WithValue(
req.Context(),
noteCtxKey{},
data.Note{ID: noteID},
))
w := httptest.NewRecorder()
rs.ListNoteVersions(w, req)
assert.Equal(t, tc.statusCode, w.Code)
if tc.statusCode == http.StatusOK {
var result []data.NoteVersion
json.Unmarshal(w.Body.Bytes(), &result)
assert.Len(t, result, 2)
}
})
}
}
func TestNotes_GetNoteVersion(t *testing.T) {
noteID := uuid.New()
version := data.NoteVersion{ID: uuid.New(), NoteID: noteID, VersionNumber: 1}
tests := []struct {
name string
version string
mock func(*mockNoteStore)
statusCode int
}{
{
"invalid version",
"invalid",
func(m *mockNoteStore) {},
http.StatusBadRequest,
},
{
"version not found",
"1",
func(m *mockNoteStore) {
m.GetNoteVersionFunc = func(context.Context, data.GetNoteVersionParams) (data.NoteVersion, error) {
return data.NoteVersion{}, errors.New("not found")
}
},
http.StatusNotFound,
},
{
"success",
"1",
func(m *mockNoteStore) {
m.GetNoteVersionFunc = func(_ context.Context, arg data.GetNoteVersionParams) (data.NoteVersion, error) {
assert.Equal(t, noteID, arg.NoteID)
assert.EqualValues(t, 1, arg.VersionNumber)
return version, nil
}
},
http.StatusOK,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
mockStore := &mockNoteStore{}
tc.mock(mockStore)
rs := notesResource{Notes: mockStore}
req := httptest.NewRequest("GET", fmt.Sprintf("/versions/%s", tc.version), nil)
// Chi router context mocks ID (passed in a URL param.) and the note object (passed in req. ctx.)
rctx := chi.NewRouteContext()
rctx.URLParams.Add("version", tc.version)
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
req = req.WithContext(context.WithValue(
req.Context(),
noteCtxKey{},
data.Note{ID: noteID},
))
w := httptest.NewRecorder()
rs.GetNoteVersion(w, req)
assert.Equal(t, tc.statusCode, w.Code)
if tc.statusCode == http.StatusOK {
var result data.NoteVersion
json.Unmarshal(w.Body.Bytes(), &result)
assert.Equal(t, version.VersionNumber, result.VersionNumber)
}
})
}
}

View File

@ -1,249 +0,0 @@
package service
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"net/http"
"strings"
"time"
"git.umbrella.haus/ae/notatest/pkg/data"
"github.com/go-chi/chi/v5"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
)
const (
accessTokenDuration = 15 * time.Minute
refreshTokenDuration = 7 * 24 * time.Hour
)
var (
ErrInvalidToken = errors.New("invalid token")
ErrAuthHeaderInvalid = errors.New("token couldn't be parsed from authentication header")
)
type userClaims struct {
Admin bool `json:"admin"`
TokenType string `json:"type"` // "access" or "refresh"
jwt.RegisteredClaims // User's UUID should be stored in the subject claim
}
type tokenPair struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
}
// Mockable database operations interface
type TokenStore interface {
CreateRefreshToken(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error)
GetRefreshTokenByHash(ctx context.Context, tokenHash string) (data.RefreshToken, error)
RevokeRefreshToken(ctx context.Context, tokenHash string) error
RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error
}
type tokensResource struct {
JWTSecret string
Tokens TokenStore
}
func (rs tokensResource) Routes() chi.Router {
r := chi.NewRouter()
// Protected routes (access token required)
r.Group(func(r chi.Router) {
r.Use(requireAccessToken(rs.JWTSecret))
r.Post("/logout", rs.HandleLogout) // POST /auth/logout - revoke all refresh cookies
})
// Protected routes (refresh token required)
r.Group(func(r chi.Router) {
r.Use(requireRefreshToken(rs.JWTSecret))
r.Post("/refresh", rs.RefreshAccessToken) // POST /auth/refresh - convert refresh token to new token pair
})
return r
}
func (rs tokensResource) GenerateTokenPair(ctx context.Context, userID uuid.UUID, isAdmin bool) (*tokenPair, error) {
tokenPair, err := generateTokenPair(userID.String(), isAdmin, rs.JWTSecret)
if err != nil {
return nil, err
}
hash := sha256.Sum256([]byte(tokenPair.RefreshToken))
tokenHash := hex.EncodeToString(hash[:])
// Store to DB with (almost) identical expiration timestamp
expiresAt := time.Now().Add(refreshTokenDuration)
_, err = rs.Tokens.CreateRefreshToken(ctx, data.CreateRefreshTokenParams{
UserID: userID,
TokenHash: tokenHash,
ExpiresAt: expiresAt,
})
if err != nil {
return nil, err
}
return tokenPair, nil
}
func (rs tokensResource) RevokeRefreshToken(ctx context.Context, token string) error {
hash := sha256.Sum256([]byte(token))
tokenHash := hex.EncodeToString(hash[:])
return rs.Tokens.RevokeRefreshToken(ctx, tokenHash)
}
func (rs tokensResource) ValidateRefreshToken(ctx context.Context, token string) (*data.RefreshToken, error) {
hash := sha256.Sum256([]byte(token))
tokenHash := hex.EncodeToString(hash[:])
dbToken, err := rs.Tokens.GetRefreshTokenByHash(ctx, tokenHash)
if err != nil {
return nil, err
}
if dbToken.Revoked || time.Now().After(dbToken.ExpiresAt) {
return nil, ErrInvalidToken
}
return &dbToken, nil
}
func (rs tokensResource) RefreshAccessToken(w http.ResponseWriter, r *http.Request) {
// Get claims from context
claims, ok := r.Context().Value(userCtxKey{}).(*userClaims)
if !ok || claims.TokenType != "refresh" {
respondError(w, http.StatusUnauthorized, "Invalid token")
return
}
// Attempt to get the token from Authentication header ("Bearer <token>")
refreshToken, err := getTokenFromRequest(r)
if err != nil {
respondError(w, http.StatusUnauthorized, fmt.Sprintf("Unauthorized: %s", err))
}
// Validate the refresh token in DB
if _, err := rs.ValidateRefreshToken(r.Context(), refreshToken); err != nil {
respondError(w, http.StatusUnauthorized, "Invalid refresh token")
return
}
// Revoke the used refresh token
if err := rs.RevokeRefreshToken(r.Context(), refreshToken); err != nil {
respondError(w, http.StatusInternalServerError, "Failed to revoke token")
return
}
// Generate a new pair (access & refresh tokens)
userID, err := uuid.Parse(claims.Subject)
if err != nil {
respondError(w, http.StatusInternalServerError, "Invalid user ID")
return
}
tokenPair, err := rs.GenerateTokenPair(r.Context(), userID, claims.Admin)
if err != nil {
respondError(w, http.StatusInternalServerError, "Failed to generate tokens")
return
}
// Set refresh token in HTTP-only cookie
http.SetCookie(w, &http.Cookie{
Name: "refresh_token",
Value: tokenPair.RefreshToken,
Path: "/",
MaxAge: int(refreshTokenDuration.Seconds()),
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteStrictMode,
})
// Return the access token in the response body (it should be stored in browser's memory client-side)
respondJSON(w, http.StatusOK, map[string]string{
"access_token": tokenPair.AccessToken,
})
}
func (rs tokensResource) HandleLogout(w http.ResponseWriter, r *http.Request) {
claims, ok := r.Context().Value(userCtxKey{}).(*userClaims)
if !ok {
respondError(w, http.StatusUnauthorized, "Not authenticated")
return
}
userID, err := uuid.Parse(claims.ID)
if err != nil {
respondError(w, http.StatusInternalServerError, "Invalid user ID")
return
}
// Clear the refresh token cookie
http.SetCookie(w, &http.Cookie{
Name: "refresh_token",
Value: "",
Path: "/",
MaxAge: -1, // Expires immediately
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteStrictMode,
})
if err := rs.Tokens.RevokeAllUserRefreshTokens(r.Context(), userID); err != nil {
respondError(w, http.StatusInternalServerError, "Failed to logout")
return
}
w.WriteHeader(http.StatusNoContent)
}
func getTokenFromRequest(r *http.Request) (string, error) {
bearerToken := r.Header.Get("Authorization")
bearerFields := strings.Fields(bearerToken)
if len(bearerFields) == 2 && strings.ToLower(bearerFields[0]) == "bearer" {
return bearerFields[1], nil
}
return "", ErrAuthHeaderInvalid
}
func generateTokenPair(userID string, isAdmin bool, jwtSecret string) (*tokenPair, error) {
atClaims := userClaims{
Admin: isAdmin,
TokenType: "access",
RegisteredClaims: jwt.RegisteredClaims{
Subject: userID,
ExpiresAt: jwt.NewNumericDate(time.Now().Add(accessTokenDuration)),
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
},
}
accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, atClaims)
t, err := accessToken.SignedString([]byte(jwtSecret))
if err != nil {
return nil, err
}
rtClaims := userClaims{
Admin: isAdmin,
TokenType: "refresh",
RegisteredClaims: jwt.RegisteredClaims{
Subject: userID,
ExpiresAt: jwt.NewNumericDate(time.Now().Add(refreshTokenDuration)),
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
},
}
refreshToken := jwt.NewWithClaims(jwt.SigningMethodHS256, rtClaims)
rt, err := refreshToken.SignedString([]byte(jwtSecret))
if err != nil {
return nil, err
}
return &tokenPair{AccessToken: t, RefreshToken: rt}, nil
}

View File

@ -1,211 +0,0 @@
package service
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"git.umbrella.haus/ae/notatest/pkg/data"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
type mockTokenStore struct {
CreateRefreshTokenFunc func(context.Context, data.CreateRefreshTokenParams) (data.RefreshToken, error)
GetRefreshTokenByHashFunc func(context.Context, string) (data.RefreshToken, error)
RevokeRefreshTokenFunc func(context.Context, string) error
RevokeAllUserRefreshTokensFunc func(context.Context, uuid.UUID) error
}
func (m *mockTokenStore) CreateRefreshToken(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) {
return m.CreateRefreshTokenFunc(ctx, arg)
}
func (m *mockTokenStore) GetRefreshTokenByHash(ctx context.Context, tokenHash string) (data.RefreshToken, error) {
return m.GetRefreshTokenByHashFunc(ctx, tokenHash)
}
func (m *mockTokenStore) RevokeRefreshToken(ctx context.Context, token string) error {
return m.RevokeRefreshTokenFunc(ctx, token)
}
func (m *mockTokenStore) RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error {
return m.RevokeAllUserRefreshTokensFunc(ctx, id)
}
func TestGenerateTokenPair_Success(t *testing.T) {
userID := uuid.New()
called := false
mockStore := &mockTokenStore{
CreateRefreshTokenFunc: func(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) {
called = true
assert.Equal(t, userID, arg.UserID)
return data.RefreshToken{}, nil
},
}
rs := tokensResource{
JWTSecret: "test-secret",
Tokens: mockStore,
}
pair, err := rs.GenerateTokenPair(context.Background(), userID, false)
assert.NoError(t, err)
assert.True(t, called)
assert.NotEmpty(t, pair.AccessToken)
assert.NotEmpty(t, pair.RefreshToken)
}
func TestGenerateTokenPair_DBError(t *testing.T) {
mockStore := &mockTokenStore{
CreateRefreshTokenFunc: func(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) {
return data.RefreshToken{}, errors.New("db error")
},
}
rs := tokensResource{Tokens: mockStore}
_, err := rs.GenerateTokenPair(context.Background(), uuid.New(), false)
assert.ErrorContains(t, err, "db error")
}
func TestValidateRefreshToken_Valid(t *testing.T) {
token := "valid-jwt-token"
hash := sha256.Sum256([]byte(token))
expectedHash := hex.EncodeToString(hash[:])
mockStore := &mockTokenStore{
GetRefreshTokenByHashFunc: func(ctx context.Context, tokenHash string) (data.RefreshToken, error) {
assert.Equal(t, expectedHash, tokenHash)
return data.RefreshToken{ExpiresAt: time.Now().Add(1 * time.Hour)}, nil
},
}
rs := tokensResource{Tokens: mockStore}
_, err := rs.ValidateRefreshToken(context.Background(), token)
assert.NoError(t, err)
}
func TestRefreshAccessToken_Success(t *testing.T) {
userID := uuid.New()
refreshToken := "valid-jwt-token"
// Expected hash of the test token
hash := sha256.Sum256([]byte(refreshToken))
expectedHash := hex.EncodeToString(hash[:])
mockStore := &mockTokenStore{
GetRefreshTokenByHashFunc: func(ctx context.Context, tokenHash string) (data.RefreshToken, error) {
assert.Equal(t, expectedHash, tokenHash)
// Must return an unrevoked token with future expiration
return data.RefreshToken{
TokenHash: tokenHash,
ExpiresAt: time.Now().Add(1 * time.Hour),
Revoked: false,
}, nil
},
RevokeRefreshTokenFunc: func(ctx context.Context, tokenHash string) error {
assert.Equal(t, expectedHash, tokenHash)
return nil
},
CreateRefreshTokenFunc: func(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) {
return data.RefreshToken{}, nil
},
}
rs := tokensResource{
JWTSecret: "test-secret",
Tokens: mockStore,
}
req := httptest.NewRequest("POST", "/", nil)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", refreshToken))
req = req.WithContext(context.WithValue(
req.Context(),
userCtxKey{},
&userClaims{
Admin: false,
TokenType: "refresh",
RegisteredClaims: jwt.RegisteredClaims{
Subject: userID.String(),
},
},
))
w := httptest.NewRecorder()
rs.RefreshAccessToken(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "access_token")
cookies := w.Result().Cookies()
var refreshCookie *http.Cookie
for _, cookie := range cookies {
if cookie.Name == "refresh_token" {
refreshCookie = cookie
break
}
}
if refreshCookie == nil {
t.Fatal("refresh token cookie not set")
}
assert.True(t, refreshCookie.HttpOnly, "cookie should be HttpOnly")
assert.Equal(t, http.SameSiteStrictMode, refreshCookie.SameSite, "invalid SameSite mode")
assert.Equal(t, "/", refreshCookie.Path, "invalid cookie path")
assert.Greater(t, refreshCookie.MaxAge, 0, "cookie should have expiration")
}
func TestHandleLogout_Success(t *testing.T) {
userID := uuid.New()
called := false
mockStore := &mockTokenStore{
RevokeAllUserRefreshTokensFunc: func(ctx context.Context, id uuid.UUID) error {
called = true
assert.Equal(t, userID, id)
return nil
},
}
rs := tokensResource{Tokens: mockStore}
req := httptest.NewRequest("POST", "/", nil)
req = req.WithContext(context.WithValue(
req.Context(),
userCtxKey{},
&userClaims{
RegisteredClaims: jwt.RegisteredClaims{
ID: userID.String(),
},
},
))
w := httptest.NewRecorder()
rs.HandleLogout(w, req)
assert.True(t, called)
assert.Equal(t, http.StatusNoContent, w.Code) // 204
cookies := w.Result().Cookies()
var refreshCookie *http.Cookie
for _, cookie := range cookies {
if cookie.Name == "refresh_token" {
refreshCookie = cookie
break
}
}
if refreshCookie != nil && refreshCookie.MaxAge != -1 {
t.Fatal("refresh token cookie not invalidated")
}
}

View File

@ -1,356 +0,0 @@
package service
import (
"context"
"encoding/json"
"errors"
"net/http"
"strconv"
"time"
"git.umbrella.haus/ae/notatest/pkg/data"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgconn"
"github.com/rs/zerolog/log"
"golang.org/x/crypto/bcrypt"
)
type userCtxKey struct{}
// Stripped object that only contains non-critical data
type userResponse struct {
ID uuid.UUID `json:"id"`
Username string `json:"username"`
IsAdmin bool `json:"is_admin"`
CreatedAt *time.Time `json:"created_at"`
UpdatedAt *time.Time `json:"updated_at"`
}
// Mockable database operations interface
type UserStore interface {
CreateUser(ctx context.Context, arg data.CreateUserParams) (data.User, error)
ListUsers(ctx context.Context) ([]data.User, error)
GetUserByID(ctx context.Context, id uuid.UUID) (data.User, error)
GetUserByUsername(ctx context.Context, username string) (data.User, error)
UpdatePassword(ctx context.Context, arg data.UpdatePasswordParams) error
DeleteUser(ctx context.Context, id uuid.UUID) error
RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error
}
type usersResource struct {
JWTSecret string
Users UserStore
}
func (rs usersResource) Routes() chi.Router {
r := chi.NewRouter()
// Public routes (no tokens required)
r.Post("/", rs.Create) // POST /users - registration/signup
r.Post("/login", rs.Login) // POST /users/login - login as existing user
// Protected routes (access token required)
r.Group(func(r chi.Router) {
r.Use(requireAccessToken(rs.JWTSecret))
r.Get("/me", rs.Get) // GET /users/me - get current user data
// Admin only general routes
r.Group(func(r chi.Router) {
r.Use(adminOnlyMiddleware)
r.Get("/", rs.List) // GET /users - list all users
})
// User specific routes
r.Route("/{id}", func(r chi.Router) {
r.Use(userCtx(rs.Users)) // DB -> req. context
// Admin routes
r.Route("/admin", func(r chi.Router) {
r.Use(adminOnlyMiddleware)
r.Delete("/", rs.AdminDelete) // DELETE /users/admin/{id} - delete user
})
// Owner routes
r.Route("/owner", func(r chi.Router) {
r.Use(ownerOnlyMiddleware)
r.Put("/", rs.UpdatePassword) // PUT /users/owner/{id} - update user password
r.Delete("/", rs.OwnerDelete) // DELETE /users/owner/{id} - delete user
})
})
})
return r
}
func (rs usersResource) Create(w http.ResponseWriter, r *http.Request) {
type request struct {
Username string `json:"username"`
Password string `json:"password"`
}
var req request
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
respondError(w, http.StatusBadRequest, "Invalid request body")
return
}
normalizedUsername := normalizeUsername(req.Username)
if err := validateUsername(normalizedUsername); err != nil {
respondError(w, http.StatusBadRequest, err.Error())
return
}
if err := validatePassword(req.Password); err != nil {
respondError(w, http.StatusBadRequest, err.Error())
return
}
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
if err != nil {
respondError(w, http.StatusInternalServerError, "Failed to create user")
return
}
user, err := rs.Users.CreateUser(r.Context(), data.CreateUserParams{
Username: normalizedUsername,
PasswordHash: string(hashedPassword),
})
if err != nil {
if isDuplicateEntry(err) {
respondError(w, http.StatusConflict, "Username is already in use")
} else {
respondError(w, http.StatusInternalServerError, "Failed to create user")
}
return
}
respondJSON(w, http.StatusCreated, map[string]string{
"id": user.ID.String(),
"username": user.Username,
})
}
func (rs usersResource) Login(w http.ResponseWriter, r *http.Request) {
type request struct {
Username string `json:"username"`
Password string `json:"password"`
}
var req request
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
respondError(w, http.StatusBadRequest, "Invalid request body")
return
}
user, err := rs.Users.GetUserByUsername(r.Context(), normalizeUsername(req.Username))
if err != nil {
respondError(w, http.StatusUnauthorized, "Invalid credentials")
return
}
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password)); err != nil {
respondError(w, http.StatusUnauthorized, "Invalid credentials")
return
}
tokenPair, err := generateTokenPair(user.ID.String(), user.IsAdmin, rs.JWTSecret)
if err != nil {
respondError(w, http.StatusInternalServerError, "Failed to generate tokens")
return
}
// Set refresh token in HTTP-only cookie
http.SetCookie(w, &http.Cookie{
Name: "refresh_token",
Value: tokenPair.RefreshToken,
Path: "/",
MaxAge: int(refreshTokenDuration.Seconds()),
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteStrictMode,
})
// Build response
response := map[string]any{
"access_token": tokenPair.AccessToken,
}
// Include user data if the client has requested it (`?includeUser=true`)
if includeUser, _ := strconv.ParseBool(r.URL.Query().Get("includeUser")); includeUser {
response["user"] = userResponse{
ID: user.ID,
Username: user.Username,
IsAdmin: user.IsAdmin,
CreatedAt: user.CreatedAt,
UpdatedAt: user.UpdatedAt,
}
}
respondJSON(w, http.StatusOK, response)
}
func (rs usersResource) List(w http.ResponseWriter, r *http.Request) {
users, err := rs.Users.ListUsers(r.Context())
if err != nil {
respondError(w, http.StatusInternalServerError, "Failed to retrieve users")
return
}
// Output sanitization
var output []map[string]any
for _, user := range users {
output = append(output, map[string]any{
"id": user.ID,
"username": user.Username,
"created_at": user.CreatedAt,
"updated_at": user.UpdatedAt,
})
}
respondJSON(w, http.StatusOK, output)
}
func (rs usersResource) Get(w http.ResponseWriter, r *http.Request) {
claims, ok := r.Context().Value(userCtxKey{}).(*userClaims)
if !ok {
respondError(w, http.StatusUnauthorized, "Unauthorized")
return
}
userID, err := uuid.Parse(claims.Subject)
if err != nil {
respondError(w, http.StatusInternalServerError, "Invalid user ID")
return
}
user, err := rs.Users.GetUserByID(r.Context(), userID)
if err != nil {
respondError(w, http.StatusNotFound, "User not found")
return
}
respondJSON(w, http.StatusOK, userResponse{
ID: user.ID,
Username: user.Username,
IsAdmin: user.IsAdmin,
CreatedAt: user.CreatedAt,
UpdatedAt: user.UpdatedAt,
})
}
func (rs usersResource) UpdatePassword(w http.ResponseWriter, r *http.Request) {
type request struct {
OldPassword string `json:"old_password"`
NewPassword string `json:"new_password"`
}
var req request
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
respondError(w, http.StatusBadRequest, "Invalid request body")
return
}
user, ok := r.Context().Value(userCtxKey{}).(data.User)
if !ok {
respondError(w, http.StatusNotFound, "User not found")
return
}
// Verify the old password before allowing the update
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.OldPassword)); err != nil {
respondError(w, http.StatusUnauthorized, "Invalid credentials")
return
}
if err := validatePassword(req.NewPassword); err != nil {
respondError(w, http.StatusBadRequest, err.Error())
return
}
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.NewPassword), bcrypt.DefaultCost)
if err != nil {
respondError(w, http.StatusInternalServerError, "Failed to update password")
return
}
if err := rs.Users.UpdatePassword(r.Context(), data.UpdatePasswordParams{
ID: user.ID,
PasswordHash: string(hashedPassword),
}); err != nil {
respondError(w, http.StatusInternalServerError, "Failed to update password")
return
}
if err := rs.Users.RevokeAllUserRefreshTokens(r.Context(), user.ID); err != nil {
log.Error().Msgf("Failed to revoke refresh tokens: %s", err)
}
w.WriteHeader(http.StatusNoContent)
}
func (rs usersResource) AdminDelete(w http.ResponseWriter, r *http.Request) {
user, ok := r.Context().Value(userCtxKey{}).(data.User)
if !ok {
respondError(w, http.StatusNotFound, "User not found")
return
}
if err := rs.Users.DeleteUser(r.Context(), user.ID); err != nil {
respondError(w, http.StatusInternalServerError, "Failed to delete user")
return
}
if err := rs.Users.RevokeAllUserRefreshTokens(r.Context(), user.ID); err != nil {
log.Error().Msgf("Failed to revoke refresh tokens: %s", err)
}
w.WriteHeader(http.StatusNoContent)
}
func (rs usersResource) OwnerDelete(w http.ResponseWriter, r *http.Request) {
type request struct {
Password string `json:"password"`
}
var req request
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
respondError(w, http.StatusBadRequest, "Invalid request body")
return
}
user, ok := r.Context().Value(userCtxKey{}).(data.User)
if !ok {
respondError(w, http.StatusNotFound, "User not found")
return
}
// Verify the old password before allowing the deletion
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password)); err != nil {
respondError(w, http.StatusUnauthorized, "Invalid credentials")
return
}
err := rs.Users.DeleteUser(r.Context(), user.ID)
if err != nil {
respondError(w, http.StatusInternalServerError, "Failed to delete user")
return
}
if err := rs.Users.RevokeAllUserRefreshTokens(r.Context(), user.ID); err != nil {
log.Error().Msgf("Failed to revoke refresh tokens: %s", err)
}
w.WriteHeader(http.StatusNoContent)
}
// Check if the given error is a PostgreSQL error for `unique_violation`, i.e. whether an entry
// with the given details already exists in the database table.
func isDuplicateEntry(err error) bool {
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) {
return pgErr.Code == "23505"
}
return false
}

View File

@ -1,557 +0,0 @@
package service
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"git.umbrella.haus/ae/notatest/pkg/data"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgconn"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/bcrypt"
)
type mockUserStore struct {
CreateUserFunc func(context.Context, data.CreateUserParams) (data.User, error)
ListUsersFunc func(context.Context) ([]data.User, error)
GetUserByIDFunc func(context.Context, uuid.UUID) (data.User, error)
GetUserByUsernameFunc func(context.Context, string) (data.User, error)
UpdatePasswordFunc func(context.Context, data.UpdatePasswordParams) error
DeleteUserFunc func(context.Context, uuid.UUID) error
RevokeAllUserRefreshTokensFunc func(context.Context, uuid.UUID) error
}
func (m *mockUserStore) CreateUser(ctx context.Context, arg data.CreateUserParams) (data.User, error) {
return m.CreateUserFunc(ctx, arg)
}
func (m *mockUserStore) ListUsers(ctx context.Context) ([]data.User, error) {
return m.ListUsersFunc(ctx)
}
func (m *mockUserStore) GetUserByID(ctx context.Context, id uuid.UUID) (data.User, error) {
return m.GetUserByIDFunc(ctx, id)
}
func (m *mockUserStore) GetUserByUsername(ctx context.Context, username string) (data.User, error) {
return m.GetUserByUsernameFunc(ctx, username)
}
func (m *mockUserStore) UpdatePassword(ctx context.Context, arg data.UpdatePasswordParams) error {
return m.UpdatePasswordFunc(ctx, arg)
}
func (m *mockUserStore) DeleteUser(ctx context.Context, id uuid.UUID) error {
return m.DeleteUserFunc(ctx, id)
}
func (m *mockUserStore) RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error {
return m.RevokeAllUserRefreshTokensFunc(ctx, id)
}
func TestCreateUser_Duplicate(t *testing.T) {
mockStore := &mockUserStore{
CreateUserFunc: func(ctx context.Context, arg data.CreateUserParams) (data.User, error) {
return data.User{}, &pgconn.PgError{Code: "23505"}
},
}
rs := usersResource{Users: mockStore, JWTSecret: "test-secret"}
reqBody := `{"username": "existing", "password": "validPass123!"}`
req := httptest.NewRequest("POST", "/", strings.NewReader(reqBody))
w := httptest.NewRecorder()
rs.Create(w, req)
assert.Equal(t, http.StatusConflict, w.Code)
assert.Contains(t, w.Body.String(), "Username is already in use")
}
func TestCreateUser_InvalidUsername(t *testing.T) {
mockStore := &mockUserStore{} // No DB calls expected
rs := usersResource{Users: mockStore, JWTSecret: "test-secret"}
// Test various invalid usernames
tests := []struct {
name string
body string
}{
{
"too short",
`{"username": "a", "password": "validPass123!"}`,
},
{
"invalid chars",
`{"username": "user@name", "password": "validPass123!"}`,
},
{
"empty",
`{"username": "", "password": "validPass123!"}`,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest("POST", "/", strings.NewReader(tc.body))
w := httptest.NewRecorder()
rs.Create(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
})
}
}
func TestCreateUser_InvalidPassword(t *testing.T) {
mockStore := &mockUserStore{}
rs := usersResource{Users: mockStore, JWTSecret: "test-secret"}
tests := []struct {
name string
body string
}{
{
"too short",
fmt.Sprintf(`{"username": "valid", "password": "%s"}`, strings.Repeat("a", minPasswordLength-1)),
},
{
"too long",
fmt.Sprintf(`{"username": "valid", "password": "%s"}`, strings.Repeat("a", maxPasswordLength+1)),
},
{
"low entropy",
fmt.Sprintf(`{"username": "valid", "password": "%s"}`, strings.Repeat("a", minPasswordLength)),
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest("POST", "/", strings.NewReader(tc.body))
w := httptest.NewRecorder()
rs.Create(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
})
}
}
func TestListUsers_Success(t *testing.T) {
testUsers := []data.User{
{ID: uuid.New(), Username: "user1"},
{ID: uuid.New(), Username: "user2"},
}
mockStore := &mockUserStore{
ListUsersFunc: func(ctx context.Context) ([]data.User, error) {
return testUsers, nil
},
}
rs := usersResource{Users: mockStore}
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
rs.List(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var response []map[string]any
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatal(err)
}
assert.Len(t, response, 2)
assert.Equal(t, "user1", response[0]["username"])
assert.NotContains(t, response[0], "password_hash")
}
func TestUpdatePassword_InvalidOldPassword(t *testing.T) {
// User with password hash that won't match "wrongpassword"
user := data.User{
ID: uuid.New(),
PasswordHash: "$2a$10$PHhno.bZBF8IEINdFRZAPujMxIN65msElATgJG6FIxZdeWYVLSfFi", // Hash of "correctpassword"
}
mockStore := &mockUserStore{}
rs := usersResource{Users: mockStore}
reqBody := `{"old_password": "wrongpassword", "new_password": "NewValidPass321!"}`
req := httptest.NewRequest("PUT", "/", strings.NewReader(reqBody))
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, user))
w := httptest.NewRecorder()
rs.UpdatePassword(w, req)
assert.Equal(t, http.StatusUnauthorized, w.Code)
}
func TestAdminDelete_Success(t *testing.T) {
user := data.User{ID: uuid.New()}
deleteCalled := false
revokeCalled := false
mockStore := &mockUserStore{
DeleteUserFunc: func(ctx context.Context, id uuid.UUID) error {
deleteCalled = true
assert.Equal(t, user.ID, id)
return nil
},
RevokeAllUserRefreshTokensFunc: func(ctx context.Context, id uuid.UUID) error {
revokeCalled = true
assert.Equal(t, user.ID, id)
return nil
},
}
rs := usersResource{Users: mockStore}
req := httptest.NewRequest("DELETE", "/", nil)
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, user))
w := httptest.NewRecorder()
rs.AdminDelete(w, req)
assert.Equal(t, http.StatusNoContent, w.Code)
assert.True(t, deleteCalled)
assert.True(t, revokeCalled)
}
func TestOwnerDelete_InvalidCredentials(t *testing.T) {
// Create user with known password hash
correctPassword := "CorrectPass123!"
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(correctPassword), bcrypt.DefaultCost)
user := data.User{
ID: uuid.New(),
PasswordHash: string(hashedPassword),
}
mockStore := &mockUserStore{}
rs := usersResource{Users: mockStore}
reqBody := `{"password": "wrongpassword"}`
req := httptest.NewRequest("DELETE", "/", strings.NewReader(reqBody))
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, user))
w := httptest.NewRecorder()
rs.OwnerDelete(w, req)
assert.Equal(t, http.StatusUnauthorized, w.Code)
}
func TestUsersGetCurrentUser(t *testing.T) {
validUserID := uuid.New()
testTime := time.Now().UTC().Truncate(time.Second)
testUser := data.User{
ID: validUserID,
Username: "testuser",
CreatedAt: &testTime,
UpdatedAt: &testTime,
IsAdmin: false,
}
tests := []struct {
name string
setupContext func(context.Context) context.Context
mockSetup func(*mockUserStore)
wantStatus int
wantResponse string
}{
{
name: "success",
setupContext: func(ctx context.Context) context.Context {
return context.WithValue(ctx, userCtxKey{}, &userClaims{
RegisteredClaims: jwt.RegisteredClaims{
Subject: validUserID.String(),
},
})
},
mockSetup: func(m *mockUserStore) {
m.GetUserByIDFunc = func(_ context.Context, id uuid.UUID) (data.User, error) {
assert.Equal(t, validUserID, id)
return testUser, nil
}
},
wantStatus: http.StatusOK,
wantResponse: fmt.Sprintf(
`{"created_at":"%s","id":"%s","is_admin":false,"updated_at":"%s","username":"testuser"}`,
testUser.CreatedAt.Format(time.RFC3339Nano),
validUserID.String(),
testUser.UpdatedAt.Format(time.RFC3339Nano),
),
},
{
name: "user not found",
setupContext: func(ctx context.Context) context.Context {
return context.WithValue(ctx, userCtxKey{}, &userClaims{
RegisteredClaims: jwt.RegisteredClaims{
Subject: validUserID.String(),
},
})
},
mockSetup: func(m *mockUserStore) {
m.GetUserByIDFunc = func(_ context.Context, id uuid.UUID) (data.User, error) {
return data.User{}, errors.New("not found")
}
},
wantStatus: http.StatusNotFound,
wantResponse: `{"error":"User not found"}`,
},
{
name: "unauthorized",
setupContext: func(ctx context.Context) context.Context {
return ctx // No user claims in context
},
mockSetup: func(m *mockUserStore) {},
wantStatus: http.StatusUnauthorized,
wantResponse: `{"error":"Unauthorized"}`,
},
{
name: "invalid user ID",
setupContext: func(ctx context.Context) context.Context {
return context.WithValue(ctx, userCtxKey{}, &userClaims{
RegisteredClaims: jwt.RegisteredClaims{
Subject: "invalid",
},
})
},
mockSetup: func(m *mockUserStore) {},
wantStatus: http.StatusInternalServerError,
wantResponse: `{"error":"Invalid user ID"}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockStore := &mockUserStore{}
tt.mockSetup(mockStore)
rs := usersResource{Users: mockStore}
req := httptest.NewRequest("GET", "/me", nil)
req = req.WithContext(tt.setupContext(req.Context()))
w := httptest.NewRecorder()
rs.Get(w, req)
assert.Equal(t, tt.wantStatus, w.Code)
if tt.wantResponse != "" {
actual := strings.TrimSpace(w.Body.String())
assert.JSONEq(t, tt.wantResponse, actual)
}
// Verify sensitive fields are never exposed
if w.Code == http.StatusOK {
var response map[string]any
json.Unmarshal(w.Body.Bytes(), &response)
_, exists := response["password_hash"]
assert.False(t, exists, "password_hash should not be exposed")
}
})
}
}
func TestUpdatePassword_DatabaseError(t *testing.T) {
// Add user with a valid password to the context
oldPassword := "OldValidPass321!"
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(oldPassword), bcrypt.DefaultCost)
user := data.User{
ID: uuid.New(),
PasswordHash: string(hashedPassword),
}
mockStore := &mockUserStore{
UpdatePasswordFunc: func(ctx context.Context, arg data.UpdatePasswordParams) error {
return errors.New("database error")
},
}
rs := usersResource{Users: mockStore}
reqBody := fmt.Sprintf(`{"old_password": "%s", "new_password": "NewValidPass123!"}`, oldPassword)
req := httptest.NewRequest("PUT", "/", strings.NewReader(reqBody))
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, user))
w := httptest.NewRecorder()
rs.UpdatePassword(w, req)
assert.Equal(t, http.StatusInternalServerError, w.Code)
assert.Contains(t, w.Body.String(), "Failed to update password")
}
func TestUsersLogin(t *testing.T) {
validPassword := "validPass123!"
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(validPassword), bcrypt.DefaultCost)
testUser := data.User{
ID: uuid.New(),
Username: "test_username",
PasswordHash: string(hashedPassword),
}
jwtSecret := "test-secret"
tests := []struct {
name string
includeUser string
wantUserData bool
requestBody any
mockSetup func(*mockUserStore)
wantStatus int
wantResponse string
checkCookie bool
}{
{
name: "invalid request body",
requestBody: "invalid",
mockSetup: func(m *mockUserStore) {},
wantStatus: http.StatusBadRequest,
wantResponse: `{"error":"Invalid request body"}`,
},
{
name: "user not found",
requestBody: map[string]string{
"username": "nouser",
"password": validPassword,
},
mockSetup: func(m *mockUserStore) {
m.GetUserByUsernameFunc = func(_ context.Context, username string) (data.User, error) {
return data.User{}, errors.New("not found")
}
},
wantStatus: http.StatusUnauthorized,
wantResponse: `{"error":"Invalid credentials"}`,
},
{
name: "invalid password",
requestBody: map[string]string{
"username": testUser.Username,
"password": "wrongpassword",
},
mockSetup: func(m *mockUserStore) {
m.GetUserByUsernameFunc = func(_ context.Context, username string) (data.User, error) {
return testUser, nil
}
},
wantStatus: http.StatusUnauthorized,
wantResponse: `{"error":"Invalid credentials"}`,
},
{
name: "successful login with user data",
includeUser: "true",
wantUserData: true,
requestBody: map[string]string{
"username": testUser.Username,
"password": validPassword,
},
mockSetup: func(m *mockUserStore) {
m.GetUserByUsernameFunc = func(_ context.Context, username string) (data.User, error) {
return testUser, nil
}
},
wantStatus: http.StatusOK,
checkCookie: true,
},
{
name: "successful login without user data",
includeUser: "false",
wantUserData: false,
requestBody: map[string]string{
"username": testUser.Username,
"password": validPassword,
},
mockSetup: func(m *mockUserStore) {
m.GetUserByUsernameFunc = func(_ context.Context, username string) (data.User, error) {
return testUser, nil
}
},
wantStatus: http.StatusOK,
checkCookie: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockStore := &mockUserStore{}
tt.mockSetup(mockStore)
rs := usersResource{
Users: mockStore,
JWTSecret: jwtSecret,
}
body, _ := json.Marshal(tt.requestBody)
req := httptest.NewRequest("POST", "/login", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
// Add the necessary query parameters
q := url.Values{}
q.Add("includeUser", tt.includeUser)
req.URL.RawQuery = q.Encode()
w := httptest.NewRecorder()
rs.Login(w, req)
if w.Code != tt.wantStatus {
t.Errorf("expected status %d, got %d", tt.wantStatus, w.Code)
}
if tt.wantResponse != "" && strings.TrimSpace(w.Body.String()) != tt.wantResponse {
t.Errorf("expected response %q, got %q", tt.wantResponse, w.Body.String())
}
if tt.wantUserData {
var response struct {
AccessToken string `json:"access_token"`
User data.User `json:"user"` // Cast to the "raw" type to allow checking for sensitive data fields
}
json.Unmarshal(w.Body.Bytes(), &response)
assert.Equal(t, testUser.ID, response.User.ID)
assert.Equal(t, testUser.Username, response.User.Username)
assert.Empty(t, response.User.PasswordHash) // Ensure sensitive data excluded
}
if tt.checkCookie {
cookies := w.Result().Cookies()
var refreshCookie *http.Cookie
for _, cookie := range cookies {
if cookie.Name == "refresh_token" {
refreshCookie = cookie
break
}
}
if refreshCookie == nil {
t.Fatal("refresh token cookie not set")
}
assert.True(t, refreshCookie.HttpOnly, "cookie should be HttpOnly")
assert.Equal(t, http.SameSiteStrictMode, refreshCookie.SameSite, "invalid SameSite mode")
assert.Equal(t, "/", refreshCookie.Path, "invalid cookie path")
assert.Greater(t, refreshCookie.MaxAge, 0, "cookie should have expiration")
// Validate access token in response
var response map[string]string
json.Unmarshal(w.Body.Bytes(), &response)
if response["access_token"] == "" {
t.Error("access token not in response")
}
// Verify JWT validity
token, err := jwt.ParseWithClaims(
response["access_token"],
&userClaims{},
func(token *jwt.Token) (any, error) {
return []byte(jwtSecret), nil
},
)
assert.NoError(t, err, "invalid JWT")
assert.True(t, token.Valid, "invalid JWT")
}
})
}
}

View File

@ -1,10 +1,5 @@
CREATE EXTENSION IF NOT EXISTS "uuid-ossp"; CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
CREATE TABLE IF NOT EXISTS schema_migrations (
version BIGINT PRIMARY KEY,
applied_at TIMESTAMPTZ DEFAULT NOW()
);
CREATE TABLE IF NOT EXISTS users ( CREATE TABLE IF NOT EXISTS users (
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
username TEXT UNIQUE NOT NULL, username TEXT UNIQUE NOT NULL,
@ -26,6 +21,8 @@ CREATE TABLE IF NOT EXISTS refresh_tokens (
CREATE TABLE IF NOT EXISTS notes ( CREATE TABLE IF NOT EXISTS notes (
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
current_version INT NOT NULL DEFAULT 1, -- active version (can be historical)
latest_version INT NOT NULL DEFAULT 1, -- highest version number
created_at TIMESTAMPTZ DEFAULT NOW(), created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW() updated_at TIMESTAMPTZ DEFAULT NOW()
); );
@ -35,15 +32,18 @@ CREATE TABLE IF NOT EXISTS note_versions (
note_id UUID NOT NULL REFERENCES notes(id) ON DELETE CASCADE, note_id UUID NOT NULL REFERENCES notes(id) ON DELETE CASCADE,
title TEXT NOT NULL, title TEXT NOT NULL,
content TEXT NOT NULL, content TEXT NOT NULL,
version_number INT NOT NULL, version_number INT NOT NULL DEFAULT 1,
content_hash TEXT NOT NULL, content_hash TEXT NOT NULL,
created_at TIMESTAMPTZ DEFAULT NOW() created_at TIMESTAMPTZ DEFAULT NOW(),
UNIQUE (note_id, version_number)
); );
CREATE UNIQUE INDEX IF NOT EXISTS idx_users_username ON users(username); CREATE UNIQUE INDEX IF NOT EXISTS idx_users_username ON users(username);
CREATE INDEX IF NOT EXISTS idx_refresh_tokens_user_id ON refresh_tokens(user_id); CREATE INDEX IF NOT EXISTS idx_refresh_tokens_user_id ON refresh_tokens(user_id);
CREATE INDEX IF NOT EXISTS idx_refresh_tokens_expires_at ON refresh_tokens(expires_at); CREATE INDEX IF NOT EXISTS idx_refresh_tokens_expires_at ON refresh_tokens(expires_at);
CREATE UNIQUE INDEX IF NOT EXISTS idx_note_version_unique ON note_versions(note_id, version_number); CREATE INDEX IF NOT EXISTS idx_notes_user_updated ON notes(user_id, updated_at DESC);
CREATE INDEX IF NOT EXISTS idx_note_versions_note ON note_versions(note_id); CREATE INDEX IF NOT EXISTS idx_notes_current_version ON notes(current_version);
CREATE INDEX IF NOT EXISTS idx_note_versions_number ON note_versions(version_number DESC);
CREATE INDEX IF NOT EXISTS idx_note_versions_content_hash ON note_versions(note_id, content_hash);

View File

@ -1,27 +1,58 @@
-- name: CreateNoteVersion :one -- name: CreateNoteVersion :exec
INSERT INTO note_versions (note_id, title, content, version_number, content_hash) WITH potential_duplicate AS (
VALUES ( SELECT version_number
$1, FROM note_versions
$2, WHERE
$3, note_id = $1
(SELECT COALESCE(MAX(version_number), 0) + 1 FROM note_versions WHERE note_id = $1), AND content_hash = $2
encode(sha256($2::bytea || '\n'::bytea || $3::bytea), 'hex') ORDER BY version_number DESC
LIMIT 1
),
note_update AS (
UPDATE notes
SET
current_version = COALESCE(
(SELECT version_number FROM potential_duplicate),
latest_version + 1 -- increment only if we don't jump into a historical version
),
latest_version = CASE
WHEN (SELECT version_number FROM potential_duplicate) IS NULL
THEN latest_version + 1
ELSE latest_version
END,
updated_at = NOW()
WHERE id = $1
RETURNING current_version, latest_version
) )
RETURNING *; INSERT INTO note_versions (
note_id, title, content, version_number, content_hash
)
SELECT
$1, -- note_id
$3, -- title
$4, -- content
current_version,
$2 -- content_hash
FROM note_update
WHERE NOT EXISTS (SELECT 1 FROM potential_duplicate);
-- name: GetNoteVersions :many -- name: GetVersionHistory :many
SELECT * FROM note_versions SELECT
id AS version_id,
title,
version_number,
created_at
FROM note_versions
WHERE note_id = $1 WHERE note_id = $1
ORDER BY version_number DESC ORDER BY version_number DESC
LIMIT $2 OFFSET $3; LIMIT $2 OFFSET $3;
-- name: GetNoteVersion :one -- name: GetVersion :one
SELECT * FROM note_versions SELECT
WHERE note_id = $1 AND version_number = $2 LIMIT 1; id AS version_id,
title,
-- name: FindDuplicateContent :one content,
SELECT EXISTS( version_number,
SELECT 1 FROM note_versions created_at
WHERE note_id = $1 FROM note_versions
AND content_hash = encode(sha256($2::bytea || '\n'::bytea || $3::bytea), 'hex') WHERE note_id = $1 AND id = $2;
);

View File

@ -3,16 +3,34 @@ INSERT INTO notes (user_id)
VALUES ($1) VALUES ($1)
RETURNING *; RETURNING *;
-- name: GetNote :one
SELECT * FROM notes
WHERE id = $1 AND user_id = $2 LIMIT 1;
-- name: ListNotes :many -- name: ListNotes :many
SELECT * FROM notes SELECT
WHERE user_id = $1 n.id AS note_id,
ORDER BY created_at DESC n.user_id AS owner_id,
nv.title,
n.updated_at
FROM notes n
JOIN note_versions nv
ON n.id = nv.note_id AND n.current_version = nv.version_number
WHERE n.user_id = $1
ORDER BY n.updated_at DESC
LIMIT $2 OFFSET $3; LIMIT $2 OFFSET $3;
-- name: GetFullNote :one
SELECT
n.id AS note_id,
n.user_id AS owner_id,
nv.title,
nv.content,
nv.version_number,
nv.created_at AS version_created_at,
n.created_at AS note_created_at,
n.updated_at AS note_updated_at
FROM notes n
JOIN note_versions nv
ON n.id = nv.note_id AND n.current_version = nv.version_number
WHERE n.id = $1;
-- name: DeleteNote :exec -- name: DeleteNote :exec
DELETE FROM notes DELETE FROM notes
WHERE id = $1 AND user_id = $2; WHERE id = $1 AND user_id = $2;

View File

@ -3,9 +3,18 @@ INSERT INTO users (username, password_hash)
VALUES ($1, $2) VALUES ($1, $2)
RETURNING *; RETURNING *;
-- name: CreateAdmin :one
INSERT INTO users (username, password_hash, is_admin)
VALUES ($1, $2, true)
RETURNING *;
-- name: ListUsers :many -- name: ListUsers :many
SELECT * FROM users; SELECT * FROM users;
-- name: ListAdmins :many
SELECT * FROM users
WHERE is_admin = true;
-- name: GetUserByID :one -- name: GetUserByID :one
SELECT * FROM users SELECT * FROM users
WHERE id = $1 LIMIT 1; WHERE id = $1 LIMIT 1;

View File

@ -7,7 +7,7 @@ sql:
gen: gen:
go: go:
package: "data" package: "data"
out: "../pkg/data" out: "../internal/data"
sql_package: "pgx/v5" sql_package: "pgx/v5"
emit_json_tags: true emit_json_tags: true
overrides: overrides: