github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/datastore/postgres/migrations/driver.go (about) 1 package migrations 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 8 "github.com/jackc/pgx/v5" 9 "github.com/jackc/pgx/v5/pgconn" 10 "go.opentelemetry.io/otel" 11 12 log "github.com/authzed/spicedb/internal/logging" 13 14 pgxcommon "github.com/authzed/spicedb/internal/datastore/postgres/common" 15 "github.com/authzed/spicedb/pkg/datastore" 16 "github.com/authzed/spicedb/pkg/migrate" 17 ) 18 19 const postgresMissingTableErrorCode = "42P01" 20 21 var tracer = otel.Tracer("spicedb/internal/datastore/common") 22 23 // AlembicPostgresDriver implements a schema migration facility for use in 24 // SpiceDB's Postgres datastore. 25 // 26 // It is compatible with the popular Python library, Alembic 27 type AlembicPostgresDriver struct { 28 db *pgx.Conn 29 } 30 31 // NewAlembicPostgresDriver creates a new driver with active connections to the database specified. 32 func NewAlembicPostgresDriver(ctx context.Context, url string, credentialsProvider datastore.CredentialsProvider) (*AlembicPostgresDriver, error) { 33 ctx, span := tracer.Start(ctx, "NewAlembicPostgresDriver") 34 defer span.End() 35 36 connConfig, err := pgx.ParseConfig(url) 37 if err != nil { 38 return nil, err 39 } 40 pgxcommon.ConfigurePGXLogger(connConfig) 41 pgxcommon.ConfigureOTELTracer(connConfig) 42 43 if credentialsProvider != nil { 44 log.Ctx(ctx).Debug().Str("name", credentialsProvider.Name()).Msg("using credentials provider") 45 connConfig.User, connConfig.Password, err = credentialsProvider.Get(ctx, fmt.Sprintf("%s:%d", connConfig.Host, connConfig.Port), connConfig.User) 46 if err != nil { 47 return nil, err 48 } 49 } 50 51 db, err := pgx.ConnectConfig(ctx, connConfig) 52 if err != nil { 53 return nil, err 54 } 55 56 return &AlembicPostgresDriver{db}, nil 57 } 58 59 // Conn returns the underlying pgx.Conn instance for this driver 60 func (apd *AlembicPostgresDriver) Conn() *pgx.Conn { 61 return apd.db 62 } 63 64 func (apd *AlembicPostgresDriver) RunTx(ctx context.Context, f migrate.TxMigrationFunc[pgx.Tx]) error { 65 return pgx.BeginFunc(ctx, apd.db, func(tx pgx.Tx) error { 66 return f(ctx, tx) 67 }) 68 } 69 70 // Version returns the version of the schema to which the connected database 71 // has been migrated. 72 func (apd *AlembicPostgresDriver) Version(ctx context.Context) (string, error) { 73 var loaded string 74 75 if err := apd.db.QueryRow(ctx, "SELECT version_num from alembic_version").Scan(&loaded); err != nil { 76 var pgErr *pgconn.PgError 77 if errors.As(err, &pgErr) && pgErr.Code == postgresMissingTableErrorCode { 78 return "", nil 79 } 80 return "", fmt.Errorf("unable to load alembic revision: %w", err) 81 } 82 83 return loaded, nil 84 } 85 86 // Close disposes the driver. 87 func (apd *AlembicPostgresDriver) Close(ctx context.Context) error { 88 return apd.db.Close(ctx) 89 } 90 91 func (apd *AlembicPostgresDriver) WriteVersion(ctx context.Context, tx pgx.Tx, version, replaced string) error { 92 result, err := tx.Exec( 93 ctx, 94 "UPDATE alembic_version SET version_num=$1 WHERE version_num=$2", 95 version, 96 replaced, 97 ) 98 if err != nil { 99 return fmt.Errorf("unable to update version row: %w", err) 100 } 101 102 updatedCount := result.RowsAffected() 103 if updatedCount != 1 { 104 return fmt.Errorf("writing version update affected %d rows, should be 1", updatedCount) 105 } 106 107 return nil 108 } 109 110 var _ migrate.Driver[*pgx.Conn, pgx.Tx] = &AlembicPostgresDriver{}