Compare commits

...

7 Commits

30 changed files with 2330 additions and 4 deletions

1
.gitignore vendored
View File

@ -1 +1,2 @@
TASKS.md
.env

View File

@ -1,3 +1,30 @@
module git.umbrella.haus/ae/notatest
go 1.24.1
require (
github.com/caarlos0/env v3.5.0+incompatible
github.com/go-chi/chi/v5 v5.2.1
github.com/golang-jwt/jwt/v5 v5.2.2
github.com/google/uuid v1.6.0
github.com/jackc/pgx/v5 v5.7.4
github.com/rs/zerolog v1.34.0
github.com/stretchr/testify v1.10.0
github.com/wagslane/go-password-validator v0.3.0
golang.org/x/crypto v0.36.0
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.14.1 // indirect
github.com/stretchr/objx v0.5.2 // indirect
golang.org/x/sys v0.31.0 // indirect
golang.org/x/text v0.23.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

67
server/go.sum Normal file
View File

@ -0,0 +1,67 @@
github.com/caarlos0/env v3.5.0+incompatible h1:Yy0UN8o9Wtr/jGHZDpCBLpNrzcFLLM2yixi/rBrKyJs=
github.com/caarlos0/env v3.5.0+incompatible/go.mod h1:tdCsowwCzMLdkqRYDlHpZCp2UooDD3MspDBjZ2AD02Y=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-chi/chi/v5 v5.2.1 h1:KOIHODQj58PmL80G2Eak4WdvUzjSJSm0vG72crDCqb8=
github.com/go-chi/chi/v5 v5.2.1/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.7.4 h1:9wKznZrhWa2QiHL+NjTSPP6yjl3451BX3imWDnokYlg=
github.com/jackc/pgx/v5 v5.7.4/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/wagslane/go-password-validator v0.3.0 h1:vfxOPzGHkz5S146HDpavl0cw1DSVP061Ry2PX0/ON6I=
github.com/wagslane/go-password-validator v0.3.0/go.mod h1:TI1XJ6T5fRdRnHqHt14pvy1tNVnrwe7m3/f1f2fDphQ=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -1 +0,0 @@
package internal

View File

@ -1 +0,0 @@
package internal

View File

@ -1 +0,0 @@
package internal

View File

@ -1 +0,0 @@
package internal

View File

@ -5,4 +5,5 @@ import (
)
func main() {
internal.Run()
}

32
server/pkg/data/db.go Normal file
View File

@ -0,0 +1,32 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.28.0
package data
import (
"context"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
)
type DBTX interface {
Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error)
Query(context.Context, string, ...interface{}) (pgx.Rows, error)
QueryRow(context.Context, string, ...interface{}) pgx.Row
}
func New(db DBTX) *Queries {
return &Queries{db: db}
}
type Queries struct {
db DBTX
}
func (q *Queries) WithTx(tx pgx.Tx) *Queries {
return &Queries{
db: tx,
}
}

51
server/pkg/data/models.go Normal file
View File

@ -0,0 +1,51 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.28.0
package data
import (
"time"
"github.com/google/uuid"
)
type Note struct {
ID uuid.UUID `json:"id"`
UserID uuid.UUID `json:"user_id"`
CreatedAt *time.Time `json:"created_at"`
UpdatedAt *time.Time `json:"updated_at"`
}
type NoteVersion struct {
ID uuid.UUID `json:"id"`
NoteID uuid.UUID `json:"note_id"`
Title string `json:"title"`
Content string `json:"content"`
VersionNumber int32 `json:"version_number"`
ContentHash string `json:"content_hash"`
CreatedAt *time.Time `json:"created_at"`
}
type RefreshToken struct {
ID uuid.UUID `json:"id"`
UserID uuid.UUID `json:"user_id"`
TokenHash string `json:"token_hash"`
ExpiresAt time.Time `json:"expires_at"`
CreatedAt *time.Time `json:"created_at"`
Revoked bool `json:"revoked"`
}
type SchemaMigration struct {
Version int64 `json:"version"`
AppliedAt *time.Time `json:"applied_at"`
}
type User struct {
ID uuid.UUID `json:"id"`
Username string `json:"username"`
PasswordHash string `json:"password_hash"`
IsAdmin bool `json:"is_admin"`
CreatedAt *time.Time `json:"created_at"`
UpdatedAt *time.Time `json:"updated_at"`
}

View File

