github.com/scraniel/migrate@v0.0.0-20230320185700-339088f36cee/database/yugabytedb/yugabytedb.go (about) 1 package yugabytedb 2 3 import ( 4 "context" 5 "database/sql" 6 "errors" 7 "io" 8 "net/url" 9 "regexp" 10 "strconv" 11 "time" 12 13 "github.com/cenkalti/backoff/v4" 14 "github.com/golang-migrate/migrate/v4" 15 "github.com/golang-migrate/migrate/v4/database" 16 "github.com/hashicorp/go-multierror" 17 "github.com/jackc/pgconn" 18 "github.com/jackc/pgerrcode" 19 "github.com/lib/pq" 20 "go.uber.org/atomic" 21 ) 22 23 const ( 24 DefaultMaxRetryInterval = time.Second * 15 25 DefaultMaxRetryElapsedTime = time.Second * 30 26 DefaultMaxRetries = 10 27 DefaultMigrationsTable = "migrations" 28 DefaultLockTable = "migrations_locks" 29 ) 30 31 var ( 32 ErrNilConfig = errors.New("no config") 33 ErrNoDatabaseName = errors.New("no database name") 34 ErrMaxRetriesExceeded = errors.New("max retries exceeded") 35 ) 36 37 func init() { 38 db := YugabyteDB{} 39 database.Register("yugabyte", &db) 40 database.Register("yugabytedb", &db) 41 database.Register("ysql", &db) 42 } 43 44 type Config struct { 45 MigrationsTable string 46 LockTable string 47 ForceLock bool 48 DatabaseName string 49 MaxRetryInterval time.Duration 50 MaxRetryElapsedTime time.Duration 51 MaxRetries int 52 } 53 54 type YugabyteDB struct { 55 db *sql.DB 56 isLocked atomic.Bool 57 58 // Open and WithInstance need to guarantee that config is never nil 59 config *Config 60 } 61 62 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 63 if config == nil { 64 return nil, ErrNilConfig 65 } 66 67 if err := instance.Ping(); err != nil { 68 return nil, err 69 } 70 71 if config.DatabaseName == "" { 72 query := `SELECT current_database()` 73 var databaseName string 74 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 75 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 76 } 77 78 if len(databaseName) == 0 { 79 return nil, ErrNoDatabaseName 80 } 81 82 config.DatabaseName = databaseName 83 } 84 85 if len(config.MigrationsTable) == 0 { 86 config.MigrationsTable = DefaultMigrationsTable 87 } 88 89 if len(config.LockTable) == 0 { 90 config.LockTable = DefaultLockTable 91 } 92 93 if config.MaxRetryInterval == 0 { 94 config.MaxRetryInterval = DefaultMaxRetryInterval 95 } 96 97 if config.MaxRetryElapsedTime == 0 { 98 config.MaxRetryElapsedTime = DefaultMaxRetryElapsedTime 99 } 100 101 if config.MaxRetries == 0 { 102 config.MaxRetries = DefaultMaxRetries 103 } 104 105 px := &YugabyteDB{ 106 db: instance, 107 config: config, 108 } 109 110 // ensureVersionTable is a locking operation, so we need to ensureLockTable before we ensureVersionTable. 111 if err := px.ensureLockTable(); err != nil { 112 return nil, err 113 } 114 115 if err := px.ensureVersionTable(); err != nil { 116 return nil, err 117 } 118 119 return px, nil 120 } 121 122 func (c *YugabyteDB) Open(dbURL string) (database.Driver, error) { 123 purl, err := url.Parse(dbURL) 124 if err != nil { 125 return nil, err 126 } 127 128 // As YugabyteDB uses the postgres protocol, and 'postgres' is already a registered database, we need to replace the 129 // connect prefix, with the actual protocol, so that the library can differentiate between the implementations 130 re := regexp.MustCompile("^(yugabyte(db)?|ysql)") 131 connectString := re.ReplaceAllString(migrate.FilterCustomQuery(purl).String(), "postgres") 132 133 db, err := sql.Open("postgres", connectString) 134 if err != nil { 135 return nil, err 136 } 137 138 migrationsTable := purl.Query().Get("x-migrations-table") 139 if len(migrationsTable) == 0 { 140 migrationsTable = DefaultMigrationsTable 141 } 142 143 lockTable := purl.Query().Get("x-lock-table") 144 if len(lockTable) == 0 { 145 lockTable = DefaultLockTable 146 } 147 148 forceLockQuery := purl.Query().Get("x-force-lock") 149 forceLock, err := strconv.ParseBool(forceLockQuery) 150 if err != nil { 151 forceLock = false 152 } 153 154 maxIntervalStr := purl.Query().Get("x-max-retry-interval") 155 maxInterval, err := time.ParseDuration(maxIntervalStr) 156 if err != nil { 157 maxInterval = DefaultMaxRetryInterval 158 } 159 160 maxElapsedTimeStr := purl.Query().Get("x-max-retry-elapsed-time") 161 maxElapsedTime, err := time.ParseDuration(maxElapsedTimeStr) 162 if err != nil { 163 maxElapsedTime = DefaultMaxRetryElapsedTime 164 } 165 166 maxRetriesStr := purl.Query().Get("x-max-retries") 167 maxRetries, err := strconv.Atoi(maxRetriesStr) 168 if err != nil { 169 maxRetries = DefaultMaxRetries 170 } 171 172 px, err := WithInstance(db, &Config{ 173 DatabaseName: purl.Path, 174 MigrationsTable: migrationsTable, 175 LockTable: lockTable, 176 ForceLock: forceLock, 177 MaxRetryInterval: maxInterval, 178 MaxRetryElapsedTime: maxElapsedTime, 179 MaxRetries: maxRetries, 180 }) 181 if err != nil { 182 return nil, err 183 } 184 185 return px, nil 186 } 187 188 func (c *YugabyteDB) Close() error { 189 return c.db.Close() 190 } 191 192 // Locking is done manually with a separate lock table. Implementing advisory locks in YugabyteDB is being discussed 193 // See: https://github.com/yugabyte/yugabyte-db/issues/3642 194 func (c *YugabyteDB) Lock() error { 195 return database.CasRestoreOnErr(&c.isLocked, false, true, database.ErrLocked, func() (err error) { 196 return c.doTxWithRetry(context.Background(), &sql.TxOptions{Isolation: sql.LevelSerializable}, func(tx *sql.Tx) (err error) { 197 aid, err := database.GenerateAdvisoryLockId(c.config.DatabaseName) 198 if err != nil { 199 return err 200 } 201 202 query := "SELECT * FROM " + c.config.LockTable + " WHERE lock_id = $1" 203 rows, err := tx.Query(query, aid) 204 if err != nil { 205 return database.Error{OrigErr: err, Err: "failed to fetch migration lock", Query: []byte(query)} 206 } 207 defer func() { 208 if errClose := rows.Close(); errClose != nil { 209 err = multierror.Append(err, errClose) 210 } 211 }() 212 213 // If row exists at all, lock is present 214 locked := rows.Next() 215 if locked && !c.config.ForceLock { 216 return database.ErrLocked 217 } 218 219 query = "INSERT INTO " + c.config.LockTable + " (lock_id) VALUES ($1)" 220 if _, err := tx.Exec(query, aid); err != nil { 221 return database.Error{OrigErr: err, Err: "failed to set migration lock", Query: []byte(query)} 222 } 223 224 return nil 225 }) 226 }) 227 } 228 229 // Locking is done manually with a separate lock table. Implementing advisory locks in YugabyteDB is being discussed 230 // See: https://github.com/yugabyte/yugabyte-db/issues/3642 231 func (c *YugabyteDB) Unlock() error { 232 return database.CasRestoreOnErr(&c.isLocked, true, false, database.ErrNotLocked, func() (err error) { 233 aid, err := database.GenerateAdvisoryLockId(c.config.DatabaseName) 234 if err != nil { 235 return err 236 } 237 238 // In the event of an implementation (non-migration) error, it is possible for the lock to not be released. Until 239 // a better locking mechanism is added, a manual purging of the lock table may be required in such circumstances 240 query := "DELETE FROM " + c.config.LockTable + " WHERE lock_id = $1" 241 if _, err := c.db.Exec(query, aid); err != nil { 242 if e, ok := err.(*pq.Error); ok { 243 // 42P01 is "UndefinedTableError" in YugabyteDB 244 // https://github.com/yugabyte/yugabyte-db/blob/9c6b8e6beb56eed8eeb357178c0c6b837eb49896/src/postgres/src/backend/utils/errcodes.txt#L366 245 if e.Code == "42P01" { 246 // On drops, the lock table is fully removed; This is fine, and is a valid "unlocked" state for the schema 247 return nil 248 } 249 } 250 251 return database.Error{OrigErr: err, Err: "failed to release migration lock", Query: []byte(query)} 252 } 253 254 return nil 255 }) 256 } 257 258 func (c *YugabyteDB) Run(migration io.Reader) error { 259 migr, err := io.ReadAll(migration) 260 if err != nil { 261 return err 262 } 263 264 // run migration 265 query := string(migr[:]) 266 if _, err := c.db.Exec(query); err != nil { 267 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 268 } 269 270 return nil 271 } 272 273 func (c *YugabyteDB) SetVersion(version int, dirty bool) error { 274 return c.doTxWithRetry(context.Background(), &sql.TxOptions{Isolation: sql.LevelSerializable}, func(tx *sql.Tx) error { 275 if _, err := tx.Exec(`DELETE FROM "` + c.config.MigrationsTable + `"`); err != nil { 276 return err 277 } 278 279 // Also re-write the schema version for nil dirty versions to prevent 280 // empty schema version for failed down migration on the first migration 281 // See: https://github.com/golang-migrate/migrate/issues/330 282 if version >= 0 || (version == database.NilVersion && dirty) { 283 if _, err := tx.Exec(`INSERT INTO "`+c.config.MigrationsTable+`" (version, dirty) VALUES ($1, $2)`, version, dirty); err != nil { 284 return err 285 } 286 } 287 288 return nil 289 }) 290 } 291 292 func (c *YugabyteDB) Version() (version int, dirty bool, err error) { 293 query := `SELECT version, dirty FROM "` + c.config.MigrationsTable + `" LIMIT 1` 294 err = c.db.QueryRow(query).Scan(&version, &dirty) 295 296 switch { 297 case err == sql.ErrNoRows: 298 return database.NilVersion, false, nil 299 300 case err != nil: 301 if e, ok := err.(*pq.Error); ok { 302 // 42P01 is "UndefinedTableError" in YugabyteDB 303 // https://github.com/yugabyte/yugabyte-db/blob/9c6b8e6beb56eed8eeb357178c0c6b837eb49896/src/postgres/src/backend/utils/errcodes.txt#L366 304 if e.Code == "42P01" { 305 return database.NilVersion, false, nil 306 } 307 } 308 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 309 310 default: 311 return version, dirty, nil 312 } 313 } 314 315 func (c *YugabyteDB) Drop() (err error) { 316 query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'` 317 tables, err := c.db.Query(query) 318 if err != nil { 319 return &database.Error{OrigErr: err, Query: []byte(query)} 320 } 321 defer func() { 322 if errClose := tables.Close(); errClose != nil { 323 err = multierror.Append(err, errClose) 324 } 325 }() 326 327 // delete one table after another 328 tableNames := make([]string, 0) 329 for tables.Next() { 330 var tableName string 331 if err := tables.Scan(&tableName); err != nil { 332 return err 333 } 334 if len(tableName) > 0 { 335 tableNames = append(tableNames, tableName) 336 } 337 } 338 if err := tables.Err(); err != nil { 339 return &database.Error{OrigErr: err, Query: []byte(query)} 340 } 341 342 if len(tableNames) > 0 { 343 for _, t := range tableNames { 344 query = `DROP TABLE IF EXISTS ` + t + ` CASCADE` 345 if _, err := c.db.Exec(query); err != nil { 346 return &database.Error{OrigErr: err, Query: []byte(query)} 347 } 348 } 349 } 350 351 return nil 352 } 353 354 // ensureVersionTable checks if versions table exists and, if not, creates it. 355 // Note that this function locks the database 356 func (c *YugabyteDB) ensureVersionTable() (err error) { 357 if err = c.Lock(); err != nil { 358 return err 359 } 360 361 defer func() { 362 if e := c.Unlock(); e != nil { 363 if err == nil { 364 err = e 365 } else { 366 err = multierror.Append(err, e) 367 } 368 } 369 }() 370 371 // check if migration table exists 372 var count int 373 query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1` 374 if err := c.db.QueryRow(query, c.config.MigrationsTable).Scan(&count); err != nil { 375 return &database.Error{OrigErr: err, Query: []byte(query)} 376 } 377 if count == 1 { 378 return nil 379 } 380 381 // if not, create the empty migration table 382 query = `CREATE TABLE "` + c.config.MigrationsTable + `" (version INT NOT NULL PRIMARY KEY, dirty BOOL NOT NULL)` 383 if _, err := c.db.Exec(query); err != nil { 384 return &database.Error{OrigErr: err, Query: []byte(query)} 385 } 386 return nil 387 } 388 389 func (c *YugabyteDB) ensureLockTable() error { 390 // check if lock table exists 391 var count int 392 query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1` 393 if err := c.db.QueryRow(query, c.config.LockTable).Scan(&count); err != nil { 394 return &database.Error{OrigErr: err, Query: []byte(query)} 395 } 396 if count == 1 { 397 return nil 398 } 399 400 // if not, create the empty lock table 401 query = `CREATE TABLE "` + c.config.LockTable + `" (lock_id TEXT NOT NULL PRIMARY KEY)` 402 if _, err := c.db.Exec(query); err != nil { 403 return &database.Error{OrigErr: err, Query: []byte(query)} 404 } 405 406 return nil 407 } 408 409 func (c *YugabyteDB) doTxWithRetry( 410 ctx context.Context, 411 txOpts *sql.TxOptions, 412 fn func(tx *sql.Tx) error, 413 ) error { 414 backOff := c.newBackoff(ctx) 415 416 return backoff.Retry(func() error { 417 tx, err := c.db.BeginTx(ctx, txOpts) 418 if err != nil { 419 return backoff.Permanent(err) 420 } 421 422 // If we've tried to commit the transaction Rollback just returns sql.ErrTxDone. 423 //nolint:errcheck 424 defer tx.Rollback() 425 426 if err := fn(tx); err != nil { 427 if errIsRetryable(err) { 428 return err 429 } 430 431 return backoff.Permanent(err) 432 } 433 434 if err := tx.Commit(); err != nil { 435 if errIsRetryable(err) { 436 return err 437 } 438 439 return backoff.Permanent(err) 440 } 441 442 return nil 443 }, backOff) 444 } 445 446 func (c *YugabyteDB) newBackoff(ctx context.Context) backoff.BackOff { 447 if ctx == nil { 448 ctx = context.Background() 449 } 450 451 retrier := backoff.WithMaxRetries(backoff.WithContext(&backoff.ExponentialBackOff{ 452 InitialInterval: backoff.DefaultInitialInterval, 453 RandomizationFactor: backoff.DefaultRandomizationFactor, 454 Multiplier: backoff.DefaultMultiplier, 455 MaxInterval: c.config.MaxRetryInterval, 456 MaxElapsedTime: c.config.MaxRetryElapsedTime, 457 Stop: backoff.Stop, 458 Clock: backoff.SystemClock, 459 }, ctx), uint64(c.config.MaxRetries)) 460 461 retrier.Reset() 462 463 return retrier 464 } 465 466 func errIsRetryable(err error) bool { 467 var pgErr *pgconn.PgError 468 if !errors.As(err, &pgErr) { 469 return false 470 } 471 472 // Assume that it's safe to retry 08006 and XX000 because we check for lock existence 473 // before creating and lock ID is primary key. Version field in migrations table is primary key too 474 // and delete all versions is an idempotent operation. 475 return pgErr.Code == pgerrcode.SerializationFailure || // optimistic locking conflict 476 pgErr.Code == pgerrcode.DeadlockDetected || 477 pgErr.Code == pgerrcode.ConnectionFailure || // node down, need to reconnect 478 pgErr.Code == pgerrcode.InternalError // may happen during HA 479 }