Compare commits
7 Commits
ac169851d3
...
3257b19313
Author | SHA1 | Date | |
---|---|---|---|
3257b19313 | |||
41d1336f58 | |||
9ba182d925 | |||
66fde0a700 | |||
6569a399e3 | |||
de72ea53e1 | |||
e0bdf32bfd |
1
.gitignore
vendored
1
.gitignore
vendored
@ -1 +1,2 @@
|
||||
TASKS.md
|
||||
.env
|
||||
|
@ -1,3 +1,30 @@
|
||||
module git.umbrella.haus/ae/notatest
|
||||
|
||||
go 1.24.1
|
||||
|
||||
require (
|
||||
github.com/caarlos0/env v3.5.0+incompatible
|
||||
github.com/go-chi/chi/v5 v5.2.1
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/jackc/pgx/v5 v5.7.4
|
||||
github.com/rs/zerolog v1.34.0
|
||||
github.com/stretchr/testify v1.10.0
|
||||
github.com/wagslane/go-password-validator v0.3.0
|
||||
golang.org/x/crypto v0.36.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rogpeppe/go-internal v1.14.1 // indirect
|
||||
github.com/stretchr/objx v0.5.2 // indirect
|
||||
golang.org/x/sys v0.31.0 // indirect
|
||||
golang.org/x/text v0.23.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
67
server/go.sum
Normal file
67
server/go.sum
Normal file
@ -0,0 +1,67 @@
|
||||
github.com/caarlos0/env v3.5.0+incompatible h1:Yy0UN8o9Wtr/jGHZDpCBLpNrzcFLLM2yixi/rBrKyJs=
|
||||
github.com/caarlos0/env v3.5.0+incompatible/go.mod h1:tdCsowwCzMLdkqRYDlHpZCp2UooDD3MspDBjZ2AD02Y=
|
||||
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/go-chi/chi/v5 v5.2.1 h1:KOIHODQj58PmL80G2Eak4WdvUzjSJSm0vG72crDCqb8=
|
||||
github.com/go-chi/chi/v5 v5.2.1/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops=
|
||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx/v5 v5.7.4 h1:9wKznZrhWa2QiHL+NjTSPP6yjl3451BX3imWDnokYlg=
|
||||
github.com/jackc/pgx/v5 v5.7.4/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ=
|
||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
|
||||
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
|
||||
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
|
||||
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
|
||||
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
|
||||
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/wagslane/go-password-validator v0.3.0 h1:vfxOPzGHkz5S146HDpavl0cw1DSVP061Ry2PX0/ON6I=
|
||||
github.com/wagslane/go-password-validator v0.3.0/go.mod h1:TI1XJ6T5fRdRnHqHt14pvy1tNVnrwe7m3/f1f2fDphQ=
|
||||
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
||||
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
||||
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
|
||||
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
||||
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
|
||||
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
@ -1 +0,0 @@
|
||||
package internal
|
@ -1 +0,0 @@
|
||||
package internal
|
@ -1 +0,0 @@
|
||||
package internal
|
@ -1 +0,0 @@
|
||||
package internal
|
@ -5,4 +5,5 @@ import (
|
||||
)
|
||||
|
||||
func main() {
|
||||
internal.Run()
|
||||
}
|
32
server/pkg/data/db.go
Normal file
32
server/pkg/data/db.go
Normal file
@ -0,0 +1,32 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.28.0
|
||||
|
||||
package data
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
)
|
||||
|
||||
type DBTX interface {
|
||||
Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error)
|
||||
Query(context.Context, string, ...interface{}) (pgx.Rows, error)
|
||||
QueryRow(context.Context, string, ...interface{}) pgx.Row
|
||||
}
|
||||
|
||||
func New(db DBTX) *Queries {
|
||||
return &Queries{db: db}
|
||||
}
|
||||
|
||||
type Queries struct {
|
||||
db DBTX
|
||||
}
|
||||
|
||||
func (q *Queries) WithTx(tx pgx.Tx) *Queries {
|
||||
return &Queries{
|
||||
db: tx,
|
||||
}
|
||||
}
|
51
server/pkg/data/models.go
Normal file
51
server/pkg/data/models.go
Normal file
@ -0,0 +1,51 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.28.0
|
||||
|
||||
package data
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type Note struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
UserID uuid.UUID `json:"user_id"`
|
||||
CreatedAt *time.Time `json:"created_at"`
|
||||
UpdatedAt *time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type NoteVersion struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
NoteID uuid.UUID `json:"note_id"`
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
VersionNumber int32 `json:"version_number"`
|
||||
ContentHash string `json:"content_hash"`
|
||||
CreatedAt *time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type RefreshToken struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
UserID uuid.UUID `json:"user_id"`
|
||||
TokenHash string `json:"token_hash"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
CreatedAt *time.Time `json:"created_at"`
|
||||
Revoked bool `json:"revoked"`
|
||||
}
|
||||
|
||||
type SchemaMigration struct {
|
||||
Version int64 `json:"version"`
|
||||
AppliedAt *time.Time `json:"applied_at"`
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
Username string `json:"username"`
|
||||
PasswordHash string `json:"password_hash"`
|
||||
IsAdmin bool `json:"is_admin"`
|
||||
CreatedAt *time.Time `json:"created_at"`
|
||||
UpdatedAt *time.Time `json:"updated_at"`
|
||||
}
|
132
server/pkg/data/note_versions.sql.go
Normal file
132
server/pkg/data/note_versions.sql.go
Normal file
@ -0,0 +1,132 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.28.0
|
||||
// source: note_versions.sql
|
||||
|
||||
package data
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const createNoteVersion = `-- name: CreateNoteVersion :one
|
||||
INSERT INTO note_versions (note_id, title, content, version_number, content_hash)
|
||||
VALUES (
|
||||
$1,
|
||||
$2,
|
||||
$3,
|
||||
(SELECT COALESCE(MAX(version_number), 0) + 1 FROM note_versions WHERE note_id = $1),
|
||||
encode(sha256($2::bytea || '\n'::bytea || $3::bytea), 'hex')
|
||||
)
|
||||
RETURNING id, note_id, title, content, version_number, content_hash, created_at
|
||||
`
|
||||
|
||||
type CreateNoteVersionParams struct {
|
||||
NoteID uuid.UUID `json:"note_id"`
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
func (q *Queries) CreateNoteVersion(ctx context.Context, arg CreateNoteVersionParams) (NoteVersion, error) {
|
||||
row := q.db.QueryRow(ctx, createNoteVersion, arg.NoteID, arg.Title, arg.Content)
|
||||
var i NoteVersion
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.NoteID,
|
||||
&i.Title,
|
||||
&i.Content,
|
||||
&i.VersionNumber,
|
||||
&i.ContentHash,
|
||||
&i.CreatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const findDuplicateContent = `-- name: FindDuplicateContent :one
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM note_versions
|
||||
WHERE note_id = $1
|
||||
AND content_hash = encode(sha256($2::bytea || '\n'::bytea || $3::bytea), 'hex')
|
||||
)
|
||||
`
|
||||
|
||||
type FindDuplicateContentParams struct {
|
||||
NoteID uuid.UUID `json:"note_id"`
|
||||
Column2 []byte `json:"column_2"`
|
||||
Column3 []byte `json:"column_3"`
|
||||
}
|
||||
|
||||
func (q *Queries) FindDuplicateContent(ctx context.Context, arg FindDuplicateContentParams) (bool, error) {
|
||||
row := q.db.QueryRow(ctx, findDuplicateContent, arg.NoteID, arg.Column2, arg.Column3)
|
||||
var exists bool
|
||||
err := row.Scan(&exists)
|
||||
return exists, err
|
||||
}
|
||||
|
||||
const getNoteVersion = `-- name: GetNoteVersion :one
|
||||
SELECT id, note_id, title, content, version_number, content_hash, created_at FROM note_versions
|
||||
WHERE note_id = $1 AND version_number = $2 LIMIT 1
|
||||
`
|
||||
|
||||
type GetNoteVersionParams struct {
|
||||
NoteID uuid.UUID `json:"note_id"`
|
||||
VersionNumber int32 `json:"version_number"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetNoteVersion(ctx context.Context, arg GetNoteVersionParams) (NoteVersion, error) {
|
||||
row := q.db.QueryRow(ctx, getNoteVersion, arg.NoteID, arg.VersionNumber)
|
||||
var i NoteVersion
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.NoteID,
|
||||
&i.Title,
|
||||
&i.Content,
|
||||
&i.VersionNumber,
|
||||
&i.ContentHash,
|
||||
&i.CreatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getNoteVersions = `-- name: GetNoteVersions :many
|
||||
SELECT id, note_id, title, content, version_number, content_hash, created_at FROM note_versions
|
||||
WHERE note_id = $1
|
||||
ORDER BY version_number DESC
|
||||
LIMIT $2 OFFSET $3
|
||||
`
|
||||
|
||||
type GetNoteVersionsParams struct {
|
||||
NoteID uuid.UUID `json:"note_id"`
|
||||
Limit int32 `json:"limit"`
|
||||
Offset int32 `json:"offset"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetNoteVersions(ctx context.Context, arg GetNoteVersionsParams) ([]NoteVersion, error) {
|
||||
rows, err := q.db.Query(ctx, getNoteVersions, arg.NoteID, arg.Limit, arg.Offset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []NoteVersion
|
||||
for rows.Next() {
|
||||
var i NoteVersion
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.NoteID,
|
||||
&i.Title,
|
||||
&i.Content,
|
||||
&i.VersionNumber,
|
||||
&i.ContentHash,
|
||||
&i.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
105
server/pkg/data/notes.sql.go
Normal file
105
server/pkg/data/notes.sql.go
Normal file
@ -0,0 +1,105 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.28.0
|
||||
// source: notes.sql
|
||||
|
||||
package data
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const createNote = `-- name: CreateNote :one
|
||||
INSERT INTO notes (user_id)
|
||||
VALUES ($1)
|
||||
RETURNING id, user_id, created_at, updated_at
|
||||
`
|
||||
|
||||
func (q *Queries) CreateNote(ctx context.Context, userID uuid.UUID) (Note, error) {
|
||||
row := q.db.QueryRow(ctx, createNote, userID)
|
||||
var i Note
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.UserID,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const deleteNote = `-- name: DeleteNote :exec
|
||||
DELETE FROM notes
|
||||
WHERE id = $1 AND user_id = $2
|
||||
`
|
||||
|
||||
type DeleteNoteParams struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
UserID uuid.UUID `json:"user_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) DeleteNote(ctx context.Context, arg DeleteNoteParams) error {
|
||||
_, err := q.db.Exec(ctx, deleteNote, arg.ID, arg.UserID)
|
||||
return err
|
||||
}
|
||||
|
||||
const getNote = `-- name: GetNote :one
|
||||
SELECT id, user_id, created_at, updated_at FROM notes
|
||||
WHERE id = $1 AND user_id = $2 LIMIT 1
|
||||
`
|
||||
|
||||
type GetNoteParams struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
UserID uuid.UUID `json:"user_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetNote(ctx context.Context, arg GetNoteParams) (Note, error) {
|
||||
row := q.db.QueryRow(ctx, getNote, arg.ID, arg.UserID)
|
||||
var i Note
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.UserID,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const listNotes = `-- name: ListNotes :many
|
||||
SELECT id, user_id, created_at, updated_at FROM notes
|
||||
WHERE user_id = $1
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $2 OFFSET $3
|
||||
`
|
||||
|
||||
type ListNotesParams struct {
|
||||
UserID uuid.UUID `json:"user_id"`
|
||||
Limit int32 `json:"limit"`
|
||||
Offset int32 `json:"offset"`
|
||||
}
|
||||
|
||||
func (q *Queries) ListNotes(ctx context.Context, arg ListNotesParams) ([]Note, error) {
|
||||
rows, err := q.db.Query(ctx, listNotes, arg.UserID, arg.Limit, arg.Offset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Note
|
||||
for rows.Next() {
|
||||
var i Note
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.UserID,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
93
server/pkg/data/refresh_tokens.sql.go
Normal file
93
server/pkg/data/refresh_tokens.sql.go
Normal file
@ -0,0 +1,93 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.28.0
|
||||
// source: refresh_tokens.sql
|
||||
|
||||
package data
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const createRefreshToken = `-- name: CreateRefreshToken :one
|
||||
INSERT INTO refresh_tokens (
|
||||
user_id,
|
||||
token_hash,
|
||||
expires_at
|
||||
) VALUES ($1, $2, $3)
|
||||
RETURNING id, user_id, token_hash, expires_at, created_at, revoked
|
||||
`
|
||||
|
||||
type CreateRefreshTokenParams struct {
|
||||
UserID uuid.UUID `json:"user_id"`
|
||||
TokenHash string `json:"token_hash"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
}
|
||||
|
||||
func (q *Queries) CreateRefreshToken(ctx context.Context, arg CreateRefreshTokenParams) (RefreshToken, error) {
|
||||
row := q.db.QueryRow(ctx, createRefreshToken, arg.UserID, arg.TokenHash, arg.ExpiresAt)
|
||||
var i RefreshToken
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.UserID,
|
||||
&i.TokenHash,
|
||||
&i.ExpiresAt,
|
||||
&i.CreatedAt,
|
||||
&i.Revoked,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const deleteExpiredRefreshTokens = `-- name: DeleteExpiredRefreshTokens :exec
|
||||
DELETE FROM refresh_tokens
|
||||
WHERE expires_at < NOW()
|
||||
`
|
||||
|
||||
func (q *Queries) DeleteExpiredRefreshTokens(ctx context.Context) error {
|
||||
_, err := q.db.Exec(ctx, deleteExpiredRefreshTokens)
|
||||
return err
|
||||
}
|
||||
|
||||
const getRefreshTokenByHash = `-- name: GetRefreshTokenByHash :one
|
||||
SELECT id, user_id, token_hash, expires_at, created_at, revoked FROM refresh_tokens
|
||||
WHERE token_hash = $1 LIMIT 1
|
||||
`
|
||||
|
||||
func (q *Queries) GetRefreshTokenByHash(ctx context.Context, tokenHash string) (RefreshToken, error) {
|
||||
row := q.db.QueryRow(ctx, getRefreshTokenByHash, tokenHash)
|
||||
var i RefreshToken
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.UserID,
|
||||
&i.TokenHash,
|
||||
&i.ExpiresAt,
|
||||
&i.CreatedAt,
|
||||
&i.Revoked,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const revokeAllUserRefreshTokens = `-- name: RevokeAllUserRefreshTokens :exec
|
||||
UPDATE refresh_tokens
|
||||
SET revoked = TRUE
|
||||
WHERE user_id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) RevokeAllUserRefreshTokens(ctx context.Context, userID uuid.UUID) error {
|
||||
_, err := q.db.Exec(ctx, revokeAllUserRefreshTokens, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
const revokeRefreshToken = `-- name: RevokeRefreshToken :exec
|
||||
UPDATE refresh_tokens
|
||||
SET revoked = TRUE
|
||||
WHERE token_hash = $1
|
||||
`
|
||||
|
||||
func (q *Queries) RevokeRefreshToken(ctx context.Context, tokenHash string) error {
|
||||
_, err := q.db.Exec(ctx, revokeRefreshToken, tokenHash)
|
||||
return err
|
||||
}
|
132
server/pkg/data/users.sql.go
Normal file
132
server/pkg/data/users.sql.go
Normal file
@ -0,0 +1,132 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.28.0
|
||||
// source: users.sql
|
||||
|
||||
package data
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const createUser = `-- name: CreateUser :one
|
||||
INSERT INTO users (username, password_hash)
|
||||
VALUES ($1, $2)
|
||||
RETURNING id, username, password_hash, is_admin, created_at, updated_at
|
||||
`
|
||||
|
||||
type CreateUserParams struct {
|
||||
Username string `json:"username"`
|
||||
PasswordHash string `json:"password_hash"`
|
||||
}
|
||||
|
||||
func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, error) {
|
||||
row := q.db.QueryRow(ctx, createUser, arg.Username, arg.PasswordHash)
|
||||
var i User
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Username,
|
||||
&i.PasswordHash,
|
||||
&i.IsAdmin,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const deleteUser = `-- name: DeleteUser :exec
|
||||
DELETE FROM users
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) DeleteUser(ctx context.Context, id uuid.UUID) error {
|
||||
_, err := q.db.Exec(ctx, deleteUser, id)
|
||||
return err
|
||||
}
|
||||
|
||||
const getUserByID = `-- name: GetUserByID :one
|
||||
SELECT id, username, password_hash, is_admin, created_at, updated_at FROM users
|
||||
WHERE id = $1 LIMIT 1
|
||||
`
|
||||
|
||||
func (q *Queries) GetUserByID(ctx context.Context, id uuid.UUID) (User, error) {
|
||||
row := q.db.QueryRow(ctx, getUserByID, id)
|
||||
var i User
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Username,
|
||||
&i.PasswordHash,
|
||||
&i.IsAdmin,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getUserByUsername = `-- name: GetUserByUsername :one
|
||||
SELECT id, username, password_hash, is_admin, created_at, updated_at FROM users
|
||||
WHERE username = $1 LIMIT 1
|
||||
`
|
||||
|
||||
func (q *Queries) GetUserByUsername(ctx context.Context, username string) (User, error) {
|
||||
row := q.db.QueryRow(ctx, getUserByUsername, username)
|
||||
var i User
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Username,
|
||||
&i.PasswordHash,
|
||||
&i.IsAdmin,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const listUsers = `-- name: ListUsers :many
|
||||
SELECT id, username, password_hash, is_admin, created_at, updated_at FROM users
|
||||
`
|
||||
|
||||
func (q *Queries) ListUsers(ctx context.Context) ([]User, error) {
|
||||
rows, err := q.db.Query(ctx, listUsers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []User
|
||||
for rows.Next() {
|
||||
var i User
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.Username,
|
||||
&i.PasswordHash,
|
||||
&i.IsAdmin,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const updatePassword = `-- name: UpdatePassword :exec
|
||||
UPDATE users
|
||||
SET password_hash = $2, updated_at = NOW()
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
type UpdatePasswordParams struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
PasswordHash string `json:"password_hash"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdatePassword(ctx context.Context, arg UpdatePasswordParams) error {
|
||||
_, err := q.db.Exec(ctx, updatePassword, arg.ID, arg.PasswordHash)
|
||||
return err
|
||||
}
|
74
server/pkg/migrate/migrate.go
Normal file
74
server/pkg/migrate/migrate.go
Normal file
@ -0,0 +1,74 @@
|
||||
// pkg/migrate/migrate.go
|
||||
package migrate
|
||||
|
||||
import (
|
||||
"context"
|
||||
"embed"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
)
|
||||
|
||||
func Run(ctx context.Context, conn *pgx.Conn, migrationsFS embed.FS) error {
|
||||
// Get already applied migrations
|
||||
rows, _ := conn.Query(ctx, "SELECT version FROM schema_migrations")
|
||||
defer rows.Close()
|
||||
|
||||
applied := make(map[int64]bool)
|
||||
for rows.Next() {
|
||||
var version int64
|
||||
if err := rows.Scan(&version); err != nil {
|
||||
return err
|
||||
}
|
||||
applied[version] = true
|
||||
}
|
||||
|
||||
files, err := migrationsFS.ReadDir("migrations")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Apply the migrations sequentially based on their ordinal number
|
||||
for _, f := range sortMigrations(files) {
|
||||
version, err := strconv.ParseInt(strings.Split(f.Name(), "_")[0], 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid migration name: %s", f.Name())
|
||||
}
|
||||
|
||||
if applied[version] {
|
||||
continue
|
||||
}
|
||||
|
||||
// Run migration
|
||||
sql, err := migrationsFS.ReadFile("migrations/" + f.Name())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := conn.Exec(ctx, string(sql)); err != nil {
|
||||
return fmt.Errorf("migration %d failed: %w", version, err)
|
||||
}
|
||||
|
||||
if _, err := conn.Exec(ctx,
|
||||
"INSERT INTO schema_migrations (version) VALUES ($1)", version,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sort the migration files based on their ordinal number prefix.
|
||||
func sortMigrations(files []fs.DirEntry) []fs.DirEntry {
|
||||
sort.Slice(files, func(i, j int) bool {
|
||||
v1, _ := strconv.ParseInt(strings.Split(files[i].Name(), "_")[0], 10, 64)
|
||||
v2, _ := strconv.ParseInt(strings.Split(files[j].Name(), "_")[0], 10, 64)
|
||||
return v1 < v2
|
||||
})
|
||||
return files
|
||||
}
|
157
server/pkg/service/middleware.go
Normal file
157
server/pkg/service/middleware.go
Normal file
@ -0,0 +1,157 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
const (
|
||||
panicRecoveryMsg = "panic recovered"
|
||||
defaultLogMsg = "incoming request"
|
||||
)
|
||||
|
||||
type userCtxKey struct{}
|
||||
|
||||
// Get JWT bearer from request's authorization header, parse it with custom user claims, and
|
||||
// ensure its validity before attaching the claims to the request's context.
|
||||
func authMiddleware(jwtSecret string, expectedType string) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
tokenString, err := getTokenFromRequest(r)
|
||||
if err != nil {
|
||||
respondError(w, http.StatusUnauthorized, fmt.Sprintf("Unauthorized: %s", err))
|
||||
}
|
||||
|
||||
token, err := jwt.ParseWithClaims(tokenString, &userClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte(jwtSecret), nil
|
||||
})
|
||||
if err != nil || !token.Valid {
|
||||
respondError(w, http.StatusUnauthorized, "Invalid token")
|
||||
return
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*userClaims)
|
||||
if !ok || claims.TokenType != expectedType {
|
||||
respondError(w, http.StatusUnauthorized, "Invalid token type")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), userCtxKey{}, claims)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// JWT access token parsing, verification, and validation.
|
||||
func requireAccessToken(jwtSecret string) func(http.Handler) http.Handler {
|
||||
return authMiddleware(jwtSecret, "access")
|
||||
}
|
||||
|
||||
// JWT refresh token parsing, verification, and validation.
|
||||
func requireRefreshToken(jwtSecret string) func(http.Handler) http.Handler {
|
||||
return authMiddleware(jwtSecret, "refresh")
|
||||
}
|
||||
|
||||
// Ensure the current user is an administrator.
|
||||
func adminOnlyMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
||||
if !ok || !user.Admin {
|
||||
respondError(w, http.StatusForbidden, "Forbidden")
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// Ensure the targeted resource is owned by the current user (i.e. current user's ID matches with
|
||||
// the one stored into the resource).
|
||||
func ownerOnlyMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
||||
requestedID := chi.URLParam(r, "id")
|
||||
if !ok || user.ID != requestedID {
|
||||
respondError(w, http.StatusForbidden, "Forbidden")
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// Append user data into request's context based on user ID as a URL parameter.
|
||||
func userCtx(store UserStore) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
userIDStr := chi.URLParam(r, "id")
|
||||
userID, err := uuid.Parse(userIDStr)
|
||||
if err != nil {
|
||||
respondError(w, http.StatusNotFound, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
user, err := store.GetUserByID(r.Context(), userID)
|
||||
if err != nil {
|
||||
respondError(w, http.StatusNotFound, "User not found")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), userCtxKey{}, user)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Zerolog compatible logger middleware. Automatically logs and recovers from errors with HTTP 500
|
||||
// response, by default logs to INFO level.
|
||||
func loggerMiddleware(log *zerolog.Logger) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||
log := log.With().Logger()
|
||||
|
||||
ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
|
||||
|
||||
t1 := time.Now()
|
||||
defer func() {
|
||||
t2 := time.Now()
|
||||
|
||||
// Recover automatically and respond with HTTP 500
|
||||
if rec := recover(); rec != nil {
|
||||
log.Error().
|
||||
Str("type", "error").
|
||||
Timestamp().
|
||||
Interface("recover_info", rec).
|
||||
Msg(panicRecoveryMsg)
|
||||
http.Error(ww, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// Log a regular HTTP request with some metadata
|
||||
log.Info().
|
||||
Str("type", "access").
|
||||
Timestamp().
|
||||
Fields(map[string]interface{}{
|
||||
"remote_ip": r.RemoteAddr,
|
||||
"url": r.URL.Path,
|
||||
"proto": r.Proto,
|
||||
"method": r.Method,
|
||||
"user_agent": r.Header.Get("User-Agent"),
|
||||
"status": ww.Status(),
|
||||
"latency_ms": float64(t2.Sub(t1).Nanoseconds()) / 1000000.0,
|
||||
"bytes_in": r.Header.Get("Content-Length"),
|
||||
"bytes_out": ww.BytesWritten(),
|
||||
}).
|
||||
Msg(defaultLogMsg)
|
||||
}()
|
||||
|
||||
next.ServeHTTP(ww, r)
|
||||
}
|
||||
return http.HandlerFunc(fn)
|
||||
}
|
||||
}
|
19
server/pkg/service/notes.go
Normal file
19
server/pkg/service/notes.go
Normal file
@ -0,0 +1,19 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
// Mockable database operations interface
|
||||
type NoteStore interface {
|
||||
// TODO: implement
|
||||
}
|
||||
|
||||
type notesResource struct {
|
||||
Notes NoteStore
|
||||
}
|
||||
|
||||
func (rs notesResource) Routes() chi.Router {
|
||||
r := chi.NewRouter()
|
||||
return r
|
||||
}
|
43
server/pkg/service/server.go
Normal file
43
server/pkg/service/server.go
Normal file
@ -0,0 +1,43 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"git.umbrella.haus/ae/notatest/pkg/data"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func Run(conn *pgx.Conn, jwtSecret string, httpPort string) error {
|
||||
q := data.New(conn)
|
||||
r := chi.NewRouter()
|
||||
|
||||
tokenService := tokenService{
|
||||
JWTSecret: jwtSecret,
|
||||
Tokens: q,
|
||||
}
|
||||
usersRouter := usersResource{
|
||||
JWTSecret: jwtSecret,
|
||||
Users: q,
|
||||
}
|
||||
notesRouter := notesResource{}
|
||||
|
||||
// Global middlewares
|
||||
r.Use(middleware.RequestID)
|
||||
r.Use(middleware.RealIP)
|
||||
r.Use(loggerMiddleware(&log.Logger))
|
||||
r.Use(middleware.Recoverer)
|
||||
r.Use(middleware.AllowContentType("application/json"))
|
||||
|
||||
// Routes grouped by functionality
|
||||
r.Post("/auth/refresh", tokenService.RefreshAccessToken) // POST /auth/refresh - new access token for refresh token
|
||||
r.Mount("/users", usersRouter.Routes())
|
||||
r.Mount("/notes", notesRouter.Routes())
|
||||
|
||||
portStr := fmt.Sprintf(":%s", httpPort)
|
||||
log.Info().Msgf("Starting server on %s", portStr)
|
||||
return http.ListenAndServe(portStr, r)
|
||||
}
|
205
server/pkg/service/tokens.go
Normal file
205
server/pkg/service/tokens.go
Normal file
@ -0,0 +1,205 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.umbrella.haus/ae/notatest/pkg/data"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
accessTokenDuration = 15 * time.Minute
|
||||
refreshTokenDuration = 7 * 24 * time.Hour
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidToken = errors.New("invalid token")
|
||||
ErrAuthHeaderInvalid = errors.New("token couldn't be parsed from authentication header")
|
||||
)
|
||||
|
||||
type userClaims struct {
|
||||
Admin bool `json:"admin"`
|
||||
TokenType string `json:"type"` // "access" or "refresh"
|
||||
jwt.RegisteredClaims // User's UUID should be stored in the subject claim
|
||||
}
|
||||
|
||||
type tokenPair struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
|
||||
// Mockable database operations interface
|
||||
type TokenStore interface {
|
||||
CreateRefreshToken(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error)
|
||||
GetRefreshTokenByHash(ctx context.Context, tokenHash string) (data.RefreshToken, error)
|
||||
RevokeRefreshToken(ctx context.Context, tokenHash string) error
|
||||
RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error
|
||||
}
|
||||
|
||||
type tokenService struct {
|
||||
JWTSecret string
|
||||
Tokens TokenStore
|
||||
}
|
||||
|
||||
func (ts *tokenService) GenerateTokenPair(ctx context.Context, userID uuid.UUID, isAdmin bool) (*tokenPair, error) {
|
||||
tokenPair, err := generateTokenPair(userID.String(), isAdmin, ts.JWTSecret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hash := sha256.Sum256([]byte(tokenPair.RefreshToken))
|
||||
tokenHash := hex.EncodeToString(hash[:])
|
||||
|
||||
// Store to DB with (almost) identical expiration timestamp
|
||||
expiresAt := time.Now().Add(refreshTokenDuration)
|
||||
_, err = ts.Tokens.CreateRefreshToken(ctx, data.CreateRefreshTokenParams{
|
||||
UserID: userID,
|
||||
TokenHash: tokenHash,
|
||||
ExpiresAt: expiresAt,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tokenPair, nil
|
||||
}
|
||||
|
||||
func (ts *tokenService) RevokeRefreshToken(ctx context.Context, token string) error {
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
tokenHash := hex.EncodeToString(hash[:])
|
||||
return ts.Tokens.RevokeRefreshToken(ctx, tokenHash)
|
||||
}
|
||||
|
||||
func (ts *tokenService) ValidateRefreshToken(ctx context.Context, token string) (*data.RefreshToken, error) {
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
tokenHash := hex.EncodeToString(hash[:])
|
||||
|
||||
dbToken, err := ts.Tokens.GetRefreshTokenByHash(ctx, tokenHash)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if dbToken.Revoked || time.Now().After(dbToken.ExpiresAt) {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
return &dbToken, nil
|
||||
}
|
||||
|
||||
func (ts *tokenService) RefreshAccessToken(w http.ResponseWriter, r *http.Request) {
|
||||
// Get claims from context
|
||||
claims, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
||||
if !ok || claims.TokenType != "refresh" {
|
||||
respondError(w, http.StatusUnauthorized, "Invalid token")
|
||||
return
|
||||
}
|
||||
|
||||
// Attempt to get the token from Authentication header ("Bearer <token>")
|
||||
refreshToken, err := getTokenFromRequest(r)
|
||||
if err != nil {
|
||||
respondError(w, http.StatusUnauthorized, fmt.Sprintf("Unauthorized: %s", err))
|
||||
}
|
||||
|
||||
// Validate the refresh token in DB
|
||||
if _, err := ts.ValidateRefreshToken(r.Context(), refreshToken); err != nil {
|
||||
respondError(w, http.StatusUnauthorized, "Invalid refresh token")
|
||||
return
|
||||
}
|
||||
|
||||
// Revoke the used refresh token
|
||||
if err := ts.RevokeRefreshToken(r.Context(), refreshToken); err != nil {
|
||||
respondError(w, http.StatusInternalServerError, "Failed to revoke token")
|
||||
return
|
||||
}
|
||||
|
||||
// Generate a new pair (access & refresh tokens)
|
||||
userID, err := uuid.Parse(claims.Subject)
|
||||
if err != nil {
|
||||
respondError(w, http.StatusInternalServerError, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
tokenPair, err := ts.GenerateTokenPair(r.Context(), userID, claims.Admin)
|
||||
if err != nil {
|
||||
respondError(w, http.StatusInternalServerError, "Failed to generate tokens")
|
||||
return
|
||||
}
|
||||
|
||||
respondJSON(w, http.StatusOK, tokenPair)
|
||||
}
|
||||
|
||||
func (ts *tokenService) HandleLogout(w http.ResponseWriter, r *http.Request) {
|
||||
claims, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
||||
if !ok {
|
||||
respondError(w, http.StatusUnauthorized, "Not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
userID, err := uuid.Parse(claims.ID)
|
||||
if err != nil {
|
||||
respondError(w, http.StatusInternalServerError, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := ts.Tokens.RevokeAllUserRefreshTokens(r.Context(), userID); err != nil {
|
||||
respondError(w, http.StatusInternalServerError, "Failed to logout")
|
||||
return
|
||||
}
|
||||
|
||||
respondJSON(w, http.StatusOK, map[string]string{"status": "logged out"})
|
||||
}
|
||||
|
||||
func getTokenFromRequest(r *http.Request) (string, error) {
|
||||
bearerToken := r.Header.Get("Authorization")
|
||||
bearerFields := strings.Fields(bearerToken)
|
||||
if len(bearerFields) == 2 && strings.ToLower(bearerFields[0]) == "bearer" {
|
||||
return bearerFields[1], nil
|
||||
}
|
||||
return "", ErrAuthHeaderInvalid
|
||||
}
|
||||
|
||||
func generateTokenPair(userID string, isAdmin bool, jwtSecret string) (*tokenPair, error) {
|
||||
atClaims := userClaims{
|
||||
Admin: isAdmin,
|
||||
TokenType: "access",
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: userID,
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(accessTokenDuration)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
NotBefore: jwt.NewNumericDate(time.Now()),
|
||||
},
|
||||
}
|
||||
accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, atClaims)
|
||||
|
||||
t, err := accessToken.SignedString([]byte(jwtSecret))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rtClaims := userClaims{
|
||||
Admin: isAdmin,
|
||||
TokenType: "refresh",
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: userID,
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(refreshTokenDuration)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
NotBefore: jwt.NewNumericDate(time.Now()),
|
||||
},
|
||||
}
|
||||
refreshToken := jwt.NewWithClaims(jwt.SigningMethodHS256, rtClaims)
|
||||
|
||||
rt, err := refreshToken.SignedString([]byte(jwtSecret))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &tokenPair{AccessToken: t, RefreshToken: rt}, nil
|
||||
}
|
180
server/pkg/service/tokens_test.go
Normal file
180
server/pkg/service/tokens_test.go
Normal file
@ -0,0 +1,180 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.umbrella.haus/ae/notatest/pkg/data"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type mockTokenStore struct {
|
||||
CreateRefreshTokenFunc func(context.Context, data.CreateRefreshTokenParams) (data.RefreshToken, error)
|
||||
GetRefreshTokenByHashFunc func(context.Context, string) (data.RefreshToken, error)
|
||||
RevokeRefreshTokenFunc func(context.Context, string) error
|
||||
RevokeAllUserRefreshTokensFunc func(context.Context, uuid.UUID) error
|
||||
}
|
||||
|
||||
func (m *mockTokenStore) CreateRefreshToken(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) {
|
||||
return m.CreateRefreshTokenFunc(ctx, arg)
|
||||
}
|
||||
|
||||
func (m *mockTokenStore) GetRefreshTokenByHash(ctx context.Context, tokenHash string) (data.RefreshToken, error) {
|
||||
return m.GetRefreshTokenByHashFunc(ctx, tokenHash)
|
||||
}
|
||||
|
||||
func (m *mockTokenStore) RevokeRefreshToken(ctx context.Context, token string) error {
|
||||
return m.RevokeRefreshTokenFunc(ctx, token)
|
||||
}
|
||||
|
||||
func (m *mockTokenStore) RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error {
|
||||
return m.RevokeAllUserRefreshTokensFunc(ctx, id)
|
||||
}
|
||||
|
||||
func TestGenerateTokenPair_Success(t *testing.T) {
|
||||
userID := uuid.New()
|
||||
called := false
|
||||
|
||||
mockStore := &mockTokenStore{
|
||||
CreateRefreshTokenFunc: func(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) {
|
||||
called = true
|
||||
assert.Equal(t, userID, arg.UserID)
|
||||
return data.RefreshToken{}, nil
|
||||
},
|
||||
}
|
||||
|
||||
ts := tokenService{
|
||||
JWTSecret: "test-secret",
|
||||
Tokens: mockStore,
|
||||
}
|
||||
|
||||
pair, err := ts.GenerateTokenPair(context.Background(), userID, false)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, called)
|
||||
assert.NotEmpty(t, pair.AccessToken)
|
||||
assert.NotEmpty(t, pair.RefreshToken)
|
||||
}
|
||||
|
||||
func TestGenerateTokenPair_DBError(t *testing.T) {
|
||||
mockStore := &mockTokenStore{
|
||||
CreateRefreshTokenFunc: func(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) {
|
||||
return data.RefreshToken{}, errors.New("db error")
|
||||
},
|
||||
}
|
||||
|
||||
ts := tokenService{Tokens: mockStore}
|
||||
_, err := ts.GenerateTokenPair(context.Background(), uuid.New(), false)
|
||||
assert.ErrorContains(t, err, "db error")
|
||||
}
|
||||
|
||||
func TestValidateRefreshToken_Valid(t *testing.T) {
|
||||
token := "valid-jwt-token"
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
expectedHash := hex.EncodeToString(hash[:])
|
||||
|
||||
mockStore := &mockTokenStore{
|
||||
GetRefreshTokenByHashFunc: func(ctx context.Context, tokenHash string) (data.RefreshToken, error) {
|
||||
assert.Equal(t, expectedHash, tokenHash)
|
||||
return data.RefreshToken{ExpiresAt: time.Now().Add(1 * time.Hour)}, nil
|
||||
},
|
||||
}
|
||||
|
||||
ts := tokenService{Tokens: mockStore}
|
||||
_, err := ts.ValidateRefreshToken(context.Background(), token)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestRefreshAccessToken_Success(t *testing.T) {
|
||||
userID := uuid.New()
|
||||
refreshToken := "valid-jwt-token"
|
||||
|
||||
// Expected hash of the test token
|
||||
hash := sha256.Sum256([]byte(refreshToken))
|
||||
expectedHash := hex.EncodeToString(hash[:])
|
||||
|
||||
mockStore := &mockTokenStore{
|
||||
GetRefreshTokenByHashFunc: func(ctx context.Context, tokenHash string) (data.RefreshToken, error) {
|
||||
assert.Equal(t, expectedHash, tokenHash)
|
||||
|
||||
// Must return an unrevoked token with future expiration
|
||||
return data.RefreshToken{
|
||||
TokenHash: tokenHash,
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
Revoked: false,
|
||||
}, nil
|
||||
},
|
||||
RevokeRefreshTokenFunc: func(ctx context.Context, tokenHash string) error {
|
||||
assert.Equal(t, expectedHash, tokenHash)
|
||||
return nil
|
||||
},
|
||||
CreateRefreshTokenFunc: func(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) {
|
||||
return data.RefreshToken{}, nil
|
||||
},
|
||||
}
|
||||
|
||||
ts := tokenService{
|
||||
JWTSecret: "test-secret",
|
||||
Tokens: mockStore,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/", nil)
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", refreshToken))
|
||||
req = req.WithContext(context.WithValue(
|
||||
req.Context(),
|
||||
userCtxKey{},
|
||||
&userClaims{
|
||||
Admin: false,
|
||||
TokenType: "refresh",
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: userID.String(),
|
||||
},
|
||||
},
|
||||
))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
ts.RefreshAccessToken(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "access_token", "refresh_token")
|
||||
}
|
||||
|
||||
func TestHandleLogout_Success(t *testing.T) {
|
||||
userID := uuid.New()
|
||||
called := false
|
||||
|
||||
mockStore := &mockTokenStore{
|
||||
RevokeAllUserRefreshTokensFunc: func(ctx context.Context, id uuid.UUID) error {
|
||||
called = true
|
||||
assert.Equal(t, userID, id)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
ts := tokenService{Tokens: mockStore}
|
||||
|
||||
req := httptest.NewRequest("POST", "/", nil)
|
||||
req = req.WithContext(context.WithValue(
|
||||
req.Context(),
|
||||
userCtxKey{},
|
||||
&userClaims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ID: userID.String(),
|
||||
},
|
||||
},
|
||||
))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
ts.HandleLogout(w, req)
|
||||
|
||||
assert.True(t, called)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
268
server/pkg/service/users.go
Normal file
268
server/pkg/service/users.go
Normal file
@ -0,0 +1,268 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"git.umbrella.haus/ae/notatest/pkg/data"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// Mockable database operations interface
|
||||
type UserStore interface {
|
||||
CreateUser(ctx context.Context, arg data.CreateUserParams) (data.User, error)
|
||||
ListUsers(ctx context.Context) ([]data.User, error)
|
||||
GetUserByID(ctx context.Context, id uuid.UUID) (data.User, error)
|
||||
UpdatePassword(ctx context.Context, arg data.UpdatePasswordParams) error
|
||||
DeleteUser(ctx context.Context, id uuid.UUID) error
|
||||
RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error
|
||||
}
|
||||
|
||||
type usersResource struct {
|
||||
JWTSecret string
|
||||
Users UserStore
|
||||
}
|
||||
|
||||
func (rs usersResource) Routes() chi.Router {
|
||||
r := chi.NewRouter()
|
||||
|
||||
// Public routes (no tokens required)
|
||||
r.Post("/", rs.Create) // POST /users - registration/signup
|
||||
|
||||
// Protected routes (access token required)
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(requireAccessToken(rs.JWTSecret))
|
||||
|
||||
// Admin only general routes
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(adminOnlyMiddleware)
|
||||
r.Get("/", rs.List) // GET /users - list all users
|
||||
})
|
||||
|
||||
// User specific routes
|
||||
r.Route("/{id}", func(r chi.Router) {
|
||||
r.Use(userCtx(rs.Users)) // DB -> req. context
|
||||
|
||||
// Admin routes
|
||||
r.Route("/admin", func(r chi.Router) {
|
||||
r.Use(adminOnlyMiddleware)
|
||||
r.Get("/", rs.Get) // GET /users/admin/{id} - get single user
|
||||
r.Delete("/", rs.AdminDelete) // DELETE /users/admin/{id} - delete user
|
||||
})
|
||||
|
||||
// Owner routes
|
||||
r.Route("/owner", func(r chi.Router) {
|
||||
r.Use(ownerOnlyMiddleware)
|
||||
r.Put("/", rs.UpdatePassword) // PUT /users/owner/{id} - update user password
|
||||
r.Delete("/", rs.OwnerDelete) // DELETE /users/owner/{id} - delete user
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func (rs usersResource) Create(w http.ResponseWriter, r *http.Request) {
|
||||
type request struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
var req request
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
respondError(w, http.StatusBadRequest, "Invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
normalizedUsername := normalizeUsername(req.Username)
|
||||
if err := validateUsername(normalizedUsername); err != nil {
|
||||
respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := validatePassword(req.Password); err != nil {
|
||||
respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
respondError(w, http.StatusInternalServerError, "Failed to create user")
|
||||
return
|
||||
}
|
||||
|
||||
user, err := rs.Users.CreateUser(r.Context(), data.CreateUserParams{
|
||||
Username: normalizedUsername,
|
||||
PasswordHash: string(hashedPassword),
|
||||
})
|
||||
if err != nil {
|
||||
if isDuplicateEntry(err) {
|
||||
respondError(w, http.StatusConflict, "Username is already in use")
|
||||
} else {
|
||||
respondError(w, http.StatusInternalServerError, "Failed to create user")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
respondJSON(w, http.StatusCreated, map[string]string{
|
||||
"id": user.ID.String(),
|
||||
"username": user.Username,
|
||||
})
|
||||
}
|
||||
|
||||
func (rs usersResource) List(w http.ResponseWriter, r *http.Request) {
|
||||
users, err := rs.Users.ListUsers(r.Context())
|
||||
if err != nil {
|
||||
respondError(w, http.StatusInternalServerError, "Failed to retrieve users")
|
||||
return
|
||||
}
|
||||
|
||||
// Output sanitization
|
||||
var output []map[string]interface{}
|
||||
for _, user := range users {
|
||||
output = append(output, map[string]interface{}{
|
||||
"id": user.ID,
|
||||
"username": user.Username,
|
||||
"created_at": user.CreatedAt,
|
||||
"updated_at": user.UpdatedAt,
|
||||
})
|
||||
}
|
||||
|
||||
respondJSON(w, http.StatusOK, output)
|
||||
}
|
||||
|
||||
func (rs usersResource) Get(w http.ResponseWriter, r *http.Request) {
|
||||
user, ok := r.Context().Value(userCtxKey{}).(data.User)
|
||||
if !ok {
|
||||
respondError(w, http.StatusNotFound, "User not found")
|
||||
return
|
||||
}
|
||||
|
||||
respondJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"id": user.ID,
|
||||
"username": user.Username,
|
||||
"created_at": user.CreatedAt,
|
||||
"updated_at": user.UpdatedAt,
|
||||
})
|
||||
}
|
||||
|
||||
func (rs usersResource) UpdatePassword(w http.ResponseWriter, r *http.Request) {
|
||||
type request struct {
|
||||
OldPassword string `json:"old_password"`
|
||||
NewPassword string `json:"new_password"`
|
||||
}
|
||||
|
||||
var req request
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
respondError(w, http.StatusBadRequest, "Invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := r.Context().Value(userCtxKey{}).(data.User)
|
||||
if !ok {
|
||||
respondError(w, http.StatusNotFound, "User not found")
|
||||
return
|
||||
}
|
||||
|
||||
// Verify the old password before allowing the update
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.OldPassword)); err != nil {
|
||||
respondError(w, http.StatusUnauthorized, "Invalid credentials")
|
||||
return
|
||||
}
|
||||
|
||||
if err := validatePassword(req.NewPassword); err != nil {
|
||||
respondError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.NewPassword), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
respondError(w, http.StatusInternalServerError, "Failed to update password")
|
||||
return
|
||||
}
|
||||
|
||||
if err := rs.Users.UpdatePassword(r.Context(), data.UpdatePasswordParams{
|
||||
ID: user.ID,
|
||||
PasswordHash: string(hashedPassword),
|
||||
}); err != nil {
|
||||
respondError(w, http.StatusInternalServerError, "Failed to update password")
|
||||
return
|
||||
}
|
||||
|
||||
if err := rs.Users.RevokeAllUserRefreshTokens(r.Context(), user.ID); err != nil {
|
||||
log.Error().Msgf("Failed to revoke refresh tokens: %s", err)
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (rs usersResource) AdminDelete(w http.ResponseWriter, r *http.Request) {
|
||||
user, ok := r.Context().Value(userCtxKey{}).(data.User)
|
||||
if !ok {
|
||||
respondError(w, http.StatusNotFound, "User not found")
|
||||
return
|
||||
}
|
||||
|
||||
if err := rs.Users.DeleteUser(r.Context(), user.ID); err != nil {
|
||||
respondError(w, http.StatusInternalServerError, "Failed to delete user")
|
||||
return
|
||||
}
|
||||
|
||||
if err := rs.Users.RevokeAllUserRefreshTokens(r.Context(), user.ID); err != nil {
|
||||
log.Error().Msgf("Failed to revoke refresh tokens: %s", err)
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (rs usersResource) OwnerDelete(w http.ResponseWriter, r *http.Request) {
|
||||
type request struct {
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
var req request
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
respondError(w, http.StatusBadRequest, "Invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := r.Context().Value(userCtxKey{}).(data.User)
|
||||
if !ok {
|
||||
respondError(w, http.StatusNotFound, "User not found")
|
||||
return
|
||||
}
|
||||
|
||||
// Verify the old password before allowing the deletion
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password)); err != nil {
|
||||
respondError(w, http.StatusUnauthorized, "Invalid credentials")
|
||||
return
|
||||
}
|
||||
|
||||
err := rs.Users.DeleteUser(r.Context(), user.ID)
|
||||
if err != nil {
|
||||
respondError(w, http.StatusInternalServerError, "Failed to delete user")
|
||||
return
|
||||
}
|
||||
|
||||
if err := rs.Users.RevokeAllUserRefreshTokens(r.Context(), user.ID); err != nil {
|
||||
log.Error().Msgf("Failed to revoke refresh tokens: %s", err)
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// Check if the given error is a PostgreSQL error for `unique_violation`, i.e. whether an entry
|
||||
// with the given details already exists in the database table.
|
||||
func isDuplicateEntry(err error) bool {
|
||||
var pgErr *pgconn.PgError
|
||||
if errors.As(err, &pgErr) {
|
||||
return pgErr.Code == "23505"
|
||||
}
|
||||
return false
|
||||
}
|
259
server/pkg/service/users_test.go
Normal file
259
server/pkg/service/users_test.go
Normal file
@ -0,0 +1,259 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"git.umbrella.haus/ae/notatest/pkg/data"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
type mockUserStore struct {
|
||||
CreateUserFunc func(context.Context, data.CreateUserParams) (data.User, error)
|
||||
ListUsersFunc func(context.Context) ([]data.User, error)
|
||||
GetUserByIDFunc func(context.Context, uuid.UUID) (data.User, error)
|
||||
UpdatePasswordFunc func(context.Context, data.UpdatePasswordParams) error
|
||||
DeleteUserFunc func(context.Context, uuid.UUID) error
|
||||
RevokeAllUserRefreshTokensFunc func(context.Context, uuid.UUID) error
|
||||
}
|
||||
|
||||
func (m *mockUserStore) CreateUser(ctx context.Context, arg data.CreateUserParams) (data.User, error) {
|
||||
return m.CreateUserFunc(ctx, arg)
|
||||
}
|
||||
|
||||
func (m *mockUserStore) ListUsers(ctx context.Context) ([]data.User, error) {
|
||||
return m.ListUsersFunc(ctx)
|
||||
}
|
||||
|
||||
func (m *mockUserStore) GetUserByID(ctx context.Context, id uuid.UUID) (data.User, error) {
|
||||
return m.GetUserByIDFunc(ctx, id)
|
||||
}
|
||||
|
||||
func (m *mockUserStore) UpdatePassword(ctx context.Context, arg data.UpdatePasswordParams) error {
|
||||
return m.UpdatePasswordFunc(ctx, arg)
|
||||
}
|
||||
|
||||
func (m *mockUserStore) DeleteUser(ctx context.Context, id uuid.UUID) error {
|
||||
return m.DeleteUserFunc(ctx, id)
|
||||
}
|
||||
|
||||
func (m *mockUserStore) RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error {
|
||||
return m.RevokeAllUserRefreshTokensFunc(ctx, id)
|
||||
}
|
||||
|
||||
func TestCreateUser_Duplicate(t *testing.T) {
|
||||
mockStore := &mockUserStore{
|
||||
CreateUserFunc: func(ctx context.Context, arg data.CreateUserParams) (data.User, error) {
|
||||
return data.User{}, &pgconn.PgError{Code: "23505"}
|
||||
},
|
||||
}
|
||||
|
||||
rs := usersResource{Users: mockStore, JWTSecret: "test-secret"}
|
||||
|
||||
reqBody := `{"username": "existing", "password": "validPass123!"}`
|
||||
req := httptest.NewRequest("POST", "/", strings.NewReader(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
rs.Create(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusConflict, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "Username is already in use")
|
||||
}
|
||||
|
||||
func TestCreateUser_InvalidUsername(t *testing.T) {
|
||||
mockStore := &mockUserStore{} // No DB calls expected
|
||||
rs := usersResource{Users: mockStore, JWTSecret: "test-secret"}
|
||||
|
||||
// Test various invalid usernames
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
}{
|
||||
{"Too short", `{"username": "a", "password": "validPass123!"}`},
|
||||
{"Invalid chars", `{"username": "user@name", "password": "validPass123!"}`},
|
||||
{"Empty", `{"username": "", "password": "validPass123!"}`},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/", strings.NewReader(tc.body))
|
||||
w := httptest.NewRecorder()
|
||||
rs.Create(w, req)
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateUser_InvalidPassword(t *testing.T) {
|
||||
mockStore := &mockUserStore{}
|
||||
rs := usersResource{Users: mockStore, JWTSecret: "test-secret"}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
}{
|
||||
{"too short", fmt.Sprintf(`{"username": "valid", "password": "%s"}`, strings.Repeat("a", minPasswordLength-1))},
|
||||
{"too long", fmt.Sprintf(`{"username": "valid", "password": "%s"}`, strings.Repeat("a", maxPasswordLength+1))},
|
||||
{"low entropy", fmt.Sprintf(`{"username": "valid", "password": "%s"}`, strings.Repeat("a", minPasswordLength))},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/", strings.NewReader(tc.body))
|
||||
w := httptest.NewRecorder()
|
||||
rs.Create(w, req)
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestListUsers_Success(t *testing.T) {
|
||||
testUsers := []data.User{
|
||||
{ID: uuid.New(), Username: "user1"},
|
||||
{ID: uuid.New(), Username: "user2"},
|
||||
}
|
||||
|
||||
mockStore := &mockUserStore{
|
||||
ListUsersFunc: func(ctx context.Context) ([]data.User, error) {
|
||||
return testUsers, nil
|
||||
},
|
||||
}
|
||||
|
||||
rs := usersResource{Users: mockStore}
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
rs.List(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var response []map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
assert.Len(t, response, 2)
|
||||
assert.Equal(t, "user1", response[0]["username"])
|
||||
assert.NotContains(t, response[0], "password_hash")
|
||||
}
|
||||
|
||||
func TestUpdatePassword_InvalidOldPassword(t *testing.T) {
|
||||
// User with password hash that won't match "wrongpassword"
|
||||
user := data.User{
|
||||
ID: uuid.New(),
|
||||
PasswordHash: "$2a$10$PHhno.bZBF8IEINdFRZAPujMxIN65msElATgJG6FIxZdeWYVLSfFi", // Hash of "correctpassword"
|
||||
}
|
||||
|
||||
mockStore := &mockUserStore{}
|
||||
rs := usersResource{Users: mockStore}
|
||||
|
||||
reqBody := `{"old_password": "wrongpassword", "new_password": "NewValidPass321!"}`
|
||||
req := httptest.NewRequest("PUT", "/", strings.NewReader(reqBody))
|
||||
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, user))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
rs.UpdatePassword(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
}
|
||||
|
||||
func TestAdminDelete_Success(t *testing.T) {
|
||||
user := data.User{ID: uuid.New()}
|
||||
deleteCalled := false
|
||||
revokeCalled := false
|
||||
|
||||
mockStore := &mockUserStore{
|
||||
DeleteUserFunc: func(ctx context.Context, id uuid.UUID) error {
|
||||
deleteCalled = true
|
||||
assert.Equal(t, user.ID, id)
|
||||
return nil
|
||||
},
|
||||
RevokeAllUserRefreshTokensFunc: func(ctx context.Context, id uuid.UUID) error {
|
||||
revokeCalled = true
|
||||
assert.Equal(t, user.ID, id)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
rs := usersResource{Users: mockStore}
|
||||
req := httptest.NewRequest("DELETE", "/", nil)
|
||||
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, user))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
rs.AdminDelete(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusNoContent, w.Code)
|
||||
assert.True(t, deleteCalled)
|
||||
assert.True(t, revokeCalled)
|
||||
}
|
||||
|
||||
func TestOwnerDelete_InvalidCredentials(t *testing.T) {
|
||||
// Create user with known password hash
|
||||
correctPassword := "CorrectPass123!"
|
||||
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(correctPassword), bcrypt.DefaultCost)
|
||||
user := data.User{
|
||||
ID: uuid.New(),
|
||||
PasswordHash: string(hashedPassword),
|
||||
}
|
||||
|
||||
mockStore := &mockUserStore{}
|
||||
rs := usersResource{Users: mockStore}
|
||||
|
||||
reqBody := `{"password": "wrongpassword"}`
|
||||
req := httptest.NewRequest("DELETE", "/", strings.NewReader(reqBody))
|
||||
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, user))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
rs.OwnerDelete(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
}
|
||||
|
||||
func TestGetUser_NotFound(t *testing.T) {
|
||||
mockStore := &mockUserStore{}
|
||||
rs := usersResource{Users: mockStore}
|
||||
|
||||
// No user in context
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
rs.Get(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
}
|
||||
|
||||
func TestUpdatePassword_DatabaseError(t *testing.T) {
|
||||
// Add user with a valid password to the context
|
||||
oldPassword := "OldValidPass321!"
|
||||
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(oldPassword), bcrypt.DefaultCost)
|
||||
user := data.User{
|
||||
ID: uuid.New(),
|
||||
PasswordHash: string(hashedPassword),
|
||||
}
|
||||
mockStore := &mockUserStore{
|
||||
UpdatePasswordFunc: func(ctx context.Context, arg data.UpdatePasswordParams) error {
|
||||
return errors.New("database error")
|
||||
},
|
||||
}
|
||||
|
||||
rs := usersResource{Users: mockStore}
|
||||
|
||||
reqBody := fmt.Sprintf(`{"old_password": "%s", "new_password": "NewValidPass123!"}`, oldPassword)
|
||||
req := httptest.NewRequest("PUT", "/", strings.NewReader(reqBody))
|
||||
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, user))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
rs.UpdatePassword(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "Failed to update password")
|
||||
}
|
133
server/pkg/service/util.go
Normal file
133
server/pkg/service/util.go
Normal file
@ -0,0 +1,133 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
passwordvalidator "github.com/wagslane/go-password-validator"
|
||||
)
|
||||
|
||||
const (
|
||||
minPasswordLength = 12 // Entropy checks prevent short passwords anyway
|
||||
maxPasswordLength = 72 // Limitation of bcrypt
|
||||
minPasswordEntropy = 60.0
|
||||
|
||||
minUsernameLength = 3
|
||||
maxUsernameLength = 20
|
||||
|
||||
hibpAPI = "https://api.pwnedpasswords.com/range" // Doesn't require an API key
|
||||
)
|
||||
|
||||
var (
|
||||
usernameRegex = regexp.MustCompile("^[a-z0-9_]+$")
|
||||
)
|
||||
|
||||
func respondJSON(w http.ResponseWriter, status int, data interface{}) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
json.NewEncoder(w).Encode(data)
|
||||
}
|
||||
|
||||
func respondError(w http.ResponseWriter, status int, message string) {
|
||||
respondJSON(w, status, map[string]string{"error": message})
|
||||
}
|
||||
|
||||
/*
|
||||
Client-side check:
|
||||
|
||||
```
|
||||
function estimateEntropy(password: string): number {
|
||||
const pool: number = getCharsetSize(password); // Character diversity (R)
|
||||
const entropy: number = password.length * Math.log2(pool); // E = L * log_2(R)
|
||||
return entropy; // Value (E) that can be compared against a hardcoded threshold (e.g. 60)
|
||||
}
|
||||
```
|
||||
*/
|
||||
|
||||
// Validate the given password using a hybrid approach: length (max. set due to bcrypt's input
|
||||
// upper limit of 72 bytes), entropy, and HIBP API.
|
||||
func validatePassword(password string) error {
|
||||
if len(password) < minPasswordLength {
|
||||
return fmt.Errorf("password must be at least %d characters", minPasswordLength)
|
||||
}
|
||||
if len(password) > maxPasswordLength {
|
||||
return fmt.Errorf("password cannot be longer than %d characters", maxPasswordLength)
|
||||
}
|
||||
|
||||
// Formatted error message will contain tips to increase the password strength (safe to show)
|
||||
err := passwordvalidator.Validate(password, minPasswordEntropy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if compromised, _ := isPasswordCompromised(password); compromised {
|
||||
return errors.New("password is compromised")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send the first five bytes of the password's SHA-1 hash to HIBP API, then check if the rest of
|
||||
// the hash is present in the APi's response data (k-Anonymity model).
|
||||
func isPasswordCompromised(password string) (bool, error) {
|
||||
hash := sha1.Sum([]byte(password))
|
||||
hashStr := strings.ToUpper(hex.EncodeToString(hash[:]))
|
||||
prefix, suffix := hashStr[:5], hashStr[5:]
|
||||
|
||||
resp, err := http.Get(fmt.Sprintf("%s/%s", hibpAPI, prefix))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return strings.Contains(strings.ToUpper(string(body)), suffix), nil
|
||||
}
|
||||
|
||||
// Normalize the username by making it lowercase and trimming any leading or trailing whitespace.
|
||||
func normalizeUsername(username string) string {
|
||||
return strings.ToLower(strings.TrimSpace(username))
|
||||
}
|
||||
|
||||
/*
|
||||
Client-side check (additionally input should automatically perform the normalization steps):
|
||||
|
||||
```
|
||||
function validateUsername(username: string): string {
|
||||
const min: number = 3, max: number = 20;
|
||||
if (username.length < min) return "Too short";
|
||||
if (username.length > max) return "Too long";
|
||||
if (!/^[a-zA-Z0-9_]+$/.test(username)) return "Invalid characters";
|
||||
return "Valid";
|
||||
}
|
||||
```
|
||||
*/
|
||||
|
||||
// Validate the given username by making sure it only contains alphanumeric characters or
|
||||
// underscores and adheres the hardcoded minimum and maximum length rules.
|
||||
func validateUsername(username string) error {
|
||||
if utf8.RuneCountInString(username) < minUsernameLength {
|
||||
return fmt.Errorf("username must be at least %d characters", minUsernameLength)
|
||||
}
|
||||
if utf8.RuneCountInString(username) > maxUsernameLength {
|
||||
return fmt.Errorf("username cannot be longer than %d characters", maxUsernameLength)
|
||||
}
|
||||
|
||||
if !usernameRegex.MatchString(username) {
|
||||
return errors.New("username can only contain numbers, letters, and underscores")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
184
server/pkg/service/util_test.go
Normal file
184
server/pkg/service/util_test.go
Normal file
@ -0,0 +1,184 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
passwordvalidator "github.com/wagslane/go-password-validator"
|
||||
)
|
||||
|
||||
type MockHTTPClient struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockHTTPClient) Get(url string) (*http.Response, error) {
|
||||
args := m.Called(url)
|
||||
return args.Get(0).(*http.Response), args.Error(1)
|
||||
}
|
||||
|
||||
func TestValidatePassword(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
password string
|
||||
wantErr string
|
||||
mockHTTP func(*MockHTTPClient)
|
||||
}{
|
||||
{
|
||||
name: "too short",
|
||||
password: strings.Repeat("a", minPasswordLength-1),
|
||||
wantErr: fmt.Sprintf("password must be at least %d characters", minPasswordLength),
|
||||
},
|
||||
{
|
||||
name: "too long",
|
||||
password: strings.Repeat("a", maxPasswordLength+1),
|
||||
wantErr: fmt.Sprintf("password cannot be longer than %d characters", maxPasswordLength),
|
||||
},
|
||||
{
|
||||
name: "low entropy",
|
||||
password: strings.Repeat("a", minPasswordLength),
|
||||
wantErr: "insecure password", // Error produced by wagslane/go-password-validator
|
||||
},
|
||||
{
|
||||
name: "valid password",
|
||||
password: "SecurePassw0rd!123",
|
||||
wantErr: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Mock HTTP client if needed
|
||||
if tt.mockHTTP != nil {
|
||||
mockClient := new(MockHTTPClient)
|
||||
tt.mockHTTP(mockClient)
|
||||
}
|
||||
|
||||
err := validatePassword(tt.password)
|
||||
if tt.wantErr == "" {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.ErrorContains(t, err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsPasswordCompromised(t *testing.T) {
|
||||
t.Run("known compromised", func(t *testing.T) {
|
||||
compromised, err := isPasswordCompromised("password123456")
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, compromised)
|
||||
})
|
||||
|
||||
t.Run("randomly generated", func(t *testing.T) {
|
||||
randomStr := genRandomString(12)
|
||||
compromised, err := isPasswordCompromised(randomStr)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, compromised)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPasswordEntropyCalculation(t *testing.T) {
|
||||
tests := []struct {
|
||||
password string
|
||||
entropy float64
|
||||
}{
|
||||
{"password", 37.6},
|
||||
{"SecurePassw0rd!123", 103.12},
|
||||
{"aaaaaaaaaaaaaaaa", 9.5},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.password, func(t *testing.T) {
|
||||
entropy := passwordvalidator.GetEntropy(tt.password)
|
||||
assert.InDelta(t, tt.entropy, entropy, 1.0)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateUsername(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "too short",
|
||||
input: strings.Repeat("a", minUsernameLength-1),
|
||||
wantErr: fmt.Sprintf("username must be at least %d characters", minUsernameLength),
|
||||
},
|
||||
{
|
||||
name: "too long",
|
||||
input: strings.Repeat("a", maxUsernameLength+1),
|
||||
wantErr: fmt.Sprintf("username cannot be longer than %d characters", maxUsernameLength),
|
||||
},
|
||||
{
|
||||
name: "invalid characters",
|
||||
input: "user@name",
|
||||
wantErr: "username can only contain numbers, letters, and underscores",
|
||||
},
|
||||
{
|
||||
name: "valid username",
|
||||
input: "valid_user123",
|
||||
wantErr: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateUsername(tt.input)
|
||||
if tt.wantErr == "" {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.ErrorContains(t, err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeUsername(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "trim whitespace",
|
||||
input: " test_user ",
|
||||
want: "test_user",
|
||||
},
|
||||
{
|
||||
name: "lowercase",
|
||||
input: "TestUser",
|
||||
want: "testuser",
|
||||
},
|
||||
{
|
||||
name: "mixed case and spaces",
|
||||
input: " UserName123 ",
|
||||
want: "username123",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := normalizeUsername(tt.input)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func genRandomString(length int) string {
|
||||
const charset = "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
seededRand := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
b := make([]byte, length)
|
||||
for i := range b {
|
||||
b[i] = charset[seededRand.Intn(len(charset))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
49
server/sql/migrations/0001_initial.up.sql
Normal file
49
server/sql/migrations/0001_initial.up.sql
Normal file
@ -0,0 +1,49 @@
|
||||
CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
|
||||
|
||||
CREATE TABLE IF NOT EXISTS schema_migrations (
|
||||
version BIGINT PRIMARY KEY,
|
||||
applied_at TIMESTAMPTZ DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
||||
username TEXT UNIQUE NOT NULL,
|
||||
password_hash TEXT NOT NULL,
|
||||
is_admin BOOLEAN NOT NULL DEFAULT false,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS refresh_tokens (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
token_hash TEXT NOT NULL,
|
||||
expires_at TIMESTAMPTZ NOT NULL,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
revoked BOOLEAN NOT NULL DEFAULT false
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS notes (
|
||||
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
||||
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS note_versions (
|
||||
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
||||
note_id UUID NOT NULL REFERENCES notes(id) ON DELETE CASCADE,
|
||||
title TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
version_number INT NOT NULL,
|
||||
content_hash TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_users_username ON users(username);
|
||||
CREATE INDEX IF NOT EXISTS idx_refresh_tokens_user_id ON refresh_tokens(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_refresh_tokens_expires_at ON refresh_tokens(expires_at);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_note_version_unique ON note_versions(note_id, version_number);
|
||||
CREATE INDEX IF NOT EXISTS idx_note_versions_note ON note_versions(note_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_note_versions_number ON note_versions(version_number DESC);
|
27
server/sql/queries/note_versions.sql
Normal file
27
server/sql/queries/note_versions.sql
Normal file
@ -0,0 +1,27 @@
|
||||
-- name: CreateNoteVersion :one
|
||||
INSERT INTO note_versions (note_id, title, content, version_number, content_hash)
|
||||
VALUES (
|
||||
$1,
|
||||
$2,
|
||||
$3,
|
||||
(SELECT COALESCE(MAX(version_number), 0) + 1 FROM note_versions WHERE note_id = $1),
|
||||
encode(sha256($2::bytea || '\n'::bytea || $3::bytea), 'hex')
|
||||
)
|
||||
RETURNING *;
|
||||
|
||||
-- name: GetNoteVersions :many
|
||||
SELECT * FROM note_versions
|
||||
WHERE note_id = $1
|
||||
ORDER BY version_number DESC
|
||||
LIMIT $2 OFFSET $3;
|
||||
|
||||
-- name: GetNoteVersion :one
|
||||
SELECT * FROM note_versions
|
||||
WHERE note_id = $1 AND version_number = $2 LIMIT 1;
|
||||
|
||||
-- name: FindDuplicateContent :one
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM note_versions
|
||||
WHERE note_id = $1
|
||||
AND content_hash = encode(sha256($2::bytea || '\n'::bytea || $3::bytea), 'hex')
|
||||
);
|
18
server/sql/queries/notes.sql
Normal file
18
server/sql/queries/notes.sql
Normal file
@ -0,0 +1,18 @@
|
||||
-- name: CreateNote :one
|
||||
INSERT INTO notes (user_id)
|
||||
VALUES ($1)
|
||||
RETURNING *;
|
||||
|
||||
-- name: GetNote :one
|
||||
SELECT * FROM notes
|
||||
WHERE id = $1 AND user_id = $2 LIMIT 1;
|
||||
|
||||
-- name: ListNotes :many
|
||||
SELECT * FROM notes
|
||||
WHERE user_id = $1
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $2 OFFSET $3;
|
||||
|
||||
-- name: DeleteNote :exec
|
||||
DELETE FROM notes
|
||||
WHERE id = $1 AND user_id = $2;
|
25
server/sql/queries/refresh_tokens.sql
Normal file
25
server/sql/queries/refresh_tokens.sql
Normal file
@ -0,0 +1,25 @@
|
||||
-- name: CreateRefreshToken :one
|
||||
INSERT INTO refresh_tokens (
|
||||
user_id,
|
||||
token_hash,
|
||||
expires_at
|
||||
) VALUES ($1, $2, $3)
|
||||
RETURNING *;
|
||||
|
||||
-- name: GetRefreshTokenByHash :one
|
||||
SELECT * FROM refresh_tokens
|
||||
WHERE token_hash = $1 LIMIT 1;
|
||||
|
||||
-- name: RevokeRefreshToken :exec
|
||||
UPDATE refresh_tokens
|
||||
SET revoked = TRUE
|
||||
WHERE token_hash = $1;
|
||||
|
||||
-- name: RevokeAllUserRefreshTokens :exec
|
||||
UPDATE refresh_tokens
|
||||
SET revoked = TRUE
|
||||
WHERE user_id = $1;
|
||||
|
||||
-- name: DeleteExpiredRefreshTokens :exec
|
||||
DELETE FROM refresh_tokens
|
||||
WHERE expires_at < NOW();
|
24
server/sql/queries/users.sql
Normal file
24
server/sql/queries/users.sql
Normal file
@ -0,0 +1,24 @@
|
||||
-- name: CreateUser :one
|
||||
INSERT INTO users (username, password_hash)
|
||||
VALUES ($1, $2)
|
||||
RETURNING *;
|
||||
|
||||
-- name: ListUsers :many
|
||||
SELECT * FROM users;
|
||||
|
||||
-- name: GetUserByID :one
|
||||
SELECT * FROM users
|
||||
WHERE id = $1 LIMIT 1;
|
||||
|
||||
-- name: GetUserByUsername :one
|
||||
SELECT * FROM users
|
||||
WHERE username = $1 LIMIT 1;
|
||||
|
||||
-- name: UpdatePassword :exec
|
||||
UPDATE users
|
||||
SET password_hash = $2, updated_at = NOW()
|
||||
WHERE id = $1;
|
||||
|
||||
-- name: DeleteUser :exec
|
||||
DELETE FROM users
|
||||
WHERE id = $1;
|
24
server/sql/sqlc.yaml
Normal file
24
server/sql/sqlc.yaml
Normal file
@ -0,0 +1,24 @@
|
||||
version: "2"
|
||||
|
||||
sql:
|
||||
- engine: "postgresql"
|
||||
queries: "./queries/"
|
||||
schema: "./migrations/"
|
||||
gen:
|
||||
go:
|
||||
package: "data"
|
||||
out: "../pkg/data"
|
||||
sql_package: "pgx/v5"
|
||||
emit_json_tags: true
|
||||
overrides:
|
||||
- db_type: "timestamptz"
|
||||
go_type:
|
||||
type: "time.Time"
|
||||
- db_type: "timestamptz"
|
||||
nullable: true
|
||||
go_type:
|
||||
type: "*time.Time"
|
||||
- db_type: "uuid"
|
||||
go_type:
|
||||
import: "github.com/google/uuid"
|
||||
type: "UUID"
|
Loading…
x
Reference in New Issue
Block a user