From 2b9b14210cd6dd5e9cc82378586b5e3b3dba62dd Mon Sep 17 00:00:00 2001 From: ae Date: Sun, 4 May 2025 11:12:39 +0300 Subject: [PATCH] feat: exp. parsing from note (version) title --- server/internal/service/notes.go | 46 +++++++- server/internal/service/service.go | 35 ++++-- server/internal/service/util.go | 130 +++++++++++++++++++++ server/internal/service/util_test.go | 166 +++++++++++++++++++++++++++ 4 files changed, 367 insertions(+), 10 deletions(-) diff --git a/server/internal/service/notes.go b/server/internal/service/notes.go index 39e5873..4b58223 100644 --- a/server/internal/service/notes.go +++ b/server/internal/service/notes.go @@ -9,6 +9,8 @@ import ( "git.umbrella.haus/ae/qnote/internal/data" "github.com/go-chi/chi/v5" "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "github.com/rs/zerolog/log" ) const ( @@ -38,8 +40,10 @@ type NoteStore interface { // Chi HTTP router for notes related CRUD actions. type notesResource struct { - Config SvcConfig - Notes NoteStore + Config SvcConfig + Notes NoteStore + RawQueries *data.Queries + DB *pgx.Conn } func (rs notesResource) Routes() chi.Router { @@ -250,6 +254,22 @@ func (rs *notesResource) CreateVersion(w http.ResponseWriter, r *http.Request) { return } + // Attempt to parse the expiration date from the title + expiresAt, err := parseTitleExpiration(req.Title) + if err != nil && err != ErrNoExpirationDateFound { + log.Error().Err(err).Msg("Failed parsing expiration date from note title") + } + + tx, err := rs.DB.Begin(r.Context()) + if err != nil { + log.Error().Err(err).Msg("Failed to begin transaction") + respondError(w, http.StatusInternalServerError, "Database error") + return + } + defer tx.Rollback(r.Context()) + + qtx := rs.RawQueries.WithTx(tx) + /* The SQL query handles de-duplication checks and "intelligent" versioning increments, so we don't have to worry about them here (`latest_version` = highest version number that exists @@ -263,17 +283,37 @@ func (rs *notesResource) CreateVersion(w http.ResponseWriter, r *http.Request) { - Sync `current_version` with `latest_version` */ - err := rs.Notes.CreateNoteVersion(r.Context(), data.CreateNoteVersionParams{ + if expiresAt != nil { + err = qtx.SetNoteExpiration(r.Context(), data.SetNoteExpirationParams{ + ID: fullNote.NoteID, + UserID: fullNote.OwnerID, + ExpiresAt: expiresAt, + }) + if err != nil { + log.Error().Err(err).Msg("Failed to set note expiration") + respondError(w, http.StatusInternalServerError, "Failed to set note expiration") + return + } + } + + err = qtx.CreateNoteVersion(r.Context(), data.CreateNoteVersionParams{ NoteID: fullNote.NoteID, Title: *req.Title, Content: *req.Content, ContentHash: sha1ContentHash(*req.Title, *req.Content), }) if err != nil { + log.Error().Err(err).Msg("Failed to create new note version") respondError(w, http.StatusInternalServerError, "Failed to create note version") return } + if err = tx.Commit(r.Context()); err != nil { + log.Error().Err(err).Msg("Failed to commit transaction") + respondError(w, http.StatusInternalServerError, "Database error") + return + } + w.WriteHeader(http.StatusNoContent) } diff --git a/server/internal/service/service.go b/server/internal/service/service.go index 893edef..ebd3670 100644 --- a/server/internal/service/service.go +++ b/server/internal/service/service.go @@ -1,6 +1,7 @@ package service import ( + "context" "net/http" "time" @@ -43,8 +44,10 @@ func Run(conn *pgx.Conn, q *data.Queries, config SvcConfig) error { Tokens: q, } notesRouter := notesResource{ - Config: config, - Notes: q, + Config: config, + Notes: q, // Wrapped (to be unit testable with mock DB) + RawQueries: q, // Passed separately to allow tx. usage + DB: conn, } // Global middlewares @@ -63,20 +66,38 @@ func Run(conn *pgx.Conn, q *data.Queries, config SvcConfig) error { r.Use(middleware.Recoverer) r.Use(middleware.AllowContentType("application/json")) + // Cleanup workers + scheduleTokenCleanup(context.Background(), q) + // Routes grouped by functionality (we must prefix the API routes with `/api` // as the domain will be the same for the front and back ends) r.Route("/api", func(r chi.Router) { r.Mount("/auth", authRouter.Routes()) r.Mount("/notes", notesRouter.Routes()) - r.Get("/ping", ping) + r.Get("/ping", func(w http.ResponseWriter, r *http.Request) { + respondJSON(w, http.StatusOK, map[string]string{ + "message": "pong", + }) + }) }) log.Info().Msg("Starting server on :8080") return http.ListenAndServe(":8080", r) } -func ping(w http.ResponseWriter, r *http.Request) { - respondJSON(w, http.StatusOK, map[string]string{ - "message": "pong", - }) +// Start worker that automatically cleans up the `notes` (cascading to `note_versions`) and +// `refresh_tokens` tables from expired (or revoked) entries. The tasks run once during +// initialization and then once an hour until the backend is shutdown. +func scheduleTokenCleanup(ctx context.Context, q *data.Queries) { + cleanupNotes(ctx, q) + cleanupRefreshTokens(ctx, q) + + ticker := time.NewTicker(1 * time.Hour) + go func() { + for range ticker.C { + cleanupCtx := context.Background() + cleanupNotes(cleanupCtx, q) + cleanupRefreshTokens(cleanupCtx, q) + } + }() } diff --git a/server/internal/service/util.go b/server/internal/service/util.go index 25da334..8ea7258 100644 --- a/server/internal/service/util.go +++ b/server/internal/service/util.go @@ -1,6 +1,7 @@ package service import ( + "context" "crypto/sha1" "encoding/hex" "encoding/json" @@ -12,8 +13,12 @@ import ( "regexp" "strconv" "strings" + "time" "unicode" "unicode/utf8" + + "git.umbrella.haus/ae/qnote/internal/data" + "github.com/rs/zerolog/log" ) const ( @@ -24,11 +29,25 @@ const ( minUsernameLength = 3 maxUsernameLength = 20 + maxFutureExpirationYears = 10 + hibpAPI = "https://api.pwnedpasswords.com/range" // Doesn't require an API key ) var ( usernameRegex = regexp.MustCompile("^[a-z0-9_]+$") + + // Format: @exp:2025-06-15 or @expires:2025-06-15 + dateFormatRegex = regexp.MustCompile(`^@(?:exp|expires):(\d{4}-\d{2}-\d{2})`) + + // Format: @exp:+7d or @expires:+7d (7 days from now), + // supports d (days), w (weeks), m (months), y (years) + relativeFormatRegex = regexp.MustCompile(`^@(?:exp|expires):\+(\d+)([dwmy])`) + + ErrNoExpirationDateFound = errors.New("no expiration date found") + ErrInvalidExpirationDate = errors.New("invalid expiration date format") + ErrPastExpirationDate = errors.New("expiration date cannot be in the past") + ErrExpirationTooFar = fmt.Errorf("expiration date too far in the future (max. %d years)", maxFutureExpirationYears) ) func respondJSON(w http.ResponseWriter, status int, data any) { @@ -186,3 +205,114 @@ func sha1ContentHash(title, content string) string { return hashStr } + +func parseTitleExpiration(title *string) (*time.Time, error) { + // Absolute date format: '@exp:YYYY-MM-DD' (or '@expires:') + if match := dateFormatRegex.FindStringSubmatch(*title); match != nil { + dateStr := match[1] + + expiresAt, err := time.Parse("2006-01-02", dateStr) + if err != nil { + return nil, ErrInvalidExpirationDate + } + + if err := validateExpirationDate(expiresAt); err != nil { + return nil, err + } + + // Set midnight at the end of the specified day (+0000 UTC) + expiresAt = time.Date(expiresAt.Year(), expiresAt.Month(), expiresAt.Day(), 23, 59, 59, 0, time.UTC) + + return &expiresAt, nil + } + + if match := relativeFormatRegex.FindStringSubmatch(*title); match != nil { + amount := match[1] + unit := match[2] + + var amountInt int + _, err := fmt.Sscanf(amount, "%d", &amountInt) + if err != nil || amountInt <= 0 { + return nil, ErrInvalidExpirationDate + } + + now := time.Now() + var expiresAt time.Time + + switch unit { + case "d": + expiresAt = now.AddDate(0, 0, amountInt) + case "w": + expiresAt = now.AddDate(0, 0, amountInt*7) + case "m": + expiresAt = now.AddDate(0, amountInt, 0) + case "y": + expiresAt = now.AddDate(amountInt, 0, 0) + default: + return nil, ErrInvalidExpirationDate + } + + if err := validateExpirationDate(expiresAt); err != nil { + return nil, err + } + + // Set midnight at the end of the specified day (+0000 UTC) + expiresAt = time.Date(expiresAt.Year(), expiresAt.Month(), expiresAt.Day(), 23, 59, 59, 0, time.UTC) + + return &expiresAt, nil + } + + return nil, ErrNoExpirationDateFound +} + +func validateExpirationDate(date time.Time) error { + now := time.Now() + + if date.Before(now) { + return ErrPastExpirationDate + } + + maxDate := now.AddDate(maxFutureExpirationYears, 0, 0) + if date.After(maxDate) { + return ErrExpirationTooFar + } + + return nil +} + +func cleanupNotes(ctx context.Context, q *data.Queries) { + expiredNotes, err := q.ListExpiredNotes(ctx) + if err != nil { + log.Error().Err(err).Msg("Failed querying expired notes") + return + } + + if len(expiredNotes) == 0 { + return + } + + // Log what we're about to delete to be able to track potential bugs in the expiration implementation + for _, note := range expiredNotes { + log.Debug().Msgf("Deleting expired note: %s (ID: %s, UID: %s), expired at %s", + note.Title, note.NoteID, note.OwnerID, note.ExpiresAt.Format(time.RFC3339)) + } + + if err = q.DeleteExpiredNotes(ctx); err != nil { + log.Error().Err(err).Msg("Failed deleting expired notes") + return + } + + log.Info().Msgf("Successfully deleted %d expired notes during scheduled cleanup", len(expiredNotes)) +} + +func cleanupRefreshTokens(ctx context.Context, q *data.Queries) { + rowsAffected, err := q.DeleteExpiredRefreshTokens(ctx) + if err != nil { + log.Error().Err(err).Msg("Failed cleaning up refresh tokens") + return + } + + if rowsAffected > 0 { + log.Info().Msgf("Cleaned up %d expired/revoked refresh tokens during scheduled cleanup", rowsAffected) + } +} diff --git a/server/internal/service/util_test.go b/server/internal/service/util_test.go index 64c2220..206eacc 100644 --- a/server/internal/service/util_test.go +++ b/server/internal/service/util_test.go @@ -1,6 +1,7 @@ package service import ( + "errors" "fmt" "math/rand" "net/http" @@ -290,6 +291,170 @@ func TestSHA1ContentHash(t *testing.T) { } } +func TestParseTitleAbsoluteExpiration(t *testing.T) { + threeDaysLater := time.Now().AddDate(0, 0, 3) + threeDaysInPast := time.Now().AddDate(0, 0, -3) + overMaxYearsLater := time.Now().AddDate(maxFutureExpirationYears+1, 0, 0) + + tests := []struct { + name string + title string + expected *time.Time + err error + }{ + { + name: "Valid absolute date", + title: fmt.Sprintf("@exp:%s Task", formatAbsDate(t, threeDaysLater)), + expected: timePtr(t, createEndOfDay(t, threeDaysLater)), + err: nil, + }, + { + name: "Valid absolute date with expires keyword", + title: fmt.Sprintf("@expires:%s Task", formatAbsDate(t, threeDaysLater)), + expected: timePtr(t, createEndOfDay(t, threeDaysLater)), + err: nil, + }, + { + name: "Absolute date in the past", + title: fmt.Sprintf("@exp:%s Task", formatAbsDate(t, threeDaysInPast)), + expected: nil, + err: ErrPastExpirationDate, + }, + { + name: "Absolute date too far in the future", + title: fmt.Sprintf("@exp:%s Task", formatAbsDate(t, overMaxYearsLater)), + expected: nil, + err: ErrExpirationTooFar, + }, + { + name: "Invalid absolute date format", + title: "@exp:2028-13-31 Task", // Invalid month + expected: nil, + err: ErrInvalidExpirationDate, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := parseTitleExpiration(&tc.title) + if tc.err != nil { + if !errors.Is(err, tc.err) { + t.Errorf("Expected error %s, got %s", tc.err, err) + } + } else if err != nil { + t.Errorf("Unexpected error: %s", err) + } + + if tc.expected == nil && result != nil { + t.Errorf("Expected nil result, got %s", *result) + } else if tc.expected != nil && result != nil { + if !tc.expected.Equal(*result) { + t.Errorf("Expected %s, got %s", *tc.expected, *result) + } + } + }) + } +} + +func TestParseTitleRelativeExpiration(t *testing.T) { + threeDaysLater := time.Now().AddDate(0, 0, 3) + threeWeeksLater := time.Now().AddDate(0, 0, 3*7) + threeMonthsLater := time.Now().AddDate(0, 3, 0) + threeYearsLater := time.Now().AddDate(3, 0, 0) + + tests := []struct { + name string + title string + expected *time.Time + err error + }{ + { + name: "Valid relative date format with days", + title: "@exp:+3d Task", + expected: timePtr(t, createEndOfDay(t, threeDaysLater)), + err: nil, + }, + { + name: "Valid relative date format with weeks", + title: "@exp:+3w Task", + expected: timePtr(t, createEndOfDay(t, threeWeeksLater)), + err: nil, + }, + { + name: "Valid relative date format with months", + title: "@exp:+3m Task", + expected: timePtr(t, createEndOfDay(t, threeMonthsLater)), + err: nil, + }, + { + name: "Valid relative date format with years", + title: "@exp:+3y Task", + expected: timePtr(t, createEndOfDay(t, threeYearsLater)), + err: nil, + }, + { + name: "Invalid relative amount (zero)", + title: "@exp:+0d Task", + expected: nil, + err: ErrInvalidExpirationDate, + }, + { + name: "Invalid relative amount (negative)", + title: "@exp:-1d Task", + expected: nil, + err: ErrNoExpirationDateFound, // Doesn't match either of the RegExs + }, + { + name: "Invalid relative unit", + title: "@exp:+30a Task", + expected: nil, + err: ErrNoExpirationDateFound, // Doesn't match either of the RegExs + }, + { + name: "Relative date too far in the future", + title: fmt.Sprintf("@exp:+%dy Task", maxFutureExpirationYears+1), + expected: nil, + err: ErrExpirationTooFar, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := parseTitleExpiration(&tc.title) + if tc.err != nil { + if !errors.Is(err, tc.err) { + t.Errorf("Expected error %s, got %s", tc.err, err) + } + } else if err != nil { + t.Errorf("Unexpected error: %s", err) + } + + if tc.expected == nil && result != nil { + t.Errorf("Expected nil result, got %s", *result) + } else if tc.expected != nil && result != nil { + if !tc.expected.Equal(*result) { + t.Errorf("Expected %s, got %s", *tc.expected, *result) + } + } + }) + } +} + +func createEndOfDay(t *testing.T, tm time.Time) time.Time { + t.Helper() + return time.Date(tm.Year(), tm.Month(), tm.Day(), 23, 59, 59, 0, time.UTC) +} + +func timePtr(t *testing.T, tm time.Time) *time.Time { + t.Helper() + return &tm +} + +func formatAbsDate(t *testing.T, tm time.Time) string { + t.Helper() + return tm.Format("2006-01-02") +} + func genRandomString(t *testing.T, length int) string { t.Helper() @@ -299,5 +464,6 @@ func genRandomString(t *testing.T, length int) string { for i := range b { b[i] = charset[seededRand.Intn(len(charset))] } + return string(b) }