github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/datastore/postgres/migrations/driver.go (about)

     1  package migrations
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  
     8  	"github.com/jackc/pgx/v5"
     9  	"github.com/jackc/pgx/v5/pgconn"
    10  	"go.opentelemetry.io/otel"
    11  
    12  	log "github.com/authzed/spicedb/internal/logging"
    13  
    14  	pgxcommon "github.com/authzed/spicedb/internal/datastore/postgres/common"
    15  	"github.com/authzed/spicedb/pkg/datastore"
    16  	"github.com/authzed/spicedb/pkg/migrate"
    17  )
    18  
    19  const postgresMissingTableErrorCode = "42P01"
    20  
    21  var tracer = otel.Tracer("spicedb/internal/datastore/common")
    22  
    23  // AlembicPostgresDriver implements a schema migration facility for use in
    24  // SpiceDB's Postgres datastore.
    25  //
    26  // It is compatible with the popular Python library, Alembic
    27  type AlembicPostgresDriver struct {
    28  	db *pgx.Conn
    29  }
    30  
    31  // NewAlembicPostgresDriver creates a new driver with active connections to the database specified.
    32  func NewAlembicPostgresDriver(ctx context.Context, url string, credentialsProvider datastore.CredentialsProvider) (*AlembicPostgresDriver, error) {
    33  	ctx, span := tracer.Start(ctx, "NewAlembicPostgresDriver")
    34  	defer span.End()
    35  
    36  	connConfig, err := pgx.ParseConfig(url)
    37  	if err != nil {
    38  		return nil, err
    39  	}
    40  	pgxcommon.ConfigurePGXLogger(connConfig)
    41  	pgxcommon.ConfigureOTELTracer(connConfig)
    42  
    43  	if credentialsProvider != nil {
    44  		log.Ctx(ctx).Debug().Str("name", credentialsProvider.Name()).Msg("using credentials provider")
    45  		connConfig.User, connConfig.Password, err = credentialsProvider.Get(ctx, fmt.Sprintf("%s:%d", connConfig.Host, connConfig.Port), connConfig.User)
    46  		if err != nil {
    47  			return nil, err
    48  		}
    49  	}
    50  
    51  	db, err := pgx.ConnectConfig(ctx, connConfig)
    52  	if err != nil {
    53  		return nil, err
    54  	}
    55  
    56  	return &AlembicPostgresDriver{db}, nil
    57  }
    58  
    59  // Conn returns the underlying pgx.Conn instance for this driver
    60  func (apd *AlembicPostgresDriver) Conn() *pgx.Conn {
    61  	return apd.db
    62  }
    63  
    64  func (apd *AlembicPostgresDriver) RunTx(ctx context.Context, f migrate.TxMigrationFunc[pgx.Tx]) error {
    65  	return pgx.BeginFunc(ctx, apd.db, func(tx pgx.Tx) error {
    66  		return f(ctx, tx)
    67  	})
    68  }
    69  
    70  // Version returns the version of the schema to which the connected database
    71  // has been migrated.
    72  func (apd *AlembicPostgresDriver) Version(ctx context.Context) (string, error) {
    73  	var loaded string
    74  
    75  	if err := apd.db.QueryRow(ctx, "SELECT version_num from alembic_version").Scan(&loaded); err != nil {
    76  		var pgErr *pgconn.PgError
    77  		if errors.As(err, &pgErr) && pgErr.Code == postgresMissingTableErrorCode {
    78  			return "", nil
    79  		}
    80  		return "", fmt.Errorf("unable to load alembic revision: %w", err)
    81  	}
    82  
    83  	return loaded, nil
    84  }
    85  
    86  // Close disposes the driver.
    87  func (apd *AlembicPostgresDriver) Close(ctx context.Context) error {
    88  	return apd.db.Close(ctx)
    89  }
    90  
    91  func (apd *AlembicPostgresDriver) WriteVersion(ctx context.Context, tx pgx.Tx, version, replaced string) error {
    92  	result, err := tx.Exec(
    93  		ctx,
    94  		"UPDATE alembic_version SET version_num=$1 WHERE version_num=$2",
    95  		version,
    96  		replaced,
    97  	)
    98  	if err != nil {
    99  		return fmt.Errorf("unable to update version row: %w", err)
   100  	}
   101  
   102  	updatedCount := result.RowsAffected()
   103  	if updatedCount != 1 {
   104  		return fmt.Errorf("writing version update affected %d rows, should be 1", updatedCount)
   105  	}
   106  
   107  	return nil
   108  }
   109  
   110  var _ migrate.Driver[*pgx.Conn, pgx.Tx] = &AlembicPostgresDriver{}