@ -0,0 +1,132 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.28.0
// source: note_versions.sql
package data
import (
"context"
"github.com/google/uuid"
)
const createNoteVersion = `-- name: CreateNoteVersion :one
INSERT INTO note_versions (note_id, title, content, version_number, content_hash)
VALUES (
$1,
$2,
$3,
(SELECT COALESCE(MAX(version_number), 0) + 1 FROM note_versions WHERE note_id = $1),
encode(sha256($2::bytea || '\n'::bytea || $3::bytea), 'hex')
)
RETURNING id, note_id, title, content, version_number, content_hash, created_at
`
type CreateNoteVersionParams struct {
NoteID uuid.UUID `json:"note_id"`
Title string `json:"title"`
Content string `json:"content"`
}
func (q *Queries) CreateNoteVersion(ctx context.Context, arg CreateNoteVersionParams) (NoteVersion, error) {
row := q.db.QueryRow(ctx, createNoteVersion, arg.NoteID, arg.Title, arg.Content)
var i NoteVersion
err := row.Scan(
&i.ID,
&i.NoteID,
&i.Title,
&i.Content,
&i.VersionNumber,
&i.ContentHash,
&i.CreatedAt,
)
return i, err
}
const findDuplicateContent = `-- name: FindDuplicateContent :one
SELECT EXISTS(
SELECT 1 FROM note_versions
WHERE note_id = $1
AND content_hash = encode(sha256($2::bytea || '\n'::bytea || $3::bytea), 'hex')
)
`
type FindDuplicateContentParams struct {
NoteID uuid.UUID `json:"note_id"`
Column2 []byte `json:"column_2"`
Column3 []byte `json:"column_3"`
}
func (q *Queries) FindDuplicateContent(ctx context.Context, arg FindDuplicateContentParams) (bool, error) {
row := q.db.QueryRow(ctx, findDuplicateContent, arg.NoteID, arg.Column2, arg.Column3)
var exists bool
err := row.Scan(&exists)
return exists, err
}
const getNoteVersion = `-- name: GetNoteVersion :one
SELECT id, note_id, title, content, version_number, content_hash, created_at FROM note_versions
WHERE note_id = $1 AND version_number = $2 LIMIT 1
`
type GetNoteVersionParams struct {
NoteID uuid.UUID `json:"note_id"`
VersionNumber int32 `json:"version_number"`
}
func (q *Queries) GetNoteVersion(ctx context.Context, arg GetNoteVersionParams) (NoteVersion, error) {
row := q.db.QueryRow(ctx, getNoteVersion, arg.NoteID, arg.VersionNumber)
var i NoteVersion
err := row.Scan(
&i.ID,
&i.NoteID,
&i.Title,
&i.Content,
&i.VersionNumber,
&i.ContentHash,
&i.CreatedAt,
)
return i, err
}
const getNoteVersions = `-- name: GetNoteVersions :many
SELECT id, note_id, title, content, version_number, content_hash, created_at FROM note_versions
WHERE note_id = $1
ORDER BY version_number DESC
LIMIT $2 OFFSET $3
`
type GetNoteVersionsParams struct {
NoteID uuid.UUID `json:"note_id"`
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
}
func (q *Queries) GetNoteVersions(ctx context.Context, arg GetNoteVersionsParams) ([]NoteVersion, error) {
rows, err := q.db.Query(ctx, getNoteVersions, arg.NoteID, arg.Limit, arg.Offset)
if err != nil {
return nil, err
}
defer rows.Close()
var items []NoteVersion
for rows.Next() {
var i NoteVersion
if err := rows.Scan(
&i.ID,
&i.NoteID,
&i.Title,
&i.Content,
&i.VersionNumber,
&i.ContentHash,
&i.CreatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}

View File

@ -0,0 +1,105 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.28.0
// source: notes.sql
package data
import (
"context"
"github.com/google/uuid"
)
const createNote = `-- name: CreateNote :one
INSERT INTO notes (user_id)
VALUES ($1)
RETURNING id, user_id, created_at, updated_at
`
func (q *Queries) CreateNote(ctx context.Context, userID uuid.UUID) (Note, error) {
row := q.db.QueryRow(ctx, createNote, userID)
var i Note
err := row.Scan(
&i.ID,
&i.UserID,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const deleteNote = `-- name: DeleteNote :exec
DELETE FROM notes
WHERE id = $1 AND user_id = $2
`
type DeleteNoteParams struct {
ID uuid.UUID `json:"id"`
UserID uuid.UUID `json:"user_id"`
}
func (q *Queries) DeleteNote(ctx context.Context, arg DeleteNoteParams) error {
_, err := q.db.Exec(ctx, deleteNote, arg.ID, arg.UserID)
return err
}
const getNote = `-- name: GetNote :one
SELECT id, user_id, created_at, updated_at FROM notes
WHERE id = $1 AND user_id = $2 LIMIT 1
`
type GetNoteParams struct {
ID uuid.UUID `json:"id"`
UserID uuid.UUID `json:"user_id"`
}
func (q *Queries) GetNote(ctx context.Context, arg GetNoteParams) (Note, error) {
row := q.db.QueryRow(ctx, getNote, arg.ID, arg.UserID)
var i Note
err := row.Scan(
&i.ID,
&i.UserID,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const listNotes = `-- name: ListNotes :many
SELECT id, user_id, created_at, updated_at FROM notes
WHERE user_id = $1
ORDER BY created_at DESC
LIMIT $2 OFFSET $3
`
type ListNotesParams struct {
UserID uuid.UUID `json:"user_id"`
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
}
func (q *Queries) ListNotes(ctx context.Context, arg ListNotesParams) ([]Note, error) {
rows, err := q.db.Query(ctx, listNotes, arg.UserID, arg.Limit, arg.Offset)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Note
for rows.Next() {
var i Note
if err := rows.Scan(
&i.ID,
&i.UserID,
&i.CreatedAt,
&i.UpdatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}

View File

@ -0,0 +1,93 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.28.0
// source: refresh_tokens.sql
package data
import (
"context"
"time"
"github.com/google/uuid"
)
const createRefreshToken = `-- name: CreateRefreshToken :one
INSERT INTO refresh_tokens (
user_id,
token_hash,
expires_at
) VALUES ($1, $2, $3)
RETURNING id, user_id, token_hash, expires_at, created_at, revoked
`
type CreateRefreshTokenParams struct {
UserID uuid.UUID `json:"user_id"`
TokenHash string `json:"token_hash"`
ExpiresAt time.Time `json:"expires_at"`
}
func (q *Queries) CreateRefreshToken(ctx context.Context, arg CreateRefreshTokenParams) (RefreshToken, error) {
row := q.db.QueryRow(ctx, createRefreshToken, arg.UserID, arg.TokenHash, arg.ExpiresAt)
var i RefreshToken
err := row.Scan(
&i.ID,
&i.UserID,
&i.TokenHash,
&i.ExpiresAt,
&i.CreatedAt,
&i.Revoked,
)
return i, err
}
const deleteExpiredRefreshTokens = `-- name: DeleteExpiredRefreshTokens :exec
DELETE FROM refresh_tokens
WHERE expires_at < NOW()
`
func (q *Queries) DeleteExpiredRefreshTokens(ctx context.Context) error {
_, err := q.db.Exec(ctx, deleteExpiredRefreshTokens)
return err
}
const getRefreshTokenByHash = `-- name: GetRefreshTokenByHash :one
SELECT id, user_id, token_hash, expires_at, created_at, revoked FROM refresh_tokens
WHERE token_hash = $1 LIMIT 1
`
func (q *Queries) GetRefreshTokenByHash(ctx context.Context, tokenHash string) (RefreshToken, error) {
row := q.db.QueryRow(ctx, getRefreshTokenByHash, tokenHash)
var i RefreshToken
err := row.Scan(
&i.ID,
&i.UserID,
&i.TokenHash,
&i.ExpiresAt,
&i.CreatedAt,
&i.Revoked,
)
return i, err
}
const revokeAllUserRefreshTokens = `-- name: RevokeAllUserRefreshTokens :exec
UPDATE refresh_tokens
SET revoked = TRUE
WHERE user_id = $1
`
func (q *Queries) RevokeAllUserRefreshTokens(ctx context.Context, userID uuid.UUID) error {
_, err := q.db.Exec(ctx, revokeAllUserRefreshTokens, userID)
return err
}
const revokeRefreshToken = `-- name: RevokeRefreshToken :exec
UPDATE refresh_tokens
SET revoked = TRUE
WHERE token_hash = $1
`
func (q *Queries) RevokeRefreshToken(ctx context.Context, tokenHash string) error {
_, err := q.db.Exec(ctx, revokeRefreshToken, tokenHash)
return err
}

View File

@ -0,0 +1,132 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.28.0
// source: users.sql
package data
import (
"context"
"github.com/google/uuid"
)
const createUser = `-- name: CreateUser :one
INSERT INTO users (username, password_hash)
VALUES ($1, $2)
RETURNING id, username, password_hash, is_admin, created_at, updated_at
`
type CreateUserParams struct {
Username string `json:"username"`
PasswordHash string `json:"password_hash"`
}
func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, error) {
row := q.db.QueryRow(ctx, createUser, arg.Username, arg.PasswordHash)
var i User
err := row.Scan(
&i.ID,
&i.Username,
&i.PasswordHash,
&i.IsAdmin,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const deleteUser = `-- name: DeleteUser :exec
DELETE FROM users
WHERE id = $1
`
func (q *Queries) DeleteUser(ctx context.Context, id uuid.UUID) error {
_, err := q.db.Exec(ctx, deleteUser, id)
return err
}
const getUserByID = `-- name: GetUserByID :one
SELECT id, username, password_hash, is_admin, created_at, updated_at FROM users
WHERE id = $1 LIMIT 1
`
func (q *Queries) GetUserByID(ctx context.Context, id uuid.UUID) (User, error) {
row := q.db.QueryRow(ctx, getUserByID, id)
var i User
err := row.Scan(
&i.ID,
&i.Username,
&i.PasswordHash,
&i.IsAdmin,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const getUserByUsername = `-- name: GetUserByUsername :one
SELECT id, username, password_hash, is_admin, created_at, updated_at FROM users
WHERE username = $1 LIMIT 1
`
func (q *Queries) GetUserByUsername(ctx context.Context, username string) (User, error) {
row := q.db.QueryRow(ctx, getUserByUsername, username)
var i User
err := row.Scan(
&i.ID,
&i.Username,
&i.PasswordHash,
&i.IsAdmin,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const listUsers = `-- name: ListUsers :many
SELECT id, username, password_hash, is_admin, created_at, updated_at FROM users
`
func (q *Queries) ListUsers(ctx context.Context) ([]User, error) {
rows, err := q.db.Query(ctx, listUsers)
if err != nil {
return nil, err
}
defer rows.Close()
var items []User
for rows.Next() {
var i User
if err := rows.Scan(
&i.ID,
&i.Username,
&i.PasswordHash,
&i.IsAdmin,
&i.CreatedAt,
&i.UpdatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const updatePassword = `-- name: UpdatePassword :exec
UPDATE users
SET password_hash = $2, updated_at = NOW()
WHERE id = $1
`
type UpdatePasswordParams struct {
ID uuid.UUID `json:"id"`
PasswordHash string `json:"password_hash"`
}
func (q *Queries) UpdatePassword(ctx context.Context, arg UpdatePasswordParams) error {
_, err := q.db.Exec(ctx, updatePassword, arg.ID, arg.PasswordHash)
return err
}

View File

@ -0,0 +1,74 @@
// pkg/migrate/migrate.go
package migrate
import (
"context"
"embed"
"fmt"
"io/fs"
"sort"
"strconv"
"strings"
"github.com/jackc/pgx/v5"
)
func Run(ctx context.Context, conn *pgx.Conn, migrationsFS embed.FS) error {
// Get already applied migrations
rows, _ := conn.Query(ctx, "SELECT version FROM schema_migrations")
defer rows.Close()
applied := make(map[int64]bool)
for rows.Next() {
var version int64
if err := rows.Scan(&version); err != nil {
return err
}
applied[version] = true
}
files, err := migrationsFS.ReadDir("migrations")
if err != nil {
return err
}
// Apply the migrations sequentially based on their ordinal number
for _, f := range sortMigrations(files) {
version, err := strconv.ParseInt(strings.Split(f.Name(), "_")[0], 10, 64)
if err != nil {
return fmt.Errorf("invalid migration name: %s", f.Name())
}
if applied[version] {
continue
}
// Run migration
sql, err := migrationsFS.ReadFile("migrations/" + f.Name())
if err != nil {
return err
}
if _, err := conn.Exec(ctx, string(sql)); err != nil {
return fmt.Errorf("migration %d failed: %w", version, err)
}
if _, err := conn.Exec(ctx,
"INSERT INTO schema_migrations (version) VALUES ($1)", version,
); err != nil {
return err
}
}
return nil
}
// Sort the migration files based on their ordinal number prefix.
func sortMigrations(files []fs.DirEntry) []fs.DirEntry {
sort.Slice(files, func(i, j int) bool {
v1, _ := strconv.ParseInt(strings.Split(files[i].Name(), "_")[0], 10, 64)
v2, _ := strconv.ParseInt(strings.Split(files[j].Name(), "_")[0], 10, 64)
return v1 < v2
})
return files
}

View File

@ -0,0 +1,157 @@
package service
import (
"context"
"fmt"
"net/http"
"time"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/rs/zerolog"
)
const (
panicRecoveryMsg = "panic recovered"
defaultLogMsg = "incoming request"
)
type userCtxKey struct{}
// Get JWT bearer from request's authorization header, parse it with custom user claims, and
// ensure its validity before attaching the claims to the request's context.
func authMiddleware(jwtSecret string, expectedType string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tokenString, err := getTokenFromRequest(r)
if err != nil {
respondError(w, http.StatusUnauthorized, fmt.Sprintf("Unauthorized: %s", err))
}
token, err := jwt.ParseWithClaims(tokenString, &userClaims{}, func(token *jwt.Token) (interface{}, error) {
return []byte(jwtSecret), nil
})
if err != nil || !token.Valid {
respondError(w, http.StatusUnauthorized, "Invalid token")
return
}
claims, ok := token.Claims.(*userClaims)
if !ok || claims.TokenType != expectedType {
respondError(w, http.StatusUnauthorized, "Invalid token type")
return
}
ctx := context.WithValue(r.Context(), userCtxKey{}, claims)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// JWT access token parsing, verification, and validation.
func requireAccessToken(jwtSecret string) func(http.Handler) http.Handler {
return authMiddleware(jwtSecret, "access")
}
// JWT refresh token parsing, verification, and validation.
func requireRefreshToken(jwtSecret string) func(http.Handler) http.Handler {
return authMiddleware(jwtSecret, "refresh")
}
// Ensure the current user is an administrator.
func adminOnlyMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
if !ok || !user.Admin {
respondError(w, http.StatusForbidden, "Forbidden")
return
}
next.ServeHTTP(w, r)
})
}
// Ensure the targeted resource is owned by the current user (i.e. current user's ID matches with
// the one stored into the resource).
func ownerOnlyMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
requestedID := chi.URLParam(r, "id")
if !ok || user.ID != requestedID {
respondError(w, http.StatusForbidden, "Forbidden")
return
}
next.ServeHTTP(w, r)
})
}
// Append user data into request's context based on user ID as a URL parameter.
func userCtx(store UserStore) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
userIDStr := chi.URLParam(r, "id")
userID, err := uuid.Parse(userIDStr)
if err != nil {
respondError(w, http.StatusNotFound, "Invalid user ID")
return
}
user, err := store.GetUserByID(r.Context(), userID)
if err != nil {
respondError(w, http.StatusNotFound, "User not found")
return
}
ctx := context.WithValue(r.Context(), userCtxKey{}, user)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// Zerolog compatible logger middleware. Automatically logs and recovers from errors with HTTP 500
// response, by default logs to INFO level.
func loggerMiddleware(log *zerolog.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
log := log.With().Logger()
ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
t1 := time.Now()
defer func() {
t2 := time.Now()
// Recover automatically and respond with HTTP 500
if rec := recover(); rec != nil {
log.Error().
Str("type", "error").
Timestamp().
Interface("recover_info", rec).
Msg(panicRecoveryMsg)
http.Error(ww, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}
// Log a regular HTTP request with some metadata
log.Info().
Str("type", "access").
Timestamp().
Fields(map[string]interface{}{
"remote_ip": r.RemoteAddr,
"url": r.URL.Path,
"proto": r.Proto,
"method": r.Method,
"user_agent": r.Header.Get("User-Agent"),
"status": ww.Status(),
"latency_ms": float64(t2.Sub(t1).Nanoseconds()) / 1000000.0,
"bytes_in": r.Header.Get("Content-Length"),
"bytes_out": ww.BytesWritten(),
}).
Msg(defaultLogMsg)
}()
next.ServeHTTP(ww, r)
}
return http.HandlerFunc(fn)
}
}

View File

@ -0,0 +1,19 @@
package service
import (
"github.com/go-chi/chi/v5"
)
// Mockable database operations interface
type NoteStore interface {
// TODO: implement
}
type notesResource struct {
Notes NoteStore
}
func (rs notesResource) Routes() chi.Router {
r := chi.NewRouter()
return r
}

View File

@ -0,0 +1,43 @@
package service
import (
"fmt"
"net/http"
"git.umbrella.haus/ae/notatest/pkg/data"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/jackc/pgx/v5"
"github.com/rs/zerolog/log"
)
func Run(conn *pgx.Conn, jwtSecret string, httpPort string) error {
q := data.New(conn)
r := chi.NewRouter()
tokenService := tokenService{
JWTSecret: jwtSecret,
Tokens: q,
}
usersRouter := usersResource{
JWTSecret: jwtSecret,
Users: q,
}
notesRouter := notesResource{}
// Global middlewares
r.Use(middleware.RequestID)
r.Use(middleware.RealIP)
r.Use(loggerMiddleware(&log.Logger))
r.Use(middleware.Recoverer)
r.Use(middleware.AllowContentType("application/json"))
// Routes grouped by functionality
r.Post("/auth/refresh", tokenService.RefreshAccessToken) // POST /auth/refresh - new access token for refresh token
r.Mount("/users", usersRouter.Routes())
r.Mount("/notes", notesRouter.Routes())
portStr := fmt.Sprintf(":%s", httpPort)
log.Info().Msgf("Starting server on %s", portStr)
return http.ListenAndServe(portStr, r)
}

View File

@ -0,0 +1,205 @@
package service
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"net/http"
"strings"
"time"
"git.umbrella.haus/ae/notatest/pkg/data"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
)
const (
accessTokenDuration = 15 * time.Minute
refreshTokenDuration = 7 * 24 * time.Hour
)
var (
ErrInvalidToken = errors.New("invalid token")
ErrAuthHeaderInvalid = errors.New("token couldn't be parsed from authentication header")
)
type userClaims struct {
Admin bool `json:"admin"`
TokenType string `json:"type"` // "access" or "refresh"
jwt.RegisteredClaims // User's UUID should be stored in the subject claim
}
type tokenPair struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
}
// Mockable database operations interface
type TokenStore interface {
CreateRefreshToken(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error)
GetRefreshTokenByHash(ctx context.Context, tokenHash string) (data.RefreshToken, error)
RevokeRefreshToken(ctx context.Context, tokenHash string) error
RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error
}
type tokenService struct {
JWTSecret string
Tokens TokenStore
}
func (ts *tokenService) GenerateTokenPair(ctx context.Context, userID uuid.UUID, isAdmin bool) (*tokenPair, error) {
tokenPair, err := generateTokenPair(userID.String(), isAdmin, ts.JWTSecret)
if err != nil {
return nil, err
}
hash := sha256.Sum256([]byte(tokenPair.RefreshToken))
tokenHash := hex.EncodeToString(hash[:])
// Store to DB with (almost) identical expiration timestamp
expiresAt := time.Now().Add(refreshTokenDuration)
_, err = ts.Tokens.CreateRefreshToken(ctx, data.CreateRefreshTokenParams{
UserID: userID,
TokenHash: tokenHash,
ExpiresAt: expiresAt,
})
if err != nil {
return nil, err
}
return tokenPair, nil
}
func (ts *tokenService) RevokeRefreshToken(ctx context.Context, token string) error {
hash := sha256.Sum256([]byte(token))
tokenHash := hex.EncodeToString(hash[:])
return ts.Tokens.RevokeRefreshToken(ctx, tokenHash)
}
func (ts *tokenService) ValidateRefreshToken(ctx context.Context, token string) (*data.RefreshToken, error) {
hash := sha256.Sum256([]byte(token))
tokenHash := hex.EncodeToString(hash[:])
dbToken, err := ts.Tokens.GetRefreshTokenByHash(ctx, tokenHash)
if err != nil {
return nil, err
}
if dbToken.Revoked || time.Now().After(dbToken.ExpiresAt) {
return nil, ErrInvalidToken
}
return &dbToken, nil
}
func (ts *tokenService) RefreshAccessToken(w http.ResponseWriter, r *http.Request) {
// Get claims from context
claims, ok := r.Context().Value(userCtxKey{}).(*userClaims)
if !ok || claims.TokenType != "refresh" {
respondError(w, http.StatusUnauthorized, "Invalid token")
return
}
// Attempt to get the token from Authentication header ("Bearer <token>")
refreshToken, err := getTokenFromRequest(r)
if err != nil {
respondError(w, http.StatusUnauthorized, fmt.Sprintf("Unauthorized: %s", err))
}
// Validate the refresh token in DB
if _, err := ts.ValidateRefreshToken(r.Context(), refreshToken); err != nil {
respondError(w, http.StatusUnauthorized, "Invalid refresh token")
return
}
// Revoke the used refresh token
if err := ts.RevokeRefreshToken(r.Context(), refreshToken); err != nil {
respondError(w, http.StatusInternalServerError, "Failed to revoke token")
return
}
// Generate a new pair (access & refresh tokens)
userID, err := uuid.Parse(claims.Subject)
if err != nil {
respondError(w, http.StatusInternalServerError, "Invalid user ID")
return
}
tokenPair, err := ts.GenerateTokenPair(r.Context(), userID, claims.Admin)
if err != nil {
respondError(w, http.StatusInternalServerError, "Failed to generate tokens")
return
}
respondJSON(w, http.StatusOK, tokenPair)
}
func (ts *tokenService) HandleLogout(w http.ResponseWriter, r *http.Request) {
claims, ok := r.Context().Value(userCtxKey{}).(*userClaims)
if !ok {
respondError(w, http.StatusUnauthorized, "Not authenticated")
return
}
userID, err := uuid.Parse(claims.ID)
if err != nil {
respondError(w, http.StatusInternalServerError, "Invalid user ID")
return
}
if err := ts.Tokens.RevokeAllUserRefreshTokens(r.Context(), userID); err != nil {
respondError(w, http.StatusInternalServerError, "Failed to logout")
return
}
respondJSON(w, http.StatusOK, map[string]string{"status": "logged out"})
}
func getTokenFromRequest(r *http.Request) (string, error) {
bearerToken := r.Header.Get("Authorization")
bearerFields := strings.Fields(bearerToken)
if len(bearerFields) == 2 && strings.ToLower(bearerFields[0]) == "bearer" {
return bearerFields[1], nil
}
return "", ErrAuthHeaderInvalid
}
func generateTokenPair(userID string, isAdmin bool, jwtSecret string) (*tokenPair, error) {
atClaims := userClaims{
Admin: isAdmin,
TokenType: "access",
RegisteredClaims: jwt.RegisteredClaims{
Subject: userID,
ExpiresAt: jwt.NewNumericDate(time.Now().Add(accessTokenDuration)),
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
},
}
accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, atClaims)
t, err := accessToken.SignedString([]byte(jwtSecret))
if err != nil {
return nil, err
}
rtClaims := userClaims{
Admin: isAdmin,
TokenType: "refresh",
RegisteredClaims: jwt.RegisteredClaims{
Subject: userID,
ExpiresAt: jwt.NewNumericDate(time.Now().Add(refreshTokenDuration)),
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
},
}
refreshToken := jwt.NewWithClaims(jwt.SigningMethodHS256, rtClaims)
rt, err := refreshToken.SignedString([]byte(jwtSecret))
if err != nil {
return nil, err
}
return &tokenPair{AccessToken: t, RefreshToken: rt}, nil
}

View File

@ -0,0 +1,180 @@
package service
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"git.umbrella.haus/ae/notatest/pkg/data"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
type mockTokenStore struct {
CreateRefreshTokenFunc func(context.Context, data.CreateRefreshTokenParams) (data.RefreshToken, error)
GetRefreshTokenByHashFunc func(context.Context, string) (data.RefreshToken, error)
RevokeRefreshTokenFunc func(context.Context, string) error
RevokeAllUserRefreshTokensFunc func(context.Context, uuid.UUID) error
}
func (m *mockTokenStore) CreateRefreshToken(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) {
return m.CreateRefreshTokenFunc(ctx, arg)
}
func (m *mockTokenStore) GetRefreshTokenByHash(ctx context.Context, tokenHash string) (data.RefreshToken, error) {
return m.GetRefreshTokenByHashFunc(ctx, tokenHash)
}
func (m *mockTokenStore) RevokeRefreshToken(ctx context.Context, token string) error {
return m.RevokeRefreshTokenFunc(ctx, token)
}
func (m *mockTokenStore) RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error {
return m.RevokeAllUserRefreshTokensFunc(ctx, id)
}
func TestGenerateTokenPair_Success(t *testing.T) {
userID := uuid.New()
called := false
mockStore := &mockTokenStore{
CreateRefreshTokenFunc: func(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) {
called = true
assert.Equal(t, userID, arg.UserID)
return data.RefreshToken{}, nil
},
}
ts := tokenService{
JWTSecret: "test-secret",
Tokens: mockStore,
}
pair, err := ts.GenerateTokenPair(context.Background(), userID, false)
assert.NoError(t, err)
assert.True(t, called)
assert.NotEmpty(t, pair.AccessToken)
assert.NotEmpty(t, pair.RefreshToken)
}
func TestGenerateTokenPair_DBError(t *testing.T) {
mockStore := &mockTokenStore{
CreateRefreshTokenFunc: func(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) {
return data.RefreshToken{}, errors.New("db error")
},
}
ts := tokenService{Tokens: mockStore}
_, err := ts.GenerateTokenPair(context.Background(), uuid.New(), false)
assert.ErrorContains(t, err, "db error")
}
func TestValidateRefreshToken_Valid(t *testing.T) {
token := "valid-jwt-token"
hash := sha256.Sum256([]byte(token))
expectedHash := hex.EncodeToString(hash[:])
mockStore := &mockTokenStore{
GetRefreshTokenByHashFunc: func(ctx context.Context, tokenHash string) (data.RefreshToken, error) {
assert.Equal(t, expectedHash, tokenHash)
return data.RefreshToken{ExpiresAt: time.Now().Add(1 * time.Hour)}, nil
},
}
ts := tokenService{Tokens: mockStore}
_, err := ts.ValidateRefreshToken(context.Background(), token)
assert.NoError(t, err)
}
func TestRefreshAccessToken_Success(t *testing.T) {
userID := uuid.New()
refreshToken := "valid-jwt-token"
// Expected hash of the test token
hash := sha256.Sum256([]byte(refreshToken))
expectedHash := hex.EncodeToString(hash[:])
mockStore := &mockTokenStore{
GetRefreshTokenByHashFunc: func(ctx context.Context, tokenHash string) (data.RefreshToken, error) {
assert.Equal(t, expectedHash, tokenHash)
// Must return an unrevoked token with future expiration
return data.RefreshToken{
TokenHash: tokenHash,
ExpiresAt: time.Now().Add(1 * time.Hour),
Revoked: false,
}, nil
},
RevokeRefreshTokenFunc: func(ctx context.Context, tokenHash string) error {
assert.Equal(t, expectedHash, tokenHash)
return nil
},
CreateRefreshTokenFunc: func(ctx context.Context, arg data.CreateRefreshTokenParams) (data.RefreshToken, error) {
return data.RefreshToken{}, nil
},
}
ts := tokenService{
JWTSecret: "test-secret",
Tokens: mockStore,
}
req := httptest.NewRequest("POST", "/", nil)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", refreshToken))
req = req.WithContext(context.WithValue(
req.Context(),
userCtxKey{},
&userClaims{
Admin: false,
TokenType: "refresh",
RegisteredClaims: jwt.RegisteredClaims{
Subject: userID.String(),
},
},
))
w := httptest.NewRecorder()
ts.RefreshAccessToken(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "access_token", "refresh_token")
}
func TestHandleLogout_Success(t *testing.T) {
userID := uuid.New()
called := false
mockStore := &mockTokenStore{
RevokeAllUserRefreshTokensFunc: func(ctx context.Context, id uuid.UUID) error {
called = true
assert.Equal(t, userID, id)
return nil
},
}
ts := tokenService{Tokens: mockStore}
req := httptest.NewRequest("POST", "/", nil)
req = req.WithContext(context.WithValue(
req.Context(),
userCtxKey{},
&userClaims{
RegisteredClaims: jwt.RegisteredClaims{
ID: userID.String(),
},
},
))
w := httptest.NewRecorder()
ts.HandleLogout(w, req)
assert.True(t, called)
assert.Equal(t, http.StatusOK, w.Code)
}

268
server/pkg/service/users.go Normal file
View File

@ -0,0 +1,268 @@
package service
import (
"context"
"encoding/json"
"errors"
"net/http"
"git.umbrella.haus/ae/notatest/pkg/data"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgconn"
"github.com/rs/zerolog/log"
"golang.org/x/crypto/bcrypt"
)
// Mockable database operations interface
type UserStore interface {
CreateUser(ctx context.Context, arg data.CreateUserParams) (data.User, error)
ListUsers(ctx context.Context) ([]data.User, error)
GetUserByID(ctx context.Context, id uuid.UUID) (data.User, error)
UpdatePassword(ctx context.Context, arg data.UpdatePasswordParams) error
DeleteUser(ctx context.Context, id uuid.UUID) error
RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error
}
type usersResource struct {
JWTSecret string
Users UserStore
}
func (rs usersResource) Routes() chi.Router {
r := chi.NewRouter()
// Public routes (no tokens required)
r.Post("/", rs.Create) // POST /users - registration/signup
// Protected routes (access token required)
r.Group(func(r chi.Router) {
r.Use(requireAccessToken(rs.JWTSecret))
// Admin only general routes
r.Group(func(r chi.Router) {
r.Use(adminOnlyMiddleware)
r.Get("/", rs.List) // GET /users - list all users
})
// User specific routes
r.Route("/{id}", func(r chi.Router) {
r.Use(userCtx(rs.Users)) // DB -> req. context
// Admin routes
r.Route("/admin", func(r chi.Router) {
r.Use(adminOnlyMiddleware)
r.Get("/", rs.Get) // GET /users/admin/{id} - get single user
r.Delete("/", rs.AdminDelete) // DELETE /users/admin/{id} - delete user
})
// Owner routes
r.Route("/owner", func(r chi.Router) {
r.Use(ownerOnlyMiddleware)
r.Put("/", rs.UpdatePassword) // PUT /users/owner/{id} - update user password
r.Delete("/", rs.OwnerDelete) // DELETE /users/owner/{id} - delete user
})
})
})
return r
}
func (rs usersResource) Create(w http.ResponseWriter, r *http.Request) {
type request struct {
Username string `json:"username"`
Password string `json:"password"`
}
var req request
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
respondError(w, http.StatusBadRequest, "Invalid request body")
return
}
normalizedUsername := normalizeUsername(req.Username)
if err := validateUsername(normalizedUsername); err != nil {
respondError(w, http.StatusBadRequest, err.Error())
return
}
if err := validatePassword(req.Password); err != nil {
respondError(w, http.StatusBadRequest, err.Error())
return
}
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
if err != nil {
respondError(w, http.StatusInternalServerError, "Failed to create user")
return
}
user, err := rs.Users.CreateUser(r.Context(), data.CreateUserParams{
Username: normalizedUsername,
PasswordHash: string(hashedPassword),
})
if err != nil {
if isDuplicateEntry(err) {
respondError(w, http.StatusConflict, "Username is already in use")
} else {
respondError(w, http.StatusInternalServerError, "Failed to create user")
}
return
}
respondJSON(w, http.StatusCreated, map[string]string{
"id": user.ID.String(),
"username": user.Username,
})
}
func (rs usersResource) List(w http.ResponseWriter, r *http.Request) {
users, err := rs.Users.ListUsers(r.Context())
if err != nil {
respondError(w, http.StatusInternalServerError, "Failed to retrieve users")
return
}
// Output sanitization
var output []map[string]interface{}
for _, user := range users {
output = append(output, map[string]interface{}{
"id": user.ID,
"username": user.Username,
"created_at": user.CreatedAt,
"updated_at": user.UpdatedAt,
})
}
respondJSON(w, http.StatusOK, output)
}
func (rs usersResource) Get(w http.ResponseWriter, r *http.Request) {
user, ok := r.Context().Value(userCtxKey{}).(data.User)
if !ok {
respondError(w, http.StatusNotFound, "User not found")
return
}
respondJSON(w, http.StatusOK, map[string]interface{}{
"id": user.ID,
"username": user.Username,
"created_at": user.CreatedAt,
"updated_at": user.UpdatedAt,
})
}
func (rs usersResource) UpdatePassword(w http.ResponseWriter, r *http.Request) {
type request struct {
OldPassword string `json:"old_password"`
NewPassword string `json:"new_password"`
}
var req request
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
respondError(w, http.StatusBadRequest, "Invalid request body")
return
}
user, ok := r.Context().Value(userCtxKey{}).(data.User)
if !ok {
respondError(w, http.StatusNotFound, "User not found")
return
}
// Verify the old password before allowing the update
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.OldPassword)); err != nil {
respondError(w, http.StatusUnauthorized, "Invalid credentials")
return
}
if err := validatePassword(req.NewPassword); err != nil {
respondError(w, http.StatusBadRequest, err.Error())
return
}
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.NewPassword), bcrypt.DefaultCost)
if err != nil {
respondError(w, http.StatusInternalServerError, "Failed to update password")
return
}
if err := rs.Users.UpdatePassword(r.Context(), data.UpdatePasswordParams{
ID: user.ID,
PasswordHash: string(hashedPassword),
}); err != nil {
respondError(w, http.StatusInternalServerError, "Failed to update password")
return
}
if err := rs.Users.RevokeAllUserRefreshTokens(r.Context(), user.ID); err != nil {
log.Error().Msgf("Failed to revoke refresh tokens: %s", err)
}
w.WriteHeader(http.StatusNoContent)
}
func (rs usersResource) AdminDelete(w http.ResponseWriter, r *http.Request) {
user, ok := r.Context().Value(userCtxKey{}).(data.User)
if !ok {
respondError(w, http.StatusNotFound, "User not found")
return
}
if err := rs.Users.DeleteUser(r.Context(), user.ID); err != nil {
respondError(w, http.StatusInternalServerError, "Failed to delete user")
return
}
if err := rs.Users.RevokeAllUserRefreshTokens(r.Context(), user.ID); err != nil {
log.Error().Msgf("Failed to revoke refresh tokens: %s", err)
}
w.WriteHeader(http.StatusNoContent)
}
func (rs usersResource) OwnerDelete(w http.ResponseWriter, r *http.Request) {
type request struct {
Password string `json:"password"`
}
var req request
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
respondError(w, http.StatusBadRequest, "Invalid request body")
return
}
user, ok := r.Context().Value(userCtxKey{}).(data.User)
if !ok {
respondError(w, http.StatusNotFound, "User not found")
return
}
// Verify the old password before allowing the deletion
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password)); err != nil {
respondError(w, http.StatusUnauthorized, "Invalid credentials")
return
}
err := rs.Users.DeleteUser(r.Context(), user.ID)
if err != nil {
respondError(w, http.StatusInternalServerError, "Failed to delete user")
return
}
if err := rs.Users.RevokeAllUserRefreshTokens(r.Context(), user.ID); err != nil {
log.Error().Msgf("Failed to revoke refresh tokens: %s", err)
}
w.WriteHeader(http.StatusNoContent)
}
// Check if the given error is a PostgreSQL error for `unique_violation`, i.e. whether an entry
// with the given details already exists in the database table.
func isDuplicateEntry(err error) bool {
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) {
return pgErr.Code == "23505"
}
return false
}

