feat!: trimming & logic/schema improvements
- build: somewhat polished dockerization setup - build: io/fs migrations with `golang-migrate` - feat: automatic init. admin account creation (.env creds) - feat(routers): combined user & token routers into single auth router - feat(routers): improved route layouts (`Routes`) - feat(middlewares): removed redundant `userCtx` middleware - fix(schema): note <-> note_versions relation (versioning) - feat(queries): removed redundant rollback functionality - feat(queries): combined duplicate version check & insertion/creation - tests: decreased redundancy by removing 'unnecessary' unit tests - refactor: hid internal packages behind `server/internal` - docs: notes & auth handler comments
This commit is contained in:
parent
b1edbeb0a3
commit
62b1a58e56
@ -0,0 +1,19 @@
|
|||||||
|
# Build stage
|
||||||
|
FROM golang:1.24.2-alpine AS builder
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
COPY go.mod go.sum .
|
||||||
|
RUN go mod download
|
||||||
|
|
||||||
|
COPY . .
|
||||||
|
# Optionally we could also strip debug symbols with `-ldflags '-s'`
|
||||||
|
RUN CGO_ENABLED=0 GOOS=linux go build -o /notatest .
|
||||||
|
|
||||||
|
# Final stage (optimized image size)
|
||||||
|
FROM alpine:latest
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
COPY --from=builder /notatest /app/notatest
|
||||||
|
EXPOSE 8080
|
||||||
|
|
||||||
|
CMD ["/app/notatest"]
|
@ -3,9 +3,10 @@ module git.umbrella.haus/ae/notatest
|
|||||||
go 1.24.1
|
go 1.24.1
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/caarlos0/env v3.5.0+incompatible
|
github.com/caarlos0/env/v10 v10.0.0
|
||||||
github.com/go-chi/chi/v5 v5.2.1
|
github.com/go-chi/chi/v5 v5.2.1
|
||||||
github.com/golang-jwt/jwt/v5 v5.2.2
|
github.com/golang-jwt/jwt/v5 v5.2.2
|
||||||
|
github.com/golang-migrate/migrate/v4 v4.18.2
|
||||||
github.com/google/uuid v1.6.0
|
github.com/google/uuid v1.6.0
|
||||||
github.com/jackc/pgx/v5 v5.7.4
|
github.com/jackc/pgx/v5 v5.7.4
|
||||||
github.com/rs/zerolog v1.34.0
|
github.com/rs/zerolog v1.34.0
|
||||||
@ -16,14 +17,16 @@ require (
|
|||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
|
github.com/hashicorp/errwrap v1.1.0 // indirect
|
||||||
|
github.com/hashicorp/go-multierror v1.1.1 // indirect
|
||||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||||
github.com/kr/text v0.2.0 // indirect
|
github.com/lib/pq v1.10.9 // indirect
|
||||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
github.com/rogpeppe/go-internal v1.14.1 // indirect
|
|
||||||
github.com/stretchr/objx v0.5.2 // indirect
|
github.com/stretchr/objx v0.5.2 // indirect
|
||||||
|
go.uber.org/atomic v1.11.0 // indirect
|
||||||
golang.org/x/sys v0.31.0 // indirect
|
golang.org/x/sys v0.31.0 // indirect
|
||||||
golang.org/x/text v0.23.0 // indirect
|
golang.org/x/text v0.23.0 // indirect
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
|
@ -1,17 +1,45 @@
|
|||||||
github.com/caarlos0/env v3.5.0+incompatible h1:Yy0UN8o9Wtr/jGHZDpCBLpNrzcFLLM2yixi/rBrKyJs=
|
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0=
|
||||||
github.com/caarlos0/env v3.5.0+incompatible/go.mod h1:tdCsowwCzMLdkqRYDlHpZCp2UooDD3MspDBjZ2AD02Y=
|
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
|
||||||
|
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||||
|
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||||
|
github.com/caarlos0/env/v10 v10.0.0 h1:yIHUBZGsyqCnpTkbjk8asUlx6RFhhEs+h7TOBdgdzXA=
|
||||||
|
github.com/caarlos0/env/v10 v10.0.0/go.mod h1:ZfulV76NvVPw3tm591U4SwL3Xx9ldzBP9aGxzeN7G18=
|
||||||
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
|
||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/dhui/dktest v0.4.4 h1:+I4s6JRE1yGuqflzwqG+aIaMdgXIorCf5P98JnaAWa8=
|
||||||
|
github.com/dhui/dktest v0.4.4/go.mod h1:4+22R4lgsdAXrDyaH4Nqx2JEz2hLp49MqQmm9HLCQhM=
|
||||||
|
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||||
|
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
||||||
|
github.com/docker/docker v27.2.0+incompatible h1:Rk9nIVdfH3+Vz4cyI/uhbINhEZ/oLmc+CBXmH6fbNk4=
|
||||||
|
github.com/docker/docker v27.2.0+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
||||||
|
github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c=
|
||||||
|
github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc=
|
||||||
|
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
|
||||||
|
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
||||||
|
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
|
||||||
|
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||||
github.com/go-chi/chi/v5 v5.2.1 h1:KOIHODQj58PmL80G2Eak4WdvUzjSJSm0vG72crDCqb8=
|
github.com/go-chi/chi/v5 v5.2.1 h1:KOIHODQj58PmL80G2Eak4WdvUzjSJSm0vG72crDCqb8=
|
||||||
github.com/go-chi/chi/v5 v5.2.1/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops=
|
github.com/go-chi/chi/v5 v5.2.1/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops=
|
||||||
|
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
|
||||||
|
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||||
|
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||||
|
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||||
|
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
||||||
|
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||||
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
||||||
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||||
|
github.com/golang-migrate/migrate/v4 v4.18.2 h1:2VSCMz7x7mjyTXx3m2zPokOY82LTRgxK1yQYKo6wWQ8=
|
||||||
|
github.com/golang-migrate/migrate/v4 v4.18.2/go.mod h1:2CM6tJvn2kqPXwnXO/d3rAQYiyoIm180VsO8PRX6Rpk=
|
||||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
|
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
||||||
|
github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
|
||||||
|
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
||||||
|
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
|
||||||
|
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
|
||||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||||
@ -24,6 +52,8 @@ github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
|
|||||||
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
|
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
|
||||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||||
|
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||||
|
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||||
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||||
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
|
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
|
||||||
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
|
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
|
||||||
@ -31,11 +61,22 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/
|
|||||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
|
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
|
||||||
|
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
|
||||||
|
github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0=
|
||||||
|
github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y=
|
||||||
|
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
||||||
|
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
||||||
|
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||||
|
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||||
|
github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug=
|
||||||
|
github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM=
|
||||||
|
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
|
||||||
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
|
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
|
||||||
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
|
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
|
||||||
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
|
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
|
||||||
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
|
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
|
||||||
@ -48,6 +89,16 @@ github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOf
|
|||||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
github.com/wagslane/go-password-validator v0.3.0 h1:vfxOPzGHkz5S146HDpavl0cw1DSVP061Ry2PX0/ON6I=
|
github.com/wagslane/go-password-validator v0.3.0 h1:vfxOPzGHkz5S146HDpavl0cw1DSVP061Ry2PX0/ON6I=
|
||||||
github.com/wagslane/go-password-validator v0.3.0/go.mod h1:TI1XJ6T5fRdRnHqHt14pvy1tNVnrwe7m3/f1f2fDphQ=
|
github.com/wagslane/go-password-validator v0.3.0/go.mod h1:TI1XJ6T5fRdRnHqHt14pvy1tNVnrwe7m3/f1f2fDphQ=
|
||||||
|
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 h1:TT4fX+nBOA/+LUkobKGW1ydGcn+G3vRw9+g5HwCphpk=
|
||||||
|
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0/go.mod h1:L7UH0GbB0p47T4Rri3uHjbpCFYrVrwc1I25QhNPiGK8=
|
||||||
|
go.opentelemetry.io/otel v1.29.0 h1:PdomN/Al4q/lN6iBJEN3AwPvUiHPMlt93c8bqTG5Llw=
|
||||||
|
go.opentelemetry.io/otel v1.29.0/go.mod h1:N/WtXPs1CNCUEx+Agz5uouwCba+i+bJGFicT8SR4NP8=
|
||||||
|
go.opentelemetry.io/otel/metric v1.29.0 h1:vPf/HFWTNkPu1aYeIsc98l4ktOQaL6LeSoeV2g+8YLc=
|
||||||
|
go.opentelemetry.io/otel/metric v1.29.0/go.mod h1:auu/QWieFVWx+DmQOUMgj0F8LHWdgalxXqvp7BII/W8=
|
||||||
|
go.opentelemetry.io/otel/trace v1.29.0 h1:J/8ZNK4XgR7a21DZUAsbF8pZ5Jcw1VhACmnYt39JTi4=
|
||||||
|
go.opentelemetry.io/otel/trace v1.29.0/go.mod h1:eHl3w0sp3paPkYstJOmAimxhiFXPg+MMTlEh3nsQgWQ=
|
||||||
|
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
|
||||||
|
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
|
||||||
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
||||||
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
||||||
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
|
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
|
||||||
|
@ -11,10 +11,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Note struct {
|
type Note struct {
|
||||||
ID uuid.UUID `json:"id"`
|
ID uuid.UUID `json:"id"`
|
||||||
UserID uuid.UUID `json:"user_id"`
|
UserID uuid.UUID `json:"user_id"`
|
||||||
CreatedAt *time.Time `json:"created_at"`
|
CurrentVersion int32 `json:"current_version"`
|
||||||
UpdatedAt *time.Time `json:"updated_at"`
|
LatestVersion int32 `json:"latest_version"`
|
||||||
|
CreatedAt *time.Time `json:"created_at"`
|
||||||
|
UpdatedAt *time.Time `json:"updated_at"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type NoteVersion struct {
|
type NoteVersion struct {
|
||||||
@ -36,11 +38,6 @@ type RefreshToken struct {
|
|||||||
Revoked bool `json:"revoked"`
|
Revoked bool `json:"revoked"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type SchemaMigration struct {
|
|
||||||
Version int64 `json:"version"`
|
|
||||||
AppliedAt *time.Time `json:"applied_at"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
ID uuid.UUID `json:"id"`
|
ID uuid.UUID `json:"id"`
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
156
server/internal/data/note_versions.sql.go
Normal file
156
server/internal/data/note_versions.sql.go
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
// Code generated by sqlc. DO NOT EDIT.
|
||||||
|
// versions:
|
||||||
|
// sqlc v1.28.0
|
||||||
|
// source: note_versions.sql
|
||||||
|
|
||||||
|
package data
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
const createNoteVersion = `-- name: CreateNoteVersion :exec
|
||||||
|
WITH potential_duplicate AS (
|
||||||
|
SELECT version_number
|
||||||
|
FROM note_versions
|
||||||
|
WHERE
|
||||||
|
note_id = $1
|
||||||
|
AND content_hash = $2
|
||||||
|
ORDER BY version_number DESC
|
||||||
|
LIMIT 1
|
||||||
|
),
|
||||||
|
note_update AS (
|
||||||
|
UPDATE notes
|
||||||
|
SET
|
||||||
|
current_version = COALESCE(
|
||||||
|
(SELECT version_number FROM potential_duplicate),
|
||||||
|
latest_version + 1 -- increment only if we don't jump into a historical version
|
||||||
|
),
|
||||||
|
latest_version = CASE
|
||||||
|
WHEN (SELECT version_number FROM potential_duplicate) IS NULL
|
||||||
|
THEN latest_version + 1
|
||||||
|
ELSE latest_version
|
||||||
|
END,
|
||||||
|
updated_at = NOW()
|
||||||
|
WHERE id = $1
|
||||||
|
RETURNING current_version, latest_version
|
||||||
|
)
|
||||||
|
INSERT INTO note_versions (
|
||||||
|
note_id, title, content, version_number, content_hash
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
$1, -- note_id
|
||||||
|
$3, -- title
|
||||||
|
$4, -- content
|
||||||
|
current_version,
|
||||||
|
$2 -- content_hash
|
||||||
|
FROM note_update
|
||||||
|
WHERE NOT EXISTS (SELECT 1 FROM potential_duplicate)
|
||||||
|
`
|
||||||
|
|
||||||
|
type CreateNoteVersionParams struct {
|
||||||
|
NoteID uuid.UUID `json:"note_id"`
|
||||||
|
ContentHash string `json:"content_hash"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) CreateNoteVersion(ctx context.Context, arg CreateNoteVersionParams) error {
|
||||||
|
_, err := q.db.Exec(ctx, createNoteVersion,
|
||||||
|
arg.NoteID,
|
||||||
|
arg.ContentHash,
|
||||||
|
arg.Title,
|
||||||
|
arg.Content,
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
const getVersion = `-- name: GetVersion :one
|
||||||
|
SELECT
|
||||||
|
id AS version_id,
|
||||||
|
title,
|
||||||
|
content,
|
||||||
|
version_number,
|
||||||
|
created_at
|
||||||
|
FROM note_versions
|
||||||
|
WHERE note_id = $1 AND id = $2
|
||||||
|
`
|
||||||
|
|
||||||
|
type GetVersionParams struct {
|
||||||
|
NoteID uuid.UUID `json:"note_id"`
|
||||||
|
ID uuid.UUID `json:"id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GetVersionRow struct {
|
||||||
|
VersionID uuid.UUID `json:"version_id"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
VersionNumber int32 `json:"version_number"`
|
||||||
|
CreatedAt *time.Time `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) GetVersion(ctx context.Context, arg GetVersionParams) (GetVersionRow, error) {
|
||||||
|
row := q.db.QueryRow(ctx, getVersion, arg.NoteID, arg.ID)
|
||||||
|
var i GetVersionRow
|
||||||
|
err := row.Scan(
|
||||||
|
&i.VersionID,
|
||||||
|
&i.Title,
|
||||||
|
&i.Content,
|
||||||
|
&i.VersionNumber,
|
||||||
|
&i.CreatedAt,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const getVersionHistory = `-- name: GetVersionHistory :many
|
||||||
|
SELECT
|
||||||
|
id AS version_id,
|
||||||
|
title,
|
||||||
|
version_number,
|
||||||
|
created_at
|
||||||
|
FROM note_versions
|
||||||
|
WHERE note_id = $1
|
||||||
|
ORDER BY version_number DESC
|
||||||
|
LIMIT $2 OFFSET $3
|
||||||
|
`
|
||||||
|
|
||||||
|
type GetVersionHistoryParams struct {
|
||||||
|
NoteID uuid.UUID `json:"note_id"`
|
||||||
|
Limit int32 `json:"limit"`
|
||||||
|
Offset int32 `json:"offset"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GetVersionHistoryRow struct {
|
||||||
|
VersionID uuid.UUID `json:"version_id"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
VersionNumber int32 `json:"version_number"`
|
||||||
|
CreatedAt *time.Time `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) GetVersionHistory(ctx context.Context, arg GetVersionHistoryParams) ([]GetVersionHistoryRow, error) {
|
||||||
|
rows, err := q.db.Query(ctx, getVersionHistory, arg.NoteID, arg.Limit, arg.Offset)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
var items []GetVersionHistoryRow
|
||||||
|
for rows.Next() {
|
||||||
|
var i GetVersionHistoryRow
|
||||||
|
if err := rows.Scan(
|
||||||
|
&i.VersionID,
|
||||||
|
&i.Title,
|
||||||
|
&i.VersionNumber,
|
||||||
|
&i.CreatedAt,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
items = append(items, i)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return items, nil
|
||||||
|
}
|
143
server/internal/data/notes.sql.go
Normal file
143
server/internal/data/notes.sql.go
Normal file
@ -0,0 +1,143 @@
|
|||||||
|
// Code generated by sqlc. DO NOT EDIT.
|
||||||
|
// versions:
|
||||||
|
// sqlc v1.28.0
|
||||||
|
// source: notes.sql
|
||||||
|
|
||||||
|
package data
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
const createNote = `-- name: CreateNote :one
|
||||||
|
INSERT INTO notes (user_id)
|
||||||
|
VALUES ($1)
|
||||||
|
RETURNING id, user_id, current_version, latest_version, created_at, updated_at
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) CreateNote(ctx context.Context, userID uuid.UUID) (Note, error) {
|
||||||
|
row := q.db.QueryRow(ctx, createNote, userID)
|
||||||
|
var i Note
|
||||||
|
err := row.Scan(
|
||||||
|
&i.ID,
|
||||||
|
&i.UserID,
|
||||||
|
&i.CurrentVersion,
|
||||||
|
&i.LatestVersion,
|
||||||
|
&i.CreatedAt,
|
||||||
|
&i.UpdatedAt,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const deleteNote = `-- name: DeleteNote :exec
|
||||||
|
DELETE FROM notes
|
||||||
|
WHERE id = $1 AND user_id = $2
|
||||||
|
`
|
||||||
|
|
||||||
|
type DeleteNoteParams struct {
|
||||||
|
ID uuid.UUID `json:"id"`
|
||||||
|
UserID uuid.UUID `json:"user_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) DeleteNote(ctx context.Context, arg DeleteNoteParams) error {
|
||||||
|
_, err := q.db.Exec(ctx, deleteNote, arg.ID, arg.UserID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
const getFullNote = `-- name: GetFullNote :one
|
||||||
|
SELECT
|
||||||
|
n.id AS note_id,
|
||||||
|
n.user_id AS owner_id,
|
||||||
|
nv.title,
|
||||||
|
nv.content,
|
||||||
|
nv.version_number,
|
||||||
|
nv.created_at AS version_created_at,
|
||||||
|
n.created_at AS note_created_at,
|
||||||
|
n.updated_at AS note_updated_at
|
||||||
|
FROM notes n
|
||||||
|
JOIN note_versions nv
|
||||||
|
ON n.id = nv.note_id AND n.current_version = nv.version_number
|
||||||
|
WHERE n.id = $1
|
||||||
|
`
|
||||||
|
|
||||||
|
type GetFullNoteRow struct {
|
||||||
|
NoteID uuid.UUID `json:"note_id"`
|
||||||
|
OwnerID uuid.UUID `json:"owner_id"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
VersionNumber int32 `json:"version_number"`
|
||||||
|
VersionCreatedAt *time.Time `json:"version_created_at"`
|
||||||
|
NoteCreatedAt *time.Time `json:"note_created_at"`
|
||||||
|
NoteUpdatedAt *time.Time `json:"note_updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) GetFullNote(ctx context.Context, id uuid.UUID) (GetFullNoteRow, error) {
|
||||||
|
row := q.db.QueryRow(ctx, getFullNote, id)
|
||||||
|
var i GetFullNoteRow
|
||||||
|
err := row.Scan(
|
||||||
|
&i.NoteID,
|
||||||
|
&i.OwnerID,
|
||||||
|
&i.Title,
|
||||||
|
&i.Content,
|
||||||
|
&i.VersionNumber,
|
||||||
|
&i.VersionCreatedAt,
|
||||||
|
&i.NoteCreatedAt,
|
||||||
|
&i.NoteUpdatedAt,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const listNotes = `-- name: ListNotes :many
|
||||||
|
SELECT
|
||||||
|
n.id AS note_id,
|
||||||
|
n.user_id AS owner_id,
|
||||||
|
nv.title,
|
||||||
|
n.updated_at
|
||||||
|
FROM notes n
|
||||||
|
JOIN note_versions nv
|
||||||
|
ON n.id = nv.note_id AND n.current_version = nv.version_number
|
||||||
|
WHERE n.user_id = $1
|
||||||
|
ORDER BY n.updated_at DESC
|
||||||
|
LIMIT $2 OFFSET $3
|
||||||
|
`
|
||||||
|
|
||||||
|
type ListNotesParams struct {
|
||||||
|
UserID uuid.UUID `json:"user_id"`
|
||||||
|
Limit int32 `json:"limit"`
|
||||||
|
Offset int32 `json:"offset"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ListNotesRow struct {
|
||||||
|
NoteID uuid.UUID `json:"note_id"`
|
||||||
|
OwnerID uuid.UUID `json:"owner_id"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
UpdatedAt *time.Time `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) ListNotes(ctx context.Context, arg ListNotesParams) ([]ListNotesRow, error) {
|
||||||
|
rows, err := q.db.Query(ctx, listNotes, arg.UserID, arg.Limit, arg.Offset)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
var items []ListNotesRow
|
||||||
|
for rows.Next() {
|
||||||
|
var i ListNotesRow
|
||||||
|
if err := rows.Scan(
|
||||||
|
&i.NoteID,
|
||||||
|
&i.OwnerID,
|
||||||
|
&i.Title,
|
||||||
|
&i.UpdatedAt,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
items = append(items, i)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return items, nil
|
||||||
|
}
|
@ -11,6 +11,31 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const createAdmin = `-- name: CreateAdmin :one
|
||||||
|
INSERT INTO users (username, password_hash, is_admin)
|
||||||
|
VALUES ($1, $2, true)
|
||||||
|
RETURNING id, username, password_hash, is_admin, created_at, updated_at
|
||||||
|
`
|
||||||
|
|
||||||
|
type CreateAdminParams struct {
|
||||||
|
Username string `json:"username"`
|
||||||
|
PasswordHash string `json:"password_hash"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) CreateAdmin(ctx context.Context, arg CreateAdminParams) (User, error) {
|
||||||
|
row := q.db.QueryRow(ctx, createAdmin, arg.Username, arg.PasswordHash)
|
||||||
|
var i User
|
||||||
|
err := row.Scan(
|
||||||
|
&i.ID,
|
||||||
|
&i.Username,
|
||||||
|
&i.PasswordHash,
|
||||||
|
&i.IsAdmin,
|
||||||
|
&i.CreatedAt,
|
||||||
|
&i.UpdatedAt,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
const createUser = `-- name: CreateUser :one
|
const createUser = `-- name: CreateUser :one
|
||||||
INSERT INTO users (username, password_hash)
|
INSERT INTO users (username, password_hash)
|
||||||
VALUES ($1, $2)
|
VALUES ($1, $2)
|
||||||
@ -84,6 +109,38 @@ func (q *Queries) GetUserByUsername(ctx context.Context, username string) (User,
|
|||||||
return i, err
|
return i, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const listAdmins = `-- name: ListAdmins :many
|
||||||
|
SELECT id, username, password_hash, is_admin, created_at, updated_at FROM users
|
||||||
|
WHERE is_admin = true
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) ListAdmins(ctx context.Context) ([]User, error) {
|
||||||
|
rows, err := q.db.Query(ctx, listAdmins)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
var items []User
|
||||||
|
for rows.Next() {
|
||||||
|
var i User
|
||||||
|
if err := rows.Scan(
|
||||||
|
&i.ID,
|
||||||
|
&i.Username,
|
||||||
|
&i.PasswordHash,
|
||||||
|
&i.IsAdmin,
|
||||||
|
&i.CreatedAt,
|
||||||
|
&i.UpdatedAt,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
items = append(items, i)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return items, nil
|
||||||
|
}
|
||||||
|
|
||||||
const listUsers = `-- name: ListUsers :many
|
const listUsers = `-- name: ListUsers :many
|
||||||
SELECT id, username, password_hash, is_admin, created_at, updated_at FROM users
|
SELECT id, username, password_hash, is_admin, created_at, updated_at FROM users
|
||||||
`
|
`
|
658
server/internal/service/auth.go
Normal file
658
server/internal/service/auth.go
Normal file
@ -0,0 +1,658 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.umbrella.haus/ae/notatest/internal/data"
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
accessTokenDuration = 15 * time.Minute
|
||||||
|
refreshTokenDuration = 7 * 24 * time.Hour
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrInvalidToken = errors.New("invalid token")
|
||||||
|
ErrAuthHeaderInvalid = errors.New("token couldn't be parsed from authentication header")
|
||||||
|
)
|
||||||
|
|
||||||
|
// User object context key for incoming requests (handled by middlewares). Only `*userClaims` type
|
||||||
|
// objects should be stored behind this key for consistency.
|
||||||
|
type userCtxKey struct{}
|
||||||
|
|
||||||
|
// DTO without sensitive data fields such as user's password hash.
|
||||||
|
type userResponse struct {
|
||||||
|
ID uuid.UUID `json:"id"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
IsAdmin bool `json:"is_admin"`
|
||||||
|
CreatedAt *time.Time `json:"created_at"`
|
||||||
|
UpdatedAt *time.Time `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Custom JWT claims (should always be handled in the middleware layer).
|
||||||
|
type userClaims struct {
|
||||||
|
Admin bool `json:"admin"`
|
||||||
|
TokenType string `json:"type"` // "access" or "refresh"
|
||||||
|
jwt.RegisteredClaims // User's UUID should be stored in the subject claim
|
||||||
|
}
|
||||||
|
|
||||||
|
type tokenPair struct {
|
||||||
|
AccessToken string
|
||||||
|
RefreshToken string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mockable token related database operations interface.
|
||||||
|
type TokenStore interface {
|
||||||
|
CreateRefreshToken(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error)
|
||||||
|
GetRefreshTokenByHash(ctx context.Context, tokenHash string) (data.RefreshToken, error)
|
||||||
|
RevokeRefreshToken(ctx context.Context, tokenHash string) error
|
||||||
|
RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mockable user related database operations interface.
|
||||||
|
type UserStore interface {
|
||||||
|
CreateUser(ctx context.Context, arg data.CreateUserParams) (data.User, error)
|
||||||
|
ListUsers(ctx context.Context) ([]data.User, error)
|
||||||
|
GetUserByID(ctx context.Context, id uuid.UUID) (data.User, error)
|
||||||
|
GetUserByUsername(ctx context.Context, username string) (data.User, error)
|
||||||
|
UpdatePassword(ctx context.Context, arg data.UpdatePasswordParams) error
|
||||||
|
DeleteUser(ctx context.Context, id uuid.UUID) error
|
||||||
|
RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Chi HTTP router for authentication/authorization related actions (users/tokens). In theory
|
||||||
|
// (especially in production) the `UserStore` and `TokenStore` will point to the same database
|
||||||
|
// handler, but for code readability they should be kept in separate structs.
|
||||||
|
type authResource struct {
|
||||||
|
JWTSecret string
|
||||||
|
Users UserStore
|
||||||
|
Tokens TokenStore
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rs authResource) Routes() chi.Router {
|
||||||
|
r := chi.NewRouter()
|
||||||
|
|
||||||
|
// Public routes
|
||||||
|
r.Post("/signup", rs.Create) // POST /auth/signup - registration
|
||||||
|
r.Post("/login", rs.Login) // POST /auth/login - login
|
||||||
|
|
||||||
|
// Protected routes (access token required)
|
||||||
|
r.Group(func(r chi.Router) {
|
||||||
|
r.Use(requireAccessToken(rs.JWTSecret)) // JWT claims -> ctx.
|
||||||
|
r.Get("/me", rs.Get) // GET /auth/me - current user data
|
||||||
|
r.Post("/logout", rs.Logout) // POST /auth/logout - revoke all refresh cookies
|
||||||
|
|
||||||
|
// Owner routes
|
||||||
|
r.Route("/owner", func(r chi.Router) {
|
||||||
|
r.Put("/", rs.UpdatePassword) // PUT /auth/owner - update user password
|
||||||
|
r.Delete("/", rs.OwnerDelete) // DELETE /auth/owner - delete user (owner)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Administration routes (admin claim required)
|
||||||
|
r.Route("/admin", func(r chi.Router) {
|
||||||
|
r.Use(adminOnlyMiddleware)
|
||||||
|
r.Get("/all", rs.List) // GET /auth/admin/all - list all users
|
||||||
|
r.Route(fmt.Sprintf("/{%s}", targetUserUUIDCtxParameter), func(r chi.Router) {
|
||||||
|
r.Use(uuidCtx(targetUserUUIDCtxParameter))
|
||||||
|
r.Delete("/", rs.AdminDelete) // DELETE /auth/admin/{id} - delete user (admin)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// Protected routes (refresh token required)
|
||||||
|
r.Group(func(r chi.Router) {
|
||||||
|
r.Use(requireRefreshToken(rs.JWTSecret))
|
||||||
|
r.Post("/refresh", rs.RefreshAccessToken) // POST /auth/refresh - convert refresh token to new token pair
|
||||||
|
})
|
||||||
|
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handler for new user creation. Will check the incoming JSON object's integrity, validate/normalize
|
||||||
|
// the username, and validate the password (check whether it's compromised via the HIBP API and
|
||||||
|
// calculate its entropy).
|
||||||
|
func (rs authResource) Create(w http.ResponseWriter, r *http.Request) {
|
||||||
|
type request struct {
|
||||||
|
Username *string `json:"username"`
|
||||||
|
Password *string `json:"password"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var req request
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.Username == nil || req.Password == nil {
|
||||||
|
respondError(w, http.StatusBadRequest, "Invalid request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Username normalization (to lowercase)
|
||||||
|
normalizedUsername := normalizeUsername(*req.Username)
|
||||||
|
if err := validateUsername(normalizedUsername); err != nil {
|
||||||
|
respondError(w, http.StatusBadRequest, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Password validation (length, HIBP API, and entropy)
|
||||||
|
if err := validatePassword(*req.Password); err != nil {
|
||||||
|
respondError(w, http.StatusBadRequest, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(*req.Password), bcrypt.DefaultCost)
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusInternalServerError, "Failed to create user")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := rs.Users.CreateUser(r.Context(), data.CreateUserParams{
|
||||||
|
Username: normalizedUsername,
|
||||||
|
PasswordHash: string(hashedPassword),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
if isDuplicateEntry(err) {
|
||||||
|
respondError(w, http.StatusConflict, "Username is already in use")
|
||||||
|
} else {
|
||||||
|
respondError(w, http.StatusInternalServerError, "Failed to create user")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
respondJSON(w, http.StatusCreated, map[string]string{
|
||||||
|
"id": user.ID.String(),
|
||||||
|
"username": user.Username,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handler for logging in to an existing user account using a username-password credentials pair.
|
||||||
|
// Will check the incoming JSON object's integrity, use a normalized version of the username for a
|
||||||
|
// database lookup, and compare the given password's hash against the one stored in the database.
|
||||||
|
// By default only returns a fresh access tokens (and a refresh token as a httpOnly cookie), but
|
||||||
|
// if the `includeUser` parameter is set to `true` the user DTO will also be included.
|
||||||
|
func (rs authResource) Login(w http.ResponseWriter, r *http.Request) {
|
||||||
|
type request struct {
|
||||||
|
Username *string `json:"username"`
|
||||||
|
Password *string `json:"password"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var req request
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.Username == nil || req.Password == nil {
|
||||||
|
respondError(w, http.StatusBadRequest, "Invalid request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := rs.Users.GetUserByUsername(r.Context(), normalizeUsername(*req.Username))
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusUnauthorized, "Invalid credentials")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(*req.Password)); err != nil {
|
||||||
|
respondError(w, http.StatusUnauthorized, "Invalid credentials")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a new access/refresh token pair and store a SHA256 hash of the refresh token into
|
||||||
|
// the database for further token rotations.
|
||||||
|
tokenPair, err := rs.GenerateTokenPair(r.Context(), user.ID, user.IsAdmin)
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusInternalServerError, "Failed to generate tokens")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set refresh token into a httpOnly cookie
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: "refresh_token",
|
||||||
|
Value: tokenPair.RefreshToken,
|
||||||
|
Path: "/",
|
||||||
|
MaxAge: int(refreshTokenDuration.Seconds()),
|
||||||
|
HttpOnly: true,
|
||||||
|
Secure: true,
|
||||||
|
SameSite: http.SameSiteStrictMode,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Build response
|
||||||
|
response := map[string]any{
|
||||||
|
"access_token": tokenPair.AccessToken,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Include user data DTO into the response if the `includeUser` parameter was set to `true`
|
||||||
|
if includeUser, _ := strconv.ParseBool(r.URL.Query().Get("includeUser")); includeUser {
|
||||||
|
response["user"] = userResponse{
|
||||||
|
ID: user.ID,
|
||||||
|
Username: user.Username,
|
||||||
|
IsAdmin: user.IsAdmin,
|
||||||
|
CreatedAt: user.CreatedAt,
|
||||||
|
UpdatedAt: user.UpdatedAt,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
respondJSON(w, http.StatusOK, response)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handler for getting full data of the current user as a DTO (database lookup) based on the JWT
|
||||||
|
// claims set into the request's context by a middleware.
|
||||||
|
func (rs authResource) Get(w http.ResponseWriter, r *http.Request) {
|
||||||
|
user := rs.userFromCtxClaims(w, r)
|
||||||
|
respondJSON(w, http.StatusOK, userResponse{
|
||||||
|
ID: user.ID,
|
||||||
|
Username: user.Username,
|
||||||
|
IsAdmin: user.IsAdmin,
|
||||||
|
CreatedAt: user.CreatedAt,
|
||||||
|
UpdatedAt: user.UpdatedAt,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handler for updating the current user's password. Performs the same password strength checks as
|
||||||
|
// the registration handler (`rs.Create`) and revokes any existing refresh tokens the user has
|
||||||
|
// stored in the database.
|
||||||
|
func (rs authResource) UpdatePassword(w http.ResponseWriter, r *http.Request) {
|
||||||
|
type request struct {
|
||||||
|
OldPassword string `json:"old_password"`
|
||||||
|
NewPassword string `json:"new_password"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var req request
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
respondError(w, http.StatusBadRequest, "Invalid request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
user := rs.userFromCtxClaims(w, r)
|
||||||
|
|
||||||
|
// Verify the old password before proceeding with the update
|
||||||
|
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.OldPassword)); err != nil {
|
||||||
|
respondError(w, http.StatusUnauthorized, "Invalid credentials")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := validatePassword(req.NewPassword); err != nil {
|
||||||
|
respondError(w, http.StatusBadRequest, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.NewPassword), bcrypt.DefaultCost)
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusInternalServerError, "Failed to update password")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rs.Users.UpdatePassword(r.Context(), data.UpdatePasswordParams{
|
||||||
|
ID: user.ID,
|
||||||
|
PasswordHash: string(hashedPassword),
|
||||||
|
}); err != nil {
|
||||||
|
respondError(w, http.StatusInternalServerError, "Failed to update password")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rs.Users.RevokeAllUserRefreshTokens(r.Context(), user.ID); err != nil {
|
||||||
|
log.Error().Msgf("Failed to revoke refresh tokens: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handler for hard deleting the current user. Requires the user's password as JSON input as a precaution.
|
||||||
|
func (rs authResource) OwnerDelete(w http.ResponseWriter, r *http.Request) {
|
||||||
|
type request struct {
|
||||||
|
Password string `json:"password"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var req request
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
respondError(w, http.StatusBadRequest, "Invalid request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
user := rs.userFromCtxClaims(w, r)
|
||||||
|
|
||||||
|
// Verify the old password before allowing the deletion
|
||||||
|
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password)); err != nil {
|
||||||
|
respondError(w, http.StatusUnauthorized, "Invalid credentials")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := rs.Users.DeleteUser(r.Context(), user.ID)
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusInternalServerError, "Failed to delete user")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rs.Users.RevokeAllUserRefreshTokens(r.Context(), user.ID); err != nil {
|
||||||
|
log.Error().Msgf("Failed to revoke refresh tokens: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handler for listing all users stored in the database. Should only be allowed to be called by
|
||||||
|
// administrator level users.
|
||||||
|
func (rs authResource) List(w http.ResponseWriter, r *http.Request) {
|
||||||
|
users, err := rs.Users.ListUsers(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusInternalServerError, "Failed to retrieve users")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Output sanitization to DTO
|
||||||
|
var output []userResponse
|
||||||
|
for _, user := range users {
|
||||||
|
output = append(output, userResponse{
|
||||||
|
ID: user.ID,
|
||||||
|
Username: user.Username,
|
||||||
|
IsAdmin: user.IsAdmin,
|
||||||
|
CreatedAt: user.CreatedAt,
|
||||||
|
UpdatedAt: user.UpdatedAt,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
respondJSON(w, http.StatusOK, output)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handler for deleting another user account based on their ID. Will check the existence of the
|
||||||
|
// user based on the given ID and additionally revoke all the stored refresh tokens on successful
|
||||||
|
// deletion. Should only be allowed to be called by administrator level users.
|
||||||
|
func (rs authResource) AdminDelete(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
|
targetID, ok := ctx.Value(uuidCtxKey{Name: targetUserUUIDCtxParameter}).(uuid.UUID)
|
||||||
|
if !ok {
|
||||||
|
respondError(w, http.StatusBadRequest, "Resource ID missing")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rs.Users.DeleteUser(r.Context(), targetID); err != nil {
|
||||||
|
respondError(w, http.StatusInternalServerError, "Failed to delete user")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rs.Users.RevokeAllUserRefreshTokens(r.Context(), targetID); err != nil {
|
||||||
|
log.Error().Msgf("Failed to revoke refresh tokens: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a new pair of access and refresh tokens (JWTs) with the user's UUID as the identifying
|
||||||
|
// `Subject` claim and custom claims for the user's administrator status (boolean) and token type
|
||||||
|
// ("refresh"/"access"). Stores a SHA256 hash of the refresh token into the database for further
|
||||||
|
// token rotations.
|
||||||
|
func (rs authResource) GenerateTokenPair(ctx context.Context, userID uuid.UUID, isAdmin bool) (*tokenPair, error) {
|
||||||
|
tokenPair, err := generateTokenPair(userID.String(), isAdmin, rs.JWTSecret)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
hash := sha256.Sum256([]byte(tokenPair.RefreshToken))
|
||||||
|
tokenHash := hex.EncodeToString(hash[:])
|
||||||
|
|
||||||
|
// Store the SHA256 hash of the refresh token with (almost) identical expiration timestamp
|
||||||
|
expiresAt := time.Now().Add(refreshTokenDuration)
|
||||||
|
_, err = rs.Tokens.CreateRefreshToken(ctx, data.CreateRefreshTokenParams{
|
||||||
|
UserID: userID,
|
||||||
|
TokenHash: tokenHash,
|
||||||
|
ExpiresAt: expiresAt,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokenPair, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Revoke the given token from the database by calculating its SHA256 hash.
|
||||||
|
func (rs authResource) RevokeRefreshToken(ctx context.Context, token string) error {
|
||||||
|
hash := sha256.Sum256([]byte(token))
|
||||||
|
tokenHash := hex.EncodeToString(hash[:])
|
||||||
|
return rs.Tokens.RevokeRefreshToken(ctx, tokenHash)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate the given refresh token by performing a database lookup with its SHA256 hash. Returns
|
||||||
|
// the refresh token database object on successful lookup. Fails if the token has been revoked
|
||||||
|
// (soft database operation) or expired (i.e. the corresponding user has to log in again).
|
||||||
|
func (rs authResource) ValidateRefreshToken(ctx context.Context, token string) (*data.RefreshToken, error) {
|
||||||
|
hash := sha256.Sum256([]byte(token))
|
||||||
|
tokenHash := hex.EncodeToString(hash[:])
|
||||||
|
|
||||||
|
dbToken, err := rs.Tokens.GetRefreshTokenByHash(ctx, tokenHash)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for soft revocation and/or expiration
|
||||||
|
if dbToken.Revoked || time.Now().After(dbToken.ExpiresAt) {
|
||||||
|
return nil, ErrInvalidToken
|
||||||
|
}
|
||||||
|
|
||||||
|
return &dbToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handler for performing a token rotation, i.e. invalidating the given refresh token (each refresh
|
||||||
|
// token is a single use utility) and exchanging it for a new pair of refresh and access tokens.
|
||||||
|
func (rs authResource) RefreshAccessToken(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Get claims from context
|
||||||
|
claims, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
||||||
|
if !ok || claims.TokenType != "refresh" {
|
||||||
|
respondError(w, http.StatusUnauthorized, "Invalid token")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attempt to get the token from Authorization header (formatted as "Bearer <token>")
|
||||||
|
refreshToken, err := getTokenFromRequest(r)
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusUnauthorized, "Unauthorized")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate the refresh token in the database
|
||||||
|
if _, err := rs.ValidateRefreshToken(r.Context(), refreshToken); err != nil {
|
||||||
|
respondError(w, http.StatusUnauthorized, "Invalid refresh token")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Revoke the given (single use) refresh token
|
||||||
|
if err := rs.RevokeRefreshToken(r.Context(), refreshToken); err != nil {
|
||||||
|
respondError(w, http.StatusInternalServerError, "Failed to revoke token")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userID, err := uuid.Parse(claims.Subject)
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusBadRequest, "Invalid user ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a new pair (access & refresh tokens)
|
||||||
|
tokenPair, err := rs.GenerateTokenPair(r.Context(), userID, claims.Admin)
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusInternalServerError, "Failed to generate tokens")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set refresh token into a httpOnly cookie
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: "refresh_token",
|
||||||
|
Value: tokenPair.RefreshToken,
|
||||||
|
Path: "/",
|
||||||
|
MaxAge: int(refreshTokenDuration.Seconds()),
|
||||||
|
HttpOnly: true,
|
||||||
|
Secure: true,
|
||||||
|
SameSite: http.SameSiteStrictMode,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Return the access token in the response body (it should be stored in browser's memory client-side)
|
||||||
|
respondJSON(w, http.StatusOK, map[string]string{
|
||||||
|
"access_token": tokenPair.AccessToken,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handler for performing a logout process for the current user, i.e. replacing the current
|
||||||
|
// httpOnly `refresh_token` cookie with one that expires immediately. Theoretically the user
|
||||||
|
// will still be able to authenticate until the access token (stored client-side) expires,
|
||||||
|
// but that's up to the client to handle.
|
||||||
|
func (rs authResource) Logout(w http.ResponseWriter, r *http.Request) {
|
||||||
|
claims, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
||||||
|
if !ok {
|
||||||
|
respondError(w, http.StatusUnauthorized, "Not authenticated")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userID, err := uuid.Parse(claims.Subject)
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusBadRequest, "Invalid user ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear the refresh token cookie
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: "refresh_token",
|
||||||
|
Value: "",
|
||||||
|
Path: "/",
|
||||||
|
MaxAge: 0, // Expires immediately
|
||||||
|
HttpOnly: true,
|
||||||
|
Secure: true,
|
||||||
|
SameSite: http.SameSiteStrictMode,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := rs.Tokens.RevokeAllUserRefreshTokens(r.Context(), userID); err != nil {
|
||||||
|
respondError(w, http.StatusInternalServerError, "Failed to logout")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function for generating the initial administrator level account if one doesn't already
|
||||||
|
// exists in the database.
|
||||||
|
func CreateAdminIfNotExists(ctx context.Context, q *data.Queries, username, password string) error {
|
||||||
|
admins, err := q.ListAdmins(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if len(admins) > 0 {
|
||||||
|
log.Debug().Msg("Admin accounts already exist, skipping creation")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Username normalization (to lowercase)
|
||||||
|
normalizedUsername := normalizeUsername(username)
|
||||||
|
if err := validateUsername(normalizedUsername); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Password validation (length, HIBP API, and entropy)
|
||||||
|
if err := validatePassword(password); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = q.CreateAdmin(ctx, data.CreateAdminParams{
|
||||||
|
Username: normalizedUsername,
|
||||||
|
PasswordHash: string(hashedPassword),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info().Msgf("Initial admin user '%s' created successfully", username)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the JWT bearer token from the request's Authorization header.
|
||||||
|
func getTokenFromRequest(r *http.Request) (string, error) {
|
||||||
|
bearerToken := r.Header.Get("Authorization")
|
||||||
|
bearerFields := strings.Fields(bearerToken)
|
||||||
|
if len(bearerFields) == 2 && strings.ToLower(bearerFields[0]) == "bearer" {
|
||||||
|
return bearerFields[1], nil
|
||||||
|
}
|
||||||
|
return "", ErrAuthHeaderInvalid
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function for generating a new JWT token pair with the given specifications.
|
||||||
|
func generateTokenPair(userID string, isAdmin bool, jwtSecret string) (*tokenPair, error) {
|
||||||
|
atClaims := userClaims{
|
||||||
|
Admin: isAdmin,
|
||||||
|
TokenType: "access",
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
Subject: userID,
|
||||||
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(accessTokenDuration)),
|
||||||
|
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||||
|
NotBefore: jwt.NewNumericDate(time.Now()),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, atClaims)
|
||||||
|
|
||||||
|
t, err := accessToken.SignedString([]byte(jwtSecret))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
rtClaims := userClaims{
|
||||||
|
Admin: isAdmin,
|
||||||
|
TokenType: "refresh",
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
Subject: userID,
|
||||||
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(refreshTokenDuration)),
|
||||||
|
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||||
|
NotBefore: jwt.NewNumericDate(time.Now()),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
refreshToken := jwt.NewWithClaims(jwt.SigningMethodHS256, rtClaims)
|
||||||
|
|
||||||
|
rt, err := refreshToken.SignedString([]byte(jwtSecret))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &tokenPair{AccessToken: t, RefreshToken: rt}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse JWT claims (`userClaims`) from the request's context, and perform a database lookup based
|
||||||
|
// on `Subject` (after parsing it to `uuid.UUID`) to fetch the corresponding user's data.
|
||||||
|
func (rs authResource) userFromCtxClaims(w http.ResponseWriter, r *http.Request) *data.User {
|
||||||
|
claims, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
||||||
|
if !ok {
|
||||||
|
respondError(w, http.StatusUnauthorized, "Unauthorized")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
userID, err := uuid.Parse(claims.Subject)
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusBadRequest, "Invalid user ID")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := rs.Users.GetUserByID(r.Context(), userID)
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusNotFound, "User not found")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &user
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the given error is a PostgreSQL error for `unique_violation` (error code 23505), i.e.
|
||||||
|
// whether an entry with the given details already exists in the database table.
|
||||||
|
func isDuplicateEntry(err error) bool {
|
||||||
|
var pgErr *pgconn.PgError
|
||||||
|
if errors.As(err, &pgErr) {
|
||||||
|
return pgErr.Code == "23505"
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
@ -2,11 +2,10 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.umbrella.haus/ae/notatest/pkg/data"
|
"git.umbrella.haus/ae/notatest/internal/data"
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
"github.com/go-chi/chi/v5/middleware"
|
"github.com/go-chi/chi/v5/middleware"
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
@ -17,8 +16,17 @@ import (
|
|||||||
const (
|
const (
|
||||||
panicRecoveryMsg = "panic recovered"
|
panicRecoveryMsg = "panic recovered"
|
||||||
defaultLogMsg = "incoming request"
|
defaultLogMsg = "incoming request"
|
||||||
|
|
||||||
|
noteUUIDCtxParameter = "noteID"
|
||||||
|
versionUUIDCtxParameter = "versionID"
|
||||||
|
targetUserUUIDCtxParameter = "targetID"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// General resource ID (UUID) context key.
|
||||||
|
type uuidCtxKey struct {
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
// Get JWT bearer from request's authorization header, parse it with custom user claims, and
|
// Get JWT bearer from request's authorization header, parse it with custom user claims, and
|
||||||
// ensure its validity before attaching the claims to the request's context.
|
// ensure its validity before attaching the claims to the request's context.
|
||||||
func authMiddleware(jwtSecret string, expectedType string) func(http.Handler) http.Handler {
|
func authMiddleware(jwtSecret string, expectedType string) func(http.Handler) http.Handler {
|
||||||
@ -26,7 +34,8 @@ func authMiddleware(jwtSecret string, expectedType string) func(http.Handler) ht
|
|||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
tokenString, err := getTokenFromRequest(r)
|
tokenString, err := getTokenFromRequest(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
respondError(w, http.StatusUnauthorized, fmt.Sprintf("Unauthorized: %s", err))
|
respondError(w, http.StatusUnauthorized, "Unauthorized")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := jwt.ParseWithClaims(tokenString, &userClaims{}, func(token *jwt.Token) (any, error) {
|
token, err := jwt.ParseWithClaims(tokenString, &userClaims{}, func(token *jwt.Token) (any, error) {
|
||||||
@ -59,7 +68,8 @@ func requireRefreshToken(jwtSecret string) func(http.Handler) http.Handler {
|
|||||||
return authMiddleware(jwtSecret, "refresh")
|
return authMiddleware(jwtSecret, "refresh")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure the current user is an administrator.
|
// Ensure the current user is an administrator. Can be used to protect routes that can be utilized
|
||||||
|
// to view/modify/delete accounts that the current user isn't the owner of.
|
||||||
func adminOnlyMiddleware(next http.Handler) http.Handler {
|
func adminOnlyMiddleware(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
||||||
@ -71,57 +81,38 @@ func adminOnlyMiddleware(next http.Handler) http.Handler {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure the targeted resource is owned by the current user (i.e. current user's ID matches with
|
// Append UUID from the given URL parameter to the request's context (`uuidCtxKey` with the
|
||||||
// the one stored into the resource).
|
// parameter name as the "context identifier").
|
||||||
func ownerOnlyMiddleware(next http.Handler) http.Handler {
|
func uuidCtx(parameter string) func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
|
||||||
requestedID := chi.URLParam(r, "id")
|
|
||||||
if !ok || user.Subject != requestedID {
|
|
||||||
respondError(w, http.StatusForbidden, "Forbidden")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Append user data into request's context based on user ID as a URL parameter.
|
|
||||||
func userCtx(store UserStore) func(http.Handler) http.Handler {
|
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
userIDStr := chi.URLParam(r, "id")
|
uuidParam := chi.URLParam(r, parameter)
|
||||||
userID, err := uuid.Parse(userIDStr)
|
resourceID, err := uuid.Parse(uuidParam)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
respondError(w, http.StatusNotFound, "Invalid user ID")
|
respondError(w, http.StatusBadRequest, "Invalid resource ID")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := store.GetUserByID(r.Context(), userID)
|
ctx := context.WithValue(r.Context(), uuidCtxKey{Name: parameter}, resourceID)
|
||||||
if err != nil {
|
|
||||||
respondError(w, http.StatusNotFound, "User not found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := context.WithValue(r.Context(), userCtxKey{}, user)
|
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Append note data into request's context based on note ID as a URL parameter and user ID as
|
// Append full note data (metadata + active version) into request's context based on note ID as a
|
||||||
// context parameter.
|
// URL parameter and user ID as context parameter. Must be chained with `uuidCtx` to parse the
|
||||||
|
// resource ID into the request's context.
|
||||||
func noteCtx(store NoteStore) func(http.Handler) http.Handler {
|
func noteCtx(store NoteStore) func(http.Handler) http.Handler {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
noteIDStr := chi.URLParam(r, "id")
|
ctx := r.Context()
|
||||||
noteID, err := uuid.Parse(noteIDStr)
|
noteID, ok := ctx.Value(uuidCtxKey{Name: noteUUIDCtxParameter}).(uuid.UUID)
|
||||||
if err != nil {
|
if !ok {
|
||||||
respondError(w, http.StatusNotFound, "Invalid note ID")
|
respondError(w, http.StatusBadRequest, "Resource ID missing")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOTE: user must already be in the context (e.g. via JWT middleware)
|
user, ok := ctx.Value(userCtxKey{}).(*userClaims)
|
||||||
user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
|
||||||
if !ok {
|
if !ok {
|
||||||
respondError(w, http.StatusUnauthorized, "Unauthorized")
|
respondError(w, http.StatusUnauthorized, "Unauthorized")
|
||||||
return
|
return
|
||||||
@ -129,20 +120,58 @@ func noteCtx(store NoteStore) func(http.Handler) http.Handler {
|
|||||||
|
|
||||||
userID, err := uuid.Parse(user.Subject)
|
userID, err := uuid.Parse(user.Subject)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
respondError(w, http.StatusInternalServerError, "Invalid user ID")
|
respondError(w, http.StatusUnauthorized, "Invalid token")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
note, err := store.GetNote(r.Context(), data.GetNoteParams{
|
// Get the "full note" (metadata + active version) with a single query
|
||||||
ID: noteID,
|
fullNote, err := store.GetFullNote(r.Context(), noteID)
|
||||||
UserID: userID,
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
respondError(w, http.StatusNotFound, "Note not found")
|
respondError(w, http.StatusNotFound, "Note not found")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.WithValue(r.Context(), noteCtxKey{}, note)
|
// Validate note ownership
|
||||||
|
if userID != fullNote.OwnerID {
|
||||||
|
respondError(w, http.StatusForbidden, "Forbidden")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx = context.WithValue(r.Context(), noteCtxKey{}, &fullNote)
|
||||||
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append single version's data into request's context based on version ID as a URL parameter and
|
||||||
|
// note ID as context parameter. Must be chained with `noteCtx` and `uuidCtx` to parse the necessary
|
||||||
|
// resource IDs into request's context.
|
||||||
|
func versionCtx(store NoteStore) func(http.Handler) http.Handler {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
|
fullNote, ok := ctx.Value(noteCtxKey{}).(*data.GetFullNoteRow)
|
||||||
|
if !ok {
|
||||||
|
respondError(w, http.StatusNotFound, "Note not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
versionID, ok := ctx.Value(uuidCtxKey{Name: versionUUIDCtxParameter}).(uuid.UUID)
|
||||||
|
if !ok {
|
||||||
|
respondError(w, http.StatusBadRequest, "Resource ID missing")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
version, err := store.GetVersion(r.Context(), data.GetVersionParams{
|
||||||
|
NoteID: fullNote.NoteID,
|
||||||
|
ID: versionID,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusNotFound, "Version not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx = context.WithValue(r.Context(), versionCtxKey{}, &version)
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
})
|
})
|
||||||
}
|
}
|
519
server/internal/service/middleware_test.go
Normal file
519
server/internal/service/middleware_test.go
Normal file
@ -0,0 +1,519 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.umbrella.haus/ae/notatest/internal/data"
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockNoteStore struct {
|
||||||
|
CreateNoteFunc func(context.Context, uuid.UUID) (data.Note, error)
|
||||||
|
DeleteNoteFunc func(context.Context, data.DeleteNoteParams) error
|
||||||
|
GetFullNoteFunc func(context.Context, uuid.UUID) (data.GetFullNoteRow, error)
|
||||||
|
ListNotesFunc func(context.Context, data.ListNotesParams) ([]data.ListNotesRow, error)
|
||||||
|
CreateNoteVersionFunc func(context.Context, data.CreateNoteVersionParams) error
|
||||||
|
GetVersionFunc func(context.Context, data.GetVersionParams) (data.GetVersionRow, error)
|
||||||
|
GetVersionHistoryFunc func(context.Context, data.GetVersionHistoryParams) ([]data.GetVersionHistoryRow, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockNoteStore) CreateNote(ctx context.Context, id uuid.UUID) (data.Note, error) {
|
||||||
|
return m.CreateNoteFunc(ctx, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockNoteStore) DeleteNote(ctx context.Context, arg data.DeleteNoteParams) error {
|
||||||
|
return m.DeleteNoteFunc(ctx, arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockNoteStore) GetFullNote(ctx context.Context, id uuid.UUID) (data.GetFullNoteRow, error) {
|
||||||
|
return m.GetFullNoteFunc(ctx, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockNoteStore) ListNotes(ctx context.Context, arg data.ListNotesParams) ([]data.ListNotesRow, error) {
|
||||||
|
return m.ListNotesFunc(ctx, arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockNoteStore) CreateNoteVersion(ctx context.Context, arg data.CreateNoteVersionParams) error {
|
||||||
|
return m.CreateNoteVersionFunc(ctx, arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockNoteStore) GetVersion(ctx context.Context, arg data.GetVersionParams) (data.GetVersionRow, error) {
|
||||||
|
return m.GetVersionFunc(ctx, arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockNoteStore) GetVersionHistory(ctx context.Context, arg data.GetVersionHistoryParams) ([]data.GetVersionHistoryRow, error) {
|
||||||
|
return m.GetVersionHistoryFunc(ctx, arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthMiddleware(t *testing.T) {
|
||||||
|
secret := "test-jwt-secret"
|
||||||
|
testUserID := uuid.New().String()
|
||||||
|
|
||||||
|
validRT := generateTestToken(t, secret, "refresh", testUserID, true)
|
||||||
|
validAT := generateTestToken(t, secret, "access", testUserID, true)
|
||||||
|
expiredAT := generateTestToken(t, secret, "access", testUserID, true, func(claims *userClaims) {
|
||||||
|
claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(-1 * time.Hour))
|
||||||
|
})
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
token string
|
||||||
|
expectedErr string
|
||||||
|
statusCode int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"no token",
|
||||||
|
"",
|
||||||
|
"Unauthorized",
|
||||||
|
http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"invalid token",
|
||||||
|
"invalid",
|
||||||
|
"Invalid token",
|
||||||
|
http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"expired token",
|
||||||
|
expiredAT,
|
||||||
|
"Invalid token",
|
||||||
|
http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"wrong token type",
|
||||||
|
validRT,
|
||||||
|
"Invalid token type",
|
||||||
|
http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"valid token",
|
||||||
|
validAT,
|
||||||
|
"",
|
||||||
|
http.StatusOK,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
atAuthMiddleware := requireAccessToken(secret)
|
||||||
|
|
||||||
|
// Mock request
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
if tc.token != "" {
|
||||||
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", tc.token))
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
called := false
|
||||||
|
|
||||||
|
// Mock endpoint that the middleware protects
|
||||||
|
handler := atAuthMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
called = true
|
||||||
|
_, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
||||||
|
assert.True(t, ok)
|
||||||
|
}))
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.statusCode, w.Code)
|
||||||
|
assert.Equal(t, tc.statusCode == http.StatusOK, called)
|
||||||
|
if tc.expectedErr != "" {
|
||||||
|
assert.Contains(t, w.Body.String(), tc.expectedErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminOnlyMiddleware(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
user *userClaims
|
||||||
|
statusCode int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"no user",
|
||||||
|
nil,
|
||||||
|
http.StatusForbidden,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"non admin user",
|
||||||
|
&userClaims{
|
||||||
|
Admin: false,
|
||||||
|
},
|
||||||
|
http.StatusForbidden,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"admin user",
|
||||||
|
&userClaims{
|
||||||
|
Admin: true,
|
||||||
|
},
|
||||||
|
http.StatusOK,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
// Mock request
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
if tc.user != nil {
|
||||||
|
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, tc.user))
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
called := false
|
||||||
|
|
||||||
|
// Mock endpoint that the middleware protects
|
||||||
|
handler := adminOnlyMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
called = true
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.statusCode, w.Code)
|
||||||
|
assert.Equal(t, tc.statusCode == http.StatusOK, called)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUUIDCtxMiddleware(t *testing.T) {
|
||||||
|
testKeyName := "testKey"
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
parameter string
|
||||||
|
expectedErr string
|
||||||
|
statusCode int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"missing uuid",
|
||||||
|
"",
|
||||||
|
"Invalid resource ID",
|
||||||
|
http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"invalid uuid",
|
||||||
|
"invalid",
|
||||||
|
"Invalid resource ID",
|
||||||
|
http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"valid uuid",
|
||||||
|
uuid.New().String(),
|
||||||
|
"",
|
||||||
|
http.StatusOK,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
uuidCtxMiddleware := uuidCtx(testKeyName)
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
|
||||||
|
// We need to mock the URL parameter as we don't setup an actual router in this test env.
|
||||||
|
rctx := chi.NewRouteContext()
|
||||||
|
rctx.URLParams.Add(testKeyName, tc.parameter)
|
||||||
|
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
called := false
|
||||||
|
|
||||||
|
// Mock endpoint that the middleware protects
|
||||||
|
handler := uuidCtxMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
called = true
|
||||||
|
_, ok := r.Context().Value(uuidCtxKey{Name: testKeyName}).(uuid.UUID)
|
||||||
|
assert.True(t, ok)
|
||||||
|
}))
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.statusCode, w.Code)
|
||||||
|
assert.Equal(t, tc.statusCode == http.StatusOK, called)
|
||||||
|
if tc.expectedErr != "" {
|
||||||
|
assert.Contains(t, w.Body.String(), tc.expectedErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNoteCtxMiddleware(t *testing.T) {
|
||||||
|
testTitle := "Test title"
|
||||||
|
tesTContent := "## Test content\nData 123"
|
||||||
|
testVersion := int32(3)
|
||||||
|
noteID := uuid.New()
|
||||||
|
|
||||||
|
ownerUserID := uuid.New()
|
||||||
|
testOwnerClaims := userClaims{RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
Subject: ownerUserID.String(),
|
||||||
|
}}
|
||||||
|
|
||||||
|
otherUserID := uuid.New()
|
||||||
|
testOtherClaims := userClaims{RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
Subject: otherUserID.String(),
|
||||||
|
}}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
resourceID *uuid.UUID
|
||||||
|
user *userClaims
|
||||||
|
mock func(*mockNoteStore)
|
||||||
|
statusCode int
|
||||||
|
expectedErr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"no resource id",
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
func(m *mockNoteStore) {},
|
||||||
|
http.StatusBadRequest,
|
||||||
|
"Resource ID missing",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"unauthorized",
|
||||||
|
¬eID,
|
||||||
|
nil,
|
||||||
|
func(m *mockNoteStore) {},
|
||||||
|
http.StatusUnauthorized,
|
||||||
|
"Unauthorized",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"note not found",
|
||||||
|
¬eID,
|
||||||
|
&testOwnerClaims,
|
||||||
|
func(m *mockNoteStore) {
|
||||||
|
m.GetFullNoteFunc = func(ctx context.Context, id uuid.UUID) (data.GetFullNoteRow, error) {
|
||||||
|
return data.GetFullNoteRow{}, errors.New("not found")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
http.StatusNotFound,
|
||||||
|
"Note not found",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"not owner",
|
||||||
|
¬eID,
|
||||||
|
&testOtherClaims,
|
||||||
|
func(m *mockNoteStore) {
|
||||||
|
m.GetFullNoteFunc = func(ctx context.Context, id uuid.UUID) (data.GetFullNoteRow, error) {
|
||||||
|
assert.Equal(t, noteID, id)
|
||||||
|
testTs := time.Now()
|
||||||
|
return data.GetFullNoteRow{
|
||||||
|
NoteID: id,
|
||||||
|
OwnerID: ownerUserID,
|
||||||
|
Title: testTitle,
|
||||||
|
Content: tesTContent,
|
||||||
|
VersionNumber: testVersion,
|
||||||
|
VersionCreatedAt: &testTs,
|
||||||
|
NoteCreatedAt: &testTs,
|
||||||
|
NoteUpdatedAt: &testTs,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
},
|
||||||
|
http.StatusForbidden,
|
||||||
|
"Forbidden",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"success",
|
||||||
|
¬eID,
|
||||||
|
&testOwnerClaims,
|
||||||
|
func(m *mockNoteStore) {
|
||||||
|
m.GetFullNoteFunc = func(ctx context.Context, id uuid.UUID) (data.GetFullNoteRow, error) {
|
||||||
|
assert.Equal(t, noteID, id)
|
||||||
|
testTs := time.Now()
|
||||||
|
return data.GetFullNoteRow{
|
||||||
|
NoteID: id,
|
||||||
|
OwnerID: ownerUserID,
|
||||||
|
Title: testTitle,
|
||||||
|
Content: tesTContent,
|
||||||
|
VersionNumber: testVersion,
|
||||||
|
VersionCreatedAt: &testTs,
|
||||||
|
NoteCreatedAt: &testTs,
|
||||||
|
NoteUpdatedAt: &testTs,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
},
|
||||||
|
http.StatusOK,
|
||||||
|
"",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
// UUID ctx. (mock) -> note ctx. (tested here)
|
||||||
|
mockStore := &mockNoteStore{}
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
tc.mock(mockStore)
|
||||||
|
|
||||||
|
// Mock endpoint that the middleware protects (where the attached note data is actually utilized)
|
||||||
|
handler := noteCtx(mockStore)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
fullNote, ok := r.Context().Value(noteCtxKey{}).(*data.GetFullNoteRow)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, noteID, fullNote.NoteID)
|
||||||
|
assert.Equal(t, testTitle, fullNote.Title)
|
||||||
|
assert.Equal(t, tesTContent, fullNote.Content)
|
||||||
|
assert.Equal(t, testVersion, fullNote.VersionNumber)
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Request parameters don't need to be mocked, as parsing of them isn't handled
|
||||||
|
// by this middleware, and thus that portion shouldn't be tested here.
|
||||||
|
|
||||||
|
if tc.resourceID != nil {
|
||||||
|
req = req.WithContext(context.WithValue(req.Context(), uuidCtxKey{Name: noteUUIDCtxParameter}, *tc.resourceID))
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc.user != nil {
|
||||||
|
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, tc.user))
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.statusCode, w.Code)
|
||||||
|
if tc.expectedErr != "" {
|
||||||
|
assert.Contains(t, w.Body.String(), tc.expectedErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVersionCtxMiddleware(t *testing.T) {
|
||||||
|
testTitle := "Test title"
|
||||||
|
tesTContent := "## Test content\nData 123"
|
||||||
|
testVersion := int32(3)
|
||||||
|
versionID := uuid.New()
|
||||||
|
|
||||||
|
noteID := uuid.New()
|
||||||
|
testNote := data.GetFullNoteRow{
|
||||||
|
NoteID: noteID,
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
resourceID *uuid.UUID
|
||||||
|
note *data.GetFullNoteRow
|
||||||
|
mock func(*mockNoteStore)
|
||||||
|
statusCode int
|
||||||
|
expectedErr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"no note",
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
func(m *mockNoteStore) {},
|
||||||
|
http.StatusNotFound,
|
||||||
|
"Note not found",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"no resource id",
|
||||||
|
nil,
|
||||||
|
&testNote,
|
||||||
|
func(m *mockNoteStore) {},
|
||||||
|
http.StatusBadRequest,
|
||||||
|
"Resource ID missing",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"version not found",
|
||||||
|
&versionID,
|
||||||
|
&testNote,
|
||||||
|
func(m *mockNoteStore) {
|
||||||
|
m.GetVersionFunc = func(ctx context.Context, gvp data.GetVersionParams) (data.GetVersionRow, error) {
|
||||||
|
return data.GetVersionRow{}, errors.New("not found")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
http.StatusNotFound,
|
||||||
|
"Version not found",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"success",
|
||||||
|
&versionID,
|
||||||
|
&testNote,
|
||||||
|
func(m *mockNoteStore) {
|
||||||
|
m.GetVersionFunc = func(ctx context.Context, gvp data.GetVersionParams) (data.GetVersionRow, error) {
|
||||||
|
assert.Equal(t, versionID, gvp.ID)
|
||||||
|
assert.Equal(t, noteID, gvp.NoteID)
|
||||||
|
testTs := time.Now()
|
||||||
|
return data.GetVersionRow{
|
||||||
|
VersionID: gvp.ID,
|
||||||
|
Title: testTitle,
|
||||||
|
Content: tesTContent,
|
||||||
|
VersionNumber: testVersion,
|
||||||
|
CreatedAt: &testTs,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
},
|
||||||
|
http.StatusOK,
|
||||||
|
"",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
// note ctx. (mock) -> UUID ctx. (mock) -> version ctx. (tested here)
|
||||||
|
mockStore := &mockNoteStore{}
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
tc.mock(mockStore)
|
||||||
|
|
||||||
|
// Mock endpoint that the middleware protects (where the attached note data is actually utilized)
|
||||||
|
handler := versionCtx(mockStore)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
fullVersion, ok := r.Context().Value(versionCtxKey{}).(*data.GetVersionRow)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, versionID, fullVersion.VersionID)
|
||||||
|
assert.Equal(t, testTitle, fullVersion.Title)
|
||||||
|
assert.Equal(t, tesTContent, fullVersion.Content)
|
||||||
|
assert.Equal(t, testVersion, fullVersion.VersionNumber)
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Request parameters don't need to be mocked, as parsing of them isn't handled
|
||||||
|
// by this middleware, and thus that portion shouldn't be tested here.
|
||||||
|
|
||||||
|
if tc.note != nil {
|
||||||
|
req = req.WithContext(context.WithValue(req.Context(), noteCtxKey{}, tc.note))
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc.resourceID != nil {
|
||||||
|
req = req.WithContext(context.WithValue(req.Context(), uuidCtxKey{Name: versionUUIDCtxParameter}, *tc.resourceID))
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.statusCode, w.Code)
|
||||||
|
if tc.expectedErr != "" {
|
||||||
|
assert.Contains(t, w.Body.String(), tc.expectedErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateTestToken(t *testing.T, secret, tokenType, userID string, isAdmin bool, opts ...func(*userClaims)) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
claims := &userClaims{
|
||||||
|
Admin: isAdmin,
|
||||||
|
TokenType: tokenType,
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
Subject: userID,
|
||||||
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(claims)
|
||||||
|
}
|
||||||
|
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||||
|
signedToken, err := token.SignedString([]byte(secret))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to generate test token: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return signedToken
|
||||||
|
}
|
331
server/internal/service/notes.go
Normal file
331
server/internal/service/notes.go
Normal file
@ -0,0 +1,331 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/sha1"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"git.umbrella.haus/ae/notatest/internal/data"
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
titleMaxLength = 150
|
||||||
|
initVersionTitle = "Untitled"
|
||||||
|
initVersionContent = ""
|
||||||
|
)
|
||||||
|
|
||||||
|
// Note object context key for incoming requests (handled my middlewares). Only `*data.GetFullNoteRow`
|
||||||
|
// type objects should be stored behind this key for consistency.
|
||||||
|
type noteCtxKey struct{}
|
||||||
|
|
||||||
|
// Note version object context key for incoming requests (handled by middlewares). Only
|
||||||
|
// `*data.GetVersionRow` type objects should be stored behind this key for consistency.
|
||||||
|
type versionCtxKey struct{}
|
||||||
|
|
||||||
|
// Mockable database operations interface
|
||||||
|
type NoteStore interface {
|
||||||
|
CreateNote(ctx context.Context, userID uuid.UUID) (data.Note, error)
|
||||||
|
DeleteNote(ctx context.Context, arg data.DeleteNoteParams) error
|
||||||
|
GetFullNote(ctx context.Context, noteID uuid.UUID) (data.GetFullNoteRow, error)
|
||||||
|
ListNotes(ctx context.Context, arg data.ListNotesParams) ([]data.ListNotesRow, error)
|
||||||
|
CreateNoteVersion(ctx context.Context, arg data.CreateNoteVersionParams) error
|
||||||
|
GetVersion(ctx context.Context, arg data.GetVersionParams) (data.GetVersionRow, error)
|
||||||
|
GetVersionHistory(ctx context.Context, arg data.GetVersionHistoryParams) ([]data.GetVersionHistoryRow, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Chi HTTP router for notes related CRUD actions.
|
||||||
|
type notesResource struct {
|
||||||
|
JWTSecret string
|
||||||
|
Notes NoteStore
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rs notesResource) Routes() chi.Router {
|
||||||
|
r := chi.NewRouter()
|
||||||
|
|
||||||
|
r.Group(func(r chi.Router) {
|
||||||
|
r.Use(requireAccessToken(rs.JWTSecret)) // JWT claims -> ctx.
|
||||||
|
|
||||||
|
r.Post("/", rs.Create) // POST /notes - create new note
|
||||||
|
r.Get("/", rs.ListMetadata) // GET /notes - get all notes (metadata + titles)
|
||||||
|
|
||||||
|
/*
|
||||||
|
Clients should utilize `rs.ListMetadata` to load index of user's available notes (e.g.
|
||||||
|
sidebar view), use `rs.GetFullNote` to get full notes individually, and if request the
|
||||||
|
versioning history with `rs.GetVersionHistory` if necessary (and similarly fetch each
|
||||||
|
version individually if the client wants to view them).
|
||||||
|
*/
|
||||||
|
|
||||||
|
r.Route(fmt.Sprintf("/{%s}", noteUUIDCtxParameter), func(r chi.Router) {
|
||||||
|
r.Use(uuidCtx(noteUUIDCtxParameter))
|
||||||
|
r.Use(noteCtx(rs.Notes)) // DB -> req. context (metadata + active version)
|
||||||
|
r.Get("/", rs.GetFullNote) // GET /notes/{id} - get note from context
|
||||||
|
r.Delete("/", rs.Delete) // DELETE /notes/{id} - delete note
|
||||||
|
r.Get("/versions", rs.GetVersionHistory) // GET /notes/{id}/versions - get full versioning history
|
||||||
|
r.Post("/versions", rs.CreateVersion) // POST /notes/{id}/versions - create new version
|
||||||
|
|
||||||
|
r.Route(fmt.Sprintf("/{%s}", versionUUIDCtxParameter), func(r chi.Router) {
|
||||||
|
r.Use(uuidCtx(versionUUIDCtxParameter))
|
||||||
|
r.Use(versionCtx(rs.Notes)) // DB -> req. context (scoped version)
|
||||||
|
r.Get("/", rs.GetFullVersion) // GET /notes/{id}/{id} - get
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handler for new note creation. Creates the parent metadata object (`notes` table) and an initial
|
||||||
|
// placeholder content version (`note_versions` table), and returns the placeholder contents to the
|
||||||
|
// caller in the HTTP response.
|
||||||
|
func (rs *notesResource) Create(w http.ResponseWriter, r *http.Request) {
|
||||||
|
user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
||||||
|
if !ok {
|
||||||
|
respondError(w, http.StatusUnauthorized, "Unauthorized")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userID, err := uuid.Parse(user.Subject)
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusUnauthorized, "Invalid user ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Metadata object (parent)
|
||||||
|
note, err := rs.Notes.CreateNote(r.Context(), userID)
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusInternalServerError, "Failed to create note")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initial (empty) placeholder version of the contents
|
||||||
|
err = rs.Notes.CreateNoteVersion(r.Context(), data.CreateNoteVersionParams{
|
||||||
|
NoteID: note.ID,
|
||||||
|
Title: initVersionTitle,
|
||||||
|
Content: initVersionContent,
|
||||||
|
ContentHash: sha1ContentHash(initVersionTitle, initVersionContent),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusInternalServerError, "Failed to create initial version")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Placeholder contents are decided server-side, so we need to inform the client of them via a
|
||||||
|
// one-time-use DTO
|
||||||
|
type response struct {
|
||||||
|
Title string `json:"title"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
res := response{
|
||||||
|
Title: initVersionTitle,
|
||||||
|
Content: initVersionContent,
|
||||||
|
}
|
||||||
|
|
||||||
|
respondJSON(w, http.StatusCreated, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rs *notesResource) ListMetadata(w http.ResponseWriter, r *http.Request) {
|
||||||
|
user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
||||||
|
if !ok {
|
||||||
|
respondError(w, http.StatusUnauthorized, "Unauthorized")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userID, err := uuid.Parse(user.Subject)
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusUnauthorized, "Invalid user ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
limit, offset := getPaginationParams(r)
|
||||||
|
notes, err := rs.Notes.ListNotes(r.Context(), data.ListNotesParams{
|
||||||
|
UserID: userID,
|
||||||
|
Limit: limit,
|
||||||
|
Offset: offset,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusInternalServerError, "Failed to retrieve notes")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, note := range notes {
|
||||||
|
if userID != note.OwnerID {
|
||||||
|
respondError(w, http.StatusForbidden, "Forbidden")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
respondJSON(w, http.StatusOK, notes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handler for returning the currently scoped (included to the request's context by a middleware)
|
||||||
|
// full note object.
|
||||||
|
func (rs *notesResource) GetFullNote(w http.ResponseWriter, r *http.Request) {
|
||||||
|
fullNote, ok := r.Context().Value(noteCtxKey{}).(*data.GetFullNoteRow)
|
||||||
|
if !ok {
|
||||||
|
respondError(w, http.StatusNotFound, "Note not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
respondJSON(w, http.StatusOK, fullNote)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handler for hard deelting the currently scoped note (including its versions via database cascade).
|
||||||
|
func (rs *notesResource) Delete(w http.ResponseWriter, r *http.Request) {
|
||||||
|
fullNote, ok := r.Context().Value(noteCtxKey{}).(*data.GetFullNoteRow)
|
||||||
|
if !ok {
|
||||||
|
respondError(w, http.StatusNotFound, "Note not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
||||||
|
if !ok {
|
||||||
|
respondError(w, http.StatusUnauthorized, "Unauthorized")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userID, err := uuid.Parse(user.Subject)
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusUnauthorized, "Invalid user ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = rs.Notes.DeleteNote(r.Context(), data.DeleteNoteParams{
|
||||||
|
ID: fullNote.NoteID,
|
||||||
|
UserID: userID, // NOTE: using `fullNote.userID` here'd be insecure
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusInternalServerError, "Failed to delete note")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handler for listing the currently scoped note's version history. If pagination parameters
|
||||||
|
// (`limit` and `offset`) aren't defined, limit of 50 versions (with offset 0) will be returned.
|
||||||
|
func (rs *notesResource) GetVersionHistory(w http.ResponseWriter, r *http.Request) {
|
||||||
|
fullNote, ok := r.Context().Value(noteCtxKey{}).(*data.GetFullNoteRow)
|
||||||
|
if !ok {
|
||||||
|
respondError(w, http.StatusNotFound, "Note not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
limit, offset := getPaginationParams(r)
|
||||||
|
versions, err := rs.Notes.GetVersionHistory(r.Context(), data.GetVersionHistoryParams{
|
||||||
|
NoteID: fullNote.NoteID,
|
||||||
|
Limit: limit,
|
||||||
|
Offset: offset,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusInternalServerError, "Failed to get version history")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
respondJSON(w, http.StatusOK, versions)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handler for creating a new content version for the currently scoped note. Will check the incoming
|
||||||
|
// JSON object's integrity and perform a de-duplication check for identical versions stored in the
|
||||||
|
// database (SHA-1 hash of version contents). If a duplicate version is found, it'll be placed as the
|
||||||
|
// active version by swapping its version number to HEAD+1.
|
||||||
|
func (rs *notesResource) CreateVersion(w http.ResponseWriter, r *http.Request) {
|
||||||
|
fullNote, ok := r.Context().Value(noteCtxKey{}).(*data.GetFullNoteRow)
|
||||||
|
if !ok {
|
||||||
|
respondError(w, http.StatusNotFound, "Note not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req struct {
|
||||||
|
Title *string `json:"title"`
|
||||||
|
Content *string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.Title == nil || req.Content == nil {
|
||||||
|
respondError(w, http.StatusBadRequest, "Invalid request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extra check for frontend readability reasons (max. length isn't specifically limited in the database)
|
||||||
|
if len(*req.Title) > titleMaxLength {
|
||||||
|
respondError(w, http.StatusBadRequest, fmt.Sprintf("Title must be shorter than %d characters", titleMaxLength))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
The SQL query handles de-duplication checks and "intelligent" versioning increments, so we
|
||||||
|
don't have to worry about them here (`latest_version` = highest version number that exists
|
||||||
|
in this note's context; `current_version` = note's active content version):
|
||||||
|
|
||||||
|
- New version's contents are a duplicate of a historical version:
|
||||||
|
- Don't increment `latest_version`
|
||||||
|
- Sync `current_version` with the `version_number` of the duplicate version
|
||||||
|
- New version's contents are unique:
|
||||||
|
- Increment `latest_version`
|
||||||
|
- Sync `current_version` with `latest_version`
|
||||||
|
*/
|
||||||
|
err := rs.Notes.CreateNoteVersion(r.Context(), data.CreateNoteVersionParams{
|
||||||
|
NoteID: fullNote.NoteID,
|
||||||
|
Title: *req.Title,
|
||||||
|
Content: *req.Content,
|
||||||
|
ContentHash: sha1ContentHash(*req.Title, *req.Content),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusInternalServerError, "Failed to create note version")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handler for returning full data of the currently scoped note version. Identical to the beginning
|
||||||
|
// of the `RollbackNoteVersion` handler.
|
||||||
|
func (rs *notesResource) GetFullVersion(w http.ResponseWriter, r *http.Request) {
|
||||||
|
fullVersion, ok := r.Context().Value(noteCtxKey{}).(*data.GetVersionRow)
|
||||||
|
if !ok {
|
||||||
|
respondError(w, http.StatusNotFound, "Note not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
respondJSON(w, http.StatusOK, fullVersion)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse `limit` and `offset` 32-bit integer URL parameters from the given request. Defaults to
|
||||||
|
// limit of 50 and offset 0 if parameters are missing/invalid.
|
||||||
|
func getPaginationParams(r *http.Request) (limit int32, offset int32) {
|
||||||
|
defaultLimit := 50
|
||||||
|
defaultOffset := 0
|
||||||
|
|
||||||
|
limitStr := r.URL.Query().Get("limit")
|
||||||
|
if limitStr != "" {
|
||||||
|
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 {
|
||||||
|
defaultLimit = l
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
offsetStr := r.URL.Query().Get("offset")
|
||||||
|
if offsetStr != "" {
|
||||||
|
if o, err := strconv.Atoi(offsetStr); err == nil && o >= 0 {
|
||||||
|
defaultOffset = o
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return int32(defaultLimit), int32(defaultOffset)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Concatenate the title and content strings, calculate a SHA-1 hash of the resulting string, and
|
||||||
|
// return the resulting hash as a string.
|
||||||
|
func sha1ContentHash(title, content string) string {
|
||||||
|
hashContent := title + content
|
||||||
|
hash := sha1.Sum([]byte(hashContent))
|
||||||
|
hashStr := strings.ToUpper(hex.EncodeToString(hash[:]))
|
||||||
|
|
||||||
|
return hashStr
|
||||||
|
}
|
@ -3,26 +3,25 @@ package service
|
|||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"git.umbrella.haus/ae/notatest/pkg/data"
|
"git.umbrella.haus/ae/notatest/internal/data"
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
"github.com/go-chi/chi/v5/middleware"
|
"github.com/go-chi/chi/v5/middleware"
|
||||||
"github.com/jackc/pgx/v5"
|
"github.com/jackc/pgx/v5"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Run(conn *pgx.Conn, jwtSecret string) error {
|
func Run(conn *pgx.Conn, q *data.Queries, jwtSecret string) error {
|
||||||
q := data.New(conn)
|
|
||||||
r := chi.NewRouter()
|
r := chi.NewRouter()
|
||||||
|
|
||||||
tokensRouter := tokensResource{
|
authRouter := authResource{
|
||||||
JWTSecret: jwtSecret,
|
|
||||||
Tokens: q,
|
|
||||||
}
|
|
||||||
usersRouter := usersResource{
|
|
||||||
JWTSecret: jwtSecret,
|
JWTSecret: jwtSecret,
|
||||||
Users: q,
|
Users: q,
|
||||||
|
Tokens: q,
|
||||||
|
}
|
||||||
|
notesRouter := notesResource{
|
||||||
|
JWTSecret: jwtSecret,
|
||||||
|
Notes: q,
|
||||||
}
|
}
|
||||||
notesRouter := notesResource{}
|
|
||||||
|
|
||||||
// Global middlewares
|
// Global middlewares
|
||||||
r.Use(middleware.RequestID)
|
r.Use(middleware.RequestID)
|
||||||
@ -31,10 +30,12 @@ func Run(conn *pgx.Conn, jwtSecret string) error {
|
|||||||
r.Use(middleware.Recoverer)
|
r.Use(middleware.Recoverer)
|
||||||
r.Use(middleware.AllowContentType("application/json"))
|
r.Use(middleware.AllowContentType("application/json"))
|
||||||
|
|
||||||
// Routes grouped by functionality
|
// Routes grouped by functionality (we must prefix the API routes with `/api`
|
||||||
r.Mount("/auth", tokensRouter.Routes())
|
// as the domain will be the same for the front and back ends)
|
||||||
r.Mount("/users", usersRouter.Routes())
|
r.Route("/api", func(r chi.Router) {
|
||||||
r.Mount("/notes", notesRouter.Routes())
|
r.Mount("/auth", authRouter.Routes())
|
||||||
|
r.Mount("/notes", notesRouter.Routes())
|
||||||
|
})
|
||||||
|
|
||||||
log.Info().Msg("Starting server on :8080")
|
log.Info().Msg("Starting server on :8080")
|
||||||
return http.ListenAndServe(":8080", r)
|
return http.ListenAndServe(":8080", r)
|
@ -41,7 +41,7 @@ func respondError(w http.ResponseWriter, status int, message string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
Client-side check:
|
Example client-side check:
|
||||||
|
|
||||||
```
|
```
|
||||||
function estimateEntropy(password: string): number {
|
function estimateEntropy(password: string): number {
|
||||||
@ -102,7 +102,7 @@ func normalizeUsername(username string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
Client-side check (additionally input should automatically perform the normalization steps):
|
Example client-side check (without input normalization):
|
||||||
|
|
||||||
```
|
```
|
||||||
function validateUsername(username: string): string {
|
function validateUsername(username: string): string {
|
@ -3,11 +3,15 @@ package main
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"embed"
|
"embed"
|
||||||
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"git.umbrella.haus/ae/notatest/pkg/migrate"
|
"git.umbrella.haus/ae/notatest/internal/data"
|
||||||
"git.umbrella.haus/ae/notatest/pkg/service"
|
"git.umbrella.haus/ae/notatest/internal/service"
|
||||||
"github.com/caarlos0/env"
|
"github.com/caarlos0/env/v10"
|
||||||
|
"github.com/golang-migrate/migrate/v4"
|
||||||
|
_ "github.com/golang-migrate/migrate/v4/database/postgres"
|
||||||
|
"github.com/golang-migrate/migrate/v4/source/iofs"
|
||||||
"github.com/jackc/pgx/v5"
|
"github.com/jackc/pgx/v5"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
@ -17,60 +21,72 @@ import (
|
|||||||
var migrationsFS embed.FS
|
var migrationsFS embed.FS
|
||||||
|
|
||||||
var (
|
var (
|
||||||
isDevelopment = false
|
config Config
|
||||||
config Config
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
JWTSecret string `env:"JWT_SECRET,notEmpty"`
|
JWTSecret string `env:"JWT_SECRET,notEmpty"`
|
||||||
DBURL string `env:"PG_URL,notEmpty"`
|
DatabaseURL string `env:"DB_URL,notEmpty"`
|
||||||
RunMode string `env:"GO_ENV" envDefault:"production"`
|
LogLevel string `env:"LOG_LEVEL" envDefault:"info"`
|
||||||
|
AdminUsername string `env:"ADMIN_USERNAME,notEmpty,unset"`
|
||||||
|
AdminPassword string `env:"ADMIN_PASSWORD,notEmpty,unset"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
initLogger()
|
|
||||||
config = Config{}
|
config = Config{}
|
||||||
env.Parse(&config)
|
if err := env.Parse(&config); err != nil {
|
||||||
|
log.Fatal().Err(err).Msg("Failed to parse environment variables")
|
||||||
if config.RunMode == "development" {
|
|
||||||
log.Info().Msg("Development mode enabled")
|
|
||||||
isDevelopment = true
|
|
||||||
}
|
}
|
||||||
|
initLogger()
|
||||||
log.Debug().Msg("Initialization completed")
|
log.Debug().Msg("Initialization completed")
|
||||||
}
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
conn, err := pgx.Connect(context.Background(), config.DBURL)
|
log.Debug().Msgf("Database URL: %s", config.DatabaseURL)
|
||||||
|
conn, err := pgx.Connect(context.Background(), config.DatabaseURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal().Msgf("Failed connecting to database: %s", err)
|
log.Fatal().Err(err).Msg("Failed to connect to database")
|
||||||
}
|
}
|
||||||
log.Info().Msg("Successfully connected to the database")
|
log.Info().Msg("Successfully connected to the database")
|
||||||
log.Debug().Msg(config.DBURL)
|
log.Info().Msg("Applying migrations...")
|
||||||
|
|
||||||
if isDevelopment {
|
d, err := iofs.New(migrationsFS, "sql/migrations")
|
||||||
if err := migrate.Run(context.Background(), conn, migrationsFS); err != nil {
|
if err != nil {
|
||||||
log.Fatal().Msgf("Failed running migrations: %s", err)
|
log.Fatal().Err(err).Msg("Failed constructing io/fs driver")
|
||||||
}
|
}
|
||||||
|
migrator, err := migrate.NewWithSourceInstance("iofs", d, config.DatabaseURL)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal().Err(err).Msg("Failed to apply migrations")
|
||||||
|
}
|
||||||
|
defer migrator.Close()
|
||||||
|
|
||||||
|
if err := migrator.Up(); err != nil && err != migrate.ErrNoChange {
|
||||||
|
log.Fatal().Err(err).Msg("Failed to apply migrations")
|
||||||
}
|
}
|
||||||
|
|
||||||
service.Run(conn, config.JWTSecret)
|
q := data.New(conn)
|
||||||
|
err = service.CreateAdminIfNotExists(context.Background(), q, config.AdminUsername, config.AdminPassword)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal().Err(err).Msg("Failed initial admin account creation")
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info().Msg("Migrations applied succesfully, proceeding to HTTP server startup")
|
||||||
|
service.Run(conn, q, config.JWTSecret)
|
||||||
}
|
}
|
||||||
|
|
||||||
func initLogger() {
|
func initLogger() {
|
||||||
logLevel := os.Getenv("LOG_LEVEL")
|
fmt.Println(config.LogLevel)
|
||||||
level, err := zerolog.ParseLevel(logLevel)
|
level, err := zerolog.ParseLevel(config.LogLevel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Default to INFO
|
|
||||||
level = zerolog.InfoLevel
|
level = zerolog.InfoLevel
|
||||||
}
|
}
|
||||||
zerolog.SetGlobalLevel(level)
|
zerolog.SetGlobalLevel(level)
|
||||||
|
|
||||||
if isDevelopment {
|
output := zerolog.ConsoleWriter{
|
||||||
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr})
|
Out: os.Stdout,
|
||||||
} else {
|
TimeFormat: "2006-01-02 15:04:05",
|
||||||
log.Logger = log.Output(os.Stderr) // JSON to stdout/stderr
|
|
||||||
}
|
}
|
||||||
|
log.Logger = log.Output(output).With().Timestamp().Caller().Logger()
|
||||||
|
|
||||||
log.Debug().Msg("Logger initialized")
|
log.Info().Msgf("Logger initialized (log level: %s)", level)
|
||||||
}
|
}
|
||||||
|
@ -1,132 +0,0 @@
|
|||||||
// Code generated by sqlc. DO NOT EDIT.
|
|
||||||
// versions:
|
|
||||||
// sqlc v1.28.0
|
|
||||||
// source: note_versions.sql
|
|
||||||
|
|
||||||
package data
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
)
|
|
||||||
|
|
||||||
const createNoteVersion = `-- name: CreateNoteVersion :one
|
|
||||||
INSERT INTO note_versions (note_id, title, content, version_number, content_hash)
|
|
||||||
VALUES (
|
|
||||||
$1,
|
|
||||||
$2,
|
|
||||||
$3,
|
|
||||||
(SELECT COALESCE(MAX(version_number), 0) + 1 FROM note_versions WHERE note_id = $1),
|
|
||||||
encode(sha256($2::bytea || '\n'::bytea || $3::bytea), 'hex')
|
|
||||||
)
|
|
||||||
RETURNING id, note_id, title, content, version_number, content_hash, created_at
|
|
||||||
`
|
|
||||||
|
|
||||||
type CreateNoteVersionParams struct {
|
|
||||||
NoteID uuid.UUID `json:"note_id"`
|
|
||||||
Title string `json:"title"`
|
|
||||||
Content string `json:"content"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (q *Queries) CreateNoteVersion(ctx context.Context, arg CreateNoteVersionParams) (NoteVersion, error) {
|
|
||||||
row := q.db.QueryRow(ctx, createNoteVersion, arg.NoteID, arg.Title, arg.Content)
|
|
||||||
var i NoteVersion
|
|
||||||
err := row.Scan(
|
|
||||||
&i.ID,
|
|
||||||
&i.NoteID,
|
|
||||||
&i.Title,
|
|
||||||
&i.Content,
|
|
||||||
&i.VersionNumber,
|
|
||||||
&i.ContentHash,
|
|
||||||
&i.CreatedAt,
|
|
||||||
)
|
|
||||||
return i, err
|
|
||||||
}
|
|
||||||
|
|
||||||
const findDuplicateContent = `-- name: FindDuplicateContent :one
|
|
||||||
SELECT EXISTS(
|
|
||||||
SELECT 1 FROM note_versions
|
|
||||||
WHERE note_id = $1
|
|
||||||
AND content_hash = encode(sha256($2::bytea || '\n'::bytea || $3::bytea), 'hex')
|
|
||||||
)
|
|
||||||
`
|
|
||||||
|
|
||||||
type FindDuplicateContentParams struct {
|
|
||||||
NoteID uuid.UUID `json:"note_id"`
|
|
||||||
Column2 []byte `json:"column_2"`
|
|
||||||
Column3 []byte `json:"column_3"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (q *Queries) FindDuplicateContent(ctx context.Context, arg FindDuplicateContentParams) (bool, error) {
|
|
||||||
row := q.db.QueryRow(ctx, findDuplicateContent, arg.NoteID, arg.Column2, arg.Column3)
|
|
||||||
var exists bool
|
|
||||||
err := row.Scan(&exists)
|
|
||||||
return exists, err
|
|
||||||
}
|
|
||||||
|
|
||||||
const getNoteVersion = `-- name: GetNoteVersion :one
|
|
||||||
SELECT id, note_id, title, content, version_number, content_hash, created_at FROM note_versions
|
|
||||||
WHERE note_id = $1 AND version_number = $2 LIMIT 1
|
|
||||||
`
|
|
||||||
|
|
||||||
type GetNoteVersionParams struct {
|
|
||||||
NoteID uuid.UUID `json:"note_id"`
|
|
||||||
VersionNumber int32 `json:"version_number"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (q *Queries) GetNoteVersion(ctx context.Context, arg GetNoteVersionParams) (NoteVersion, error) {
|
|
||||||
row := q.db.QueryRow(ctx, getNoteVersion, arg.NoteID, arg.VersionNumber)
|
|
||||||
var i NoteVersion
|
|
||||||
err := row.Scan(
|
|
||||||
&i.ID,
|
|
||||||
&i.NoteID,
|
|
||||||
&i.Title,
|
|
||||||
&i.Content,
|
|
||||||
&i.VersionNumber,
|
|
||||||
&i.ContentHash,
|
|
||||||
&i.CreatedAt,
|
|
||||||
)
|
|
||||||
return i, err
|
|
||||||
}
|
|
||||||
|
|
||||||
const getNoteVersions = `-- name: GetNoteVersions :many
|
|
||||||
SELECT id, note_id, title, content, version_number, content_hash, created_at FROM note_versions
|
|
||||||
WHERE note_id = $1
|
|
||||||
ORDER BY version_number DESC
|
|
||||||
LIMIT $2 OFFSET $3
|
|
||||||
`
|
|
||||||
|
|
||||||
type GetNoteVersionsParams struct {
|
|
||||||
NoteID uuid.UUID `json:"note_id"`
|
|
||||||
Limit int32 `json:"limit"`
|
|
||||||
Offset int32 `json:"offset"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (q *Queries) GetNoteVersions(ctx context.Context, arg GetNoteVersionsParams) ([]NoteVersion, error) {
|
|
||||||
rows, err := q.db.Query(ctx, getNoteVersions, arg.NoteID, arg.Limit, arg.Offset)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
var items []NoteVersion
|
|
||||||
for rows.Next() {
|
|
||||||
var i NoteVersion
|
|
||||||
if err := rows.Scan(
|
|
||||||
&i.ID,
|
|
||||||
&i.NoteID,
|
|
||||||
&i.Title,
|
|
||||||
&i.Content,
|
|
||||||
&i.VersionNumber,
|
|
||||||
&i.ContentHash,
|
|
||||||
&i.CreatedAt,
|
|
||||||
); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
items = append(items, i)
|
|
||||||
}
|
|
||||||
if err := rows.Err(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return items, nil
|
|
||||||
}
|
|
@ -1,105 +0,0 @@
|
|||||||
// Code generated by sqlc. DO NOT EDIT.
|
|
||||||
// versions:
|
|
||||||
// sqlc v1.28.0
|
|
||||||
// source: notes.sql
|
|
||||||
|
|
||||||
package data
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
)
|
|
||||||
|
|
||||||
const createNote = `-- name: CreateNote :one
|
|
||||||
INSERT INTO notes (user_id)
|
|
||||||
VALUES ($1)
|
|
||||||
RETURNING id, user_id, created_at, updated_at
|
|
||||||
`
|
|
||||||
|
|
||||||
func (q *Queries) CreateNote(ctx context.Context, userID uuid.UUID) (Note, error) {
|
|
||||||
row := q.db.QueryRow(ctx, createNote, userID)
|
|
||||||
var i Note
|
|
||||||
err := row.Scan(
|
|
||||||
&i.ID,
|
|
||||||
&i.UserID,
|
|
||||||
&i.CreatedAt,
|
|
||||||
&i.UpdatedAt,
|
|
||||||
)
|
|
||||||
return i, err
|
|
||||||
}
|
|
||||||
|
|
||||||
const deleteNote = `-- name: DeleteNote :exec
|
|
||||||
DELETE FROM notes
|
|
||||||
WHERE id = $1 AND user_id = $2
|
|
||||||
`
|
|
||||||
|
|
||||||
type DeleteNoteParams struct {
|
|
||||||
ID uuid.UUID `json:"id"`
|
|
||||||
UserID uuid.UUID `json:"user_id"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (q *Queries) DeleteNote(ctx context.Context, arg DeleteNoteParams) error {
|
|
||||||
_, err := q.db.Exec(ctx, deleteNote, arg.ID, arg.UserID)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
const getNote = `-- name: GetNote :one
|
|
||||||
SELECT id, user_id, created_at, updated_at FROM notes
|
|
||||||
WHERE id = $1 AND user_id = $2 LIMIT 1
|
|
||||||
`
|
|
||||||
|
|
||||||
type GetNoteParams struct {
|
|
||||||
ID uuid.UUID `json:"id"`
|
|
||||||
UserID uuid.UUID `json:"user_id"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (q *Queries) GetNote(ctx context.Context, arg GetNoteParams) (Note, error) {
|
|
||||||
row := q.db.QueryRow(ctx, getNote, arg.ID, arg.UserID)
|
|
||||||
var i Note
|
|
||||||
err := row.Scan(
|
|
||||||
&i.ID,
|
|
||||||
&i.UserID,
|
|
||||||
&i.CreatedAt,
|
|
||||||
&i.UpdatedAt,
|
|
||||||
)
|
|
||||||
return i, err
|
|
||||||
}
|
|
||||||
|
|
||||||
const listNotes = `-- name: ListNotes :many
|
|
||||||
SELECT id, user_id, created_at, updated_at FROM notes
|
|
||||||
WHERE user_id = $1
|
|
||||||
ORDER BY created_at DESC
|
|
||||||
LIMIT $2 OFFSET $3
|
|
||||||
`
|
|
||||||
|
|
||||||
type ListNotesParams struct {
|
|
||||||
UserID uuid.UUID `json:"user_id"`
|
|
||||||
Limit int32 `json:"limit"`
|
|
||||||
Offset int32 `json:"offset"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (q *Queries) ListNotes(ctx context.Context, arg ListNotesParams) ([]Note, error) {
|
|
||||||
rows, err := q.db.Query(ctx, listNotes, arg.UserID, arg.Limit, arg.Offset)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
var items []Note
|
|
||||||
for rows.Next() {
|
|
||||||
var i Note
|
|
||||||
if err := rows.Scan(
|
|
||||||
&i.ID,
|
|
||||||
&i.UserID,
|
|
||||||
&i.CreatedAt,
|
|
||||||
&i.UpdatedAt,
|
|
||||||
); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
items = append(items, i)
|
|
||||||
}
|
|
||||||
if err := rows.Err(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return items, nil
|
|
||||||
}
|
|
@ -1,74 +0,0 @@
|
|||||||
// pkg/migrate/migrate.go
|
|
||||||
package migrate
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"embed"
|
|
||||||
"fmt"
|
|
||||||
"io/fs"
|
|
||||||
"sort"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/v5"
|
|
||||||
)
|
|
||||||
|
|
||||||
func Run(ctx context.Context, conn *pgx.Conn, migrationsFS embed.FS) error {
|
|
||||||
// Get already applied migrations
|
|
||||||
rows, _ := conn.Query(ctx, "SELECT version FROM schema_migrations")
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
applied := make(map[int64]bool)
|
|
||||||
for rows.Next() {
|
|
||||||
var version int64
|
|
||||||
if err := rows.Scan(&version); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
applied[version] = true
|
|
||||||
}
|
|
||||||
|
|
||||||
files, err := migrationsFS.ReadDir("migrations")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply the migrations sequentially based on their ordinal number
|
|
||||||
for _, f := range sortMigrations(files) {
|
|
||||||
version, err := strconv.ParseInt(strings.Split(f.Name(), "_")[0], 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("invalid migration name: %s", f.Name())
|
|
||||||
}
|
|
||||||
|
|
||||||
if applied[version] {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run migration
|
|
||||||
sql, err := migrationsFS.ReadFile("migrations/" + f.Name())
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := conn.Exec(ctx, string(sql)); err != nil {
|
|
||||||
return fmt.Errorf("migration %d failed: %w", version, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := conn.Exec(ctx,
|
|
||||||
"INSERT INTO schema_migrations (version) VALUES ($1)", version,
|
|
||||||
); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sort the migration files based on their ordinal number prefix.
|
|
||||||
func sortMigrations(files []fs.DirEntry) []fs.DirEntry {
|
|
||||||
sort.Slice(files, func(i, j int) bool {
|
|
||||||
v1, _ := strconv.ParseInt(strings.Split(files[i].Name(), "_")[0], 10, 64)
|
|
||||||
v2, _ := strconv.ParseInt(strings.Split(files[j].Name(), "_")[0], 10, 64)
|
|
||||||
return v1 < v2
|
|
||||||
})
|
|
||||||
return files
|
|
||||||
}
|
|
@ -1,386 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.umbrella.haus/ae/notatest/pkg/data"
|
|
||||||
"github.com/go-chi/chi/v5"
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestAuthMiddleware(t *testing.T) {
|
|
||||||
secret := "test-secret"
|
|
||||||
validToken := generateTestToken(t, secret, "access", uuid.New().String(), true)
|
|
||||||
expiredToken := generateTestToken(t, secret, "access", uuid.New().String(), true, func(claims *userClaims) {
|
|
||||||
claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(-1 * time.Hour))
|
|
||||||
})
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
token string
|
|
||||||
expectedErr string
|
|
||||||
statusCode int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
"no token",
|
|
||||||
"",
|
|
||||||
"Unauthorized",
|
|
||||||
http.StatusUnauthorized,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"invalid token",
|
|
||||||
"invalid",
|
|
||||||
"Invalid token",
|
|
||||||
http.StatusUnauthorized,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"expired token",
|
|
||||||
expiredToken,
|
|
||||||
"Invalid token",
|
|
||||||
http.StatusUnauthorized,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"wrong type",
|
|
||||||
generateTestToken(
|
|
||||||
t,
|
|
||||||
secret,
|
|
||||||
"refresh",
|
|
||||||
uuid.New().String(),
|
|
||||||
true,
|
|
||||||
),
|
|
||||||
"Invalid token type",
|
|
||||||
http.StatusUnauthorized,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"valid token",
|
|
||||||
validToken,
|
|
||||||
"",
|
|
||||||
http.StatusOK,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
mw := requireAccessToken(secret)
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
|
||||||
if tc.token != "" {
|
|
||||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", tc.token))
|
|
||||||
}
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
called := false
|
|
||||||
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
called = true
|
|
||||||
_, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
|
||||||
assert.True(t, ok)
|
|
||||||
}))
|
|
||||||
|
|
||||||
handler.ServeHTTP(w, req)
|
|
||||||
|
|
||||||
assert.Equal(t, tc.statusCode, w.Code)
|
|
||||||
if tc.expectedErr != "" {
|
|
||||||
assert.Contains(t, w.Body.String(), tc.expectedErr)
|
|
||||||
}
|
|
||||||
assert.Equal(t, tc.statusCode == http.StatusOK, called)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAdminOnlyMiddleware(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
user *userClaims
|
|
||||||
statusCode int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
"no user",
|
|
||||||
nil,
|
|
||||||
http.StatusForbidden,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"non admin user",
|
|
||||||
&userClaims{
|
|
||||||
Admin: false,
|
|
||||||
},
|
|
||||||
http.StatusForbidden,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"admin user",
|
|
||||||
&userClaims{
|
|
||||||
Admin: true,
|
|
||||||
},
|
|
||||||
http.StatusOK,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
mw := adminOnlyMiddleware
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
|
||||||
if tc.user != nil {
|
|
||||||
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, tc.user))
|
|
||||||
}
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
called := false
|
|
||||||
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
called = true
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
}))
|
|
||||||
|
|
||||||
handler.ServeHTTP(w, req)
|
|
||||||
|
|
||||||
assert.Equal(t, tc.statusCode, w.Code)
|
|
||||||
assert.Equal(t, tc.statusCode == http.StatusOK, called)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOwnerOnlyMiddleware(t *testing.T) {
|
|
||||||
userID := uuid.New().String()
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
user *userClaims
|
|
||||||
urlID string
|
|
||||||
statusCode int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
"no user",
|
|
||||||
nil,
|
|
||||||
userID,
|
|
||||||
http.StatusForbidden,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"different ID",
|
|
||||||
&userClaims{
|
|
||||||
RegisteredClaims: jwt.RegisteredClaims{
|
|
||||||
Subject: uuid.New().String(),
|
|
||||||
}},
|
|
||||||
userID,
|
|
||||||
http.StatusForbidden,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"matching ID",
|
|
||||||
&userClaims{
|
|
||||||
RegisteredClaims: jwt.RegisteredClaims{
|
|
||||||
Subject: userID,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
userID,
|
|
||||||
http.StatusOK,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
r := chi.NewRouter()
|
|
||||||
|
|
||||||
handlerChain := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
})
|
|
||||||
|
|
||||||
r.With(
|
|
||||||
// Add user with the given claims to request's context
|
|
||||||
func(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
var ctx context.Context = r.Context()
|
|
||||||
if tc.user != nil {
|
|
||||||
ctx = context.WithValue(ctx, userCtxKey{}, tc.user)
|
|
||||||
}
|
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
|
||||||
})
|
|
||||||
},
|
|
||||||
ownerOnlyMiddleware,
|
|
||||||
).Get("/{id}", handlerChain)
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", fmt.Sprintf("/%s", tc.urlID), nil)
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
r.ServeHTTP(w, req)
|
|
||||||
|
|
||||||
if tc.urlID == "invalid" {
|
|
||||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
|
||||||
} else {
|
|
||||||
assert.Equal(t, tc.statusCode, w.Code)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUserCtxMiddleware(t *testing.T) {
|
|
||||||
validUserID := uuid.New()
|
|
||||||
invalidUserID := "invalid"
|
|
||||||
|
|
||||||
mockStore := &mockUserStore{
|
|
||||||
GetUserByIDFunc: func(ctx context.Context, id uuid.UUID) (data.User, error) {
|
|
||||||
if id == validUserID {
|
|
||||||
return data.User{ID: validUserID}, nil
|
|
||||||
}
|
|
||||||
return data.User{}, errors.New("not found")
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
urlID string
|
|
||||||
statusCode int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
"valid ID",
|
|
||||||
validUserID.String(),
|
|
||||||
http.StatusOK,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"invalid ID",
|
|
||||||
invalidUserID,
|
|
||||||
http.StatusNotFound,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"non existent ID",
|
|
||||||
uuid.New().String(),
|
|
||||||
http.StatusNotFound,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
mw := userCtx(mockStore)
|
|
||||||
r := chi.NewRouter()
|
|
||||||
r.With(mw).Get("/{id}", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
user, ok := r.Context().Value(userCtxKey{}).(data.User)
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.Equal(t, validUserID, user.ID)
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
})
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", fmt.Sprintf("/%s", tc.urlID), nil)
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
r.ServeHTTP(w, req)
|
|
||||||
|
|
||||||
assert.Equal(t, tc.statusCode, w.Code)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNoteCtxMiddleware(t *testing.T) {
|
|
||||||
userID := uuid.New()
|
|
||||||
noteID := uuid.New()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
noteID string
|
|
||||||
user any
|
|
||||||
mock func(*mockNoteStore)
|
|
||||||
statusCode int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
"invalid note ID",
|
|
||||||
"invalid",
|
|
||||||
&userClaims{RegisteredClaims: jwt.RegisteredClaims{
|
|
||||||
Subject: userID.String(),
|
|
||||||
}},
|
|
||||||
func(m *mockNoteStore) {},
|
|
||||||
http.StatusNotFound,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"unauthorized user",
|
|
||||||
noteID.String(),
|
|
||||||
nil,
|
|
||||||
func(m *mockNoteStore) {},
|
|
||||||
http.StatusUnauthorized,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"note not found",
|
|
||||||
noteID.String(),
|
|
||||||
&userClaims{RegisteredClaims: jwt.RegisteredClaims{
|
|
||||||
Subject: userID.String(),
|
|
||||||
}},
|
|
||||||
func(m *mockNoteStore) {
|
|
||||||
m.GetNoteFunc = func(ctx context.Context, arg data.GetNoteParams) (data.Note, error) {
|
|
||||||
return data.Note{}, errors.New("not found")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
http.StatusNotFound,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"success",
|
|
||||||
noteID.String(),
|
|
||||||
&userClaims{RegisteredClaims: jwt.RegisteredClaims{
|
|
||||||
Subject: userID.String(),
|
|
||||||
}},
|
|
||||||
func(m *mockNoteStore) {
|
|
||||||
m.GetNoteFunc = func(ctx context.Context, arg data.GetNoteParams) (data.Note, error) {
|
|
||||||
assert.Equal(t, noteID, arg.ID)
|
|
||||||
assert.Equal(t, userID, arg.UserID)
|
|
||||||
return data.Note{ID: noteID}, nil
|
|
||||||
}
|
|
||||||
},
|
|
||||||
http.StatusOK,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
mockStore := &mockNoteStore{}
|
|
||||||
tc.mock(mockStore)
|
|
||||||
|
|
||||||
handler := noteCtx(mockStore)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
_, ok := r.Context().Value(noteCtxKey{}).(data.Note)
|
|
||||||
assert.True(t, ok)
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
}))
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", fmt.Sprintf("/notes/%s", tc.noteID), nil)
|
|
||||||
|
|
||||||
// Chi router context mocks ID passed in a URL parameter
|
|
||||||
rctx := chi.NewRouteContext()
|
|
||||||
rctx.URLParams.Add("id", tc.noteID)
|
|
||||||
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
|
|
||||||
|
|
||||||
if tc.user != nil {
|
|
||||||
req = req.WithContext(context.WithValue(
|
|
||||||
req.Context(),
|
|
||||||
userCtxKey{},
|
|
||||||
tc.user,
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
handler.ServeHTTP(w, req)
|
|
||||||
|
|
||||||
assert.Equal(t, tc.statusCode, w.Code)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func generateTestToken(t *testing.T, secret, tokenType, userID string, isAdmin bool, opts ...func(*userClaims)) string {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
claims := &userClaims{
|
|
||||||
Admin: isAdmin,
|
|
||||||
TokenType: tokenType,
|
|
||||||
RegisteredClaims: jwt.RegisteredClaims{
|
|
||||||
Subject: userID,
|
|
||||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, opt := range opts {
|
|
||||||
opt(claims)
|
|
||||||
}
|
|
||||||
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
||||||
signedToken, err := token.SignedString([]byte(secret))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to generate test token: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return signedToken
|
|
||||||
}
|
|
@ -1,262 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"net/http"
|
|
||||||
"strconv"
|
|
||||||
|
|
||||||
"git.umbrella.haus/ae/notatest/pkg/data"
|
|
||||||
"github.com/go-chi/chi/v5"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
)
|
|
||||||
|
|
||||||
type noteCtxKey struct{}
|
|
||||||
|
|
||||||
// Mockable database operations interface
|
|
||||||
type NoteStore interface {
|
|
||||||
CreateNote(ctx context.Context, userID uuid.UUID) (data.Note, error)
|
|
||||||
DeleteNote(ctx context.Context, arg data.DeleteNoteParams) error
|
|
||||||
GetNote(ctx context.Context, arg data.GetNoteParams) (data.Note, error)
|
|
||||||
ListNotes(ctx context.Context, arg data.ListNotesParams) ([]data.Note, error)
|
|
||||||
CreateNoteVersion(ctx context.Context, arg data.CreateNoteVersionParams) (data.NoteVersion, error)
|
|
||||||
FindDuplicateContent(ctx context.Context, arg data.FindDuplicateContentParams) (bool, error)
|
|
||||||
GetNoteVersion(ctx context.Context, arg data.GetNoteVersionParams) (data.NoteVersion, error)
|
|
||||||
GetNoteVersions(ctx context.Context, arg data.GetNoteVersionsParams) ([]data.NoteVersion, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type notesResource struct {
|
|
||||||
JWTSecret string
|
|
||||||
Notes NoteStore
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rs notesResource) Routes() chi.Router {
|
|
||||||
r := chi.NewRouter()
|
|
||||||
|
|
||||||
r.Group(func(r chi.Router) {
|
|
||||||
r.Use(requireAccessToken(rs.JWTSecret))
|
|
||||||
|
|
||||||
r.Post("/", rs.CreateNote) // POST /notes - note creation
|
|
||||||
r.Get("/", rs.ListNotes) // GET /notes - get all notes
|
|
||||||
|
|
||||||
r.Route("/{id}", func(r chi.Router) {
|
|
||||||
r.Use(noteCtx(rs.Notes))
|
|
||||||
|
|
||||||
r.Get("/", rs.GetNote) // GET /notes/{id} - get specific note
|
|
||||||
r.Delete("/", rs.DeleteNote) // DELETE /notes/{id} - delete specific note
|
|
||||||
|
|
||||||
r.Route("/versions", func(r chi.Router) {
|
|
||||||
r.Post("/", rs.CreateNoteVersion) // POST /notes/{id}/versions - create new version
|
|
||||||
r.Get("/", rs.ListNoteVersions) // GET /notes/{id}/versions - get all existing versions
|
|
||||||
r.Get("/{version}", rs.GetNoteVersion) // GET /notes/{id}/versions/{version} - get specific version
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rs *notesResource) CreateNote(w http.ResponseWriter, r *http.Request) {
|
|
||||||
user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
|
||||||
if !ok {
|
|
||||||
respondError(w, http.StatusUnauthorized, "Unauthorized")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
userID, err := uuid.Parse(user.Subject)
|
|
||||||
if err != nil {
|
|
||||||
respondError(w, http.StatusInternalServerError, "Invalid user ID")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
note, err := rs.Notes.CreateNote(r.Context(), userID)
|
|
||||||
if err != nil {
|
|
||||||
respondError(w, http.StatusInternalServerError, "Failed to create note")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
respondJSON(w, http.StatusCreated, note)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rs *notesResource) ListNotes(w http.ResponseWriter, r *http.Request) {
|
|
||||||
user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
|
||||||
if !ok {
|
|
||||||
respondError(w, http.StatusUnauthorized, "Unauthorized")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
userID, err := uuid.Parse(user.Subject)
|
|
||||||
if err != nil {
|
|
||||||
respondError(w, http.StatusInternalServerError, "Invalid user ID")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
limit, offset := getPaginationParams(r)
|
|
||||||
|
|
||||||
notes, err := rs.Notes.ListNotes(r.Context(), data.ListNotesParams{
|
|
||||||
UserID: userID,
|
|
||||||
Limit: limit,
|
|
||||||
Offset: offset,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
respondError(w, http.StatusInternalServerError, "Failed to retrieve notes")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
respondJSON(w, http.StatusOK, notes)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rs *notesResource) GetNote(w http.ResponseWriter, r *http.Request) {
|
|
||||||
note, ok := r.Context().Value(noteCtxKey{}).(data.Note)
|
|
||||||
if !ok {
|
|
||||||
respondError(w, http.StatusNotFound, "Note not found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
respondJSON(w, http.StatusOK, note)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rs *notesResource) DeleteNote(w http.ResponseWriter, r *http.Request) {
|
|
||||||
note, ok := r.Context().Value(noteCtxKey{}).(data.Note)
|
|
||||||
if !ok {
|
|
||||||
respondError(w, http.StatusNotFound, "Note not found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
|
||||||
if !ok {
|
|
||||||
respondError(w, http.StatusUnauthorized, "Unauthorized")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
userID, err := uuid.Parse(user.Subject)
|
|
||||||
if err != nil {
|
|
||||||
respondError(w, http.StatusInternalServerError, "Invalid user ID")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = rs.Notes.DeleteNote(r.Context(), data.DeleteNoteParams{
|
|
||||||
ID: note.ID,
|
|
||||||
UserID: userID,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
respondError(w, http.StatusInternalServerError, "Failed to delete note")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
w.WriteHeader(http.StatusNoContent)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rs *notesResource) CreateNoteVersion(w http.ResponseWriter, r *http.Request) {
|
|
||||||
note, ok := r.Context().Value(noteCtxKey{}).(data.Note)
|
|
||||||
if !ok {
|
|
||||||
respondError(w, http.StatusNotFound, "Note not found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var req struct {
|
|
||||||
Title string `json:"title"`
|
|
||||||
Content string `json:"content"`
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
||||||
respondError(w, http.StatusBadRequest, "Invalid request body")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// De-duplication check
|
|
||||||
duplicate, err := rs.Notes.FindDuplicateContent(r.Context(), data.FindDuplicateContentParams{
|
|
||||||
NoteID: note.ID,
|
|
||||||
Column2: []byte(req.Title),
|
|
||||||
Column3: []byte(req.Content),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
respondError(w, http.StatusInternalServerError, "Failed to check for duplicate content")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if duplicate {
|
|
||||||
respondError(w, http.StatusConflict, "Duplicate content detected")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
version, err := rs.Notes.CreateNoteVersion(r.Context(), data.CreateNoteVersionParams{
|
|
||||||
NoteID: note.ID,
|
|
||||||
Title: req.Title,
|
|
||||||
Content: req.Content,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
respondError(w, http.StatusInternalServerError, "Failed to create note version")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
respondJSON(w, http.StatusCreated, version)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rs *notesResource) ListNoteVersions(w http.ResponseWriter, r *http.Request) {
|
|
||||||
note, ok := r.Context().Value(noteCtxKey{}).(data.Note)
|
|
||||||
if !ok {
|
|
||||||
respondError(w, http.StatusNotFound, "Note not found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
limit, offset := getPaginationParams(r)
|
|
||||||
|
|
||||||
versions, err := rs.Notes.GetNoteVersions(r.Context(), data.GetNoteVersionsParams{
|
|
||||||
NoteID: note.ID,
|
|
||||||
Limit: limit,
|
|
||||||
Offset: offset,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
respondError(w, http.StatusInternalServerError, "Failed to retrieve versions")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
respondJSON(w, http.StatusOK, versions)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rs *notesResource) GetNoteVersion(w http.ResponseWriter, r *http.Request) {
|
|
||||||
note, ok := r.Context().Value(noteCtxKey{}).(data.Note)
|
|
||||||
if !ok {
|
|
||||||
respondError(w, http.StatusNotFound, "Note not found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
versionStr := chi.URLParam(r, "version")
|
|
||||||
versionNumber, err := strconv.ParseInt(versionStr, 10, 32)
|
|
||||||
if err != nil {
|
|
||||||
respondError(w, http.StatusBadRequest, "Invalid version number")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
version, err := rs.Notes.GetNoteVersion(r.Context(), data.GetNoteVersionParams{
|
|
||||||
NoteID: note.ID,
|
|
||||||
VersionNumber: int32(versionNumber),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
respondError(w, http.StatusNotFound, "Version not found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
respondJSON(w, http.StatusOK, version)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getPaginationParams(r *http.Request) (limit int32, offset int32) {
|
|
||||||
defaultLimit := 50
|
|
||||||
defaultOffset := 0
|
|
||||||
|
|
||||||
limitStr := r.URL.Query().Get("limit")
|
|
||||||
if limitStr != "" {
|
|
||||||
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 {
|
|
||||||
defaultLimit = l
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
offsetStr := r.URL.Query().Get("offset")
|
|
||||||
if offsetStr != "" {
|
|
||||||
if o, err := strconv.Atoi(offsetStr); err == nil && o >= 0 {
|
|
||||||
defaultOffset = o
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return int32(defaultLimit), int32(defaultOffset)
|
|
||||||
}
|
|
@ -1,509 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"git.umbrella.haus/ae/notatest/pkg/data"
|
|
||||||
"github.com/go-chi/chi/v5"
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
type mockNoteStore struct {
|
|
||||||
CreateNoteFunc func(context.Context, uuid.UUID) (data.Note, error)
|
|
||||||
DeleteNoteFunc func(context.Context, data.DeleteNoteParams) error
|
|
||||||
GetNoteFunc func(context.Context, data.GetNoteParams) (data.Note, error)
|
|
||||||
ListNotesFunc func(context.Context, data.ListNotesParams) ([]data.Note, error)
|
|
||||||
CreateNoteVersionFunc func(context.Context, data.CreateNoteVersionParams) (data.NoteVersion, error)
|
|
||||||
FindDuplicateContentFunc func(context.Context, data.FindDuplicateContentParams) (bool, error)
|
|
||||||
GetNoteVersionFunc func(context.Context, data.GetNoteVersionParams) (data.NoteVersion, error)
|
|
||||||
GetNoteVersionsFunc func(context.Context, data.GetNoteVersionsParams) ([]data.NoteVersion, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockNoteStore) CreateNote(ctx context.Context, userID uuid.UUID) (data.Note, error) {
|
|
||||||
return m.CreateNoteFunc(ctx, userID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockNoteStore) DeleteNote(ctx context.Context, arg data.DeleteNoteParams) error {
|
|
||||||
return m.DeleteNoteFunc(ctx, arg)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockNoteStore) GetNote(ctx context.Context, arg data.GetNoteParams) (data.Note, error) {
|
|
||||||
return m.GetNoteFunc(ctx, arg)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockNoteStore) ListNotes(ctx context.Context, arg data.ListNotesParams) ([]data.Note, error) {
|
|
||||||
return m.ListNotesFunc(ctx, arg)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockNoteStore) CreateNoteVersion(ctx context.Context, arg data.CreateNoteVersionParams) (data.NoteVersion, error) {
|
|
||||||
return m.CreateNoteVersionFunc(ctx, arg)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockNoteStore) FindDuplicateContent(ctx context.Context, arg data.FindDuplicateContentParams) (bool, error) {
|
|
||||||
return m.FindDuplicateContentFunc(ctx, arg)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockNoteStore) GetNoteVersion(ctx context.Context, arg data.GetNoteVersionParams) (data.NoteVersion, error) {
|
|
||||||
return m.GetNoteVersionFunc(ctx, arg)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockNoteStore) GetNoteVersions(ctx context.Context, arg data.GetNoteVersionsParams) ([]data.NoteVersion, error) {
|
|
||||||
return m.GetNoteVersionsFunc(ctx, arg)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNotes_CreateNote(t *testing.T) {
|
|
||||||
userID := uuid.New()
|
|
||||||
testNote := data.Note{ID: uuid.New(), UserID: userID}
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
mock func(*mockNoteStore)
|
|
||||||
statusCode int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
"success",
|
|
||||||
func(m *mockNoteStore) {
|
|
||||||
m.CreateNoteFunc = func(_ context.Context, uid uuid.UUID) (data.Note, error) {
|
|
||||||
assert.Equal(t, userID, uid)
|
|
||||||
return testNote, nil
|
|
||||||
}
|
|
||||||
},
|
|
||||||
http.StatusCreated,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"database error",
|
|
||||||
func(m *mockNoteStore) {
|
|
||||||
m.CreateNoteFunc = func(context.Context, uuid.UUID) (data.Note, error) {
|
|
||||||
return data.Note{}, errors.New("db error")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
http.StatusInternalServerError,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
mockStore := &mockNoteStore{}
|
|
||||||
tc.mock(mockStore)
|
|
||||||
|
|
||||||
rs := notesResource{Notes: mockStore}
|
|
||||||
req := httptest.NewRequest("POST", "/", nil)
|
|
||||||
req = req.WithContext(context.WithValue(
|
|
||||||
req.Context(),
|
|
||||||
userCtxKey{},
|
|
||||||
&userClaims{RegisteredClaims: jwt.RegisteredClaims{
|
|
||||||
Subject: userID.String(),
|
|
||||||
}},
|
|
||||||
))
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
rs.CreateNote(w, req)
|
|
||||||
|
|
||||||
assert.Equal(t, tc.statusCode, w.Code)
|
|
||||||
if tc.statusCode == http.StatusCreated {
|
|
||||||
var note data.Note
|
|
||||||
json.Unmarshal(w.Body.Bytes(), ¬e)
|
|
||||||
assert.Equal(t, testNote.ID, note.ID)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNotes_ListNotes(t *testing.T) {
|
|
||||||
userID := uuid.New()
|
|
||||||
notes := []data.Note{
|
|
||||||
{ID: uuid.New(), UserID: userID},
|
|
||||||
{ID: uuid.New(), UserID: userID},
|
|
||||||
}
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
query string
|
|
||||||
mock func(*mockNoteStore)
|
|
||||||
statusCode int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
"success",
|
|
||||||
"",
|
|
||||||
func(m *mockNoteStore) {
|
|
||||||
m.ListNotesFunc = func(_ context.Context, arg data.ListNotesParams) ([]data.Note, error) {
|
|
||||||
assert.Equal(t, userID, arg.UserID)
|
|
||||||
return notes, nil
|
|
||||||
}
|
|
||||||
},
|
|
||||||
http.StatusOK,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"with pagination",
|
|
||||||
"?limit=10&offset=20",
|
|
||||||
func(m *mockNoteStore) {
|
|
||||||
m.ListNotesFunc = func(_ context.Context, arg data.ListNotesParams) ([]data.Note, error) {
|
|
||||||
assert.EqualValues(t, 10, arg.Limit)
|
|
||||||
assert.EqualValues(t, 20, arg.Offset)
|
|
||||||
return notes, nil
|
|
||||||
}
|
|
||||||
},
|
|
||||||
http.StatusOK,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"database error",
|
|
||||||
"",
|
|
||||||
func(m *mockNoteStore) {
|
|
||||||
m.ListNotesFunc = func(context.Context, data.ListNotesParams) ([]data.Note, error) {
|
|
||||||
return nil, errors.New("db error")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
http.StatusInternalServerError,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
mockStore := &mockNoteStore{}
|
|
||||||
tc.mock(mockStore)
|
|
||||||
|
|
||||||
rs := notesResource{Notes: mockStore}
|
|
||||||
req := httptest.NewRequest("GET", fmt.Sprintf("/%s", tc.query), nil)
|
|
||||||
req = req.WithContext(context.WithValue(
|
|
||||||
req.Context(),
|
|
||||||
userCtxKey{},
|
|
||||||
&userClaims{RegisteredClaims: jwt.RegisteredClaims{
|
|
||||||
Subject: userID.String(),
|
|
||||||
}},
|
|
||||||
))
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
rs.ListNotes(w, req)
|
|
||||||
|
|
||||||
assert.Equal(t, tc.statusCode, w.Code)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNotes_GetNote(t *testing.T) {
|
|
||||||
noteID := uuid.New()
|
|
||||||
userID := uuid.New()
|
|
||||||
validNote := data.Note{ID: noteID, UserID: userID}
|
|
||||||
|
|
||||||
t.Run("success", func(t *testing.T) {
|
|
||||||
rs := notesResource{}
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
|
||||||
req = req.WithContext(context.WithValue(
|
|
||||||
req.Context(),
|
|
||||||
noteCtxKey{},
|
|
||||||
validNote,
|
|
||||||
))
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
rs.GetNote(w, req)
|
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
|
||||||
var note data.Note
|
|
||||||
json.Unmarshal(w.Body.Bytes(), ¬e)
|
|
||||||
assert.Equal(t, validNote.ID, note.ID)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("not found", func(t *testing.T) {
|
|
||||||
rs := notesResource{}
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
rs.GetNote(w, req)
|
|
||||||
|
|
||||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNotes_DeleteNote(t *testing.T) {
|
|
||||||
noteID := uuid.New()
|
|
||||||
userID := uuid.New()
|
|
||||||
validNote := data.Note{ID: noteID, UserID: userID}
|
|
||||||
|
|
||||||
t.Run("success", func(t *testing.T) {
|
|
||||||
mockStore := &mockNoteStore{
|
|
||||||
DeleteNoteFunc: func(_ context.Context, arg data.DeleteNoteParams) error {
|
|
||||||
assert.Equal(t, noteID, arg.ID)
|
|
||||||
assert.Equal(t, userID, arg.UserID)
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
rs := notesResource{Notes: mockStore}
|
|
||||||
req := httptest.NewRequest("DELETE", "/", nil)
|
|
||||||
req = req.WithContext(context.WithValue(
|
|
||||||
req.Context(),
|
|
||||||
noteCtxKey{},
|
|
||||||
validNote,
|
|
||||||
))
|
|
||||||
req = req.WithContext(context.WithValue(
|
|
||||||
req.Context(),
|
|
||||||
userCtxKey{},
|
|
||||||
&userClaims{RegisteredClaims: jwt.RegisteredClaims{
|
|
||||||
Subject: userID.String(),
|
|
||||||
}},
|
|
||||||
))
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
rs.DeleteNote(w, req)
|
|
||||||
|
|
||||||
assert.Equal(t, http.StatusNoContent, w.Code)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("database error", func(t *testing.T) {
|
|
||||||
mockStore := &mockNoteStore{
|
|
||||||
DeleteNoteFunc: func(context.Context, data.DeleteNoteParams) error {
|
|
||||||
return errors.New("db error")
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
rs := notesResource{Notes: mockStore}
|
|
||||||
req := httptest.NewRequest("DELETE", "/", nil)
|
|
||||||
req = req.WithContext(context.WithValue(
|
|
||||||
req.Context(),
|
|
||||||
noteCtxKey{},
|
|
||||||
validNote,
|
|
||||||
))
|
|
||||||
req = req.WithContext(context.WithValue(
|
|
||||||
req.Context(),
|
|
||||||
userCtxKey{},
|
|
||||||
&userClaims{RegisteredClaims: jwt.RegisteredClaims{
|
|
||||||
Subject: userID.String(),
|
|
||||||
}},
|
|
||||||
))
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
rs.DeleteNote(w, req)
|
|
||||||
|
|
||||||
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNotes_CreateNoteVersion(t *testing.T) {
|
|
||||||
noteID := uuid.New()
|
|
||||||
validRequest := `{"title": "Test", "content": "Content"}`
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
body string
|
|
||||||
mock func(*mockNoteStore)
|
|
||||||
statusCode int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
"success",
|
|
||||||
validRequest,
|
|
||||||
func(m *mockNoteStore) {
|
|
||||||
m.FindDuplicateContentFunc = func(context.Context, data.FindDuplicateContentParams) (bool, error) {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
m.CreateNoteVersionFunc = func(_ context.Context, arg data.CreateNoteVersionParams) (data.NoteVersion, error) {
|
|
||||||
assert.Equal(t, noteID, arg.NoteID)
|
|
||||||
return data.NoteVersion{}, nil
|
|
||||||
}
|
|
||||||
},
|
|
||||||
http.StatusCreated,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"duplicate content",
|
|
||||||
validRequest,
|
|
||||||
func(m *mockNoteStore) {
|
|
||||||
m.FindDuplicateContentFunc = func(context.Context, data.FindDuplicateContentParams) (bool, error) {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
},
|
|
||||||
http.StatusConflict,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"invalid request",
|
|
||||||
"{invalid}",
|
|
||||||
func(m *mockNoteStore) {},
|
|
||||||
http.StatusBadRequest,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"database error",
|
|
||||||
validRequest,
|
|
||||||
func(m *mockNoteStore) {
|
|
||||||
m.FindDuplicateContentFunc = func(context.Context, data.FindDuplicateContentParams) (bool, error) {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
m.CreateNoteVersionFunc = func(context.Context, data.CreateNoteVersionParams) (data.NoteVersion, error) {
|
|
||||||
return data.NoteVersion{}, errors.New("db error")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
http.StatusInternalServerError,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
mockStore := &mockNoteStore{}
|
|
||||||
tc.mock(mockStore)
|
|
||||||
|
|
||||||
rs := notesResource{Notes: mockStore}
|
|
||||||
req := httptest.NewRequest("POST", "/", strings.NewReader(tc.body))
|
|
||||||
req = req.WithContext(context.WithValue(
|
|
||||||
req.Context(),
|
|
||||||
noteCtxKey{},
|
|
||||||
data.Note{ID: noteID},
|
|
||||||
))
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
rs.CreateNoteVersion(w, req)
|
|
||||||
|
|
||||||
assert.Equal(t, tc.statusCode, w.Code)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNotes_ListNoteVersions(t *testing.T) {
|
|
||||||
noteID := uuid.New()
|
|
||||||
versions := []data.NoteVersion{
|
|
||||||
{ID: uuid.New(), NoteID: noteID, VersionNumber: 1},
|
|
||||||
{ID: uuid.New(), NoteID: noteID, VersionNumber: 2},
|
|
||||||
}
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
query string
|
|
||||||
mock func(*mockNoteStore)
|
|
||||||
statusCode int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
"success",
|
|
||||||
"",
|
|
||||||
func(m *mockNoteStore) {
|
|
||||||
m.GetNoteVersionsFunc = func(_ context.Context, arg data.GetNoteVersionsParams) ([]data.NoteVersion, error) {
|
|
||||||
assert.Equal(t, noteID, arg.NoteID)
|
|
||||||
return versions, nil
|
|
||||||
}
|
|
||||||
},
|
|
||||||
http.StatusOK,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"with pagination",
|
|
||||||
"?limit=5&offset=10",
|
|
||||||
func(m *mockNoteStore) {
|
|
||||||
m.GetNoteVersionsFunc = func(_ context.Context, arg data.GetNoteVersionsParams) ([]data.NoteVersion, error) {
|
|
||||||
assert.EqualValues(t, 5, arg.Limit)
|
|
||||||
assert.EqualValues(t, 10, arg.Offset)
|
|
||||||
return versions, nil
|
|
||||||
}
|
|
||||||
},
|
|
||||||
http.StatusOK,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"database error",
|
|
||||||
"",
|
|
||||||
func(m *mockNoteStore) {
|
|
||||||
m.GetNoteVersionsFunc = func(context.Context, data.GetNoteVersionsParams) ([]data.NoteVersion, error) {
|
|
||||||
return nil, errors.New("db error")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
http.StatusInternalServerError,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
mockStore := &mockNoteStore{}
|
|
||||||
tc.mock(mockStore)
|
|
||||||
|
|
||||||
rs := notesResource{Notes: mockStore}
|
|
||||||
req := httptest.NewRequest("GET", fmt.Sprintf("/versions/%s", tc.query), nil)
|
|
||||||
req = req.WithContext(context.WithValue(
|
|
||||||
req.Context(),
|
|
||||||
noteCtxKey{},
|
|
||||||
data.Note{ID: noteID},
|
|
||||||
))
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
rs.ListNoteVersions(w, req)
|
|
||||||
|
|
||||||
assert.Equal(t, tc.statusCode, w.Code)
|
|
||||||
if tc.statusCode == http.StatusOK {
|
|
||||||
var result []data.NoteVersion
|
|
||||||
json.Unmarshal(w.Body.Bytes(), &result)
|
|
||||||
assert.Len(t, result, 2)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNotes_GetNoteVersion(t *testing.T) {
|
|
||||||
noteID := uuid.New()
|
|
||||||
version := data.NoteVersion{ID: uuid.New(), NoteID: noteID, VersionNumber: 1}
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
version string
|
|
||||||
mock func(*mockNoteStore)
|
|
||||||
statusCode int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
"invalid version",
|
|
||||||
"invalid",
|
|
||||||
func(m *mockNoteStore) {},
|
|
||||||
http.StatusBadRequest,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"version not found",
|
|
||||||
"1",
|
|
||||||
func(m *mockNoteStore) {
|
|
||||||
m.GetNoteVersionFunc = func(context.Context, data.GetNoteVersionParams) (data.NoteVersion, error) {
|
|
||||||
return data.NoteVersion{}, errors.New("not found")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
http.StatusNotFound,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"success",
|
|
||||||
"1",
|
|
||||||
func(m *mockNoteStore) {
|
|
||||||
m.GetNoteVersionFunc = func(_ context.Context, arg data.GetNoteVersionParams) (data.NoteVersion, error) {
|
|
||||||
assert.Equal(t, noteID, arg.NoteID)
|
|
||||||
assert.EqualValues(t, 1, arg.VersionNumber)
|
|
||||||
return version, nil
|
|
||||||
}
|
|
||||||
},
|
|
||||||
http.StatusOK,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
mockStore := &mockNoteStore{}
|
|
||||||
tc.mock(mockStore)
|
|
||||||
|
|
||||||
rs := notesResource{Notes: mockStore}
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", fmt.Sprintf("/versions/%s", tc.version), nil)
|
|
||||||
|
|
||||||
// Chi router context mocks ID (passed in a URL param.) and the note object (passed in req. ctx.)
|
|
||||||
rctx := chi.NewRouteContext()
|
|
||||||
rctx.URLParams.Add("version", tc.version)
|
|
||||||
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
|
|
||||||
req = req.WithContext(context.WithValue(
|
|
||||||
req.Context(),
|
|
||||||
noteCtxKey{},
|
|
||||||
data.Note{ID: noteID},
|
|
||||||
))
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
rs.GetNoteVersion(w, req)
|
|
||||||
|
|
||||||
assert.Equal(t, tc.statusCode, w.Code)
|
|
||||||
if tc.statusCode == http.StatusOK {
|
|
||||||
var result data.NoteVersion
|
|
||||||
json.Unmarshal(w.Body.Bytes(), &result)
|
|
||||||
assert.Equal(t, version.VersionNumber, result.VersionNumber)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,249 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/hex"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.umbrella.haus/ae/notatest/pkg/data"
|
|
||||||
"github.com/go-chi/chi/v5"
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
accessTokenDuration = 15 * time.Minute
|
|
||||||
refreshTokenDuration = 7 * 24 * time.Hour
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
ErrInvalidToken = errors.New("invalid token")
|
|
||||||
ErrAuthHeaderInvalid = errors.New("token couldn't be parsed from authentication header")
|
|
||||||
)
|
|
||||||
|
|
||||||
type userClaims struct {
|
|
||||||
Admin bool `json:"admin"`
|
|
||||||
TokenType string `json:"type"` // "access" or "refresh"
|
|
||||||
jwt.RegisteredClaims // User's UUID should be stored in the subject claim
|
|
||||||
}
|
|
||||||
|
|
||||||
type tokenPair struct {
|
|
||||||
AccessToken string `json:"access_token"`
|
|
||||||
RefreshToken string `json:"refresh_token"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mockable database operations interface
|
|
||||||
type TokenStore interface {
|
|
||||||
CreateRefreshToken(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error)
|
|
||||||
GetRefreshTokenByHash(ctx context.Context, tokenHash string) (data.RefreshToken, error)
|
|
||||||
RevokeRefreshToken(ctx context.Context, tokenHash string) error
|
|
||||||
RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type tokensResource struct {
|
|
||||||
JWTSecret string
|
|
||||||
Tokens TokenStore
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rs tokensResource) Routes() chi.Router {
|
|
||||||
r := chi.NewRouter()
|
|
||||||
|
|
||||||
// Protected routes (access token required)
|
|
||||||
r.Group(func(r chi.Router) {
|
|
||||||
r.Use(requireAccessToken(rs.JWTSecret))
|
|
||||||
r.Post("/logout", rs.HandleLogout) // POST /auth/logout - revoke all refresh cookies
|
|
||||||
})
|
|
||||||
|
|
||||||
// Protected routes (refresh token required)
|
|
||||||
r.Group(func(r chi.Router) {
|
|
||||||
r.Use(requireRefreshToken(rs.JWTSecret))
|
|
||||||
r.Post("/refresh", rs.RefreshAccessToken) // POST /auth/refresh - convert refresh token to new token pair
|
|
||||||
})
|
|
||||||
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rs tokensResource) GenerateTokenPair(ctx context.Context, userID uuid.UUID, isAdmin bool) (*tokenPair, error) {
|
|
||||||
tokenPair, err := generateTokenPair(userID.String(), isAdmin, rs.JWTSecret)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
hash := sha256.Sum256([]byte(tokenPair.RefreshToken))
|
|
||||||
tokenHash := hex.EncodeToString(hash[:])
|
|
||||||
|
|
||||||
// Store to DB with (almost) identical expiration timestamp
|
|
||||||
expiresAt := time.Now().Add(refreshTokenDuration)
|
|
||||||
_, err = rs.Tokens.CreateRefreshToken(ctx, data.CreateRefreshTokenParams{
|
|
||||||
UserID: userID,
|
|
||||||
TokenHash: tokenHash,
|
|
||||||
ExpiresAt: expiresAt,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return tokenPair, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rs tokensResource) RevokeRefreshToken(ctx context.Context, token string) error {
|
|
||||||
hash := sha256.Sum256([]byte(token))
|
|
||||||
tokenHash := hex.EncodeToString(hash[:])
|
|
||||||
return rs.Tokens.RevokeRefreshToken(ctx, tokenHash)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rs tokensResource) ValidateRefreshToken(ctx context.Context, token string) (*data.RefreshToken, error) {
|
|
||||||
hash := sha256.Sum256([]byte(token))
|
|
||||||
tokenHash := hex.EncodeToString(hash[:])
|
|
||||||
|
|
||||||
dbToken, err := rs.Tokens.GetRefreshTokenByHash(ctx, tokenHash)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if dbToken.Revoked || time.Now().After(dbToken.ExpiresAt) {
|
|
||||||
return nil, ErrInvalidToken
|
|
||||||
}
|
|
||||||
|
|
||||||
return &dbToken, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rs tokensResource) RefreshAccessToken(w http.ResponseWriter, r *http.Request) {
|
|
||||||
// Get claims from context
|
|
||||||
claims, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
|
||||||
if !ok || claims.TokenType != "refresh" {
|
|
||||||
respondError(w, http.StatusUnauthorized, "Invalid token")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Attempt to get the token from Authentication header ("Bearer <token>")
|
|
||||||
refreshToken, err := getTokenFromRequest(r)
|
|
||||||
if err != nil {
|
|
||||||
respondError(w, http.StatusUnauthorized, fmt.Sprintf("Unauthorized: %s", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate the refresh token in DB
|
|
||||||
if _, err := rs.ValidateRefreshToken(r.Context(), refreshToken); err != nil {
|
|
||||||
respondError(w, http.StatusUnauthorized, "Invalid refresh token")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Revoke the used refresh token
|
|
||||||
if err := rs.RevokeRefreshToken(r.Context(), refreshToken); err != nil {
|
|
||||||
respondError(w, http.StatusInternalServerError, "Failed to revoke token")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate a new pair (access & refresh tokens)
|
|
||||||
userID, err := uuid.Parse(claims.Subject)
|
|
||||||
if err != nil {
|
|
||||||
respondError(w, http.StatusInternalServerError, "Invalid user ID")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
tokenPair, err := rs.GenerateTokenPair(r.Context(), userID, claims.Admin)
|
|
||||||
if err != nil {
|
|
||||||
respondError(w, http.StatusInternalServerError, "Failed to generate tokens")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set refresh token in HTTP-only cookie
|
|
||||||
http.SetCookie(w, &http.Cookie{
|
|
||||||
Name: "refresh_token",
|
|
||||||
Value: tokenPair.RefreshToken,
|
|
||||||
Path: "/",
|
|
||||||
MaxAge: int(refreshTokenDuration.Seconds()),
|
|
||||||
HttpOnly: true,
|
|
||||||
Secure: true,
|
|
||||||
SameSite: http.SameSiteStrictMode,
|
|
||||||
})
|
|
||||||
|
|
||||||
// Return the access token in the response body (it should be stored in browser's memory client-side)
|
|
||||||
respondJSON(w, http.StatusOK, map[string]string{
|
|
||||||
"access_token": tokenPair.AccessToken,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rs tokensResource) HandleLogout(w http.ResponseWriter, r *http.Request) {
|
|
||||||
claims, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
|
||||||
if !ok {
|
|
||||||
respondError(w, http.StatusUnauthorized, "Not authenticated")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
userID, err := uuid.Parse(claims.ID)
|
|
||||||
if err != nil {
|
|
||||||
respondError(w, http.StatusInternalServerError, "Invalid user ID")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clear the refresh token cookie
|
|
||||||
http.SetCookie(w, &http.Cookie{
|
|
||||||
Name: "refresh_token",
|
|
||||||
Value: "",
|
|
||||||
Path: "/",
|
|
||||||
MaxAge: -1, // Expires immediately
|
|
||||||
HttpOnly: true,
|
|
||||||
Secure: true,
|
|
||||||
SameSite: http.SameSiteStrictMode,
|
|
||||||
})
|
|
||||||
|
|
||||||
if err := rs.Tokens.RevokeAllUserRefreshTokens(r.Context(), userID); err != nil {
|
|
||||||
respondError(w, http.StatusInternalServerError, "Failed to logout")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
w.WriteHeader(http.StatusNoContent)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getTokenFromRequest(r *http.Request) (string, error) {
|
|
||||||
bearerToken := r.Header.Get("Authorization")
|
|
||||||
bearerFields := strings.Fields(bearerToken)
|
|
||||||
if len(bearerFields) == 2 && strings.ToLower(bearerFields[0]) == "bearer" {
|
|
||||||
return bearerFields[1], nil
|
|
||||||
}
|
|
||||||
return "", ErrAuthHeaderInvalid
|
|
||||||
}
|
|
||||||
|
|
||||||
func generateTokenPair(userID string, isAdmin bool, jwtSecret string) (*tokenPair, error) {
|
|
||||||
atClaims := userClaims{
|
|
||||||
Admin: isAdmin,
|
|
||||||
TokenType: "access",
|
|
||||||
RegisteredClaims: jwt.RegisteredClaims{
|
|
||||||
Subject: userID,
|
|
||||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(accessTokenDuration)),
|
|
||||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
|
||||||
NotBefore: jwt.NewNumericDate(time.Now()),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, atClaims)
|
|
||||||
|
|
||||||
t, err := accessToken.SignedString([]byte(jwtSecret))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
rtClaims := userClaims{
|
|
||||||
Admin: isAdmin,
|
|
||||||
TokenType: "refresh",
|
|
||||||
RegisteredClaims: jwt.RegisteredClaims{
|
|
||||||
Subject: userID,
|
|
||||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(refreshTokenDuration)),
|
|
||||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
|
||||||
NotBefore: jwt.NewNumericDate(time.Now()),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
refreshToken := jwt.NewWithClaims(jwt.SigningMethodHS256, rtClaims)
|
|
||||||
|
|
||||||
rt, err := refreshToken.SignedString([]byte(jwtSecret))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &tokenPair{AccessToken: t, RefreshToken: rt}, nil
|
|
||||||
}
|
|
@ -1,211 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/hex"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.umbrella.haus/ae/notatest/pkg/data"
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
type mockTokenStore struct {
|
|
||||||
CreateRefreshTokenFunc func(context.Context, data.CreateRefreshTokenParams) (data.RefreshToken, error)
|
|
||||||
GetRefreshTokenByHashFunc func(context.Context, string) (data.RefreshToken, error)
|
|
||||||
RevokeRefreshTokenFunc func(context.Context, string) error
|
|
||||||
RevokeAllUserRefreshTokensFunc func(context.Context, uuid.UUID) error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockTokenStore) CreateRefreshToken(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) {
|
|
||||||
return m.CreateRefreshTokenFunc(ctx, arg)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockTokenStore) GetRefreshTokenByHash(ctx context.Context, tokenHash string) (data.RefreshToken, error) {
|
|
||||||
return m.GetRefreshTokenByHashFunc(ctx, tokenHash)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockTokenStore) RevokeRefreshToken(ctx context.Context, token string) error {
|
|
||||||
return m.RevokeRefreshTokenFunc(ctx, token)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockTokenStore) RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error {
|
|
||||||
return m.RevokeAllUserRefreshTokensFunc(ctx, id)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGenerateTokenPair_Success(t *testing.T) {
|
|
||||||
userID := uuid.New()
|
|
||||||
called := false
|
|
||||||
|
|
||||||
mockStore := &mockTokenStore{
|
|
||||||
CreateRefreshTokenFunc: func(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) {
|
|
||||||
called = true
|
|
||||||
assert.Equal(t, userID, arg.UserID)
|
|
||||||
return data.RefreshToken{}, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
rs := tokensResource{
|
|
||||||
JWTSecret: "test-secret",
|
|
||||||
Tokens: mockStore,
|
|
||||||
}
|
|
||||||
|
|
||||||
pair, err := rs.GenerateTokenPair(context.Background(), userID, false)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.True(t, called)
|
|
||||||
assert.NotEmpty(t, pair.AccessToken)
|
|
||||||
assert.NotEmpty(t, pair.RefreshToken)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGenerateTokenPair_DBError(t *testing.T) {
|
|
||||||
mockStore := &mockTokenStore{
|
|
||||||
CreateRefreshTokenFunc: func(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) {
|
|
||||||
return data.RefreshToken{}, errors.New("db error")
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
rs := tokensResource{Tokens: mockStore}
|
|
||||||
_, err := rs.GenerateTokenPair(context.Background(), uuid.New(), false)
|
|
||||||
assert.ErrorContains(t, err, "db error")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateRefreshToken_Valid(t *testing.T) {
|
|
||||||
token := "valid-jwt-token"
|
|
||||||
hash := sha256.Sum256([]byte(token))
|
|
||||||
expectedHash := hex.EncodeToString(hash[:])
|
|
||||||
|
|
||||||
mockStore := &mockTokenStore{
|
|
||||||
GetRefreshTokenByHashFunc: func(ctx context.Context, tokenHash string) (data.RefreshToken, error) {
|
|
||||||
assert.Equal(t, expectedHash, tokenHash)
|
|
||||||
return data.RefreshToken{ExpiresAt: time.Now().Add(1 * time.Hour)}, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
rs := tokensResource{Tokens: mockStore}
|
|
||||||
_, err := rs.ValidateRefreshToken(context.Background(), token)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRefreshAccessToken_Success(t *testing.T) {
|
|
||||||
userID := uuid.New()
|
|
||||||
refreshToken := "valid-jwt-token"
|
|
||||||
|
|
||||||
// Expected hash of the test token
|
|
||||||
hash := sha256.Sum256([]byte(refreshToken))
|
|
||||||
expectedHash := hex.EncodeToString(hash[:])
|
|
||||||
|
|
||||||
mockStore := &mockTokenStore{
|
|
||||||
GetRefreshTokenByHashFunc: func(ctx context.Context, tokenHash string) (data.RefreshToken, error) {
|
|
||||||
assert.Equal(t, expectedHash, tokenHash)
|
|
||||||
|
|
||||||
// Must return an unrevoked token with future expiration
|
|
||||||
return data.RefreshToken{
|
|
||||||
TokenHash: tokenHash,
|
|
||||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
|
||||||
Revoked: false,
|
|
||||||
}, nil
|
|
||||||
},
|
|
||||||
RevokeRefreshTokenFunc: func(ctx context.Context, tokenHash string) error {
|
|
||||||
assert.Equal(t, expectedHash, tokenHash)
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
CreateRefreshTokenFunc: func(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) {
|
|
||||||
return data.RefreshToken{}, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
rs := tokensResource{
|
|
||||||
JWTSecret: "test-secret",
|
|
||||||
Tokens: mockStore,
|
|
||||||
}
|
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/", nil)
|
|
||||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", refreshToken))
|
|
||||||
req = req.WithContext(context.WithValue(
|
|
||||||
req.Context(),
|
|
||||||
userCtxKey{},
|
|
||||||
&userClaims{
|
|
||||||
Admin: false,
|
|
||||||
TokenType: "refresh",
|
|
||||||
RegisteredClaims: jwt.RegisteredClaims{
|
|
||||||
Subject: userID.String(),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
))
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
rs.RefreshAccessToken(w, req)
|
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
|
||||||
assert.Contains(t, w.Body.String(), "access_token")
|
|
||||||
|
|
||||||
cookies := w.Result().Cookies()
|
|
||||||
var refreshCookie *http.Cookie
|
|
||||||
for _, cookie := range cookies {
|
|
||||||
if cookie.Name == "refresh_token" {
|
|
||||||
refreshCookie = cookie
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if refreshCookie == nil {
|
|
||||||
t.Fatal("refresh token cookie not set")
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.True(t, refreshCookie.HttpOnly, "cookie should be HttpOnly")
|
|
||||||
assert.Equal(t, http.SameSiteStrictMode, refreshCookie.SameSite, "invalid SameSite mode")
|
|
||||||
assert.Equal(t, "/", refreshCookie.Path, "invalid cookie path")
|
|
||||||
assert.Greater(t, refreshCookie.MaxAge, 0, "cookie should have expiration")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHandleLogout_Success(t *testing.T) {
|
|
||||||
userID := uuid.New()
|
|
||||||
called := false
|
|
||||||
|
|
||||||
mockStore := &mockTokenStore{
|
|
||||||
RevokeAllUserRefreshTokensFunc: func(ctx context.Context, id uuid.UUID) error {
|
|
||||||
called = true
|
|
||||||
assert.Equal(t, userID, id)
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
rs := tokensResource{Tokens: mockStore}
|
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/", nil)
|
|
||||||
req = req.WithContext(context.WithValue(
|
|
||||||
req.Context(),
|
|
||||||
userCtxKey{},
|
|
||||||
&userClaims{
|
|
||||||
RegisteredClaims: jwt.RegisteredClaims{
|
|
||||||
ID: userID.String(),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
))
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
rs.HandleLogout(w, req)
|
|
||||||
|
|
||||||
assert.True(t, called)
|
|
||||||
assert.Equal(t, http.StatusNoContent, w.Code) // 204
|
|
||||||
|
|
||||||
cookies := w.Result().Cookies()
|
|
||||||
var refreshCookie *http.Cookie
|
|
||||||
for _, cookie := range cookies {
|
|
||||||
if cookie.Name == "refresh_token" {
|
|
||||||
refreshCookie = cookie
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if refreshCookie != nil && refreshCookie.MaxAge != -1 {
|
|
||||||
t.Fatal("refresh token cookie not invalidated")
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,356 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"net/http"
|
|
||||||
"strconv"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.umbrella.haus/ae/notatest/pkg/data"
|
|
||||||
"github.com/go-chi/chi/v5"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/jackc/pgx/v5/pgconn"
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
"golang.org/x/crypto/bcrypt"
|
|
||||||
)
|
|
||||||
|
|
||||||
type userCtxKey struct{}
|
|
||||||
|
|
||||||
// Stripped object that only contains non-critical data
|
|
||||||
type userResponse struct {
|
|
||||||
ID uuid.UUID `json:"id"`
|
|
||||||
Username string `json:"username"`
|
|
||||||
IsAdmin bool `json:"is_admin"`
|
|
||||||
CreatedAt *time.Time `json:"created_at"`
|
|
||||||
UpdatedAt *time.Time `json:"updated_at"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mockable database operations interface
|
|
||||||
type UserStore interface {
|
|
||||||
CreateUser(ctx context.Context, arg data.CreateUserParams) (data.User, error)
|
|
||||||
ListUsers(ctx context.Context) ([]data.User, error)
|
|
||||||
GetUserByID(ctx context.Context, id uuid.UUID) (data.User, error)
|
|
||||||
GetUserByUsername(ctx context.Context, username string) (data.User, error)
|
|
||||||
UpdatePassword(ctx context.Context, arg data.UpdatePasswordParams) error
|
|
||||||
DeleteUser(ctx context.Context, id uuid.UUID) error
|
|
||||||
RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type usersResource struct {
|
|
||||||
JWTSecret string
|
|
||||||
Users UserStore
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rs usersResource) Routes() chi.Router {
|
|
||||||
r := chi.NewRouter()
|
|
||||||
|
|
||||||
// Public routes (no tokens required)
|
|
||||||
r.Post("/", rs.Create) // POST /users - registration/signup
|
|
||||||
r.Post("/login", rs.Login) // POST /users/login - login as existing user
|
|
||||||
|
|
||||||
// Protected routes (access token required)
|
|
||||||
r.Group(func(r chi.Router) {
|
|
||||||
r.Use(requireAccessToken(rs.JWTSecret))
|
|
||||||
|
|
||||||
r.Get("/me", rs.Get) // GET /users/me - get current user data
|
|
||||||
|
|
||||||
// Admin only general routes
|
|
||||||
r.Group(func(r chi.Router) {
|
|
||||||
r.Use(adminOnlyMiddleware)
|
|
||||||
r.Get("/", rs.List) // GET /users - list all users
|
|
||||||
})
|
|
||||||
|
|
||||||
// User specific routes
|
|
||||||
r.Route("/{id}", func(r chi.Router) {
|
|
||||||
r.Use(userCtx(rs.Users)) // DB -> req. context
|
|
||||||
|
|
||||||
// Admin routes
|
|
||||||
r.Route("/admin", func(r chi.Router) {
|
|
||||||
r.Use(adminOnlyMiddleware)
|
|
||||||
r.Delete("/", rs.AdminDelete) // DELETE /users/admin/{id} - delete user
|
|
||||||
})
|
|
||||||
|
|
||||||
// Owner routes
|
|
||||||
r.Route("/owner", func(r chi.Router) {
|
|
||||||
r.Use(ownerOnlyMiddleware)
|
|
||||||
r.Put("/", rs.UpdatePassword) // PUT /users/owner/{id} - update user password
|
|
||||||
r.Delete("/", rs.OwnerDelete) // DELETE /users/owner/{id} - delete user
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rs usersResource) Create(w http.ResponseWriter, r *http.Request) {
|
|
||||||
type request struct {
|
|
||||||
Username string `json:"username"`
|
|
||||||
Password string `json:"password"`
|
|
||||||
}
|
|
||||||
|
|
||||||
var req request
|
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
||||||
respondError(w, http.StatusBadRequest, "Invalid request body")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
normalizedUsername := normalizeUsername(req.Username)
|
|
||||||
if err := validateUsername(normalizedUsername); err != nil {
|
|
||||||
respondError(w, http.StatusBadRequest, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := validatePassword(req.Password); err != nil {
|
|
||||||
respondError(w, http.StatusBadRequest, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
|
|
||||||
if err != nil {
|
|
||||||
respondError(w, http.StatusInternalServerError, "Failed to create user")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
user, err := rs.Users.CreateUser(r.Context(), data.CreateUserParams{
|
|
||||||
Username: normalizedUsername,
|
|
||||||
PasswordHash: string(hashedPassword),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
if isDuplicateEntry(err) {
|
|
||||||
respondError(w, http.StatusConflict, "Username is already in use")
|
|
||||||
} else {
|
|
||||||
respondError(w, http.StatusInternalServerError, "Failed to create user")
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
respondJSON(w, http.StatusCreated, map[string]string{
|
|
||||||
"id": user.ID.String(),
|
|
||||||
"username": user.Username,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rs usersResource) Login(w http.ResponseWriter, r *http.Request) {
|
|
||||||
type request struct {
|
|
||||||
Username string `json:"username"`
|
|
||||||
Password string `json:"password"`
|
|
||||||
}
|
|
||||||
|
|
||||||
var req request
|
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
||||||
respondError(w, http.StatusBadRequest, "Invalid request body")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
user, err := rs.Users.GetUserByUsername(r.Context(), normalizeUsername(req.Username))
|
|
||||||
if err != nil {
|
|
||||||
respondError(w, http.StatusUnauthorized, "Invalid credentials")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password)); err != nil {
|
|
||||||
respondError(w, http.StatusUnauthorized, "Invalid credentials")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
tokenPair, err := generateTokenPair(user.ID.String(), user.IsAdmin, rs.JWTSecret)
|
|
||||||
if err != nil {
|
|
||||||
respondError(w, http.StatusInternalServerError, "Failed to generate tokens")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set refresh token in HTTP-only cookie
|
|
||||||
http.SetCookie(w, &http.Cookie{
|
|
||||||
Name: "refresh_token",
|
|
||||||
Value: tokenPair.RefreshToken,
|
|
||||||
Path: "/",
|
|
||||||
MaxAge: int(refreshTokenDuration.Seconds()),
|
|
||||||
HttpOnly: true,
|
|
||||||
Secure: true,
|
|
||||||
SameSite: http.SameSiteStrictMode,
|
|
||||||
})
|
|
||||||
|
|
||||||
// Build response
|
|
||||||
response := map[string]any{
|
|
||||||
"access_token": tokenPair.AccessToken,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Include user data if the client has requested it (`?includeUser=true`)
|
|
||||||
if includeUser, _ := strconv.ParseBool(r.URL.Query().Get("includeUser")); includeUser {
|
|
||||||
response["user"] = userResponse{
|
|
||||||
ID: user.ID,
|
|
||||||
Username: user.Username,
|
|
||||||
IsAdmin: user.IsAdmin,
|
|
||||||
CreatedAt: user.CreatedAt,
|
|
||||||
UpdatedAt: user.UpdatedAt,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
respondJSON(w, http.StatusOK, response)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rs usersResource) List(w http.ResponseWriter, r *http.Request) {
|
|
||||||
users, err := rs.Users.ListUsers(r.Context())
|
|
||||||
if err != nil {
|
|
||||||
respondError(w, http.StatusInternalServerError, "Failed to retrieve users")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Output sanitization
|
|
||||||
var output []map[string]any
|
|
||||||
for _, user := range users {
|
|
||||||
output = append(output, map[string]any{
|
|
||||||
"id": user.ID,
|
|
||||||
"username": user.Username,
|
|
||||||
"created_at": user.CreatedAt,
|
|
||||||
"updated_at": user.UpdatedAt,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
respondJSON(w, http.StatusOK, output)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rs usersResource) Get(w http.ResponseWriter, r *http.Request) {
|
|
||||||
claims, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
|
||||||
if !ok {
|
|
||||||
respondError(w, http.StatusUnauthorized, "Unauthorized")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
userID, err := uuid.Parse(claims.Subject)
|
|
||||||
if err != nil {
|
|
||||||
respondError(w, http.StatusInternalServerError, "Invalid user ID")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
user, err := rs.Users.GetUserByID(r.Context(), userID)
|
|
||||||
if err != nil {
|
|
||||||
respondError(w, http.StatusNotFound, "User not found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
respondJSON(w, http.StatusOK, userResponse{
|
|
||||||
ID: user.ID,
|
|
||||||
Username: user.Username,
|
|
||||||
IsAdmin: user.IsAdmin,
|
|
||||||
CreatedAt: user.CreatedAt,
|
|
||||||
UpdatedAt: user.UpdatedAt,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rs usersResource) UpdatePassword(w http.ResponseWriter, r *http.Request) {
|
|
||||||
type request struct {
|
|
||||||
OldPassword string `json:"old_password"`
|
|
||||||
NewPassword string `json:"new_password"`
|
|
||||||
}
|
|
||||||
|
|
||||||
var req request
|
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
||||||
respondError(w, http.StatusBadRequest, "Invalid request body")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
user, ok := r.Context().Value(userCtxKey{}).(data.User)
|
|
||||||
if !ok {
|
|
||||||
respondError(w, http.StatusNotFound, "User not found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify the old password before allowing the update
|
|
||||||
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.OldPassword)); err != nil {
|
|
||||||
respondError(w, http.StatusUnauthorized, "Invalid credentials")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := validatePassword(req.NewPassword); err != nil {
|
|
||||||
respondError(w, http.StatusBadRequest, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.NewPassword), bcrypt.DefaultCost)
|
|
||||||
if err != nil {
|
|
||||||
respondError(w, http.StatusInternalServerError, "Failed to update password")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := rs.Users.UpdatePassword(r.Context(), data.UpdatePasswordParams{
|
|
||||||
ID: user.ID,
|
|
||||||
PasswordHash: string(hashedPassword),
|
|
||||||
}); err != nil {
|
|
||||||
respondError(w, http.StatusInternalServerError, "Failed to update password")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := rs.Users.RevokeAllUserRefreshTokens(r.Context(), user.ID); err != nil {
|
|
||||||
log.Error().Msgf("Failed to revoke refresh tokens: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
w.WriteHeader(http.StatusNoContent)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rs usersResource) AdminDelete(w http.ResponseWriter, r *http.Request) {
|
|
||||||
user, ok := r.Context().Value(userCtxKey{}).(data.User)
|
|
||||||
if !ok {
|
|
||||||
respondError(w, http.StatusNotFound, "User not found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := rs.Users.DeleteUser(r.Context(), user.ID); err != nil {
|
|
||||||
respondError(w, http.StatusInternalServerError, "Failed to delete user")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := rs.Users.RevokeAllUserRefreshTokens(r.Context(), user.ID); err != nil {
|
|
||||||
log.Error().Msgf("Failed to revoke refresh tokens: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
w.WriteHeader(http.StatusNoContent)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rs usersResource) OwnerDelete(w http.ResponseWriter, r *http.Request) {
|
|
||||||
type request struct {
|
|
||||||
Password string `json:"password"`
|
|
||||||
}
|
|
||||||
|
|
||||||
var req request
|
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
||||||
respondError(w, http.StatusBadRequest, "Invalid request body")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
user, ok := r.Context().Value(userCtxKey{}).(data.User)
|
|
||||||
if !ok {
|
|
||||||
respondError(w, http.StatusNotFound, "User not found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify the old password before allowing the deletion
|
|
||||||
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password)); err != nil {
|
|
||||||
respondError(w, http.StatusUnauthorized, "Invalid credentials")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err := rs.Users.DeleteUser(r.Context(), user.ID)
|
|
||||||
if err != nil {
|
|
||||||
respondError(w, http.StatusInternalServerError, "Failed to delete user")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := rs.Users.RevokeAllUserRefreshTokens(r.Context(), user.ID); err != nil {
|
|
||||||
log.Error().Msgf("Failed to revoke refresh tokens: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
w.WriteHeader(http.StatusNoContent)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if the given error is a PostgreSQL error for `unique_violation`, i.e. whether an entry
|
|
||||||
// with the given details already exists in the database table.
|
|
||||||
func isDuplicateEntry(err error) bool {
|
|
||||||
var pgErr *pgconn.PgError
|
|
||||||
if errors.As(err, &pgErr) {
|
|
||||||
return pgErr.Code == "23505"
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
@ -1,557 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"net/url"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.umbrella.haus/ae/notatest/pkg/data"
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/jackc/pgx/v5/pgconn"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"golang.org/x/crypto/bcrypt"
|
|
||||||
)
|
|
||||||
|
|
||||||
type mockUserStore struct {
|
|
||||||
CreateUserFunc func(context.Context, data.CreateUserParams) (data.User, error)
|
|
||||||
ListUsersFunc func(context.Context) ([]data.User, error)
|
|
||||||
GetUserByIDFunc func(context.Context, uuid.UUID) (data.User, error)
|
|
||||||
GetUserByUsernameFunc func(context.Context, string) (data.User, error)
|
|
||||||
UpdatePasswordFunc func(context.Context, data.UpdatePasswordParams) error
|
|
||||||
DeleteUserFunc func(context.Context, uuid.UUID) error
|
|
||||||
RevokeAllUserRefreshTokensFunc func(context.Context, uuid.UUID) error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockUserStore) CreateUser(ctx context.Context, arg data.CreateUserParams) (data.User, error) {
|
|
||||||
return m.CreateUserFunc(ctx, arg)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockUserStore) ListUsers(ctx context.Context) ([]data.User, error) {
|
|
||||||
return m.ListUsersFunc(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockUserStore) GetUserByID(ctx context.Context, id uuid.UUID) (data.User, error) {
|
|
||||||
return m.GetUserByIDFunc(ctx, id)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockUserStore) GetUserByUsername(ctx context.Context, username string) (data.User, error) {
|
|
||||||
return m.GetUserByUsernameFunc(ctx, username)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockUserStore) UpdatePassword(ctx context.Context, arg data.UpdatePasswordParams) error {
|
|
||||||
return m.UpdatePasswordFunc(ctx, arg)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockUserStore) DeleteUser(ctx context.Context, id uuid.UUID) error {
|
|
||||||
return m.DeleteUserFunc(ctx, id)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockUserStore) RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error {
|
|
||||||
return m.RevokeAllUserRefreshTokensFunc(ctx, id)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCreateUser_Duplicate(t *testing.T) {
|
|
||||||
mockStore := &mockUserStore{
|
|
||||||
CreateUserFunc: func(ctx context.Context, arg data.CreateUserParams) (data.User, error) {
|
|
||||||
return data.User{}, &pgconn.PgError{Code: "23505"}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
rs := usersResource{Users: mockStore, JWTSecret: "test-secret"}
|
|
||||||
|
|
||||||
reqBody := `{"username": "existing", "password": "validPass123!"}`
|
|
||||||
req := httptest.NewRequest("POST", "/", strings.NewReader(reqBody))
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
|
|
||||||
rs.Create(w, req)
|
|
||||||
|
|
||||||
assert.Equal(t, http.StatusConflict, w.Code)
|
|
||||||
assert.Contains(t, w.Body.String(), "Username is already in use")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCreateUser_InvalidUsername(t *testing.T) {
|
|
||||||
mockStore := &mockUserStore{} // No DB calls expected
|
|
||||||
rs := usersResource{Users: mockStore, JWTSecret: "test-secret"}
|
|
||||||
|
|
||||||
// Test various invalid usernames
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
body string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
"too short",
|
|
||||||
`{"username": "a", "password": "validPass123!"}`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"invalid chars",
|
|
||||||
`{"username": "user@name", "password": "validPass123!"}`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"empty",
|
|
||||||
`{"username": "", "password": "validPass123!"}`,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
req := httptest.NewRequest("POST", "/", strings.NewReader(tc.body))
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
rs.Create(w, req)
|
|
||||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCreateUser_InvalidPassword(t *testing.T) {
|
|
||||||
mockStore := &mockUserStore{}
|
|
||||||
rs := usersResource{Users: mockStore, JWTSecret: "test-secret"}
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
body string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
"too short",
|
|
||||||
fmt.Sprintf(`{"username": "valid", "password": "%s"}`, strings.Repeat("a", minPasswordLength-1)),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"too long",
|
|
||||||
fmt.Sprintf(`{"username": "valid", "password": "%s"}`, strings.Repeat("a", maxPasswordLength+1)),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"low entropy",
|
|
||||||
fmt.Sprintf(`{"username": "valid", "password": "%s"}`, strings.Repeat("a", minPasswordLength)),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
req := httptest.NewRequest("POST", "/", strings.NewReader(tc.body))
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
rs.Create(w, req)
|
|
||||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestListUsers_Success(t *testing.T) {
|
|
||||||
testUsers := []data.User{
|
|
||||||
{ID: uuid.New(), Username: "user1"},
|
|
||||||
{ID: uuid.New(), Username: "user2"},
|
|
||||||
}
|
|
||||||
|
|
||||||
mockStore := &mockUserStore{
|
|
||||||
ListUsersFunc: func(ctx context.Context) ([]data.User, error) {
|
|
||||||
return testUsers, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
rs := usersResource{Users: mockStore}
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
|
|
||||||
rs.List(w, req)
|
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
|
||||||
|
|
||||||
var response []map[string]any
|
|
||||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Len(t, response, 2)
|
|
||||||
assert.Equal(t, "user1", response[0]["username"])
|
|
||||||
assert.NotContains(t, response[0], "password_hash")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdatePassword_InvalidOldPassword(t *testing.T) {
|
|
||||||
// User with password hash that won't match "wrongpassword"
|
|
||||||
user := data.User{
|
|
||||||
ID: uuid.New(),
|
|
||||||
PasswordHash: "$2a$10$PHhno.bZBF8IEINdFRZAPujMxIN65msElATgJG6FIxZdeWYVLSfFi", // Hash of "correctpassword"
|
|
||||||
}
|
|
||||||
|
|
||||||
mockStore := &mockUserStore{}
|
|
||||||
rs := usersResource{Users: mockStore}
|
|
||||||
|
|
||||||
reqBody := `{"old_password": "wrongpassword", "new_password": "NewValidPass321!"}`
|
|
||||||
req := httptest.NewRequest("PUT", "/", strings.NewReader(reqBody))
|
|
||||||
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, user))
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
|
|
||||||
rs.UpdatePassword(w, req)
|
|
||||||
|
|
||||||
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAdminDelete_Success(t *testing.T) {
|
|
||||||
user := data.User{ID: uuid.New()}
|
|
||||||
deleteCalled := false
|
|
||||||
revokeCalled := false
|
|
||||||
|
|
||||||
mockStore := &mockUserStore{
|
|
||||||
DeleteUserFunc: func(ctx context.Context, id uuid.UUID) error {
|
|
||||||
deleteCalled = true
|
|
||||||
assert.Equal(t, user.ID, id)
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
RevokeAllUserRefreshTokensFunc: func(ctx context.Context, id uuid.UUID) error {
|
|
||||||
revokeCalled = true
|
|
||||||
assert.Equal(t, user.ID, id)
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
rs := usersResource{Users: mockStore}
|
|
||||||
req := httptest.NewRequest("DELETE", "/", nil)
|
|
||||||
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, user))
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
|
|
||||||
rs.AdminDelete(w, req)
|
|
||||||
|
|
||||||
assert.Equal(t, http.StatusNoContent, w.Code)
|
|
||||||
assert.True(t, deleteCalled)
|
|
||||||
assert.True(t, revokeCalled)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOwnerDelete_InvalidCredentials(t *testing.T) {
|
|
||||||
// Create user with known password hash
|
|
||||||
correctPassword := "CorrectPass123!"
|
|
||||||
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(correctPassword), bcrypt.DefaultCost)
|
|
||||||
user := data.User{
|
|
||||||
ID: uuid.New(),
|
|
||||||
PasswordHash: string(hashedPassword),
|
|
||||||
}
|
|
||||||
|
|
||||||
mockStore := &mockUserStore{}
|
|
||||||
rs := usersResource{Users: mockStore}
|
|
||||||
|
|
||||||
reqBody := `{"password": "wrongpassword"}`
|
|
||||||
req := httptest.NewRequest("DELETE", "/", strings.NewReader(reqBody))
|
|
||||||
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, user))
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
|
|
||||||
rs.OwnerDelete(w, req)
|
|
||||||
|
|
||||||
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUsersGetCurrentUser(t *testing.T) {
|
|
||||||
validUserID := uuid.New()
|
|
||||||
testTime := time.Now().UTC().Truncate(time.Second)
|
|
||||||
testUser := data.User{
|
|
||||||
ID: validUserID,
|
|
||||||
Username: "testuser",
|
|
||||||
CreatedAt: &testTime,
|
|
||||||
UpdatedAt: &testTime,
|
|
||||||
IsAdmin: false,
|
|
||||||
}
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
setupContext func(context.Context) context.Context
|
|
||||||
mockSetup func(*mockUserStore)
|
|
||||||
wantStatus int
|
|
||||||
wantResponse string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "success",
|
|
||||||
setupContext: func(ctx context.Context) context.Context {
|
|
||||||
return context.WithValue(ctx, userCtxKey{}, &userClaims{
|
|
||||||
RegisteredClaims: jwt.RegisteredClaims{
|
|
||||||
Subject: validUserID.String(),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
},
|
|
||||||
mockSetup: func(m *mockUserStore) {
|
|
||||||
m.GetUserByIDFunc = func(_ context.Context, id uuid.UUID) (data.User, error) {
|
|
||||||
assert.Equal(t, validUserID, id)
|
|
||||||
return testUser, nil
|
|
||||||
}
|
|
||||||
},
|
|
||||||
wantStatus: http.StatusOK,
|
|
||||||
wantResponse: fmt.Sprintf(
|
|
||||||
`{"created_at":"%s","id":"%s","is_admin":false,"updated_at":"%s","username":"testuser"}`,
|
|
||||||
testUser.CreatedAt.Format(time.RFC3339Nano),
|
|
||||||
validUserID.String(),
|
|
||||||
testUser.UpdatedAt.Format(time.RFC3339Nano),
|
|
||||||
),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "user not found",
|
|
||||||
setupContext: func(ctx context.Context) context.Context {
|
|
||||||
return context.WithValue(ctx, userCtxKey{}, &userClaims{
|
|
||||||
RegisteredClaims: jwt.RegisteredClaims{
|
|
||||||
Subject: validUserID.String(),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
},
|
|
||||||
mockSetup: func(m *mockUserStore) {
|
|
||||||
m.GetUserByIDFunc = func(_ context.Context, id uuid.UUID) (data.User, error) {
|
|
||||||
return data.User{}, errors.New("not found")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
wantStatus: http.StatusNotFound,
|
|
||||||
wantResponse: `{"error":"User not found"}`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "unauthorized",
|
|
||||||
setupContext: func(ctx context.Context) context.Context {
|
|
||||||
return ctx // No user claims in context
|
|
||||||
},
|
|
||||||
mockSetup: func(m *mockUserStore) {},
|
|
||||||
wantStatus: http.StatusUnauthorized,
|
|
||||||
wantResponse: `{"error":"Unauthorized"}`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid user ID",
|
|
||||||
setupContext: func(ctx context.Context) context.Context {
|
|
||||||
return context.WithValue(ctx, userCtxKey{}, &userClaims{
|
|
||||||
RegisteredClaims: jwt.RegisteredClaims{
|
|
||||||
Subject: "invalid",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
},
|
|
||||||
mockSetup: func(m *mockUserStore) {},
|
|
||||||
wantStatus: http.StatusInternalServerError,
|
|
||||||
wantResponse: `{"error":"Invalid user ID"}`,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
mockStore := &mockUserStore{}
|
|
||||||
tt.mockSetup(mockStore)
|
|
||||||
|
|
||||||
rs := usersResource{Users: mockStore}
|
|
||||||
req := httptest.NewRequest("GET", "/me", nil)
|
|
||||||
req = req.WithContext(tt.setupContext(req.Context()))
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
rs.Get(w, req)
|
|
||||||
|
|
||||||
assert.Equal(t, tt.wantStatus, w.Code)
|
|
||||||
|
|
||||||
if tt.wantResponse != "" {
|
|
||||||
actual := strings.TrimSpace(w.Body.String())
|
|
||||||
assert.JSONEq(t, tt.wantResponse, actual)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify sensitive fields are never exposed
|
|
||||||
if w.Code == http.StatusOK {
|
|
||||||
var response map[string]any
|
|
||||||
json.Unmarshal(w.Body.Bytes(), &response)
|
|
||||||
_, exists := response["password_hash"]
|
|
||||||
assert.False(t, exists, "password_hash should not be exposed")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdatePassword_DatabaseError(t *testing.T) {
|
|
||||||
// Add user with a valid password to the context
|
|
||||||
oldPassword := "OldValidPass321!"
|
|
||||||
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(oldPassword), bcrypt.DefaultCost)
|
|
||||||
user := data.User{
|
|
||||||
ID: uuid.New(),
|
|
||||||
PasswordHash: string(hashedPassword),
|
|
||||||
}
|
|
||||||
mockStore := &mockUserStore{
|
|
||||||
UpdatePasswordFunc: func(ctx context.Context, arg data.UpdatePasswordParams) error {
|
|
||||||
return errors.New("database error")
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
rs := usersResource{Users: mockStore}
|
|
||||||
|
|
||||||
reqBody := fmt.Sprintf(`{"old_password": "%s", "new_password": "NewValidPass123!"}`, oldPassword)
|
|
||||||
req := httptest.NewRequest("PUT", "/", strings.NewReader(reqBody))
|
|
||||||
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, user))
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
|
|
||||||
rs.UpdatePassword(w, req)
|
|
||||||
|
|
||||||
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
|
||||||
assert.Contains(t, w.Body.String(), "Failed to update password")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUsersLogin(t *testing.T) {
|
|
||||||
validPassword := "validPass123!"
|
|
||||||
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(validPassword), bcrypt.DefaultCost)
|
|
||||||
testUser := data.User{
|
|
||||||
ID: uuid.New(),
|
|
||||||
Username: "test_username",
|
|
||||||
PasswordHash: string(hashedPassword),
|
|
||||||
}
|
|
||||||
jwtSecret := "test-secret"
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
includeUser string
|
|
||||||
wantUserData bool
|
|
||||||
requestBody any
|
|
||||||
mockSetup func(*mockUserStore)
|
|
||||||
wantStatus int
|
|
||||||
wantResponse string
|
|
||||||
checkCookie bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "invalid request body",
|
|
||||||
requestBody: "invalid",
|
|
||||||
mockSetup: func(m *mockUserStore) {},
|
|
||||||
wantStatus: http.StatusBadRequest,
|
|
||||||
wantResponse: `{"error":"Invalid request body"}`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "user not found",
|
|
||||||
requestBody: map[string]string{
|
|
||||||
"username": "nouser",
|
|
||||||
"password": validPassword,
|
|
||||||
},
|
|
||||||
mockSetup: func(m *mockUserStore) {
|
|
||||||
m.GetUserByUsernameFunc = func(_ context.Context, username string) (data.User, error) {
|
|
||||||
return data.User{}, errors.New("not found")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
wantStatus: http.StatusUnauthorized,
|
|
||||||
wantResponse: `{"error":"Invalid credentials"}`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid password",
|
|
||||||
requestBody: map[string]string{
|
|
||||||
"username": testUser.Username,
|
|
||||||
"password": "wrongpassword",
|
|
||||||
},
|
|
||||||
mockSetup: func(m *mockUserStore) {
|
|
||||||
m.GetUserByUsernameFunc = func(_ context.Context, username string) (data.User, error) {
|
|
||||||
return testUser, nil
|
|
||||||
}
|
|
||||||
},
|
|
||||||
wantStatus: http.StatusUnauthorized,
|
|
||||||
wantResponse: `{"error":"Invalid credentials"}`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "successful login with user data",
|
|
||||||
includeUser: "true",
|
|
||||||
wantUserData: true,
|
|
||||||
requestBody: map[string]string{
|
|
||||||
"username": testUser.Username,
|
|
||||||
"password": validPassword,
|
|
||||||
},
|
|
||||||
mockSetup: func(m *mockUserStore) {
|
|
||||||
m.GetUserByUsernameFunc = func(_ context.Context, username string) (data.User, error) {
|
|
||||||
return testUser, nil
|
|
||||||
}
|
|
||||||
},
|
|
||||||
wantStatus: http.StatusOK,
|
|
||||||
checkCookie: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "successful login without user data",
|
|
||||||
includeUser: "false",
|
|
||||||
wantUserData: false,
|
|
||||||
requestBody: map[string]string{
|
|
||||||
"username": testUser.Username,
|
|
||||||
"password": validPassword,
|
|
||||||
},
|
|
||||||
mockSetup: func(m *mockUserStore) {
|
|
||||||
m.GetUserByUsernameFunc = func(_ context.Context, username string) (data.User, error) {
|
|
||||||
return testUser, nil
|
|
||||||
}
|
|
||||||
},
|
|
||||||
wantStatus: http.StatusOK,
|
|
||||||
checkCookie: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
mockStore := &mockUserStore{}
|
|
||||||
tt.mockSetup(mockStore)
|
|
||||||
|
|
||||||
rs := usersResource{
|
|
||||||
Users: mockStore,
|
|
||||||
JWTSecret: jwtSecret,
|
|
||||||
}
|
|
||||||
|
|
||||||
body, _ := json.Marshal(tt.requestBody)
|
|
||||||
req := httptest.NewRequest("POST", "/login", bytes.NewReader(body))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
|
|
||||||
// Add the necessary query parameters
|
|
||||||
q := url.Values{}
|
|
||||||
q.Add("includeUser", tt.includeUser)
|
|
||||||
req.URL.RawQuery = q.Encode()
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
rs.Login(w, req)
|
|
||||||
|
|
||||||
if w.Code != tt.wantStatus {
|
|
||||||
t.Errorf("expected status %d, got %d", tt.wantStatus, w.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
if tt.wantResponse != "" && strings.TrimSpace(w.Body.String()) != tt.wantResponse {
|
|
||||||
t.Errorf("expected response %q, got %q", tt.wantResponse, w.Body.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
if tt.wantUserData {
|
|
||||||
var response struct {
|
|
||||||
AccessToken string `json:"access_token"`
|
|
||||||
User data.User `json:"user"` // Cast to the "raw" type to allow checking for sensitive data fields
|
|
||||||
}
|
|
||||||
json.Unmarshal(w.Body.Bytes(), &response)
|
|
||||||
|
|
||||||
assert.Equal(t, testUser.ID, response.User.ID)
|
|
||||||
assert.Equal(t, testUser.Username, response.User.Username)
|
|
||||||
assert.Empty(t, response.User.PasswordHash) // Ensure sensitive data excluded
|
|
||||||
}
|
|
||||||
|
|
||||||
if tt.checkCookie {
|
|
||||||
cookies := w.Result().Cookies()
|
|
||||||
var refreshCookie *http.Cookie
|
|
||||||
for _, cookie := range cookies {
|
|
||||||
if cookie.Name == "refresh_token" {
|
|
||||||
refreshCookie = cookie
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if refreshCookie == nil {
|
|
||||||
t.Fatal("refresh token cookie not set")
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.True(t, refreshCookie.HttpOnly, "cookie should be HttpOnly")
|
|
||||||
assert.Equal(t, http.SameSiteStrictMode, refreshCookie.SameSite, "invalid SameSite mode")
|
|
||||||
assert.Equal(t, "/", refreshCookie.Path, "invalid cookie path")
|
|
||||||
assert.Greater(t, refreshCookie.MaxAge, 0, "cookie should have expiration")
|
|
||||||
|
|
||||||
// Validate access token in response
|
|
||||||
var response map[string]string
|
|
||||||
json.Unmarshal(w.Body.Bytes(), &response)
|
|
||||||
if response["access_token"] == "" {
|
|
||||||
t.Error("access token not in response")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify JWT validity
|
|
||||||
token, err := jwt.ParseWithClaims(
|
|
||||||
response["access_token"],
|
|
||||||
&userClaims{},
|
|
||||||
func(token *jwt.Token) (any, error) {
|
|
||||||
return []byte(jwtSecret), nil
|
|
||||||
},
|
|
||||||
)
|
|
||||||
assert.NoError(t, err, "invalid JWT")
|
|
||||||
assert.True(t, token.Valid, "invalid JWT")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,10 +1,5 @@
|
|||||||
CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
|
CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS schema_migrations (
|
|
||||||
version BIGINT PRIMARY KEY,
|
|
||||||
applied_at TIMESTAMPTZ DEFAULT NOW()
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS users (
|
CREATE TABLE IF NOT EXISTS users (
|
||||||
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
||||||
username TEXT UNIQUE NOT NULL,
|
username TEXT UNIQUE NOT NULL,
|
||||||
@ -26,6 +21,8 @@ CREATE TABLE IF NOT EXISTS refresh_tokens (
|
|||||||
CREATE TABLE IF NOT EXISTS notes (
|
CREATE TABLE IF NOT EXISTS notes (
|
||||||
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
||||||
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||||
|
current_version INT NOT NULL DEFAULT 1, -- active version (can be historical)
|
||||||
|
latest_version INT NOT NULL DEFAULT 1, -- highest version number
|
||||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||||
updated_at TIMESTAMPTZ DEFAULT NOW()
|
updated_at TIMESTAMPTZ DEFAULT NOW()
|
||||||
);
|
);
|
||||||
@ -35,15 +32,18 @@ CREATE TABLE IF NOT EXISTS note_versions (
|
|||||||
note_id UUID NOT NULL REFERENCES notes(id) ON DELETE CASCADE,
|
note_id UUID NOT NULL REFERENCES notes(id) ON DELETE CASCADE,
|
||||||
title TEXT NOT NULL,
|
title TEXT NOT NULL,
|
||||||
content TEXT NOT NULL,
|
content TEXT NOT NULL,
|
||||||
version_number INT NOT NULL,
|
version_number INT NOT NULL DEFAULT 1,
|
||||||
content_hash TEXT NOT NULL,
|
content_hash TEXT NOT NULL,
|
||||||
created_at TIMESTAMPTZ DEFAULT NOW()
|
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||||
|
UNIQUE (note_id, version_number)
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_users_username ON users(username);
|
CREATE UNIQUE INDEX IF NOT EXISTS idx_users_username ON users(username);
|
||||||
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_refresh_tokens_user_id ON refresh_tokens(user_id);
|
CREATE INDEX IF NOT EXISTS idx_refresh_tokens_user_id ON refresh_tokens(user_id);
|
||||||
CREATE INDEX IF NOT EXISTS idx_refresh_tokens_expires_at ON refresh_tokens(expires_at);
|
CREATE INDEX IF NOT EXISTS idx_refresh_tokens_expires_at ON refresh_tokens(expires_at);
|
||||||
|
|
||||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_note_version_unique ON note_versions(note_id, version_number);
|
CREATE INDEX IF NOT EXISTS idx_notes_user_updated ON notes(user_id, updated_at DESC);
|
||||||
CREATE INDEX IF NOT EXISTS idx_note_versions_note ON note_versions(note_id);
|
CREATE INDEX IF NOT EXISTS idx_notes_current_version ON notes(current_version);
|
||||||
CREATE INDEX IF NOT EXISTS idx_note_versions_number ON note_versions(version_number DESC);
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_note_versions_content_hash ON note_versions(note_id, content_hash);
|
||||||
|
@ -1,27 +1,58 @@
|
|||||||
-- name: CreateNoteVersion :one
|
-- name: CreateNoteVersion :exec
|
||||||
INSERT INTO note_versions (note_id, title, content, version_number, content_hash)
|
WITH potential_duplicate AS (
|
||||||
VALUES (
|
SELECT version_number
|
||||||
$1,
|
FROM note_versions
|
||||||
$2,
|
WHERE
|
||||||
$3,
|
note_id = $1
|
||||||
(SELECT COALESCE(MAX(version_number), 0) + 1 FROM note_versions WHERE note_id = $1),
|
AND content_hash = $2
|
||||||
encode(sha256($2::bytea || '\n'::bytea || $3::bytea), 'hex')
|
ORDER BY version_number DESC
|
||||||
|
LIMIT 1
|
||||||
|
),
|
||||||
|
note_update AS (
|
||||||
|
UPDATE notes
|
||||||
|
SET
|
||||||
|
current_version = COALESCE(
|
||||||
|
(SELECT version_number FROM potential_duplicate),
|
||||||
|
latest_version + 1 -- increment only if we don't jump into a historical version
|
||||||
|
),
|
||||||
|
latest_version = CASE
|
||||||
|
WHEN (SELECT version_number FROM potential_duplicate) IS NULL
|
||||||
|
THEN latest_version + 1
|
||||||
|
ELSE latest_version
|
||||||
|
END,
|
||||||
|
updated_at = NOW()
|
||||||
|
WHERE id = $1
|
||||||
|
RETURNING current_version, latest_version
|
||||||
)
|
)
|
||||||
RETURNING *;
|
INSERT INTO note_versions (
|
||||||
|
note_id, title, content, version_number, content_hash
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
$1, -- note_id
|
||||||
|
$3, -- title
|
||||||
|
$4, -- content
|
||||||
|
current_version,
|
||||||
|
$2 -- content_hash
|
||||||
|
FROM note_update
|
||||||
|
WHERE NOT EXISTS (SELECT 1 FROM potential_duplicate);
|
||||||
|
|
||||||
-- name: GetNoteVersions :many
|
-- name: GetVersionHistory :many
|
||||||
SELECT * FROM note_versions
|
SELECT
|
||||||
|
id AS version_id,
|
||||||
|
title,
|
||||||
|
version_number,
|
||||||
|
created_at
|
||||||
|
FROM note_versions
|
||||||
WHERE note_id = $1
|
WHERE note_id = $1
|
||||||
ORDER BY version_number DESC
|
ORDER BY version_number DESC
|
||||||
LIMIT $2 OFFSET $3;
|
LIMIT $2 OFFSET $3;
|
||||||
|
|
||||||
-- name: GetNoteVersion :one
|
-- name: GetVersion :one
|
||||||
SELECT * FROM note_versions
|
SELECT
|
||||||
WHERE note_id = $1 AND version_number = $2 LIMIT 1;
|
id AS version_id,
|
||||||
|
title,
|
||||||
-- name: FindDuplicateContent :one
|
content,
|
||||||
SELECT EXISTS(
|
version_number,
|
||||||
SELECT 1 FROM note_versions
|
created_at
|
||||||
WHERE note_id = $1
|
FROM note_versions
|
||||||
AND content_hash = encode(sha256($2::bytea || '\n'::bytea || $3::bytea), 'hex')
|
WHERE note_id = $1 AND id = $2;
|
||||||
);
|
|
||||||
|
@ -3,16 +3,34 @@ INSERT INTO notes (user_id)
|
|||||||
VALUES ($1)
|
VALUES ($1)
|
||||||
RETURNING *;
|
RETURNING *;
|
||||||
|
|
||||||
-- name: GetNote :one
|
|
||||||
SELECT * FROM notes
|
|
||||||
WHERE id = $1 AND user_id = $2 LIMIT 1;
|
|
||||||
|
|
||||||
-- name: ListNotes :many
|
-- name: ListNotes :many
|
||||||
SELECT * FROM notes
|
SELECT
|
||||||
WHERE user_id = $1
|
n.id AS note_id,
|
||||||
ORDER BY created_at DESC
|
n.user_id AS owner_id,
|
||||||
|
nv.title,
|
||||||
|
n.updated_at
|
||||||
|
FROM notes n
|
||||||
|
JOIN note_versions nv
|
||||||
|
ON n.id = nv.note_id AND n.current_version = nv.version_number
|
||||||
|
WHERE n.user_id = $1
|
||||||
|
ORDER BY n.updated_at DESC
|
||||||
LIMIT $2 OFFSET $3;
|
LIMIT $2 OFFSET $3;
|
||||||
|
|
||||||
|
-- name: GetFullNote :one
|
||||||
|
SELECT
|
||||||
|
n.id AS note_id,
|
||||||
|
n.user_id AS owner_id,
|
||||||
|
nv.title,
|
||||||
|
nv.content,
|
||||||
|
nv.version_number,
|
||||||
|
nv.created_at AS version_created_at,
|
||||||
|
n.created_at AS note_created_at,
|
||||||
|
n.updated_at AS note_updated_at
|
||||||
|
FROM notes n
|
||||||
|
JOIN note_versions nv
|
||||||
|
ON n.id = nv.note_id AND n.current_version = nv.version_number
|
||||||
|
WHERE n.id = $1;
|
||||||
|
|
||||||
-- name: DeleteNote :exec
|
-- name: DeleteNote :exec
|
||||||
DELETE FROM notes
|
DELETE FROM notes
|
||||||
WHERE id = $1 AND user_id = $2;
|
WHERE id = $1 AND user_id = $2;
|
||||||
|
@ -3,9 +3,18 @@ INSERT INTO users (username, password_hash)
|
|||||||
VALUES ($1, $2)
|
VALUES ($1, $2)
|
||||||
RETURNING *;
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: CreateAdmin :one
|
||||||
|
INSERT INTO users (username, password_hash, is_admin)
|
||||||
|
VALUES ($1, $2, true)
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
-- name: ListUsers :many
|
-- name: ListUsers :many
|
||||||
SELECT * FROM users;
|
SELECT * FROM users;
|
||||||
|
|
||||||
|
-- name: ListAdmins :many
|
||||||
|
SELECT * FROM users
|
||||||
|
WHERE is_admin = true;
|
||||||
|
|
||||||
-- name: GetUserByID :one
|
-- name: GetUserByID :one
|
||||||
SELECT * FROM users
|
SELECT * FROM users
|
||||||
WHERE id = $1 LIMIT 1;
|
WHERE id = $1 LIMIT 1;
|
||||||
|
@ -7,7 +7,7 @@ sql:
|
|||||||
gen:
|
gen:
|
||||||
go:
|
go:
|
||||||
package: "data"
|
package: "data"
|
||||||
out: "../pkg/data"
|
out: "../internal/data"
|
||||||
sql_package: "pgx/v5"
|
sql_package: "pgx/v5"
|
||||||
emit_json_tags: true
|
emit_json_tags: true
|
||||||
overrides:
|
overrides:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user