github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/datastore/mysql/migrations/driver.go (about)

     1  package migrations
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"errors"
     7  	"fmt"
     8  	"strings"
     9  
    10  	mysqlCommon "github.com/authzed/spicedb/internal/datastore/mysql/common"
    11  
    12  	"github.com/authzed/spicedb/pkg/datastore"
    13  
    14  	"github.com/authzed/spicedb/internal/datastore/common"
    15  
    16  	sq "github.com/Masterminds/squirrel"
    17  	sqlDriver "github.com/go-sql-driver/mysql"
    18  
    19  	log "github.com/authzed/spicedb/internal/logging"
    20  	"github.com/authzed/spicedb/pkg/migrate"
    21  )
    22  
    23  const (
    24  	errUnableToInstantiate       = "unable to instantiate MySQLDriver: %w"
    25  	mysqlMissingTableErrorNumber = 1146
    26  
    27  	migrationVersionColumnPrefix = "_meta_version_"
    28  )
    29  
    30  var sb = sq.StatementBuilder.PlaceholderFormat(sq.Question)
    31  
    32  // MySQLDriver is an implementation of migrate.Driver for MySQL
    33  type MySQLDriver struct {
    34  	db *sql.DB
    35  	*tables
    36  }
    37  
    38  // NewMySQLDriverFromDSN creates a new migration driver with a connection pool to the database DSN specified.
    39  //
    40  // URI: [scheme://][user[:[password]]@]host[:port][/schema][?attribute1=value1&attribute2=value2...
    41  // See https://dev.mysql.com/doc/refman/8.0/en/connecting-using-uri-or-key-value-pairs.html
    42  func NewMySQLDriverFromDSN(url string, tablePrefix string, credentialsProvider datastore.CredentialsProvider) (*MySQLDriver, error) {
    43  	dbConfig, err := sqlDriver.ParseDSN(url)
    44  	if err != nil {
    45  		return nil, fmt.Errorf(errUnableToInstantiate, err)
    46  	}
    47  
    48  	err = mysqlCommon.MaybeAddCredentialsProviderHook(dbConfig, credentialsProvider)
    49  	if err != nil {
    50  		return nil, fmt.Errorf(errUnableToInstantiate, err)
    51  	}
    52  
    53  	// Call NewConnector with the existing parsed configuration to preserve the BeforeConnect added by the CredentialsProvider
    54  	connector, err := sqlDriver.NewConnector(dbConfig)
    55  	if err != nil {
    56  		return nil, fmt.Errorf(errUnableToInstantiate, err)
    57  	}
    58  
    59  	db := sql.OpenDB(connector)
    60  	err = sqlDriver.SetLogger(&log.Logger)
    61  	if err != nil {
    62  		return nil, fmt.Errorf("unable to set logging to mysql driver: %w", err)
    63  	}
    64  	return NewMySQLDriverFromDB(db, tablePrefix), nil
    65  }
    66  
    67  // NewMySQLDriverFromDB creates a new migration driver with a connection pool specified upfront.
    68  func NewMySQLDriverFromDB(db *sql.DB, tablePrefix string) *MySQLDriver {
    69  	return &MySQLDriver{db, newTables(tablePrefix)}
    70  }
    71  
    72  // revisionToColumnName generates the column name that will denote a given migration revision
    73  func revisionToColumnName(revision string) string {
    74  	return migrationVersionColumnPrefix + revision
    75  }
    76  
    77  func columnNameToRevision(columnName string) (string, bool) {
    78  	if !strings.HasPrefix(columnName, migrationVersionColumnPrefix) {
    79  		return "", false
    80  	}
    81  	return strings.TrimPrefix(columnName, migrationVersionColumnPrefix), true
    82  }
    83  
    84  // Version returns the version of the schema to which the connected database
    85  // has been migrated.
    86  func (driver *MySQLDriver) Version(ctx context.Context) (string, error) {
    87  	query, args, err := sb.Select("*").From(driver.migrationVersion()).ToSql()
    88  	if err != nil {
    89  		return "", fmt.Errorf("unable to generate query for revision: %w", err)
    90  	}
    91  
    92  	rows, err := driver.db.QueryContext(ctx, query, args...)
    93  	if err != nil {
    94  		var mysqlError *sqlDriver.MySQLError
    95  		if errors.As(err, &mysqlError) && mysqlError.Number == mysqlMissingTableErrorNumber {
    96  			return "", nil
    97  		}
    98  		return "", fmt.Errorf("unable to query revision: %w", err)
    99  	}
   100  	defer common.LogOnError(ctx, rows.Close)
   101  	if rows.Err() != nil {
   102  		return "", fmt.Errorf("unable to load revision row: %w", rows.Err())
   103  	}
   104  	cols, err := rows.Columns()
   105  	if err != nil {
   106  		return "", fmt.Errorf("failed to get columns from revision row: %w", err)
   107  	}
   108  
   109  	for _, col := range cols {
   110  		if revision, ok := columnNameToRevision(col); ok {
   111  			return revision, nil
   112  		}
   113  	}
   114  	return "", errors.New("no migration version detected")
   115  }
   116  
   117  func (driver *MySQLDriver) Conn() Wrapper {
   118  	return Wrapper{db: driver.db, tables: driver.tables}
   119  }
   120  
   121  func (driver *MySQLDriver) RunTx(ctx context.Context, f migrate.TxMigrationFunc[TxWrapper]) error {
   122  	return BeginTxFunc(
   123  		ctx,
   124  		driver.db,
   125  		&sql.TxOptions{Isolation: sql.LevelSerializable},
   126  		func(tx *sql.Tx) error {
   127  			return f(ctx, TxWrapper{tx, driver.tables})
   128  		},
   129  	)
   130  }
   131  
   132  // BeginTxFunc is a polyfill for database/sql which implements a closure style transaction lifecycle.
   133  // The underlying transaction is aborted if the supplied function returns an error.
   134  // The underlying transaction is committed if the supplied function returns nil.
   135  func BeginTxFunc(ctx context.Context, db *sql.DB, txOptions *sql.TxOptions, f func(*sql.Tx) error) error {
   136  	tx, err := db.BeginTx(ctx, txOptions)
   137  	if err != nil {
   138  		return err
   139  	}
   140  
   141  	if err := f(tx); err != nil {
   142  		rerr := tx.Rollback()
   143  		if rerr != nil {
   144  			return errors.Join(err, rerr)
   145  		}
   146  
   147  		return err
   148  	}
   149  
   150  	return tx.Commit()
   151  }
   152  
   153  // WriteVersion overwrites the _meta_version_ column name which encodes the version
   154  // of the database schema.
   155  func (driver *MySQLDriver) WriteVersion(ctx context.Context, txWrapper TxWrapper, version, replaced string) error {
   156  	stmt := fmt.Sprintf("ALTER TABLE %s CHANGE %s %s VARCHAR(255) NOT NULL",
   157  		driver.tables.migrationVersion(),
   158  		revisionToColumnName(replaced),
   159  		revisionToColumnName(version),
   160  	)
   161  	if _, err := txWrapper.tx.ExecContext(ctx, stmt); err != nil {
   162  		return fmt.Errorf("unable to write version: %w", err)
   163  	}
   164  
   165  	return nil
   166  }
   167  
   168  func (driver *MySQLDriver) Close(_ context.Context) error {
   169  	return driver.db.Close()
   170  }
   171  
   172  var _ migrate.Driver[Wrapper, TxWrapper] = &MySQLDriver{}