View File

@ -0,0 +1,259 @@
package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"git.umbrella.haus/ae/notatest/pkg/data"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgconn"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/bcrypt"
)
type mockUserStore struct {
CreateUserFunc func(context.Context, data.CreateUserParams) (data.User, error)
ListUsersFunc func(context.Context) ([]data.User, error)
GetUserByIDFunc func(context.Context, uuid.UUID) (data.User, error)
UpdatePasswordFunc func(context.Context, data.UpdatePasswordParams) error
DeleteUserFunc func(context.Context, uuid.UUID) error
RevokeAllUserRefreshTokensFunc func(context.Context, uuid.UUID) error
}
func (m *mockUserStore) CreateUser(ctx context.Context, arg data.CreateUserParams) (data.User, error) {
return m.CreateUserFunc(ctx, arg)
}
func (m *mockUserStore) ListUsers(ctx context.Context) ([]data.User, error) {
return m.ListUsersFunc(ctx)
}
func (m *mockUserStore) GetUserByID(ctx context.Context, id uuid.UUID) (data.User, error) {
return m.GetUserByIDFunc(ctx, id)
}
func (m *mockUserStore) UpdatePassword(ctx context.Context, arg data.UpdatePasswordParams) error {
return m.UpdatePasswordFunc(ctx, arg)
}
func (m *mockUserStore) DeleteUser(ctx context.Context, id uuid.UUID) error {
return m.DeleteUserFunc(ctx, id)
}
func (m *mockUserStore) RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error {
return m.RevokeAllUserRefreshTokensFunc(ctx, id)
}
func TestCreateUser_Duplicate(t *testing.T) {
mockStore := &mockUserStore{
CreateUserFunc: func(ctx context.Context, arg data.CreateUserParams) (data.User, error) {
return data.User{}, &pgconn.PgError{Code: "23505"}
},
}
rs := usersResource{Users: mockStore, JWTSecret: "test-secret"}
reqBody := `{"username": "existing", "password": "validPass123!"}`
req := httptest.NewRequest("POST", "/", strings.NewReader(reqBody))
w := httptest.NewRecorder()
rs.Create(w, req)
assert.Equal(t, http.StatusConflict, w.Code)
assert.Contains(t, w.Body.String(), "Username is already in use")
}
func TestCreateUser_InvalidUsername(t *testing.T) {
mockStore := &mockUserStore{} // No DB calls expected
rs := usersResource{Users: mockStore, JWTSecret: "test-secret"}
// Test various invalid usernames
tests := []struct {
name string
body string
}{
{"Too short", `{"username": "a", "password": "validPass123!"}`},
{"Invalid chars", `{"username": "user@name", "password": "validPass123!"}`},
{"Empty", `{"username": "", "password": "validPass123!"}`},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest("POST", "/", strings.NewReader(tc.body))
w := httptest.NewRecorder()
rs.Create(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
})
}
}
func TestCreateUser_InvalidPassword(t *testing.T) {
mockStore := &mockUserStore{}
rs := usersResource{Users: mockStore, JWTSecret: "test-secret"}
tests := []struct {
name string
body string
}{
{"too short", fmt.Sprintf(`{"username": "valid", "password": "%s"}`, strings.Repeat("a", minPasswordLength-1))},
{"too long", fmt.Sprintf(`{"username": "valid", "password": "%s"}`, strings.Repeat("a", maxPasswordLength+1))},
{"low entropy", fmt.Sprintf(`{"username": "valid", "password": "%s"}`, strings.Repeat("a", minPasswordLength))},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest("POST", "/", strings.NewReader(tc.body))
w := httptest.NewRecorder()
rs.Create(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
})
}
}
func TestListUsers_Success(t *testing.T) {
testUsers := []data.User{
{ID: uuid.New(), Username: "user1"},
{ID: uuid.New(), Username: "user2"},
}
mockStore := &mockUserStore{
ListUsersFunc: func(ctx context.Context) ([]data.User, error) {
return testUsers, nil
},
}
rs := usersResource{Users: mockStore}
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
rs.List(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var response []map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatal(err)
}
assert.Len(t, response, 2)
assert.Equal(t, "user1", response[0]["username"])
assert.NotContains(t, response[0], "password_hash")
}
func TestUpdatePassword_InvalidOldPassword(t *testing.T) {
// User with password hash that won't match "wrongpassword"
user := data.User{
ID: uuid.New(),
PasswordHash: "$2a$10$PHhno.bZBF8IEINdFRZAPujMxIN65msElATgJG6FIxZdeWYVLSfFi", // Hash of "correctpassword"
}
mockStore := &mockUserStore{}
rs := usersResource{Users: mockStore}
reqBody := `{"old_password": "wrongpassword", "new_password": "NewValidPass321!"}`
req := httptest.NewRequest("PUT", "/", strings.NewReader(reqBody))
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, user))
w := httptest.NewRecorder()
rs.UpdatePassword(w, req)
assert.Equal(t, http.StatusUnauthorized, w.Code)
}
func TestAdminDelete_Success(t *testing.T) {
user := data.User{ID: uuid.New()}
deleteCalled := false
revokeCalled := false
mockStore := &mockUserStore{
DeleteUserFunc: func(ctx context.Context, id uuid.UUID) error {
deleteCalled = true
assert.Equal(t, user.ID, id)
return nil
},
RevokeAllUserRefreshTokensFunc: func(ctx context.Context, id uuid.UUID) error {
revokeCalled = true
assert.Equal(t, user.ID, id)
return nil
},
}
rs := usersResource{Users: mockStore}
req := httptest.NewRequest("DELETE", "/", nil)
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, user))
w := httptest.NewRecorder()
rs.AdminDelete(w, req)
assert.Equal(t, http.StatusNoContent, w.Code)
assert.True(t, deleteCalled)
assert.True(t, revokeCalled)
}
func TestOwnerDelete_InvalidCredentials(t *testing.T) {
// Create user with known password hash
correctPassword := "CorrectPass123!"
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(correctPassword), bcrypt.DefaultCost)
user := data.User{
ID: uuid.New(),
PasswordHash: string(hashedPassword),
}
mockStore := &mockUserStore{}
rs := usersResource{Users: mockStore}
reqBody := `{"password": "wrongpassword"}`
req := httptest.NewRequest("DELETE", "/", strings.NewReader(reqBody))
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, user))
w := httptest.NewRecorder()
rs.OwnerDelete(w, req)
assert.Equal(t, http.StatusUnauthorized, w.Code)
}
func TestGetUser_NotFound(t *testing.T) {
mockStore := &mockUserStore{}
rs := usersResource{Users: mockStore}
// No user in context
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
rs.Get(w, req)
assert.Equal(t, http.StatusNotFound, w.Code)
}
func TestUpdatePassword_DatabaseError(t *testing.T) {
// Add user with a valid password to the context
oldPassword := "OldValidPass321!"
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(oldPassword), bcrypt.DefaultCost)
user := data.User{
ID: uuid.New(),
PasswordHash: string(hashedPassword),
}
mockStore := &mockUserStore{
UpdatePasswordFunc: func(ctx context.Context, arg data.UpdatePasswordParams) error {
return errors.New("database error")
},
}
rs := usersResource{Users: mockStore}
reqBody := fmt.Sprintf(`{"old_password": "%s", "new_password": "NewValidPass123!"}`, oldPassword)
req := httptest.NewRequest("PUT", "/", strings.NewReader(reqBody))
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, user))
w := httptest.NewRecorder()
rs.UpdatePassword(w, req)
assert.Equal(t, http.StatusInternalServerError, w.Code)
assert.Contains(t, w.Body.String(), "Failed to update password")
}

