github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/datastore/crdb/migrations/driver.go (about)

     1  package migrations
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  
     8  	"github.com/jackc/pgx/v5"
     9  	"github.com/jackc/pgx/v5/pgconn"
    10  
    11  	pgxcommon "github.com/authzed/spicedb/internal/datastore/postgres/common"
    12  	"github.com/authzed/spicedb/pkg/migrate"
    13  )
    14  
    15  const (
    16  	errUnableToInstantiate = "unable to instantiate CRDBDriver: %w"
    17  
    18  	postgresMissingTableErrorCode = "42P01"
    19  
    20  	queryLoadVersion  = "SELECT version_num from schema_version"
    21  	queryWriteVersion = "UPDATE schema_version SET version_num=$1 WHERE version_num=$2"
    22  )
    23  
    24  // CRDBDriver implements a schema migration facility for use in SpiceDB's CRDB
    25  // datastore.
    26  type CRDBDriver struct {
    27  	db *pgx.Conn
    28  }
    29  
    30  // NewCRDBDriver creates a new driver with active connections to the database
    31  // specified.
    32  func NewCRDBDriver(url string) (*CRDBDriver, error) {
    33  	connConfig, err := pgx.ParseConfig(url)
    34  	if err != nil {
    35  		return nil, fmt.Errorf(errUnableToInstantiate, err)
    36  	}
    37  	pgxcommon.ConfigurePGXLogger(connConfig)
    38  	pgxcommon.ConfigureOTELTracer(connConfig)
    39  
    40  	db, err := pgx.ConnectConfig(context.Background(), connConfig)
    41  	if err != nil {
    42  		return nil, fmt.Errorf(errUnableToInstantiate, err)
    43  	}
    44  
    45  	return &CRDBDriver{db}, nil
    46  }
    47  
    48  // Version returns the version of the schema to which the connected database
    49  // has been migrated.
    50  func (apd *CRDBDriver) Version(ctx context.Context) (string, error) {
    51  	var loaded string
    52  
    53  	if err := apd.db.QueryRow(ctx, queryLoadVersion).Scan(&loaded); err != nil {
    54  		var pgErr *pgconn.PgError
    55  		if errors.As(err, &pgErr) && pgErr.Code == postgresMissingTableErrorCode {
    56  			return "", nil
    57  		}
    58  		return "", fmt.Errorf("unable to load alembic revision: %w", err)
    59  	}
    60  
    61  	return loaded, nil
    62  }
    63  
    64  // Conn returns the underlying pgx.Conn instance for this driver
    65  func (apd *CRDBDriver) Conn() *pgx.Conn {
    66  	return apd.db
    67  }
    68  
    69  func (apd *CRDBDriver) RunTx(ctx context.Context, f migrate.TxMigrationFunc[pgx.Tx]) error {
    70  	return pgx.BeginFunc(ctx, apd.db, func(tx pgx.Tx) error {
    71  		return f(ctx, tx)
    72  	})
    73  }
    74  
    75  // Close disposes the driver.
    76  func (apd *CRDBDriver) Close(ctx context.Context) error {
    77  	return apd.db.Close(ctx)
    78  }
    79  
    80  func (apd *CRDBDriver) WriteVersion(ctx context.Context, tx pgx.Tx, version, replaced string) error {
    81  	result, err := tx.Exec(ctx, queryWriteVersion, version, replaced)
    82  	if err != nil {
    83  		return fmt.Errorf("unable to update version row: %w", err)
    84  	}
    85  
    86  	updatedCount := result.RowsAffected()
    87  	if updatedCount != 1 {
    88  		return fmt.Errorf("writing version update affected %d rows, should be 1", updatedCount)
    89  	}
    90  
    91  	return nil
    92  }
    93  
    94  var _ migrate.Driver[*pgx.Conn, pgx.Tx] = &CRDBDriver{}