Compare commits
7 Commits
ac169851d3
...
3257b19313
Author | SHA1 | Date | |
---|---|---|---|
3257b19313 | |||
41d1336f58 | |||
9ba182d925 | |||
66fde0a700 | |||
6569a399e3 | |||
de72ea53e1 | |||
e0bdf32bfd |
1
.gitignore
vendored
1
.gitignore
vendored
@ -1 +1,2 @@
|
|||||||
TASKS.md
|
TASKS.md
|
||||||
|
.env
|
||||||
|
@ -1,3 +1,30 @@
|
|||||||
module git.umbrella.haus/ae/notatest
|
module git.umbrella.haus/ae/notatest
|
||||||
|
|
||||||
go 1.24.1
|
go 1.24.1
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/caarlos0/env v3.5.0+incompatible
|
||||||
|
github.com/go-chi/chi/v5 v5.2.1
|
||||||
|
github.com/golang-jwt/jwt/v5 v5.2.2
|
||||||
|
github.com/google/uuid v1.6.0
|
||||||
|
github.com/jackc/pgx/v5 v5.7.4
|
||||||
|
github.com/rs/zerolog v1.34.0
|
||||||
|
github.com/stretchr/testify v1.10.0
|
||||||
|
github.com/wagslane/go-password-validator v0.3.0
|
||||||
|
golang.org/x/crypto v0.36.0
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
|
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||||
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||||
|
github.com/kr/text v0.2.0 // indirect
|
||||||
|
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||||
|
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
|
github.com/rogpeppe/go-internal v1.14.1 // indirect
|
||||||
|
github.com/stretchr/objx v0.5.2 // indirect
|
||||||
|
golang.org/x/sys v0.31.0 // indirect
|
||||||
|
golang.org/x/text v0.23.0 // indirect
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
|
)
|
||||||
|
67
server/go.sum
Normal file
67
server/go.sum
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
github.com/caarlos0/env v3.5.0+incompatible h1:Yy0UN8o9Wtr/jGHZDpCBLpNrzcFLLM2yixi/rBrKyJs=
|
||||||
|
github.com/caarlos0/env v3.5.0+incompatible/go.mod h1:tdCsowwCzMLdkqRYDlHpZCp2UooDD3MspDBjZ2AD02Y=
|
||||||
|
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||||
|
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||||
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/go-chi/chi/v5 v5.2.1 h1:KOIHODQj58PmL80G2Eak4WdvUzjSJSm0vG72crDCqb8=
|
||||||
|
github.com/go-chi/chi/v5 v5.2.1/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops=
|
||||||
|
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||||
|
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
||||||
|
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||||
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
|
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||||
|
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||||
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||||
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||||
|
github.com/jackc/pgx/v5 v5.7.4 h1:9wKznZrhWa2QiHL+NjTSPP6yjl3451BX3imWDnokYlg=
|
||||||
|
github.com/jackc/pgx/v5 v5.7.4/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ=
|
||||||
|
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||||
|
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||||
|
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
|
||||||
|
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
|
||||||
|
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||||
|
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||||
|
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||||
|
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
|
||||||
|
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
|
||||||
|
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||||
|
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
|
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||||
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||||
|
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
|
||||||
|
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
|
||||||
|
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
|
||||||
|
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
|
||||||
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
|
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||||
|
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||||
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
|
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
|
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||||
|
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
|
github.com/wagslane/go-password-validator v0.3.0 h1:vfxOPzGHkz5S146HDpavl0cw1DSVP061Ry2PX0/ON6I=
|
||||||
|
github.com/wagslane/go-password-validator v0.3.0/go.mod h1:TI1XJ6T5fRdRnHqHt14pvy1tNVnrwe7m3/f1f2fDphQ=
|
||||||
|
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
||||||
|
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
||||||
|
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
|
||||||
|
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||||
|
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
||||||
|
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||||
|
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
|
||||||
|
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||||
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||||
|
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
@ -1 +0,0 @@
|
|||||||
package internal
|
|
@ -1 +0,0 @@
|
|||||||
package internal
|
|
@ -1 +0,0 @@
|
|||||||
package internal
|
|
@ -1 +0,0 @@
|
|||||||
package internal
|
|
@ -5,4 +5,5 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
internal.Run()
|
||||||
}
|
}
|
32
server/pkg/data/db.go
Normal file
32
server/pkg/data/db.go
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
// Code generated by sqlc. DO NOT EDIT.
|
||||||
|
// versions:
|
||||||
|
// sqlc v1.28.0
|
||||||
|
|
||||||
|
package data
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5"
|
||||||
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
|
)
|
||||||
|
|
||||||
|
type DBTX interface {
|
||||||
|
Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error)
|
||||||
|
Query(context.Context, string, ...interface{}) (pgx.Rows, error)
|
||||||
|
QueryRow(context.Context, string, ...interface{}) pgx.Row
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(db DBTX) *Queries {
|
||||||
|
return &Queries{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Queries struct {
|
||||||
|
db DBTX
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) WithTx(tx pgx.Tx) *Queries {
|
||||||
|
return &Queries{
|
||||||
|
db: tx,
|
||||||
|
}
|
||||||
|
}
|
51
server/pkg/data/models.go
Normal file
51
server/pkg/data/models.go
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
// Code generated by sqlc. DO NOT EDIT.
|
||||||
|
// versions:
|
||||||
|
// sqlc v1.28.0
|
||||||
|
|
||||||
|
package data
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Note struct {
|
||||||
|
ID uuid.UUID `json:"id"`
|
||||||
|
UserID uuid.UUID `json:"user_id"`
|
||||||
|
CreatedAt *time.Time `json:"created_at"`
|
||||||
|
UpdatedAt *time.Time `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type NoteVersion struct {
|
||||||
|
ID uuid.UUID `json:"id"`
|
||||||
|
NoteID uuid.UUID `json:"note_id"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
VersionNumber int32 `json:"version_number"`
|
||||||
|
ContentHash string `json:"content_hash"`
|
||||||
|
CreatedAt *time.Time `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type RefreshToken struct {
|
||||||
|
ID uuid.UUID `json:"id"`
|
||||||
|
UserID uuid.UUID `json:"user_id"`
|
||||||
|
TokenHash string `json:"token_hash"`
|
||||||
|
ExpiresAt time.Time `json:"expires_at"`
|
||||||
|
CreatedAt *time.Time `json:"created_at"`
|
||||||
|
Revoked bool `json:"revoked"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type SchemaMigration struct {
|
||||||
|
Version int64 `json:"version"`
|
||||||
|
AppliedAt *time.Time `json:"applied_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
ID uuid.UUID `json:"id"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
PasswordHash string `json:"password_hash"`
|
||||||
|
IsAdmin bool `json:"is_admin"`
|
||||||
|
CreatedAt *time.Time `json:"created_at"`
|
||||||
|
UpdatedAt *time.Time `json:"updated_at"`
|
||||||
|
}
|
132
server/pkg/data/note_versions.sql.go
Normal file
132
server/pkg/data/note_versions.sql.go
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
// Code generated by sqlc. DO NOT EDIT.
|
||||||
|
// versions:
|
||||||
|
// sqlc v1.28.0
|
||||||
|
// source: note_versions.sql
|
||||||
|
|
||||||
|
package data
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
const createNoteVersion = `-- name: CreateNoteVersion :one
|
||||||
|
INSERT INTO note_versions (note_id, title, content, version_number, content_hash)
|
||||||
|
VALUES (
|
||||||
|
$1,
|
||||||
|
$2,
|
||||||
|
$3,
|
||||||
|
(SELECT COALESCE(MAX(version_number), 0) + 1 FROM note_versions WHERE note_id = $1),
|
||||||
|
encode(sha256($2::bytea || '\n'::bytea || $3::bytea), 'hex')
|
||||||
|
)
|
||||||
|
RETURNING id, note_id, title, content, version_number, content_hash, created_at
|
||||||
|
`
|
||||||
|
|
||||||
|
type CreateNoteVersionParams struct {
|
||||||
|
NoteID uuid.UUID `json:"note_id"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) CreateNoteVersion(ctx context.Context, arg CreateNoteVersionParams) (NoteVersion, error) {
|
||||||
|
row := q.db.QueryRow(ctx, createNoteVersion, arg.NoteID, arg.Title, arg.Content)
|
||||||
|
var i NoteVersion
|
||||||
|
err := row.Scan(
|
||||||
|
&i.ID,
|
||||||
|
&i.NoteID,
|
||||||
|
&i.Title,
|
||||||
|
&i.Content,
|
||||||
|
&i.VersionNumber,
|
||||||
|
&i.ContentHash,
|
||||||
|
&i.CreatedAt,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const findDuplicateContent = `-- name: FindDuplicateContent :one
|
||||||
|
SELECT EXISTS(
|
||||||
|
SELECT 1 FROM note_versions
|
||||||
|
WHERE note_id = $1
|
||||||
|
AND content_hash = encode(sha256($2::bytea || '\n'::bytea || $3::bytea), 'hex')
|
||||||
|
)
|
||||||
|
`
|
||||||
|
|
||||||
|
type FindDuplicateContentParams struct {
|
||||||
|
NoteID uuid.UUID `json:"note_id"`
|
||||||
|
Column2 []byte `json:"column_2"`
|
||||||
|
Column3 []byte `json:"column_3"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) FindDuplicateContent(ctx context.Context, arg FindDuplicateContentParams) (bool, error) {
|
||||||
|
row := q.db.QueryRow(ctx, findDuplicateContent, arg.NoteID, arg.Column2, arg.Column3)
|
||||||
|
var exists bool
|
||||||
|
err := row.Scan(&exists)
|
||||||
|
return exists, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const getNoteVersion = `-- name: GetNoteVersion :one
|
||||||
|
SELECT id, note_id, title, content, version_number, content_hash, created_at FROM note_versions
|
||||||
|
WHERE note_id = $1 AND version_number = $2 LIMIT 1
|
||||||
|
`
|
||||||
|
|
||||||
|
type GetNoteVersionParams struct {
|
||||||
|
NoteID uuid.UUID `json:"note_id"`
|
||||||
|
VersionNumber int32 `json:"version_number"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) GetNoteVersion(ctx context.Context, arg GetNoteVersionParams) (NoteVersion, error) {
|
||||||
|
row := q.db.QueryRow(ctx, getNoteVersion, arg.NoteID, arg.VersionNumber)
|
||||||
|
var i NoteVersion
|
||||||
|
err := row.Scan(
|
||||||
|
&i.ID,
|
||||||
|
&i.NoteID,
|
||||||
|
&i.Title,
|
||||||
|
&i.Content,
|
||||||
|
&i.VersionNumber,
|
||||||
|
&i.ContentHash,
|
||||||
|
&i.CreatedAt,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const getNoteVersions = `-- name: GetNoteVersions :many
|
||||||
|
SELECT id, note_id, title, content, version_number, content_hash, created_at FROM note_versions
|
||||||
|
WHERE note_id = $1
|
||||||
|
ORDER BY version_number DESC
|
||||||
|
LIMIT $2 OFFSET $3
|
||||||
|
`
|
||||||
|
|
||||||
|
type GetNoteVersionsParams struct {
|
||||||
|
NoteID uuid.UUID `json:"note_id"`
|
||||||
|
Limit int32 `json:"limit"`
|
||||||
|
Offset int32 `json:"offset"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) GetNoteVersions(ctx context.Context, arg GetNoteVersionsParams) ([]NoteVersion, error) {
|
||||||
|
rows, err := q.db.Query(ctx, getNoteVersions, arg.NoteID, arg.Limit, arg.Offset)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
var items []NoteVersion
|
||||||
|
for rows.Next() {
|
||||||
|
var i NoteVersion
|
||||||
|
if err := rows.Scan(
|
||||||
|
&i.ID,
|
||||||
|
&i.NoteID,
|
||||||
|
&i.Title,
|
||||||
|
&i.Content,
|
||||||
|
&i.VersionNumber,
|
||||||
|
&i.ContentHash,
|
||||||
|
&i.CreatedAt,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
items = append(items, i)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return items, nil
|
||||||
|
}
|
105
server/pkg/data/notes.sql.go
Normal file
105
server/pkg/data/notes.sql.go
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
// Code generated by sqlc. DO NOT EDIT.
|
||||||
|
// versions:
|
||||||
|
// sqlc v1.28.0
|
||||||
|
// source: notes.sql
|
||||||
|
|
||||||
|
package data
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
const createNote = `-- name: CreateNote :one
|
||||||
|
INSERT INTO notes (user_id)
|
||||||
|
VALUES ($1)
|
||||||
|
RETURNING id, user_id, created_at, updated_at
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) CreateNote(ctx context.Context, userID uuid.UUID) (Note, error) {
|
||||||
|
row := q.db.QueryRow(ctx, createNote, userID)
|
||||||
|
var i Note
|
||||||
|
err := row.Scan(
|
||||||
|
&i.ID,
|
||||||
|
&i.UserID,
|
||||||
|
&i.CreatedAt,
|
||||||
|
&i.UpdatedAt,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const deleteNote = `-- name: DeleteNote :exec
|
||||||
|
DELETE FROM notes
|
||||||
|
WHERE id = $1 AND user_id = $2
|
||||||
|
`
|
||||||
|
|
||||||
|
type DeleteNoteParams struct {
|
||||||
|
ID uuid.UUID `json:"id"`
|
||||||
|
UserID uuid.UUID `json:"user_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) DeleteNote(ctx context.Context, arg DeleteNoteParams) error {
|
||||||
|
_, err := q.db.Exec(ctx, deleteNote, arg.ID, arg.UserID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
const getNote = `-- name: GetNote :one
|
||||||
|
SELECT id, user_id, created_at, updated_at FROM notes
|
||||||
|
WHERE id = $1 AND user_id = $2 LIMIT 1
|
||||||
|
`
|
||||||
|
|
||||||
|
type GetNoteParams struct {
|
||||||
|
ID uuid.UUID `json:"id"`
|
||||||
|
UserID uuid.UUID `json:"user_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) GetNote(ctx context.Context, arg GetNoteParams) (Note, error) {
|
||||||
|
row := q.db.QueryRow(ctx, getNote, arg.ID, arg.UserID)
|
||||||
|
var i Note
|
||||||
|
err := row.Scan(
|
||||||
|
&i.ID,
|
||||||
|
&i.UserID,
|
||||||
|
&i.CreatedAt,
|
||||||
|
&i.UpdatedAt,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const listNotes = `-- name: ListNotes :many
|
||||||
|
SELECT id, user_id, created_at, updated_at FROM notes
|
||||||
|
WHERE user_id = $1
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
LIMIT $2 OFFSET $3
|
||||||
|
`
|
||||||
|
|
||||||
|
type ListNotesParams struct {
|
||||||
|
UserID uuid.UUID `json:"user_id"`
|
||||||
|
Limit int32 `json:"limit"`
|
||||||
|
Offset int32 `json:"offset"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) ListNotes(ctx context.Context, arg ListNotesParams) ([]Note, error) {
|
||||||
|
rows, err := q.db.Query(ctx, listNotes, arg.UserID, arg.Limit, arg.Offset)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
var items []Note
|
||||||
|
for rows.Next() {
|
||||||
|
var i Note
|
||||||
|
if err := rows.Scan(
|
||||||
|
&i.ID,
|
||||||
|
&i.UserID,
|
||||||
|
&i.CreatedAt,
|
||||||
|
&i.UpdatedAt,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
items = append(items, i)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return items, nil
|
||||||
|
}
|
93
server/pkg/data/refresh_tokens.sql.go
Normal file
93
server/pkg/data/refresh_tokens.sql.go
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
// Code generated by sqlc. DO NOT EDIT.
|
||||||
|
// versions:
|
||||||
|
// sqlc v1.28.0
|
||||||
|
// source: refresh_tokens.sql
|
||||||
|
|
||||||
|
package data
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
const createRefreshToken = `-- name: CreateRefreshToken :one
|
||||||
|
INSERT INTO refresh_tokens (
|
||||||
|
user_id,
|
||||||
|
token_hash,
|
||||||
|
expires_at
|
||||||
|
) VALUES ($1, $2, $3)
|
||||||
|
RETURNING id, user_id, token_hash, expires_at, created_at, revoked
|
||||||
|
`
|
||||||
|
|
||||||
|
type CreateRefreshTokenParams struct {
|
||||||
|
UserID uuid.UUID `json:"user_id"`
|
||||||
|
TokenHash string `json:"token_hash"`
|
||||||
|
ExpiresAt time.Time `json:"expires_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) CreateRefreshToken(ctx context.Context, arg CreateRefreshTokenParams) (RefreshToken, error) {
|
||||||
|
row := q.db.QueryRow(ctx, createRefreshToken, arg.UserID, arg.TokenHash, arg.ExpiresAt)
|
||||||
|
var i RefreshToken
|
||||||
|
err := row.Scan(
|
||||||
|
&i.ID,
|
||||||
|
&i.UserID,
|
||||||
|
&i.TokenHash,
|
||||||
|
&i.ExpiresAt,
|
||||||
|
&i.CreatedAt,
|
||||||
|
&i.Revoked,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const deleteExpiredRefreshTokens = `-- name: DeleteExpiredRefreshTokens :exec
|
||||||
|
DELETE FROM refresh_tokens
|
||||||
|
WHERE expires_at < NOW()
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) DeleteExpiredRefreshTokens(ctx context.Context) error {
|
||||||
|
_, err := q.db.Exec(ctx, deleteExpiredRefreshTokens)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
const getRefreshTokenByHash = `-- name: GetRefreshTokenByHash :one
|
||||||
|
SELECT id, user_id, token_hash, expires_at, created_at, revoked FROM refresh_tokens
|
||||||
|
WHERE token_hash = $1 LIMIT 1
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) GetRefreshTokenByHash(ctx context.Context, tokenHash string) (RefreshToken, error) {
|
||||||
|
row := q.db.QueryRow(ctx, getRefreshTokenByHash, tokenHash)
|
||||||
|
var i RefreshToken
|
||||||
|
err := row.Scan(
|
||||||
|
&i.ID,
|
||||||
|
&i.UserID,
|
||||||
|
&i.TokenHash,
|
||||||
|
&i.ExpiresAt,
|
||||||
|
&i.CreatedAt,
|
||||||
|
&i.Revoked,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const revokeAllUserRefreshTokens = `-- name: RevokeAllUserRefreshTokens :exec
|
||||||
|
UPDATE refresh_tokens
|
||||||
|
SET revoked = TRUE
|
||||||
|
WHERE user_id = $1
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) RevokeAllUserRefreshTokens(ctx context.Context, userID uuid.UUID) error {
|
||||||
|
_, err := q.db.Exec(ctx, revokeAllUserRefreshTokens, userID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
const revokeRefreshToken = `-- name: RevokeRefreshToken :exec
|
||||||
|
UPDATE refresh_tokens
|
||||||
|
SET revoked = TRUE
|
||||||
|
WHERE token_hash = $1
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) RevokeRefreshToken(ctx context.Context, tokenHash string) error {
|
||||||
|
_, err := q.db.Exec(ctx, revokeRefreshToken, tokenHash)
|
||||||
|
return err
|
||||||
|
}
|
132
server/pkg/data/users.sql.go
Normal file
132
server/pkg/data/users.sql.go
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
// Code generated by sqlc. DO NOT EDIT.
|
||||||
|
// versions:
|
||||||
|
// sqlc v1.28.0
|
||||||
|
// source: users.sql
|
||||||
|
|
||||||
|
package data
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
const createUser = `-- name: CreateUser :one
|
||||||
|
INSERT INTO users (username, password_hash)
|
||||||
|
VALUES ($1, $2)
|
||||||
|
RETURNING id, username, password_hash, is_admin, created_at, updated_at
|
||||||
|
`
|
||||||
|
|
||||||
|
type CreateUserParams struct {
|
||||||
|
Username string `json:"username"`
|
||||||
|
PasswordHash string `json:"password_hash"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, error) {
|
||||||
|
row := q.db.QueryRow(ctx, createUser, arg.Username, arg.PasswordHash)
|
||||||
|
var i User
|
||||||
|
err := row.Scan(
|
||||||
|
&i.ID,
|
||||||
|
&i.Username,
|
||||||
|
&i.PasswordHash,
|
||||||
|
&i.IsAdmin,
|
||||||
|
&i.CreatedAt,
|
||||||
|
&i.UpdatedAt,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const deleteUser = `-- name: DeleteUser :exec
|
||||||
|
DELETE FROM users
|
||||||
|
WHERE id = $1
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) DeleteUser(ctx context.Context, id uuid.UUID) error {
|
||||||
|
_, err := q.db.Exec(ctx, deleteUser, id)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
const getUserByID = `-- name: GetUserByID :one
|
||||||
|
SELECT id, username, password_hash, is_admin, created_at, updated_at FROM users
|
||||||
|
WHERE id = $1 LIMIT 1
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) GetUserByID(ctx context.Context, id uuid.UUID) (User, error) {
|
||||||
|
row := q.db.QueryRow(ctx, getUserByID, id)
|
||||||
|
var i User
|
||||||
|
err := row.Scan(
|
||||||
|
&i.ID,
|
||||||
|
&i.Username,
|
||||||
|
&i.PasswordHash,
|
||||||
|
&i.IsAdmin,
|
||||||
|
&i.CreatedAt,
|
||||||
|
&i.UpdatedAt,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const getUserByUsername = `-- name: GetUserByUsername :one
|
||||||
|
SELECT id, username, password_hash, is_admin, created_at, updated_at FROM users
|
||||||
|
WHERE username = $1 LIMIT 1
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) GetUserByUsername(ctx context.Context, username string) (User, error) {
|
||||||
|
row := q.db.QueryRow(ctx, getUserByUsername, username)
|
||||||
|
var i User
|
||||||
|
err := row.Scan(
|
||||||
|
&i.ID,
|
||||||
|
&i.Username,
|
||||||
|
&i.PasswordHash,
|
||||||
|
&i.IsAdmin,
|
||||||
|
&i.CreatedAt,
|
||||||
|
&i.UpdatedAt,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const listUsers = `-- name: ListUsers :many
|
||||||
|
SELECT id, username, password_hash, is_admin, created_at, updated_at FROM users
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) ListUsers(ctx context.Context) ([]User, error) {
|
||||||
|
rows, err := q.db.Query(ctx, listUsers)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
var items []User
|
||||||
|
for rows.Next() {
|
||||||
|
var i User
|
||||||
|
if err := rows.Scan(
|
||||||
|
&i.ID,
|
||||||
|
&i.Username,
|
||||||
|
&i.PasswordHash,
|
||||||
|
&i.IsAdmin,
|
||||||
|
&i.CreatedAt,
|
||||||
|
&i.UpdatedAt,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
items = append(items, i)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return items, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const updatePassword = `-- name: UpdatePassword :exec
|
||||||
|
UPDATE users
|
||||||
|
SET password_hash = $2, updated_at = NOW()
|
||||||
|
WHERE id = $1
|
||||||
|
`
|
||||||
|
|
||||||
|
type UpdatePasswordParams struct {
|
||||||
|
ID uuid.UUID `json:"id"`
|
||||||
|
PasswordHash string `json:"password_hash"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) UpdatePassword(ctx context.Context, arg UpdatePasswordParams) error {
|
||||||
|
_, err := q.db.Exec(ctx, updatePassword, arg.ID, arg.PasswordHash)
|
||||||
|
return err
|
||||||
|
}
|
74
server/pkg/migrate/migrate.go
Normal file
74
server/pkg/migrate/migrate.go
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
// pkg/migrate/migrate.go
|
||||||
|
package migrate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"embed"
|
||||||
|
"fmt"
|
||||||
|
"io/fs"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Run(ctx context.Context, conn *pgx.Conn, migrationsFS embed.FS) error {
|
||||||
|
// Get already applied migrations
|
||||||
|
rows, _ := conn.Query(ctx, "SELECT version FROM schema_migrations")
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
applied := make(map[int64]bool)
|
||||||
|
for rows.Next() {
|
||||||
|
var version int64
|
||||||
|
if err := rows.Scan(&version); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
applied[version] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
files, err := migrationsFS.ReadDir("migrations")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply the migrations sequentially based on their ordinal number
|
||||||
|
for _, f := range sortMigrations(files) {
|
||||||
|
version, err := strconv.ParseInt(strings.Split(f.Name(), "_")[0], 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid migration name: %s", f.Name())
|
||||||
|
}
|
||||||
|
|
||||||
|
if applied[version] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run migration
|
||||||
|
sql, err := migrationsFS.ReadFile("migrations/" + f.Name())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := conn.Exec(ctx, string(sql)); err != nil {
|
||||||
|
return fmt.Errorf("migration %d failed: %w", version, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := conn.Exec(ctx,
|
||||||
|
"INSERT INTO schema_migrations (version) VALUES ($1)", version,
|
||||||
|
); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort the migration files based on their ordinal number prefix.
|
||||||
|
func sortMigrations(files []fs.DirEntry) []fs.DirEntry {
|
||||||
|
sort.Slice(files, func(i, j int) bool {
|
||||||
|
v1, _ := strconv.ParseInt(strings.Split(files[i].Name(), "_")[0], 10, 64)
|
||||||
|
v2, _ := strconv.ParseInt(strings.Split(files[j].Name(), "_")[0], 10, 64)
|
||||||
|
return v1 < v2
|
||||||
|
})
|
||||||
|
return files
|
||||||
|
}
|
157
server/pkg/service/middleware.go
Normal file
157
server/pkg/service/middleware.go
Normal file
@ -0,0 +1,157 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
"github.com/go-chi/chi/v5/middleware"
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
panicRecoveryMsg = "panic recovered"
|
||||||
|
defaultLogMsg = "incoming request"
|
||||||
|
)
|
||||||
|
|
||||||
|
type userCtxKey struct{}
|
||||||
|
|
||||||
|
// Get JWT bearer from request's authorization header, parse it with custom user claims, and
|
||||||
|
// ensure its validity before attaching the claims to the request's context.
|
||||||
|
func authMiddleware(jwtSecret string, expectedType string) func(http.Handler) http.Handler {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
tokenString, err := getTokenFromRequest(r)
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusUnauthorized, fmt.Sprintf("Unauthorized: %s", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := jwt.ParseWithClaims(tokenString, &userClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||||
|
return []byte(jwtSecret), nil
|
||||||
|
})
|
||||||
|
if err != nil || !token.Valid {
|
||||||
|
respondError(w, http.StatusUnauthorized, "Invalid token")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
claims, ok := token.Claims.(*userClaims)
|
||||||
|
if !ok || claims.TokenType != expectedType {
|
||||||
|
respondError(w, http.StatusUnauthorized, "Invalid token type")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.WithValue(r.Context(), userCtxKey{}, claims)
|
||||||
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// JWT access token parsing, verification, and validation.
|
||||||
|
func requireAccessToken(jwtSecret string) func(http.Handler) http.Handler {
|
||||||
|
return authMiddleware(jwtSecret, "access")
|
||||||
|
}
|
||||||
|
|
||||||
|
// JWT refresh token parsing, verification, and validation.
|
||||||
|
func requireRefreshToken(jwtSecret string) func(http.Handler) http.Handler {
|
||||||
|
return authMiddleware(jwtSecret, "refresh")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure the current user is an administrator.
|
||||||
|
func adminOnlyMiddleware(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
||||||
|
if !ok || !user.Admin {
|
||||||
|
respondError(w, http.StatusForbidden, "Forbidden")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure the targeted resource is owned by the current user (i.e. current user's ID matches with
|
||||||
|
// the one stored into the resource).
|
||||||
|
func ownerOnlyMiddleware(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
||||||
|
requestedID := chi.URLParam(r, "id")
|
||||||
|
if !ok || user.ID != requestedID {
|
||||||
|
respondError(w, http.StatusForbidden, "Forbidden")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append user data into request's context based on user ID as a URL parameter.
|
||||||
|
func userCtx(store UserStore) func(http.Handler) http.Handler {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userIDStr := chi.URLParam(r, "id")
|
||||||
|
userID, err := uuid.Parse(userIDStr)
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusNotFound, "Invalid user ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := store.GetUserByID(r.Context(), userID)
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusNotFound, "User not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.WithValue(r.Context(), userCtxKey{}, user)
|
||||||
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Zerolog compatible logger middleware. Automatically logs and recovers from errors with HTTP 500
|
||||||
|
// response, by default logs to INFO level.
|
||||||
|
func loggerMiddleware(log *zerolog.Logger) func(http.Handler) http.Handler {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
log := log.With().Logger()
|
||||||
|
|
||||||
|
ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
|
||||||
|
|
||||||
|
t1 := time.Now()
|
||||||
|
defer func() {
|
||||||
|
t2 := time.Now()
|
||||||
|
|
||||||
|
// Recover automatically and respond with HTTP 500
|
||||||
|
if rec := recover(); rec != nil {
|
||||||
|
log.Error().
|
||||||
|
Str("type", "error").
|
||||||
|
Timestamp().
|
||||||
|
Interface("recover_info", rec).
|
||||||
|
Msg(panicRecoveryMsg)
|
||||||
|
http.Error(ww, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log a regular HTTP request with some metadata
|
||||||
|
log.Info().
|
||||||
|
Str("type", "access").
|
||||||
|
Timestamp().
|
||||||
|
Fields(map[string]interface{}{
|
||||||
|
"remote_ip": r.RemoteAddr,
|
||||||
|
"url": r.URL.Path,
|
||||||
|
"proto": r.Proto,
|
||||||
|
"method": r.Method,
|
||||||
|
"user_agent": r.Header.Get("User-Agent"),
|
||||||
|
"status": ww.Status(),
|
||||||
|
"latency_ms": float64(t2.Sub(t1).Nanoseconds()) / 1000000.0,
|
||||||
|
"bytes_in": r.Header.Get("Content-Length"),
|
||||||
|
"bytes_out": ww.BytesWritten(),
|
||||||
|
}).
|
||||||
|
Msg(defaultLogMsg)
|
||||||
|
}()
|
||||||
|
|
||||||
|
next.ServeHTTP(ww, r)
|
||||||
|
}
|
||||||
|
return http.HandlerFunc(fn)
|
||||||
|
}
|
||||||
|
}
|
19
server/pkg/service/notes.go
Normal file
19
server/pkg/service/notes.go
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Mockable database operations interface
|
||||||
|
type NoteStore interface {
|
||||||
|
// TODO: implement
|
||||||
|
}
|
||||||
|
|
||||||
|
type notesResource struct {
|
||||||
|
Notes NoteStore
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rs notesResource) Routes() chi.Router {
|
||||||
|
r := chi.NewRouter()
|
||||||
|
return r
|
||||||
|
}
|
43
server/pkg/service/server.go
Normal file
43
server/pkg/service/server.go
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"git.umbrella.haus/ae/notatest/pkg/data"
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
"github.com/go-chi/chi/v5/middleware"
|
||||||
|
"github.com/jackc/pgx/v5"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Run(conn *pgx.Conn, jwtSecret string, httpPort string) error {
|
||||||
|
q := data.New(conn)
|
||||||
|
r := chi.NewRouter()
|
||||||
|
|
||||||
|
tokenService := tokenService{
|
||||||
|
JWTSecret: jwtSecret,
|
||||||
|
Tokens: q,
|
||||||
|
}
|
||||||
|
usersRouter := usersResource{
|
||||||
|
JWTSecret: jwtSecret,
|
||||||
|
Users: q,
|
||||||
|
}
|
||||||
|
notesRouter := notesResource{}
|
||||||
|
|
||||||
|
// Global middlewares
|
||||||
|
r.Use(middleware.RequestID)
|
||||||
|
r.Use(middleware.RealIP)
|
||||||
|
r.Use(loggerMiddleware(&log.Logger))
|
||||||
|
r.Use(middleware.Recoverer)
|
||||||
|
r.Use(middleware.AllowContentType("application/json"))
|
||||||
|
|
||||||
|
// Routes grouped by functionality
|
||||||
|
r.Post("/auth/refresh", tokenService.RefreshAccessToken) // POST /auth/refresh - new access token for refresh token
|
||||||
|
r.Mount("/users", usersRouter.Routes())
|
||||||
|
r.Mount("/notes", notesRouter.Routes())
|
||||||
|
|
||||||
|
portStr := fmt.Sprintf(":%s", httpPort)
|
||||||
|
log.Info().Msgf("Starting server on %s", portStr)
|
||||||
|
return http.ListenAndServe(portStr, r)
|
||||||
|
}
|
205
server/pkg/service/tokens.go
Normal file
205
server/pkg/service/tokens.go
Normal file
@ -0,0 +1,205 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.umbrella.haus/ae/notatest/pkg/data"
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
accessTokenDuration = 15 * time.Minute
|
||||||
|
refreshTokenDuration = 7 * 24 * time.Hour
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrInvalidToken = errors.New("invalid token")
|
||||||
|
ErrAuthHeaderInvalid = errors.New("token couldn't be parsed from authentication header")
|
||||||
|
)
|
||||||
|
|
||||||
|
type userClaims struct {
|
||||||
|
Admin bool `json:"admin"`
|
||||||
|
TokenType string `json:"type"` // "access" or "refresh"
|
||||||
|
jwt.RegisteredClaims // User's UUID should be stored in the subject claim
|
||||||
|
}
|
||||||
|
|
||||||
|
type tokenPair struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mockable database operations interface
|
||||||
|
type TokenStore interface {
|
||||||
|
CreateRefreshToken(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error)
|
||||||
|
GetRefreshTokenByHash(ctx context.Context, tokenHash string) (data.RefreshToken, error)
|
||||||
|
RevokeRefreshToken(ctx context.Context, tokenHash string) error
|
||||||
|
RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type tokenService struct {
|
||||||
|
JWTSecret string
|
||||||
|
Tokens TokenStore
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts *tokenService) GenerateTokenPair(ctx context.Context, userID uuid.UUID, isAdmin bool) (*tokenPair, error) {
|
||||||
|
tokenPair, err := generateTokenPair(userID.String(), isAdmin, ts.JWTSecret)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
hash := sha256.Sum256([]byte(tokenPair.RefreshToken))
|
||||||
|
tokenHash := hex.EncodeToString(hash[:])
|
||||||
|
|
||||||
|
// Store to DB with (almost) identical expiration timestamp
|
||||||
|
expiresAt := time.Now().Add(refreshTokenDuration)
|
||||||
|
_, err = ts.Tokens.CreateRefreshToken(ctx, data.CreateRefreshTokenParams{
|
||||||
|
UserID: userID,
|
||||||
|
TokenHash: tokenHash,
|
||||||
|
ExpiresAt: expiresAt,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokenPair, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts *tokenService) RevokeRefreshToken(ctx context.Context, token string) error {
|
||||||
|
hash := sha256.Sum256([]byte(token))
|
||||||
|
tokenHash := hex.EncodeToString(hash[:])
|
||||||
|
return ts.Tokens.RevokeRefreshToken(ctx, tokenHash)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts *tokenService) ValidateRefreshToken(ctx context.Context, token string) (*data.RefreshToken, error) {
|
||||||
|
hash := sha256.Sum256([]byte(token))
|
||||||
|
tokenHash := hex.EncodeToString(hash[:])
|
||||||
|
|
||||||
|
dbToken, err := ts.Tokens.GetRefreshTokenByHash(ctx, tokenHash)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if dbToken.Revoked || time.Now().After(dbToken.ExpiresAt) {
|
||||||
|
return nil, ErrInvalidToken
|
||||||
|
}
|
||||||
|
|
||||||
|
return &dbToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts *tokenService) RefreshAccessToken(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Get claims from context
|
||||||
|
claims, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
||||||
|
if !ok || claims.TokenType != "refresh" {
|
||||||
|
respondError(w, http.StatusUnauthorized, "Invalid token")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attempt to get the token from Authentication header ("Bearer <token>")
|
||||||
|
refreshToken, err := getTokenFromRequest(r)
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusUnauthorized, fmt.Sprintf("Unauthorized: %s", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate the refresh token in DB
|
||||||
|
if _, err := ts.ValidateRefreshToken(r.Context(), refreshToken); err != nil {
|
||||||
|
respondError(w, http.StatusUnauthorized, "Invalid refresh token")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Revoke the used refresh token
|
||||||
|
if err := ts.RevokeRefreshToken(r.Context(), refreshToken); err != nil {
|
||||||
|
respondError(w, http.StatusInternalServerError, "Failed to revoke token")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a new pair (access & refresh tokens)
|
||||||
|
userID, err := uuid.Parse(claims.Subject)
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusInternalServerError, "Invalid user ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenPair, err := ts.GenerateTokenPair(r.Context(), userID, claims.Admin)
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusInternalServerError, "Failed to generate tokens")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
respondJSON(w, http.StatusOK, tokenPair)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts *tokenService) HandleLogout(w http.ResponseWriter, r *http.Request) {
|
||||||
|
claims, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
||||||
|
if !ok {
|
||||||
|
respondError(w, http.StatusUnauthorized, "Not authenticated")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userID, err := uuid.Parse(claims.ID)
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusInternalServerError, "Invalid user ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ts.Tokens.RevokeAllUserRefreshTokens(r.Context(), userID); err != nil {
|
||||||
|
respondError(w, http.StatusInternalServerError, "Failed to logout")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
respondJSON(w, http.StatusOK, map[string]string{"status": "logged out"})
|
||||||
|
}
|
||||||
|
|
||||||
|
func getTokenFromRequest(r *http.Request) (string, error) {
|
||||||
|
bearerToken := r.Header.Get("Authorization")
|
||||||
|
bearerFields := strings.Fields(bearerToken)
|
||||||
|
if len(bearerFields) == 2 && strings.ToLower(bearerFields[0]) == "bearer" {
|
||||||
|
return bearerFields[1], nil
|
||||||
|
}
|
||||||
|
return "", ErrAuthHeaderInvalid
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateTokenPair(userID string, isAdmin bool, jwtSecret string) (*tokenPair, error) {
|
||||||
|
atClaims := userClaims{
|
||||||
|
Admin: isAdmin,
|
||||||
|
TokenType: "access",
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
Subject: userID,
|
||||||
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(accessTokenDuration)),
|
||||||
|
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||||
|
NotBefore: jwt.NewNumericDate(time.Now()),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, atClaims)
|
||||||
|
|
||||||
|
t, err := accessToken.SignedString([]byte(jwtSecret))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
rtClaims := userClaims{
|
||||||
|
Admin: isAdmin,
|
||||||
|
TokenType: "refresh",
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
Subject: userID,
|
||||||
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(refreshTokenDuration)),
|
||||||
|
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||||
|
NotBefore: jwt.NewNumericDate(time.Now()),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
refreshToken := jwt.NewWithClaims(jwt.SigningMethodHS256, rtClaims)
|
||||||
|
|
||||||
|
rt, err := refreshToken.SignedString([]byte(jwtSecret))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &tokenPair{AccessToken: t, RefreshToken: rt}, nil
|
||||||
|
}
|
180
server/pkg/service/tokens_test.go
Normal file
180
server/pkg/service/tokens_test.go
Normal file
@ -0,0 +1,180 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.umbrella.haus/ae/notatest/pkg/data"
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockTokenStore struct {
|
||||||
|
CreateRefreshTokenFunc func(context.Context, data.CreateRefreshTokenParams) (data.RefreshToken, error)
|
||||||
|
GetRefreshTokenByHashFunc func(context.Context, string) (data.RefreshToken, error)
|
||||||
|
RevokeRefreshTokenFunc func(context.Context, string) error
|
||||||
|
RevokeAllUserRefreshTokensFunc func(context.Context, uuid.UUID) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockTokenStore) CreateRefreshToken(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) {
|
||||||
|
return m.CreateRefreshTokenFunc(ctx, arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockTokenStore) GetRefreshTokenByHash(ctx context.Context, tokenHash string) (data.RefreshToken, error) {
|
||||||
|
return m.GetRefreshTokenByHashFunc(ctx, tokenHash)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockTokenStore) RevokeRefreshToken(ctx context.Context, token string) error {
|
||||||
|
return m.RevokeRefreshTokenFunc(ctx, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockTokenStore) RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error {
|
||||||
|
return m.RevokeAllUserRefreshTokensFunc(ctx, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateTokenPair_Success(t *testing.T) {
|
||||||
|
userID := uuid.New()
|
||||||
|
called := false
|
||||||
|
|
||||||
|
mockStore := &mockTokenStore{
|
||||||
|
CreateRefreshTokenFunc: func(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) {
|
||||||
|
called = true
|
||||||
|
assert.Equal(t, userID, arg.UserID)
|
||||||
|
return data.RefreshToken{}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
ts := tokenService{
|
||||||
|
JWTSecret: "test-secret",
|
||||||
|
Tokens: mockStore,
|
||||||
|
}
|
||||||
|
|
||||||
|
pair, err := ts.GenerateTokenPair(context.Background(), userID, false)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, called)
|
||||||
|
assert.NotEmpty(t, pair.AccessToken)
|
||||||
|
assert.NotEmpty(t, pair.RefreshToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateTokenPair_DBError(t *testing.T) {
|
||||||
|
mockStore := &mockTokenStore{
|
||||||
|
CreateRefreshTokenFunc: func(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) {
|
||||||
|
return data.RefreshToken{}, errors.New("db error")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
ts := tokenService{Tokens: mockStore}
|
||||||
|
_, err := ts.GenerateTokenPair(context.Background(), uuid.New(), false)
|
||||||
|
assert.ErrorContains(t, err, "db error")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateRefreshToken_Valid(t *testing.T) {
|
||||||
|
token := "valid-jwt-token"
|
||||||
|
hash := sha256.Sum256([]byte(token))
|
||||||
|
expectedHash := hex.EncodeToString(hash[:])
|
||||||
|
|
||||||
|
mockStore := &mockTokenStore{
|
||||||
|
GetRefreshTokenByHashFunc: func(ctx context.Context, tokenHash string) (data.RefreshToken, error) {
|
||||||
|
assert.Equal(t, expectedHash, tokenHash)
|
||||||
|
return data.RefreshToken{ExpiresAt: time.Now().Add(1 * time.Hour)}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
ts := tokenService{Tokens: mockStore}
|
||||||
|
_, err := ts.ValidateRefreshToken(context.Background(), token)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshAccessToken_Success(t *testing.T) {
|
||||||
|
userID := uuid.New()
|
||||||
|
refreshToken := "valid-jwt-token"
|
||||||
|
|
||||||
|
// Expected hash of the test token
|
||||||
|
hash := sha256.Sum256([]byte(refreshToken))
|
||||||
|
expectedHash := hex.EncodeToString(hash[:])
|
||||||
|
|
||||||
|
mockStore := &mockTokenStore{
|
||||||
|
GetRefreshTokenByHashFunc: func(ctx context.Context, tokenHash string) (data.RefreshToken, error) {
|
||||||
|
assert.Equal(t, expectedHash, tokenHash)
|
||||||
|
|
||||||
|
// Must return an unrevoked token with future expiration
|
||||||
|
return data.RefreshToken{
|
||||||
|
TokenHash: tokenHash,
|
||||||
|
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||||
|
Revoked: false,
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
RevokeRefreshTokenFunc: func(ctx context.Context, tokenHash string) error {
|
||||||
|
assert.Equal(t, expectedHash, tokenHash)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
CreateRefreshTokenFunc: func(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) {
|
||||||
|
return data.RefreshToken{}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
ts := tokenService{
|
||||||
|
JWTSecret: "test-secret",
|
||||||
|
Tokens: mockStore,
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/", nil)
|
||||||
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", refreshToken))
|
||||||
|
req = req.WithContext(context.WithValue(
|
||||||
|
req.Context(),
|
||||||
|
userCtxKey{},
|
||||||
|
&userClaims{
|
||||||
|
Admin: false,
|
||||||
|
TokenType: "refresh",
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
Subject: userID.String(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
ts.RefreshAccessToken(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), "access_token", "refresh_token")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleLogout_Success(t *testing.T) {
|
||||||
|
userID := uuid.New()
|
||||||
|
called := false
|
||||||
|
|
||||||
|
mockStore := &mockTokenStore{
|
||||||
|
RevokeAllUserRefreshTokensFunc: func(ctx context.Context, id uuid.UUID) error {
|
||||||
|
called = true
|
||||||
|
assert.Equal(t, userID, id)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
ts := tokenService{Tokens: mockStore}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/", nil)
|
||||||
|
req = req.WithContext(context.WithValue(
|
||||||
|
req.Context(),
|
||||||
|
userCtxKey{},
|
||||||
|
&userClaims{
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
ID: userID.String(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
ts.HandleLogout(w, req)
|
||||||
|
|
||||||
|
assert.True(t, called)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
}
|
268
server/pkg/service/users.go
Normal file
268
server/pkg/service/users.go
Normal file
@ -0,0 +1,268 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"git.umbrella.haus/ae/notatest/pkg/data"
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Mockable database operations interface
|
||||||
|
type UserStore interface {
|
||||||
|
CreateUser(ctx context.Context, arg data.CreateUserParams) (data.User, error)
|
||||||
|
ListUsers(ctx context.Context) ([]data.User, error)
|
||||||
|
GetUserByID(ctx context.Context, id uuid.UUID) (data.User, error)
|
||||||
|
UpdatePassword(ctx context.Context, arg data.UpdatePasswordParams) error
|
||||||
|
DeleteUser(ctx context.Context, id uuid.UUID) error
|
||||||
|
RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type usersResource struct {
|
||||||
|
JWTSecret string
|
||||||
|
Users UserStore
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rs usersResource) Routes() chi.Router {
|
||||||
|
r := chi.NewRouter()
|
||||||
|
|
||||||
|
// Public routes (no tokens required)
|
||||||
|
r.Post("/", rs.Create) // POST /users - registration/signup
|
||||||
|
|
||||||
|
// Protected routes (access token required)
|
||||||
|
r.Group(func(r chi.Router) {
|
||||||
|
r.Use(requireAccessToken(rs.JWTSecret))
|
||||||
|
|
||||||
|
// Admin only general routes
|
||||||
|
r.Group(func(r chi.Router) {
|
||||||
|
r.Use(adminOnlyMiddleware)
|
||||||
|
r.Get("/", rs.List) // GET /users - list all users
|
||||||
|
})
|
||||||
|
|
||||||
|
// User specific routes
|
||||||
|
r.Route("/{id}", func(r chi.Router) {
|
||||||
|
r.Use(userCtx(rs.Users)) // DB -> req. context
|
||||||
|
|
||||||
|
// Admin routes
|
||||||
|
r.Route("/admin", func(r chi.Router) {
|
||||||
|
r.Use(adminOnlyMiddleware)
|
||||||
|
r.Get("/", rs.Get) // GET /users/admin/{id} - get single user
|
||||||
|
r.Delete("/", rs.AdminDelete) // DELETE /users/admin/{id} - delete user
|
||||||
|
})
|
||||||
|
|
||||||
|
// Owner routes
|
||||||
|
r.Route("/owner", func(r chi.Router) {
|
||||||
|
r.Use(ownerOnlyMiddleware)
|
||||||
|
r.Put("/", rs.UpdatePassword) // PUT /users/owner/{id} - update user password
|
||||||
|
r.Delete("/", rs.OwnerDelete) // DELETE /users/owner/{id} - delete user
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rs usersResource) Create(w http.ResponseWriter, r *http.Request) {
|
||||||
|
type request struct {
|
||||||
|
Username string `json:"username"`
|
||||||
|
Password string `json:"password"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var req request
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
respondError(w, http.StatusBadRequest, "Invalid request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
normalizedUsername := normalizeUsername(req.Username)
|
||||||
|
if err := validateUsername(normalizedUsername); err != nil {
|
||||||
|
respondError(w, http.StatusBadRequest, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := validatePassword(req.Password); err != nil {
|
||||||
|
respondError(w, http.StatusBadRequest, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusInternalServerError, "Failed to create user")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := rs.Users.CreateUser(r.Context(), data.CreateUserParams{
|
||||||
|
Username: normalizedUsername,
|
||||||
|
PasswordHash: string(hashedPassword),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
if isDuplicateEntry(err) {
|
||||||
|
respondError(w, http.StatusConflict, "Username is already in use")
|
||||||
|
} else {
|
||||||
|
respondError(w, http.StatusInternalServerError, "Failed to create user")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
respondJSON(w, http.StatusCreated, map[string]string{
|
||||||
|
"id": user.ID.String(),
|
||||||
|
"username": user.Username,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rs usersResource) List(w http.ResponseWriter, r *http.Request) {
|
||||||
|
users, err := rs.Users.ListUsers(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusInternalServerError, "Failed to retrieve users")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Output sanitization
|
||||||
|
var output []map[string]interface{}
|
||||||
|
for _, user := range users {
|
||||||
|
output = append(output, map[string]interface{}{
|
||||||
|
"id": user.ID,
|
||||||
|
"username": user.Username,
|
||||||
|
"created_at": user.CreatedAt,
|
||||||
|
"updated_at": user.UpdatedAt,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
respondJSON(w, http.StatusOK, output)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rs usersResource) Get(w http.ResponseWriter, r *http.Request) {
|
||||||
|
user, ok := r.Context().Value(userCtxKey{}).(data.User)
|
||||||
|
if !ok {
|
||||||
|
respondError(w, http.StatusNotFound, "User not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
respondJSON(w, http.StatusOK, map[string]interface{}{
|
||||||
|
"id": user.ID,
|
||||||
|
"username": user.Username,
|
||||||
|
"created_at": user.CreatedAt,
|
||||||
|
"updated_at": user.UpdatedAt,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rs usersResource) UpdatePassword(w http.ResponseWriter, r *http.Request) {
|
||||||
|
type request struct {
|
||||||
|
OldPassword string `json:"old_password"`
|
||||||
|
NewPassword string `json:"new_password"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var req request
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
respondError(w, http.StatusBadRequest, "Invalid request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
user, ok := r.Context().Value(userCtxKey{}).(data.User)
|
||||||
|
if !ok {
|
||||||
|
respondError(w, http.StatusNotFound, "User not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the old password before allowing the update
|
||||||
|
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.OldPassword)); err != nil {
|
||||||
|
respondError(w, http.StatusUnauthorized, "Invalid credentials")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := validatePassword(req.NewPassword); err != nil {
|
||||||
|
respondError(w, http.StatusBadRequest, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.NewPassword), bcrypt.DefaultCost)
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusInternalServerError, "Failed to update password")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rs.Users.UpdatePassword(r.Context(), data.UpdatePasswordParams{
|
||||||
|
ID: user.ID,
|
||||||
|
PasswordHash: string(hashedPassword),
|
||||||
|
}); err != nil {
|
||||||
|
respondError(w, http.StatusInternalServerError, "Failed to update password")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rs.Users.RevokeAllUserRefreshTokens(r.Context(), user.ID); err != nil {
|
||||||
|
log.Error().Msgf("Failed to revoke refresh tokens: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rs usersResource) AdminDelete(w http.ResponseWriter, r *http.Request) {
|
||||||
|
user, ok := r.Context().Value(userCtxKey{}).(data.User)
|
||||||
|
if !ok {
|
||||||
|
respondError(w, http.StatusNotFound, "User not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rs.Users.DeleteUser(r.Context(), user.ID); err != nil {
|
||||||
|
respondError(w, http.StatusInternalServerError, "Failed to delete user")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rs.Users.RevokeAllUserRefreshTokens(r.Context(), user.ID); err != nil {
|
||||||
|
log.Error().Msgf("Failed to revoke refresh tokens: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rs usersResource) OwnerDelete(w http.ResponseWriter, r *http.Request) {
|
||||||
|
type request struct {
|
||||||
|
Password string `json:"password"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var req request
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
respondError(w, http.StatusBadRequest, "Invalid request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
user, ok := r.Context().Value(userCtxKey{}).(data.User)
|
||||||
|
if !ok {
|
||||||
|
respondError(w, http.StatusNotFound, "User not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the old password before allowing the deletion
|
||||||
|
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password)); err != nil {
|
||||||
|
respondError(w, http.StatusUnauthorized, "Invalid credentials")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := rs.Users.DeleteUser(r.Context(), user.ID)
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusInternalServerError, "Failed to delete user")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rs.Users.RevokeAllUserRefreshTokens(r.Context(), user.ID); err != nil {
|
||||||
|
log.Error().Msgf("Failed to revoke refresh tokens: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the given error is a PostgreSQL error for `unique_violation`, i.e. whether an entry
|
||||||
|
// with the given details already exists in the database table.
|
||||||
|
func isDuplicateEntry(err error) bool {
|
||||||
|
var pgErr *pgconn.PgError
|
||||||
|
if errors.As(err, &pgErr) {
|
||||||
|
return pgErr.Code == "23505"
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
259
server/pkg/service/users_test.go
Normal file
259
server/pkg/service/users_test.go
Normal file
@ -0,0 +1,259 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.umbrella.haus/ae/notatest/pkg/data"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockUserStore struct {
|
||||||
|
CreateUserFunc func(context.Context, data.CreateUserParams) (data.User, error)
|
||||||
|
ListUsersFunc func(context.Context) ([]data.User, error)
|
||||||
|
GetUserByIDFunc func(context.Context, uuid.UUID) (data.User, error)
|
||||||
|
UpdatePasswordFunc func(context.Context, data.UpdatePasswordParams) error
|
||||||
|
DeleteUserFunc func(context.Context, uuid.UUID) error
|
||||||
|
RevokeAllUserRefreshTokensFunc func(context.Context, uuid.UUID) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockUserStore) CreateUser(ctx context.Context, arg data.CreateUserParams) (data.User, error) {
|
||||||
|
return m.CreateUserFunc(ctx, arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockUserStore) ListUsers(ctx context.Context) ([]data.User, error) {
|
||||||
|
return m.ListUsersFunc(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockUserStore) GetUserByID(ctx context.Context, id uuid.UUID) (data.User, error) {
|
||||||
|
return m.GetUserByIDFunc(ctx, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockUserStore) UpdatePassword(ctx context.Context, arg data.UpdatePasswordParams) error {
|
||||||
|
return m.UpdatePasswordFunc(ctx, arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockUserStore) DeleteUser(ctx context.Context, id uuid.UUID) error {
|
||||||
|
return m.DeleteUserFunc(ctx, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockUserStore) RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error {
|
||||||
|
return m.RevokeAllUserRefreshTokensFunc(ctx, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateUser_Duplicate(t *testing.T) {
|
||||||
|
mockStore := &mockUserStore{
|
||||||
|
CreateUserFunc: func(ctx context.Context, arg data.CreateUserParams) (data.User, error) {
|
||||||
|
return data.User{}, &pgconn.PgError{Code: "23505"}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
rs := usersResource{Users: mockStore, JWTSecret: "test-secret"}
|
||||||
|
|
||||||
|
reqBody := `{"username": "existing", "password": "validPass123!"}`
|
||||||
|
req := httptest.NewRequest("POST", "/", strings.NewReader(reqBody))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
rs.Create(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusConflict, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), "Username is already in use")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateUser_InvalidUsername(t *testing.T) {
|
||||||
|
mockStore := &mockUserStore{} // No DB calls expected
|
||||||
|
rs := usersResource{Users: mockStore, JWTSecret: "test-secret"}
|
||||||
|
|
||||||
|
// Test various invalid usernames
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
body string
|
||||||
|
}{
|
||||||
|
{"Too short", `{"username": "a", "password": "validPass123!"}`},
|
||||||
|
{"Invalid chars", `{"username": "user@name", "password": "validPass123!"}`},
|
||||||
|
{"Empty", `{"username": "", "password": "validPass123!"}`},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("POST", "/", strings.NewReader(tc.body))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
rs.Create(w, req)
|
||||||
|
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateUser_InvalidPassword(t *testing.T) {
|
||||||
|
mockStore := &mockUserStore{}
|
||||||
|
rs := usersResource{Users: mockStore, JWTSecret: "test-secret"}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
body string
|
||||||
|
}{
|
||||||
|
{"too short", fmt.Sprintf(`{"username": "valid", "password": "%s"}`, strings.Repeat("a", minPasswordLength-1))},
|
||||||
|
{"too long", fmt.Sprintf(`{"username": "valid", "password": "%s"}`, strings.Repeat("a", maxPasswordLength+1))},
|
||||||
|
{"low entropy", fmt.Sprintf(`{"username": "valid", "password": "%s"}`, strings.Repeat("a", minPasswordLength))},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("POST", "/", strings.NewReader(tc.body))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
rs.Create(w, req)
|
||||||
|
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListUsers_Success(t *testing.T) {
|
||||||
|
testUsers := []data.User{
|
||||||
|
{ID: uuid.New(), Username: "user1"},
|
||||||
|
{ID: uuid.New(), Username: "user2"},
|
||||||
|
}
|
||||||
|
|
||||||
|
mockStore := &mockUserStore{
|
||||||
|
ListUsersFunc: func(ctx context.Context) ([]data.User, error) {
|
||||||
|
return testUsers, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
rs := usersResource{Users: mockStore}
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
rs.List(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
|
var response []map[string]interface{}
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Len(t, response, 2)
|
||||||
|
assert.Equal(t, "user1", response[0]["username"])
|
||||||
|
assert.NotContains(t, response[0], "password_hash")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdatePassword_InvalidOldPassword(t *testing.T) {
|
||||||
|
// User with password hash that won't match "wrongpassword"
|
||||||
|
user := data.User{
|
||||||
|
ID: uuid.New(),
|
||||||
|
PasswordHash: "$2a$10$PHhno.bZBF8IEINdFRZAPujMxIN65msElATgJG6FIxZdeWYVLSfFi", // Hash of "correctpassword"
|
||||||
|
}
|
||||||
|
|
||||||
|
mockStore := &mockUserStore{}
|
||||||
|
rs := usersResource{Users: mockStore}
|
||||||
|
|
||||||
|
reqBody := `{"old_password": "wrongpassword", "new_password": "NewValidPass321!"}`
|
||||||
|
req := httptest.NewRequest("PUT", "/", strings.NewReader(reqBody))
|
||||||
|
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, user))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
rs.UpdatePassword(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminDelete_Success(t *testing.T) {
|
||||||
|
user := data.User{ID: uuid.New()}
|
||||||
|
deleteCalled := false
|
||||||
|
revokeCalled := false
|
||||||
|
|
||||||
|
mockStore := &mockUserStore{
|
||||||
|
DeleteUserFunc: func(ctx context.Context, id uuid.UUID) error {
|
||||||
|
deleteCalled = true
|
||||||
|
assert.Equal(t, user.ID, id)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
RevokeAllUserRefreshTokensFunc: func(ctx context.Context, id uuid.UUID) error {
|
||||||
|
revokeCalled = true
|
||||||
|
assert.Equal(t, user.ID, id)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
rs := usersResource{Users: mockStore}
|
||||||
|
req := httptest.NewRequest("DELETE", "/", nil)
|
||||||
|
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, user))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
rs.AdminDelete(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusNoContent, w.Code)
|
||||||
|
assert.True(t, deleteCalled)
|
||||||
|
assert.True(t, revokeCalled)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOwnerDelete_InvalidCredentials(t *testing.T) {
|
||||||
|
// Create user with known password hash
|
||||||
|
correctPassword := "CorrectPass123!"
|
||||||
|
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(correctPassword), bcrypt.DefaultCost)
|
||||||
|
user := data.User{
|
||||||
|
ID: uuid.New(),
|
||||||
|
PasswordHash: string(hashedPassword),
|
||||||
|
}
|
||||||
|
|
||||||
|
mockStore := &mockUserStore{}
|
||||||
|
rs := usersResource{Users: mockStore}
|
||||||
|
|
||||||
|
reqBody := `{"password": "wrongpassword"}`
|
||||||
|
req := httptest.NewRequest("DELETE", "/", strings.NewReader(reqBody))
|
||||||
|
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, user))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
rs.OwnerDelete(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUser_NotFound(t *testing.T) {
|
||||||
|
mockStore := &mockUserStore{}
|
||||||
|
rs := usersResource{Users: mockStore}
|
||||||
|
|
||||||
|
// No user in context
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
rs.Get(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdatePassword_DatabaseError(t *testing.T) {
|
||||||
|
// Add user with a valid password to the context
|
||||||
|
oldPassword := "OldValidPass321!"
|
||||||
|
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(oldPassword), bcrypt.DefaultCost)
|
||||||
|
user := data.User{
|
||||||
|
ID: uuid.New(),
|
||||||
|
PasswordHash: string(hashedPassword),
|
||||||
|
}
|
||||||
|
mockStore := &mockUserStore{
|
||||||
|
UpdatePasswordFunc: func(ctx context.Context, arg data.UpdatePasswordParams) error {
|
||||||
|
return errors.New("database error")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
rs := usersResource{Users: mockStore}
|
||||||
|
|
||||||
|
reqBody := fmt.Sprintf(`{"old_password": "%s", "new_password": "NewValidPass123!"}`, oldPassword)
|
||||||
|
req := httptest.NewRequest("PUT", "/", strings.NewReader(reqBody))
|
||||||
|
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, user))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
rs.UpdatePassword(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), "Failed to update password")
|
||||||
|
}
|
133
server/pkg/service/util.go
Normal file
133
server/pkg/service/util.go
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha1"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"unicode/utf8"
|
||||||
|
|
||||||
|
passwordvalidator "github.com/wagslane/go-password-validator"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
minPasswordLength = 12 // Entropy checks prevent short passwords anyway
|
||||||
|
maxPasswordLength = 72 // Limitation of bcrypt
|
||||||
|
minPasswordEntropy = 60.0
|
||||||
|
|
||||||
|
minUsernameLength = 3
|
||||||
|
maxUsernameLength = 20
|
||||||
|
|
||||||
|
hibpAPI = "https://api.pwnedpasswords.com/range" // Doesn't require an API key
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
usernameRegex = regexp.MustCompile("^[a-z0-9_]+$")
|
||||||
|
)
|
||||||
|
|
||||||
|
func respondJSON(w http.ResponseWriter, status int, data interface{}) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(status)
|
||||||
|
json.NewEncoder(w).Encode(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func respondError(w http.ResponseWriter, status int, message string) {
|
||||||
|
respondJSON(w, status, map[string]string{"error": message})
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
Client-side check:
|
||||||
|
|
||||||
|
```
|
||||||
|
function estimateEntropy(password: string): number {
|
||||||
|
const pool: number = getCharsetSize(password); // Character diversity (R)
|
||||||
|
const entropy: number = password.length * Math.log2(pool); // E = L * log_2(R)
|
||||||
|
return entropy; // Value (E) that can be compared against a hardcoded threshold (e.g. 60)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Validate the given password using a hybrid approach: length (max. set due to bcrypt's input
|
||||||
|
// upper limit of 72 bytes), entropy, and HIBP API.
|
||||||
|
func validatePassword(password string) error {
|
||||||
|
if len(password) < minPasswordLength {
|
||||||
|
return fmt.Errorf("password must be at least %d characters", minPasswordLength)
|
||||||
|
}
|
||||||
|
if len(password) > maxPasswordLength {
|
||||||
|
return fmt.Errorf("password cannot be longer than %d characters", maxPasswordLength)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Formatted error message will contain tips to increase the password strength (safe to show)
|
||||||
|
err := passwordvalidator.Validate(password, minPasswordEntropy)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if compromised, _ := isPasswordCompromised(password); compromised {
|
||||||
|
return errors.New("password is compromised")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send the first five bytes of the password's SHA-1 hash to HIBP API, then check if the rest of
|
||||||
|
// the hash is present in the APi's response data (k-Anonymity model).
|
||||||
|
func isPasswordCompromised(password string) (bool, error) {
|
||||||
|
hash := sha1.Sum([]byte(password))
|
||||||
|
hashStr := strings.ToUpper(hex.EncodeToString(hash[:]))
|
||||||
|
prefix, suffix := hashStr[:5], hashStr[5:]
|
||||||
|
|
||||||
|
resp, err := http.Get(fmt.Sprintf("%s/%s", hibpAPI, prefix))
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Contains(strings.ToUpper(string(body)), suffix), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize the username by making it lowercase and trimming any leading or trailing whitespace.
|
||||||
|
func normalizeUsername(username string) string {
|
||||||
|
return strings.ToLower(strings.TrimSpace(username))
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
Client-side check (additionally input should automatically perform the normalization steps):
|
||||||
|
|
||||||
|
```
|
||||||
|
function validateUsername(username: string): string {
|
||||||
|
const min: number = 3, max: number = 20;
|
||||||
|
if (username.length < min) return "Too short";
|
||||||
|
if (username.length > max) return "Too long";
|
||||||
|
if (!/^[a-zA-Z0-9_]+$/.test(username)) return "Invalid characters";
|
||||||
|
return "Valid";
|
||||||
|
}
|
||||||
|
```
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Validate the given username by making sure it only contains alphanumeric characters or
|
||||||
|
// underscores and adheres the hardcoded minimum and maximum length rules.
|
||||||
|
func validateUsername(username string) error {
|
||||||
|
if utf8.RuneCountInString(username) < minUsernameLength {
|
||||||
|
return fmt.Errorf("username must be at least %d characters", minUsernameLength)
|
||||||
|
}
|
||||||
|
if utf8.RuneCountInString(username) > maxUsernameLength {
|
||||||
|
return fmt.Errorf("username cannot be longer than %d characters", maxUsernameLength)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !usernameRegex.MatchString(username) {
|
||||||
|
return errors.New("username can only contain numbers, letters, and underscores")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
184
server/pkg/service/util_test.go
Normal file
184
server/pkg/service/util_test.go
Normal file
@ -0,0 +1,184 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
|
passwordvalidator "github.com/wagslane/go-password-validator"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MockHTTPClient struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockHTTPClient) Get(url string) (*http.Response, error) {
|
||||||
|
args := m.Called(url)
|
||||||
|
return args.Get(0).(*http.Response), args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidatePassword(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
password string
|
||||||
|
wantErr string
|
||||||
|
mockHTTP func(*MockHTTPClient)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "too short",
|
||||||
|
password: strings.Repeat("a", minPasswordLength-1),
|
||||||
|
wantErr: fmt.Sprintf("password must be at least %d characters", minPasswordLength),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "too long",
|
||||||
|
password: strings.Repeat("a", maxPasswordLength+1),
|
||||||
|
wantErr: fmt.Sprintf("password cannot be longer than %d characters", maxPasswordLength),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "low entropy",
|
||||||
|
password: strings.Repeat("a", minPasswordLength),
|
||||||
|
wantErr: "insecure password", // Error produced by wagslane/go-password-validator
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid password",
|
||||||
|
password: "SecurePassw0rd!123",
|
||||||
|
wantErr: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Mock HTTP client if needed
|
||||||
|
if tt.mockHTTP != nil {
|
||||||
|
mockClient := new(MockHTTPClient)
|
||||||
|
tt.mockHTTP(mockClient)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := validatePassword(tt.password)
|
||||||
|
if tt.wantErr == "" {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
} else {
|
||||||
|
assert.ErrorContains(t, err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsPasswordCompromised(t *testing.T) {
|
||||||
|
t.Run("known compromised", func(t *testing.T) {
|
||||||
|
compromised, err := isPasswordCompromised("password123456")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, compromised)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("randomly generated", func(t *testing.T) {
|
||||||
|
randomStr := genRandomString(12)
|
||||||
|
compromised, err := isPasswordCompromised(randomStr)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.False(t, compromised)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPasswordEntropyCalculation(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
password string
|
||||||
|
entropy float64
|
||||||
|
}{
|
||||||
|
{"password", 37.6},
|
||||||
|
{"SecurePassw0rd!123", 103.12},
|
||||||
|
{"aaaaaaaaaaaaaaaa", 9.5},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.password, func(t *testing.T) {
|
||||||
|
entropy := passwordvalidator.GetEntropy(tt.password)
|
||||||
|
assert.InDelta(t, tt.entropy, entropy, 1.0)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateUsername(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
wantErr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "too short",
|
||||||
|
input: strings.Repeat("a", minUsernameLength-1),
|
||||||
|
wantErr: fmt.Sprintf("username must be at least %d characters", minUsernameLength),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "too long",
|
||||||
|
input: strings.Repeat("a", maxUsernameLength+1),
|
||||||
|
wantErr: fmt.Sprintf("username cannot be longer than %d characters", maxUsernameLength),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid characters",
|
||||||
|
input: "user@name",
|
||||||
|
wantErr: "username can only contain numbers, letters, and underscores",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid username",
|
||||||
|
input: "valid_user123",
|
||||||
|
wantErr: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := validateUsername(tt.input)
|
||||||
|
if tt.wantErr == "" {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
} else {
|
||||||
|
assert.ErrorContains(t, err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeUsername(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "trim whitespace",
|
||||||
|
input: " test_user ",
|
||||||
|
want: "test_user",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "lowercase",
|
||||||
|
input: "TestUser",
|
||||||
|
want: "testuser",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed case and spaces",
|
||||||
|
input: " UserName123 ",
|
||||||
|
want: "username123",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := normalizeUsername(tt.input)
|
||||||
|
assert.Equal(t, tt.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func genRandomString(length int) string {
|
||||||
|
const charset = "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||||
|
seededRand := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||||
|
b := make([]byte, length)
|
||||||
|
for i := range b {
|
||||||
|
b[i] = charset[seededRand.Intn(len(charset))]
|
||||||
|
}
|
||||||
|
return string(b)
|
||||||
|
}
|
49
server/sql/migrations/0001_initial.up.sql
Normal file
49
server/sql/migrations/0001_initial.up.sql
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS schema_migrations (
|
||||||
|
version BIGINT PRIMARY KEY,
|
||||||
|
applied_at TIMESTAMPTZ DEFAULT NOW()
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS users (
|
||||||
|
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
||||||
|
username TEXT UNIQUE NOT NULL,
|
||||||
|
password_hash TEXT NOT NULL,
|
||||||
|
is_admin BOOLEAN NOT NULL DEFAULT false,
|
||||||
|
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||||
|
updated_at TIMESTAMPTZ DEFAULT NOW()
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS refresh_tokens (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||||
|
token_hash TEXT NOT NULL,
|
||||||
|
expires_at TIMESTAMPTZ NOT NULL,
|
||||||
|
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||||
|
revoked BOOLEAN NOT NULL DEFAULT false
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS notes (
|
||||||
|
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
||||||
|
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||||
|
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||||
|
updated_at TIMESTAMPTZ DEFAULT NOW()
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS note_versions (
|
||||||
|
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
||||||
|
note_id UUID NOT NULL REFERENCES notes(id) ON DELETE CASCADE,
|
||||||
|
title TEXT NOT NULL,
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
version_number INT NOT NULL,
|
||||||
|
content_hash TEXT NOT NULL,
|
||||||
|
created_at TIMESTAMPTZ DEFAULT NOW()
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE UNIQUE INDEX IF NOT EXISTS idx_users_username ON users(username);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_refresh_tokens_user_id ON refresh_tokens(user_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_refresh_tokens_expires_at ON refresh_tokens(expires_at);
|
||||||
|
|
||||||
|
CREATE UNIQUE INDEX IF NOT EXISTS idx_note_version_unique ON note_versions(note_id, version_number);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_note_versions_note ON note_versions(note_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_note_versions_number ON note_versions(version_number DESC);
|
27
server/sql/queries/note_versions.sql
Normal file
27
server/sql/queries/note_versions.sql
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
-- name: CreateNoteVersion :one
|
||||||
|
INSERT INTO note_versions (note_id, title, content, version_number, content_hash)
|
||||||
|
VALUES (
|
||||||
|
$1,
|
||||||
|
$2,
|
||||||
|
$3,
|
||||||
|
(SELECT COALESCE(MAX(version_number), 0) + 1 FROM note_versions WHERE note_id = $1),
|
||||||
|
encode(sha256($2::bytea || '\n'::bytea || $3::bytea), 'hex')
|
||||||
|
)
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: GetNoteVersions :many
|
||||||
|
SELECT * FROM note_versions
|
||||||
|
WHERE note_id = $1
|
||||||
|
ORDER BY version_number DESC
|
||||||
|
LIMIT $2 OFFSET $3;
|
||||||
|
|
||||||
|
-- name: GetNoteVersion :one
|
||||||
|
SELECT * FROM note_versions
|
||||||
|
WHERE note_id = $1 AND version_number = $2 LIMIT 1;
|
||||||
|
|
||||||
|
-- name: FindDuplicateContent :one
|
||||||
|
SELECT EXISTS(
|
||||||
|
SELECT 1 FROM note_versions
|
||||||
|
WHERE note_id = $1
|
||||||
|
AND content_hash = encode(sha256($2::bytea || '\n'::bytea || $3::bytea), 'hex')
|
||||||
|
);
|
18
server/sql/queries/notes.sql
Normal file
18
server/sql/queries/notes.sql
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
-- name: CreateNote :one
|
||||||
|
INSERT INTO notes (user_id)
|
||||||
|
VALUES ($1)
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: GetNote :one
|
||||||
|
SELECT * FROM notes
|
||||||
|
WHERE id = $1 AND user_id = $2 LIMIT 1;
|
||||||
|
|
||||||
|
-- name: ListNotes :many
|
||||||
|
SELECT * FROM notes
|
||||||
|
WHERE user_id = $1
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
LIMIT $2 OFFSET $3;
|
||||||
|
|
||||||
|
-- name: DeleteNote :exec
|
||||||
|
DELETE FROM notes
|
||||||
|
WHERE id = $1 AND user_id = $2;
|
25
server/sql/queries/refresh_tokens.sql
Normal file
25
server/sql/queries/refresh_tokens.sql
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
-- name: CreateRefreshToken :one
|
||||||
|
INSERT INTO refresh_tokens (
|
||||||
|
user_id,
|
||||||
|
token_hash,
|
||||||
|
expires_at
|
||||||
|
) VALUES ($1, $2, $3)
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: GetRefreshTokenByHash :one
|
||||||
|
SELECT * FROM refresh_tokens
|
||||||
|
WHERE token_hash = $1 LIMIT 1;
|
||||||
|
|
||||||
|
-- name: RevokeRefreshToken :exec
|
||||||
|
UPDATE refresh_tokens
|
||||||
|
SET revoked = TRUE
|
||||||
|
WHERE token_hash = $1;
|
||||||
|
|
||||||
|
-- name: RevokeAllUserRefreshTokens :exec
|
||||||
|
UPDATE refresh_tokens
|
||||||
|
SET revoked = TRUE
|
||||||
|
WHERE user_id = $1;
|
||||||
|
|
||||||
|
-- name: DeleteExpiredRefreshTokens :exec
|
||||||
|
DELETE FROM refresh_tokens
|
||||||
|
WHERE expires_at < NOW();
|
24
server/sql/queries/users.sql
Normal file
24
server/sql/queries/users.sql
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
-- name: CreateUser :one
|
||||||
|
INSERT INTO users (username, password_hash)
|
||||||
|
VALUES ($1, $2)
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: ListUsers :many
|
||||||
|
SELECT * FROM users;
|
||||||
|
|
||||||
|
-- name: GetUserByID :one
|
||||||
|
SELECT * FROM users
|
||||||
|
WHERE id = $1 LIMIT 1;
|
||||||
|
|
||||||
|
-- name: GetUserByUsername :one
|
||||||
|
SELECT * FROM users
|
||||||
|
WHERE username = $1 LIMIT 1;
|
||||||
|
|
||||||
|
-- name: UpdatePassword :exec
|
||||||
|
UPDATE users
|
||||||
|
SET password_hash = $2, updated_at = NOW()
|
||||||
|
WHERE id = $1;
|
||||||
|
|
||||||
|
-- name: DeleteUser :exec
|
||||||
|
DELETE FROM users
|
||||||
|
WHERE id = $1;
|
24
server/sql/sqlc.yaml
Normal file
24
server/sql/sqlc.yaml
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
version: "2"
|
||||||
|
|
||||||
|
sql:
|
||||||
|
- engine: "postgresql"
|
||||||
|
queries: "./queries/"
|
||||||
|
schema: "./migrations/"
|
||||||
|
gen:
|
||||||
|
go:
|
||||||
|
package: "data"
|
||||||
|
out: "../pkg/data"
|
||||||
|
sql_package: "pgx/v5"
|
||||||
|
emit_json_tags: true
|
||||||
|
overrides:
|
||||||
|
- db_type: "timestamptz"
|
||||||
|
go_type:
|
||||||
|
type: "time.Time"
|
||||||
|
- db_type: "timestamptz"
|
||||||
|
nullable: true
|
||||||
|
go_type:
|
||||||
|
type: "*time.Time"
|
||||||
|
- db_type: "uuid"
|
||||||
|
go_type:
|
||||||
|
import: "github.com/google/uuid"
|
||||||
|
type: "UUID"
|
Loading…
x
Reference in New Issue
Block a user