133
server/pkg/service/util.go Normal file
View File

@ -0,0 +1,133 @@
package service
import (
"crypto/sha1"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"regexp"
"strings"
"unicode/utf8"
passwordvalidator "github.com/wagslane/go-password-validator"
)
const (
minPasswordLength = 12 // Entropy checks prevent short passwords anyway
maxPasswordLength = 72 // Limitation of bcrypt
minPasswordEntropy = 60.0
minUsernameLength = 3
maxUsernameLength = 20
hibpAPI = "https://api.pwnedpasswords.com/range" // Doesn't require an API key
)
var (
usernameRegex = regexp.MustCompile("^[a-z0-9_]+$")
)
func respondJSON(w http.ResponseWriter, status int, data interface{}) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(data)
}
func respondError(w http.ResponseWriter, status int, message string) {
respondJSON(w, status, map[string]string{"error": message})
}
/*
Client-side check:
```
function estimateEntropy(password: string): number {
const pool: number = getCharsetSize(password); // Character diversity (R)
const entropy: number = password.length * Math.log2(pool); // E = L * log_2(R)
return entropy; // Value (E) that can be compared against a hardcoded threshold (e.g. 60)
}
```
*/
// Validate the given password using a hybrid approach: length (max. set due to bcrypt's input
// upper limit of 72 bytes), entropy, and HIBP API.
func validatePassword(password string) error {
if len(password) < minPasswordLength {
return fmt.Errorf("password must be at least %d characters", minPasswordLength)
}
if len(password) > maxPasswordLength {
return fmt.Errorf("password cannot be longer than %d characters", maxPasswordLength)
}
// Formatted error message will contain tips to increase the password strength (safe to show)
err := passwordvalidator.Validate(password, minPasswordEntropy)
if err != nil {
return err
}
if compromised, _ := isPasswordCompromised(password); compromised {
return errors.New("password is compromised")
}
return nil
}
// Send the first five bytes of the password's SHA-1 hash to HIBP API, then check if the rest of
// the hash is present in the APi's response data (k-Anonymity model).
func isPasswordCompromised(password string) (bool, error) {
hash := sha1.Sum([]byte(password))
hashStr := strings.ToUpper(hex.EncodeToString(hash[:]))
prefix, suffix := hashStr[:5], hashStr[5:]
resp, err := http.Get(fmt.Sprintf("%s/%s", hibpAPI, prefix))
if err != nil {
return false, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return false, err
}
return strings.Contains(strings.ToUpper(string(body)), suffix), nil
}
// Normalize the username by making it lowercase and trimming any leading or trailing whitespace.
func normalizeUsername(username string) string {
return strings.ToLower(strings.TrimSpace(username))
}
/*
Client-side check (additionally input should automatically perform the normalization steps):
```
function validateUsername(username: string): string {
const min: number = 3, max: number = 20;
if (username.length < min) return "Too short";
if (username.length > max) return "Too long";
if (!/^[a-zA-Z0-9_]+$/.test(username)) return "Invalid characters";
return "Valid";
}
```
*/
// Validate the given username by making sure it only contains alphanumeric characters or
// underscores and adheres the hardcoded minimum and maximum length rules.
func validateUsername(username string) error {
if utf8.RuneCountInString(username) < minUsernameLength {
return fmt.Errorf("username must be at least %d characters", minUsernameLength)
}
if utf8.RuneCountInString(username) > maxUsernameLength {
return fmt.Errorf("username cannot be longer than %d characters", maxUsernameLength)
}
if !usernameRegex.MatchString(username) {
return errors.New("username can only contain numbers, letters, and underscores")
}
return nil
}

