75 lines
1.6 KiB
Go
75 lines
1.6 KiB
Go
// pkg/migrate/migrate.go
|
|
package migrate
|
|
|
|
import (
|
|
"context"
|
|
"embed"
|
|
"fmt"
|
|
"io/fs"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/jackc/pgx/v5"
|
|
)
|
|
|
|
func Run(ctx context.Context, conn *pgx.Conn, migrationsFS embed.FS) error {
|
|
// Get already applied migrations
|
|
rows, _ := conn.Query(ctx, "SELECT version FROM schema_migrations")
|
|
defer rows.Close()
|
|
|
|
applied := make(map[int64]bool)
|
|
for rows.Next() {
|
|
var version int64
|
|
if err := rows.Scan(&version); err != nil {
|
|
return err
|
|
}
|
|
applied[version] = true
|
|
}
|
|
|
|
files, err := migrationsFS.ReadDir("migrations")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Apply the migrations sequentially based on their ordinal number
|
|
for _, f := range sortMigrations(files) {
|
|
version, err := strconv.ParseInt(strings.Split(f.Name(), "_")[0], 10, 64)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid migration name: %s", f.Name())
|
|
}
|
|
|
|
if applied[version] {
|
|
continue
|
|
}
|
|
|
|
// Run migration
|
|
sql, err := migrationsFS.ReadFile("migrations/" + f.Name())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if _, err := conn.Exec(ctx, string(sql)); err != nil {
|
|
return fmt.Errorf("migration %d failed: %w", version, err)
|
|
}
|
|
|
|
if _, err := conn.Exec(ctx,
|
|
"INSERT INTO schema_migrations (version) VALUES ($1)", version,
|
|
); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Sort the migration files based on their ordinal number prefix.
|
|
func sortMigrations(files []fs.DirEntry) []fs.DirEntry {
|
|
sort.Slice(files, func(i, j int) bool {
|
|
v1, _ := strconv.ParseInt(strings.Split(files[i].Name(), "_")[0], 10, 64)
|
|
v2, _ := strconv.ParseInt(strings.Split(files[j].Name(), "_")[0], 10, 64)
|
|
return v1 < v2
|
|
})
|
|
return files
|
|
}
|