notatest/server/pkg/migrate/migrate.go

75 lines
1.6 KiB
Go

// pkg/migrate/migrate.go
package migrate
import (
"context"
"embed"
"fmt"
"io/fs"
"sort"
"strconv"
"strings"
"github.com/jackc/pgx/v5"
)
func Run(ctx context.Context, conn *pgx.Conn, migrationsFS embed.FS) error {
// Get already applied migrations
rows, _ := conn.Query(ctx, "SELECT version FROM schema_migrations")
defer rows.Close()
applied := make(map[int64]bool)
for rows.Next() {
var version int64
if err := rows.Scan(&version); err != nil {
return err
}
applied[version] = true
}
files, err := migrationsFS.ReadDir("migrations")
if err != nil {
return err
}
// Apply the migrations sequentially based on their ordinal number
for _, f := range sortMigrations(files) {
version, err := strconv.ParseInt(strings.Split(f.Name(), "_")[0], 10, 64)
if err != nil {
return fmt.Errorf("invalid migration name: %s", f.Name())
}
if applied[version] {
continue
}
// Run migration
sql, err := migrationsFS.ReadFile("migrations/" + f.Name())
if err != nil {
return err
}
if _, err := conn.Exec(ctx, string(sql)); err != nil {
return fmt.Errorf("migration %d failed: %w", version, err)
}
if _, err := conn.Exec(ctx,
"INSERT INTO schema_migrations (version) VALUES ($1)", version,
); err != nil {
return err
}
}
return nil
}
// Sort the migration files based on their ordinal number prefix.
func sortMigrations(files []fs.DirEntry) []fs.DirEntry {
sort.Slice(files, func(i, j int) bool {
v1, _ := strconv.ParseInt(strings.Split(files[i].Name(), "_")[0], 10, 64)
v2, _ := strconv.ParseInt(strings.Split(files[j].Name(), "_")[0], 10, 64)
return v1 < v2
})
return files
}