View File

@ -0,0 +1,184 @@
package service
import (
"fmt"
"math/rand"
"net/http"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
passwordvalidator "github.com/wagslane/go-password-validator"
)
type MockHTTPClient struct {
mock.Mock
}
func (m *MockHTTPClient) Get(url string) (*http.Response, error) {
args := m.Called(url)
return args.Get(0).(*http.Response), args.Error(1)
}
func TestValidatePassword(t *testing.T) {
tests := []struct {
name string
password string
wantErr string
mockHTTP func(*MockHTTPClient)
}{
{
name: "too short",
password: strings.Repeat("a", minPasswordLength-1),
wantErr: fmt.Sprintf("password must be at least %d characters", minPasswordLength),
},
{
name: "too long",
password: strings.Repeat("a", maxPasswordLength+1),
wantErr: fmt.Sprintf("password cannot be longer than %d characters", maxPasswordLength),
},
{
name: "low entropy",
password: strings.Repeat("a", minPasswordLength),
wantErr: "insecure password", // Error produced by wagslane/go-password-validator
},
{
name: "valid password",
password: "SecurePassw0rd!123",
wantErr: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Mock HTTP client if needed
if tt.mockHTTP != nil {
mockClient := new(MockHTTPClient)
tt.mockHTTP(mockClient)
}
err := validatePassword(tt.password)
if tt.wantErr == "" {
assert.NoError(t, err)
} else {
assert.ErrorContains(t, err, tt.wantErr)
}
})
}
}
func TestIsPasswordCompromised(t *testing.T) {
t.Run("known compromised", func(t *testing.T) {
compromised, err := isPasswordCompromised("password123456")
assert.NoError(t, err)
assert.True(t, compromised)
})
t.Run("randomly generated", func(t *testing.T) {
randomStr := genRandomString(12)
compromised, err := isPasswordCompromised(randomStr)
assert.NoError(t, err)
assert.False(t, compromised)
})
}
func TestPasswordEntropyCalculation(t *testing.T) {
tests := []struct {
password string
entropy float64
}{
{"password", 37.6},
{"SecurePassw0rd!123", 103.12},
{"aaaaaaaaaaaaaaaa", 9.5},
}
for _, tt := range tests {
t.Run(tt.password, func(t *testing.T) {
entropy := passwordvalidator.GetEntropy(tt.password)
assert.InDelta(t, tt.entropy, entropy, 1.0)
})
}
}
func TestValidateUsername(t *testing.T) {
tests := []struct {
name string
input string
wantErr string
}{
{
name: "too short",
input: strings.Repeat("a", minUsernameLength-1),
wantErr: fmt.Sprintf("username must be at least %d characters", minUsernameLength),
},
{
name: "too long",
input: strings.Repeat("a", maxUsernameLength+1),
wantErr: fmt.Sprintf("username cannot be longer than %d characters", maxUsernameLength),
},
{
name: "invalid characters",
input: "user@name",
wantErr: "username can only contain numbers, letters, and underscores",
},
{
name: "valid username",
input: "valid_user123",
wantErr: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateUsername(tt.input)
if tt.wantErr == "" {
assert.NoError(t, err)
} else {
assert.ErrorContains(t, err, tt.wantErr)
}
})
}
}
func TestNormalizeUsername(t *testing.T) {
tests := []struct {
name string
input string
want string
}{
{
name: "trim whitespace",
input: " test_user ",
want: "test_user",
},
{
name: "lowercase",
input: "TestUser",
want: "testuser",
},
{
name: "mixed case and spaces",
input: " UserName123 ",
want: "username123",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := normalizeUsername(tt.input)
assert.Equal(t, tt.want, got)
})
}
}
func genRandomString(length int) string {
const charset = "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
seededRand := rand.New(rand.NewSource(time.Now().UnixNano()))
b := make([]byte, length)
for i := range b {
b[i] = charset[seededRand.Intn(len(charset))]
}
return string(b)
}

