326 lines
9.4 KiB
Go

package service
import (
"context"
"crypto/sha1"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"math"
"net/http"
"regexp"
"strconv"
"strings"
"time"
"unicode"
"unicode/utf8"
"git.umbrella.haus/ae/qnote/internal/data"
"github.com/rs/zerolog/log"
)
const (
minPasswordLength = 12 // Entropy checks prevent short passwords anyway
maxPasswordLength = 72 // Limitation of bcrypt
minPasswordEntropy = 60.0
minUsernameLength = 3
maxUsernameLength = 20
maxFutureExpirationYears = 10
hibpAPI = "https://api.pwnedpasswords.com/range" // Doesn't require an API key
)
var (
usernameRegex = regexp.MustCompile("^[a-z0-9_]+$")
// Format: @exp:2025-06-15 or @expires:2025-06-15
dateFormatRegex = regexp.MustCompile(`^@(?:exp|expires):(\d{4}-\d{2}-\d{2})`)
// Format: @exp:+7d or @expires:+7d (7 days from now),
// supports d (days), w (weeks), m (months), y (years)
relativeFormatRegex = regexp.MustCompile(`^@(?:exp|expires):\+(\d+)([dwmy])`)
ErrNoExpirationDateFound = errors.New("no expiration date found")
ErrInvalidExpirationDate = errors.New("invalid expiration date format")
ErrPastExpirationDate = errors.New("expiration date cannot be in the past")
ErrExpirationTooFar = fmt.Errorf("expiration date too far in the future (max. %d years)", maxFutureExpirationYears)
)
func respondJSON(w http.ResponseWriter, status int, data any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(data)
}
func respondError(w http.ResponseWriter, status int, message string) {
respondJSON(w, status, map[string]string{"error": message})
}
// Validate the given password using a hybrid approach: length (max. set due to bcrypt's input
// upper limit of 72 bytes), entropy, and HIBP API.
func validatePassword(password string) error {
if len(password) < minPasswordLength {
return fmt.Errorf("password must be at least %d characters", minPasswordLength)
}
if len(password) > maxPasswordLength {
return fmt.Errorf("password cannot be longer than %d characters", maxPasswordLength)
}
// Simple entropy approximation
if calcPasswordEntropy(password) < minPasswordEntropy {
return fmt.Errorf("password is too weak, try adding more uppercase letters, digits, and symbols")
}
if compromised, _ := isPasswordCompromised(password); compromised {
return errors.New("password was found in database leaks")
}
return nil
}
// Approximate given password's entropy. Notably the way the entropy is calculated is really
// conservative and punishes relatively harshly if the password contains a lot of repetition
// (small set of unique characters).
func calcPasswordEntropy(password string) float64 {
hasLower, hasUpper, hasDigit, hasSymbol := false, false, false, false
uniqueChars := make(map[rune]bool)
for _, c := range password {
uniqueChars[c] = true
switch {
case unicode.IsLower(c):
hasLower = true
case unicode.IsUpper(c):
hasUpper = true
case unicode.IsDigit(c):
hasDigit = true
default:
if !unicode.IsLetter(c) && !unicode.IsDigit(c) {
hasSymbol = true // Broader symbol collection than in the frontend
}
}
}
poolSize := 0
if hasLower {
poolSize += 26
}
if hasUpper {
poolSize += 26
}
if hasDigit {
poolSize += 10
}
if hasSymbol {
poolSize += 40
}
if poolSize == 0 {
return 0
}
basicEntropy := float64(len(password) * int(math.Log2(float64(poolSize))))
diversityAdjustedEntropy := math.Log2(float64(poolSize)) + float64(len(password)-1)*math.Log2(float64(len(uniqueChars)))
return math.Min(basicEntropy, diversityAdjustedEntropy)
}
// Send the first five bytes of the password's SHA-1 hash to HIBP API, then check if the rest of
// the hash is present in the APi's response data (k-Anonymity model).
func isPasswordCompromised(password string) (bool, error) {
hash := sha1.Sum([]byte(password))
hashStr := strings.ToUpper(hex.EncodeToString(hash[:]))
prefix, suffix := hashStr[:5], hashStr[5:]
resp, err := http.Get(fmt.Sprintf("%s/%s", hibpAPI, prefix))
if err != nil {
return false, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return false, err
}
return strings.Contains(strings.ToUpper(string(body)), suffix), nil
}
// Normalize the username by making it lowercase and trimming any leading or trailing whitespace.
func normalizeUsername(username string) string {
return strings.ToLower(strings.TrimSpace(username))
}
// Validate the given username by making sure it only contains alphanumeric characters or
// underscores and adheres the hardcoded minimum and maximum length rules.
func validateUsername(username string) error {
if utf8.RuneCountInString(username) < minUsernameLength {
return fmt.Errorf("username must be at least %d characters", minUsernameLength)
}
if utf8.RuneCountInString(username) > maxUsernameLength {
return fmt.Errorf("username cannot be longer than %d characters", maxUsernameLength)
}
if !usernameRegex.MatchString(username) {
return errors.New("username can only contain numbers, letters, and underscores")
}
return nil
}
// Parse `limit` and `offset` 32-bit integer URL parameters from the given request. Defaults to
// limit of 50 and offset 0 if parameters are missing/invalid.
func getPaginationParams(r *http.Request) (limit int32, offset int32) {
defaultLimit := 50
defaultOffset := 0
limitStr := r.URL.Query().Get("limit")
if limitStr != "" {
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 {
defaultLimit = l
}
}
offsetStr := r.URL.Query().Get("offset")
if offsetStr != "" {
if o, err := strconv.Atoi(offsetStr); err == nil && o >= 0 {
defaultOffset = o
}
}
return int32(defaultLimit), int32(defaultOffset)
}
// Concatenate the title and content strings, calculate a SHA-1 hash of the resulting string, and
// return the resulting hash as a string.
func sha1ContentHash(title, content string) string {
hashContent := title + content
hash := sha1.Sum([]byte(hashContent))
hashStr := strings.ToUpper(hex.EncodeToString(hash[:]))
return hashStr
}
// Parse either an absolute (e.g. '2006-01-02') or a relative expiration date (e.g. '+3d') from
// the beginning of the title prefixed with either '@exp:' or '@expires:'. The actual note
// expiration will be set to the end of that particular date (+0000 UTC).
func parseTitleExpiration(title *string) (*time.Time, error) {
// Absolute date format: '@exp:YYYY-MM-DD' (or '@expires:')
if match := dateFormatRegex.FindStringSubmatch(*title); match != nil {
dateStr := match[1]
expiresAt, err := time.Parse("2006-01-02", dateStr)
if err != nil {
return nil, ErrInvalidExpirationDate
}
if err := validateExpirationDate(expiresAt); err != nil {
return nil, err
}
// Set midnight at the end of the specified day (+0000 UTC)
expiresAt = time.Date(expiresAt.Year(), expiresAt.Month(), expiresAt.Day(), 23, 59, 59, 0, time.UTC)
return &expiresAt, nil
}
if match := relativeFormatRegex.FindStringSubmatch(*title); match != nil {
amount := match[1]
unit := match[2]
var amountInt int
_, err := fmt.Sscanf(amount, "%d", &amountInt)
if err != nil || amountInt <= 0 {
return nil, ErrInvalidExpirationDate
}
now := time.Now()
var expiresAt time.Time
switch unit {
case "d":
expiresAt = now.AddDate(0, 0, amountInt)
case "w":
expiresAt = now.AddDate(0, 0, amountInt*7)
case "m":
expiresAt = now.AddDate(0, amountInt, 0)
case "y":
expiresAt = now.AddDate(amountInt, 0, 0)
default:
return nil, ErrInvalidExpirationDate
}
if err := validateExpirationDate(expiresAt); err != nil {
return nil, err
}
// Set midnight at the end of the specified day (+0000 UTC)
expiresAt = time.Date(expiresAt.Year(), expiresAt.Month(), expiresAt.Day(), 23, 59, 59, 0, time.UTC)
return &expiresAt, nil
}
return nil, ErrNoExpirationDateFound
}
// Ensure a given date (time) is between the current time and current time + `maxFutureExpirationYears`.
func validateExpirationDate(date time.Time) error {
now := time.Now()
if date.Before(now) {
return ErrPastExpirationDate
}
maxDate := now.AddDate(maxFutureExpirationYears, 0, 0)
if date.After(maxDate) {
return ErrExpirationTooFar
}
return nil
}
// Run a database cleanup task that queries the currently expired notes in the database, prints
// metadata about them, and deletes them.
func cleanupNotes(ctx context.Context, q *data.Queries) {
expiredNotes, err := q.ListExpiredNotes(ctx)
if err != nil {
log.Error().Err(err).Msg("Failed querying expired notes")
return
}
if len(expiredNotes) == 0 {
return
}
// Log what we're about to delete to be able to track potential bugs in the expiration implementation
for _, note := range expiredNotes {
log.Debug().Msgf("Deleting expired note: %s (ID: %s, UID: %s), expired at %s",
note.Title, note.NoteID, note.OwnerID, note.ExpiresAt.Format(time.RFC3339))
}
if err = q.DeleteExpiredNotes(ctx); err != nil {
log.Error().Err(err).Msg("Failed deleting expired notes")
return
}
log.Info().Msgf("Successfully deleted %d expired notes during scheduled cleanup", len(expiredNotes))
}
// Run a database cleanup task that deletes the expired and revoked refresh tokens.
func cleanupRefreshTokens(ctx context.Context, q *data.Queries) {
rowsAffected, err := q.DeleteExpiredRefreshTokens(ctx)
if err != nil {
log.Error().Err(err).Msg("Failed cleaning up refresh tokens")
return
}
if rowsAffected > 0 {
log.Info().Msgf("Cleaned up %d expired/revoked refresh tokens during scheduled cleanup", rowsAffected)
}
}