github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/datastore/mysql/migrations/driver.go (about) 1 package migrations 2 3 import ( 4 "context" 5 "database/sql" 6 "errors" 7 "fmt" 8 "strings" 9 10 mysqlCommon "github.com/authzed/spicedb/internal/datastore/mysql/common" 11 12 "github.com/authzed/spicedb/pkg/datastore" 13 14 "github.com/authzed/spicedb/internal/datastore/common" 15 16 sq "github.com/Masterminds/squirrel" 17 sqlDriver "github.com/go-sql-driver/mysql" 18 19 log "github.com/authzed/spicedb/internal/logging" 20 "github.com/authzed/spicedb/pkg/migrate" 21 ) 22 23 const ( 24 errUnableToInstantiate = "unable to instantiate MySQLDriver: %w" 25 mysqlMissingTableErrorNumber = 1146 26 27 migrationVersionColumnPrefix = "_meta_version_" 28 ) 29 30 var sb = sq.StatementBuilder.PlaceholderFormat(sq.Question) 31 32 // MySQLDriver is an implementation of migrate.Driver for MySQL 33 type MySQLDriver struct { 34 db *sql.DB 35 *tables 36 } 37 38 // NewMySQLDriverFromDSN creates a new migration driver with a connection pool to the database DSN specified. 39 // 40 // URI: [scheme://][user[:[password]]@]host[:port][/schema][?attribute1=value1&attribute2=value2... 41 // See https://dev.mysql.com/doc/refman/8.0/en/connecting-using-uri-or-key-value-pairs.html 42 func NewMySQLDriverFromDSN(url string, tablePrefix string, credentialsProvider datastore.CredentialsProvider) (*MySQLDriver, error) { 43 dbConfig, err := sqlDriver.ParseDSN(url) 44 if err != nil { 45 return nil, fmt.Errorf(errUnableToInstantiate, err) 46 } 47 48 err = mysqlCommon.MaybeAddCredentialsProviderHook(dbConfig, credentialsProvider) 49 if err != nil { 50 return nil, fmt.Errorf(errUnableToInstantiate, err) 51 } 52 53 // Call NewConnector with the existing parsed configuration to preserve the BeforeConnect added by the CredentialsProvider 54 connector, err := sqlDriver.NewConnector(dbConfig) 55 if err != nil { 56 return nil, fmt.Errorf(errUnableToInstantiate, err) 57 } 58 59 db := sql.OpenDB(connector) 60 err = sqlDriver.SetLogger(&log.Logger) 61 if err != nil { 62 return nil, fmt.Errorf("unable to set logging to mysql driver: %w", err) 63 } 64 return NewMySQLDriverFromDB(db, tablePrefix), nil 65 } 66 67 // NewMySQLDriverFromDB creates a new migration driver with a connection pool specified upfront. 68 func NewMySQLDriverFromDB(db *sql.DB, tablePrefix string) *MySQLDriver { 69 return &MySQLDriver{db, newTables(tablePrefix)} 70 } 71 72 // revisionToColumnName generates the column name that will denote a given migration revision 73 func revisionToColumnName(revision string) string { 74 return migrationVersionColumnPrefix + revision 75 } 76 77 func columnNameToRevision(columnName string) (string, bool) { 78 if !strings.HasPrefix(columnName, migrationVersionColumnPrefix) { 79 return "", false 80 } 81 return strings.TrimPrefix(columnName, migrationVersionColumnPrefix), true 82 } 83 84 // Version returns the version of the schema to which the connected database 85 // has been migrated. 86 func (driver *MySQLDriver) Version(ctx context.Context) (string, error) { 87 query, args, err := sb.Select("*").From(driver.migrationVersion()).ToSql() 88 if err != nil { 89 return "", fmt.Errorf("unable to generate query for revision: %w", err) 90 } 91 92 rows, err := driver.db.QueryContext(ctx, query, args...) 93 if err != nil { 94 var mysqlError *sqlDriver.MySQLError 95 if errors.As(err, &mysqlError) && mysqlError.Number == mysqlMissingTableErrorNumber { 96 return "", nil 97 } 98 return "", fmt.Errorf("unable to query revision: %w", err) 99 } 100 defer common.LogOnError(ctx, rows.Close) 101 if rows.Err() != nil { 102 return "", fmt.Errorf("unable to load revision row: %w", rows.Err()) 103 } 104 cols, err := rows.Columns() 105 if err != nil { 106 return "", fmt.Errorf("failed to get columns from revision row: %w", err) 107 } 108 109 for _, col := range cols { 110 if revision, ok := columnNameToRevision(col); ok { 111 return revision, nil 112 } 113 } 114 return "", errors.New("no migration version detected") 115 } 116 117 func (driver *MySQLDriver) Conn() Wrapper { 118 return Wrapper{db: driver.db, tables: driver.tables} 119 } 120 121 func (driver *MySQLDriver) RunTx(ctx context.Context, f migrate.TxMigrationFunc[TxWrapper]) error { 122 return BeginTxFunc( 123 ctx, 124 driver.db, 125 &sql.TxOptions{Isolation: sql.LevelSerializable}, 126 func(tx *sql.Tx) error { 127 return f(ctx, TxWrapper{tx, driver.tables}) 128 }, 129 ) 130 } 131 132 // BeginTxFunc is a polyfill for database/sql which implements a closure style transaction lifecycle. 133 // The underlying transaction is aborted if the supplied function returns an error. 134 // The underlying transaction is committed if the supplied function returns nil. 135 func BeginTxFunc(ctx context.Context, db *sql.DB, txOptions *sql.TxOptions, f func(*sql.Tx) error) error { 136 tx, err := db.BeginTx(ctx, txOptions) 137 if err != nil { 138 return err 139 } 140 141 if err := f(tx); err != nil { 142 rerr := tx.Rollback() 143 if rerr != nil { 144 return errors.Join(err, rerr) 145 } 146 147 return err 148 } 149 150 return tx.Commit() 151 } 152 153 // WriteVersion overwrites the _meta_version_ column name which encodes the version 154 // of the database schema. 155 func (driver *MySQLDriver) WriteVersion(ctx context.Context, txWrapper TxWrapper, version, replaced string) error { 156 stmt := fmt.Sprintf("ALTER TABLE %s CHANGE %s %s VARCHAR(255) NOT NULL", 157 driver.tables.migrationVersion(), 158 revisionToColumnName(replaced), 159 revisionToColumnName(version), 160 ) 161 if _, err := txWrapper.tx.ExecContext(ctx, stmt); err != nil { 162 return fmt.Errorf("unable to write version: %w", err) 163 } 164 165 return nil 166 } 167 168 func (driver *MySQLDriver) Close(_ context.Context) error { 169 return driver.db.Close() 170 } 171 172 var _ migrate.Driver[Wrapper, TxWrapper] = &MySQLDriver{}