View File

@ -0,0 +1,49 @@
CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
CREATE TABLE IF NOT EXISTS schema_migrations (
version BIGINT PRIMARY KEY,
applied_at TIMESTAMPTZ DEFAULT NOW()
);
CREATE TABLE IF NOT EXISTS users (
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
username TEXT UNIQUE NOT NULL,
password_hash TEXT NOT NULL,
is_admin BOOLEAN NOT NULL DEFAULT false,
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW()
);
CREATE TABLE IF NOT EXISTS refresh_tokens (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
token_hash TEXT NOT NULL,
expires_at TIMESTAMPTZ NOT NULL,
created_at TIMESTAMPTZ DEFAULT NOW(),
revoked BOOLEAN NOT NULL DEFAULT false
);
CREATE TABLE IF NOT EXISTS notes (
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW()
);
CREATE TABLE IF NOT EXISTS note_versions (
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
note_id UUID NOT NULL REFERENCES notes(id) ON DELETE CASCADE,
title TEXT NOT NULL,
content TEXT NOT NULL,
version_number INT NOT NULL,
content_hash TEXT NOT NULL,
created_at TIMESTAMPTZ DEFAULT NOW()
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_users_username ON users(username);
CREATE INDEX IF NOT EXISTS idx_refresh_tokens_user_id ON refresh_tokens(user_id);
CREATE INDEX IF NOT EXISTS idx_refresh_tokens_expires_at ON refresh_tokens(expires_at);
CREATE UNIQUE INDEX IF NOT EXISTS idx_note_version_unique ON note_versions(note_id, version_number);
CREATE INDEX IF NOT EXISTS idx_note_versions_note ON note_versions(note_id);
CREATE INDEX IF NOT EXISTS idx_note_versions_number ON note_versions(version_number DESC);

View File

@ -0,0 +1,27 @@
-- name: CreateNoteVersion :one
INSERT INTO note_versions (note_id, title, content, version_number, content_hash)
VALUES (
$1,
$2,
$3,
(SELECT COALESCE(MAX(version_number), 0) + 1 FROM note_versions WHERE note_id = $1),
encode(sha256($2::bytea || '\n'::bytea || $3::bytea), 'hex')
)
RETURNING *;
-- name: GetNoteVersions :many
SELECT * FROM note_versions
WHERE note_id = $1
ORDER BY version_number DESC
LIMIT $2 OFFSET $3;
-- name: GetNoteVersion :one
SELECT * FROM note_versions
WHERE note_id = $1 AND version_number = $2 LIMIT 1;
-- name: FindDuplicateContent :one
SELECT EXISTS(
SELECT 1 FROM note_versions
WHERE note_id = $1
AND content_hash = encode(sha256($2::bytea || '\n'::bytea || $3::bytea), 'hex')
);

View File

@ -0,0 +1,18 @@
-- name: CreateNote :one
INSERT INTO notes (user_id)
VALUES ($1)
RETURNING *;
-- name: GetNote :one
SELECT * FROM notes
WHERE id = $1 AND user_id = $2 LIMIT 1;
-- name: ListNotes :many
SELECT * FROM notes
WHERE user_id = $1
ORDER BY created_at DESC
LIMIT $2 OFFSET $3;
-- name: DeleteNote :exec
DELETE FROM notes
WHERE id = $1 AND user_id = $2;

View File

@ -0,0 +1,25 @@
-- name: CreateRefreshToken :one
INSERT INTO refresh_tokens (
user_id,
token_hash,
expires_at
) VALUES ($1, $2, $3)
RETURNING *;
-- name: GetRefreshTokenByHash :one
SELECT * FROM refresh_tokens
WHERE token_hash = $1 LIMIT 1;
-- name: RevokeRefreshToken :exec
UPDATE refresh_tokens
SET revoked = TRUE
WHERE token_hash = $1;
-- name: RevokeAllUserRefreshTokens :exec
UPDATE refresh_tokens
SET revoked = TRUE
WHERE user_id = $1;
-- name: DeleteExpiredRefreshTokens :exec
DELETE FROM refresh_tokens
WHERE expires_at < NOW();

View File

@ -0,0 +1,24 @@
-- name: CreateUser :one
INSERT INTO users (username, password_hash)
VALUES ($1, $2)
RETURNING *;
-- name: ListUsers :many
SELECT * FROM users;
-- name: GetUserByID :one
SELECT * FROM users
WHERE id = $1 LIMIT 1;
-- name: GetUserByUsername :one
SELECT * FROM users
WHERE username = $1 LIMIT 1;
-- name: UpdatePassword :exec
UPDATE users
SET password_hash = $2, updated_at = NOW()
WHERE id = $1;
-- name: DeleteUser :exec
DELETE FROM users
WHERE id = $1;

24
server/sql/sqlc.yaml Normal file
View File

@ -0,0 +1,24 @@
version: "2"
sql:
- engine: "postgresql"
queries: "./queries/"
schema: "./migrations/"
gen:
go:
package: "data"
out: "../pkg/data"
sql_package: "pgx/v5"
emit_json_tags: true
overrides:
- db_type: "timestamptz"
go_type:
type: "time.Time"
- db_type: "timestamptz"
nullable: true
go_type:
type: "*time.Time"
- db_type: "uuid"
go_type:
import: "github.com/google/uuid"
type: "UUID"