github.com/bdollma-te/migrate/v4@v4.17.0-clickv2/database/pgx/pgx.go (about) 1 //go:build go1.9 2 // +build go1.9 3 4 package pgx 5 6 import ( 7 "context" 8 "database/sql" 9 "fmt" 10 "io" 11 nurl "net/url" 12 "regexp" 13 "strconv" 14 "strings" 15 "time" 16 17 "go.uber.org/atomic" 18 19 "github.com/bdollma-te/migrate/v4" 20 "github.com/bdollma-te/migrate/v4/database" 21 "github.com/bdollma-te/migrate/v4/database/multistmt" 22 "github.com/hashicorp/go-multierror" 23 "github.com/jackc/pgconn" 24 "github.com/jackc/pgerrcode" 25 _ "github.com/jackc/pgx/v4/stdlib" 26 "github.com/lib/pq" 27 ) 28 29 const ( 30 LockStrategyAdvisory = "advisory" 31 LockStrategyTable = "table" 32 ) 33 34 func init() { 35 db := Postgres{} 36 database.Register("pgx", &db) 37 database.Register("pgx4", &db) 38 } 39 40 var ( 41 multiStmtDelimiter = []byte(";") 42 43 DefaultMigrationsTable = "schema_migrations" 44 DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB 45 DefaultLockTable = "schema_lock" 46 DefaultLockStrategy = LockStrategyAdvisory 47 ) 48 49 var ( 50 ErrNilConfig = fmt.Errorf("no config") 51 ErrNoDatabaseName = fmt.Errorf("no database name") 52 ErrNoSchema = fmt.Errorf("no schema") 53 ErrDatabaseDirty = fmt.Errorf("database is dirty") 54 ) 55 56 type Config struct { 57 MigrationsTable string 58 DatabaseName string 59 SchemaName string 60 LockTable string 61 LockStrategy string 62 migrationsSchemaName string 63 migrationsTableName string 64 StatementTimeout time.Duration 65 MigrationsTableQuoted bool 66 MultiStatementEnabled bool 67 MultiStatementMaxSize int 68 } 69 70 type Postgres struct { 71 // Locking and unlocking need to use the same connection 72 conn *sql.Conn 73 db *sql.DB 74 isLocked atomic.Bool 75 76 // Open and WithInstance need to guarantee that config is never nil 77 config *Config 78 } 79 80 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 81 if config == nil { 82 return nil, ErrNilConfig 83 } 84 85 if err := instance.Ping(); err != nil { 86 return nil, err 87 } 88 89 if config.DatabaseName == "" { 90 query := `SELECT CURRENT_DATABASE()` 91 var databaseName string 92 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 93 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 94 } 95 96 if len(databaseName) == 0 { 97 return nil, ErrNoDatabaseName 98 } 99 100 config.DatabaseName = databaseName 101 } 102 103 if config.SchemaName == "" { 104 query := `SELECT CURRENT_SCHEMA()` 105 var schemaName string 106 if err := instance.QueryRow(query).Scan(&schemaName); err != nil { 107 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 108 } 109 110 if len(schemaName) == 0 { 111 return nil, ErrNoSchema 112 } 113 114 config.SchemaName = schemaName 115 } 116 117 if len(config.MigrationsTable) == 0 { 118 config.MigrationsTable = DefaultMigrationsTable 119 } 120 121 if len(config.LockTable) == 0 { 122 config.LockTable = DefaultLockTable 123 } 124 125 if len(config.LockStrategy) == 0 { 126 config.LockStrategy = DefaultLockStrategy 127 } 128 129 config.migrationsSchemaName = config.SchemaName 130 config.migrationsTableName = config.MigrationsTable 131 if config.MigrationsTableQuoted { 132 re := regexp.MustCompile(`"(.*?)"`) 133 result := re.FindAllStringSubmatch(config.MigrationsTable, -1) 134 config.migrationsTableName = result[len(result)-1][1] 135 if len(result) == 2 { 136 config.migrationsSchemaName = result[0][1] 137 } else if len(result) > 2 { 138 return nil, fmt.Errorf("\"%s\" MigrationsTable contains too many dot characters", config.MigrationsTable) 139 } 140 } 141 142 conn, err := instance.Conn(context.Background()) 143 144 if err != nil { 145 return nil, err 146 } 147 148 px := &Postgres{ 149 conn: conn, 150 db: instance, 151 config: config, 152 } 153 154 if err := px.ensureLockTable(); err != nil { 155 return nil, err 156 } 157 158 if err := px.ensureVersionTable(); err != nil { 159 return nil, err 160 } 161 162 return px, nil 163 } 164 165 func (p *Postgres) Open(url string) (database.Driver, error) { 166 purl, err := nurl.Parse(url) 167 if err != nil { 168 return nil, err 169 } 170 171 // Driver is registered as pgx, but connection string must use postgres schema 172 // when making actual connection 173 // i.e. pgx://user:password@host:port/db => postgres://user:password@host:port/db 174 purl.Scheme = "postgres" 175 176 db, err := sql.Open("pgx/v4", migrate.FilterCustomQuery(purl).String()) 177 if err != nil { 178 return nil, err 179 } 180 181 migrationsTable := purl.Query().Get("x-migrations-table") 182 migrationsTableQuoted := false 183 if s := purl.Query().Get("x-migrations-table-quoted"); len(s) > 0 { 184 migrationsTableQuoted, err = strconv.ParseBool(s) 185 if err != nil { 186 return nil, fmt.Errorf("Unable to parse option x-migrations-table-quoted: %w", err) 187 } 188 } 189 if (len(migrationsTable) > 0) && (migrationsTableQuoted) && ((migrationsTable[0] != '"') || (migrationsTable[len(migrationsTable)-1] != '"')) { 190 return nil, fmt.Errorf("x-migrations-table must be quoted (for instance '\"migrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: %s", migrationsTable) 191 } 192 193 statementTimeoutString := purl.Query().Get("x-statement-timeout") 194 statementTimeout := 0 195 if statementTimeoutString != "" { 196 statementTimeout, err = strconv.Atoi(statementTimeoutString) 197 if err != nil { 198 return nil, err 199 } 200 } 201 202 multiStatementMaxSize := DefaultMultiStatementMaxSize 203 if s := purl.Query().Get("x-multi-statement-max-size"); len(s) > 0 { 204 multiStatementMaxSize, err = strconv.Atoi(s) 205 if err != nil { 206 return nil, err 207 } 208 if multiStatementMaxSize <= 0 { 209 multiStatementMaxSize = DefaultMultiStatementMaxSize 210 } 211 } 212 213 multiStatementEnabled := false 214 if s := purl.Query().Get("x-multi-statement"); len(s) > 0 { 215 multiStatementEnabled, err = strconv.ParseBool(s) 216 if err != nil { 217 return nil, fmt.Errorf("Unable to parse option x-multi-statement: %w", err) 218 } 219 } 220 221 lockStrategy := purl.Query().Get("x-lock-strategy") 222 lockTable := purl.Query().Get("x-lock-table") 223 224 px, err := WithInstance(db, &Config{ 225 DatabaseName: purl.Path, 226 MigrationsTable: migrationsTable, 227 MigrationsTableQuoted: migrationsTableQuoted, 228 StatementTimeout: time.Duration(statementTimeout) * time.Millisecond, 229 MultiStatementEnabled: multiStatementEnabled, 230 MultiStatementMaxSize: multiStatementMaxSize, 231 LockStrategy: lockStrategy, 232 LockTable: lockTable, 233 }) 234 235 if err != nil { 236 return nil, err 237 } 238 239 return px, nil 240 } 241 242 func (p *Postgres) Close() error { 243 connErr := p.conn.Close() 244 dbErr := p.db.Close() 245 if connErr != nil || dbErr != nil { 246 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 247 } 248 return nil 249 } 250 251 func (p *Postgres) Lock() error { 252 return database.CasRestoreOnErr(&p.isLocked, false, true, database.ErrLocked, func() error { 253 switch p.config.LockStrategy { 254 case LockStrategyAdvisory: 255 return p.applyAdvisoryLock() 256 case LockStrategyTable: 257 return p.applyTableLock() 258 default: 259 return fmt.Errorf("unknown lock strategy \"%s\"", p.config.LockStrategy) 260 } 261 }) 262 } 263 264 func (p *Postgres) Unlock() error { 265 return database.CasRestoreOnErr(&p.isLocked, true, false, database.ErrNotLocked, func() error { 266 switch p.config.LockStrategy { 267 case LockStrategyAdvisory: 268 return p.releaseAdvisoryLock() 269 case LockStrategyTable: 270 return p.releaseTableLock() 271 default: 272 return fmt.Errorf("unknown lock strategy \"%s\"", p.config.LockStrategy) 273 } 274 }) 275 } 276 277 // https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS 278 func (p *Postgres) applyAdvisoryLock() error { 279 aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName) 280 if err != nil { 281 return err 282 } 283 284 // This will wait indefinitely until the lock can be acquired. 285 query := `SELECT pg_advisory_lock($1)` 286 if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { 287 return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} 288 } 289 return nil 290 } 291 292 func (p *Postgres) applyTableLock() error { 293 tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{}) 294 if err != nil { 295 return &database.Error{OrigErr: err, Err: "transaction start failed"} 296 } 297 defer func() { 298 errRollback := tx.Rollback() 299 if errRollback != nil { 300 err = multierror.Append(err, errRollback) 301 } 302 }() 303 304 aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName) 305 if err != nil { 306 return err 307 } 308 309 query := "SELECT * FROM " + pq.QuoteIdentifier(p.config.LockTable) + " WHERE lock_id = $1" 310 rows, err := tx.Query(query, aid) 311 if err != nil { 312 return database.Error{OrigErr: err, Err: "failed to fetch migration lock", Query: []byte(query)} 313 } 314 315 defer func() { 316 if errClose := rows.Close(); errClose != nil { 317 err = multierror.Append(err, errClose) 318 } 319 }() 320 321 // If row exists at all, lock is present 322 locked := rows.Next() 323 if locked { 324 return database.ErrLocked 325 } 326 327 query = "INSERT INTO " + pq.QuoteIdentifier(p.config.LockTable) + " (lock_id) VALUES ($1)" 328 if _, err := tx.Exec(query, aid); err != nil { 329 return database.Error{OrigErr: err, Err: "failed to set migration lock", Query: []byte(query)} 330 } 331 332 return tx.Commit() 333 } 334 335 func (p *Postgres) releaseAdvisoryLock() error { 336 aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName) 337 if err != nil { 338 return err 339 } 340 341 query := `SELECT pg_advisory_unlock($1)` 342 if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { 343 return &database.Error{OrigErr: err, Query: []byte(query)} 344 } 345 346 return nil 347 } 348 349 func (p *Postgres) releaseTableLock() error { 350 aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName) 351 if err != nil { 352 return err 353 } 354 355 query := "DELETE FROM " + pq.QuoteIdentifier(p.config.LockTable) + " WHERE lock_id = $1" 356 if _, err := p.db.Exec(query, aid); err != nil { 357 return database.Error{OrigErr: err, Err: "failed to release migration lock", Query: []byte(query)} 358 } 359 360 return nil 361 } 362 363 func (p *Postgres) Run(migration io.Reader) error { 364 if p.config.MultiStatementEnabled { 365 var err error 366 if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool { 367 if err = p.runStatement(m); err != nil { 368 return false 369 } 370 return true 371 }); e != nil { 372 return e 373 } 374 return err 375 } 376 migr, err := io.ReadAll(migration) 377 if err != nil { 378 return err 379 } 380 return p.runStatement(migr) 381 } 382 383 func (p *Postgres) runStatement(statement []byte) error { 384 ctx := context.Background() 385 if p.config.StatementTimeout != 0 { 386 var cancel context.CancelFunc 387 ctx, cancel = context.WithTimeout(ctx, p.config.StatementTimeout) 388 defer cancel() 389 } 390 query := string(statement) 391 if strings.TrimSpace(query) == "" { 392 return nil 393 } 394 if _, err := p.conn.ExecContext(ctx, query); err != nil { 395 396 if pgErr, ok := err.(*pgconn.PgError); ok { 397 var line uint 398 var col uint 399 var lineColOK bool 400 line, col, lineColOK = computeLineFromPos(query, int(pgErr.Position)) 401 message := fmt.Sprintf("migration failed: %s", pgErr.Message) 402 if lineColOK { 403 message = fmt.Sprintf("%s (column %d)", message, col) 404 } 405 if pgErr.Detail != "" { 406 message = fmt.Sprintf("%s, %s", message, pgErr.Detail) 407 } 408 return database.Error{OrigErr: err, Err: message, Query: statement, Line: line} 409 } 410 return database.Error{OrigErr: err, Err: "migration failed", Query: statement} 411 } 412 return nil 413 } 414 415 func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) { 416 // replace crlf with lf 417 s = strings.Replace(s, "\r\n", "\n", -1) 418 // pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes 419 runes := []rune(s) 420 if pos > len(runes) { 421 return 0, 0, false 422 } 423 sel := runes[:pos] 424 line = uint(runesCount(sel, newLine) + 1) 425 col = uint(pos - 1 - runesLastIndex(sel, newLine)) 426 return line, col, true 427 } 428 429 const newLine = '\n' 430 431 func runesCount(input []rune, target rune) int { 432 var count int 433 for _, r := range input { 434 if r == target { 435 count++ 436 } 437 } 438 return count 439 } 440 441 func runesLastIndex(input []rune, target rune) int { 442 for i := len(input) - 1; i >= 0; i-- { 443 if input[i] == target { 444 return i 445 } 446 } 447 return -1 448 } 449 450 func (p *Postgres) SetVersion(version int, dirty bool) error { 451 tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{}) 452 if err != nil { 453 return &database.Error{OrigErr: err, Err: "transaction start failed"} 454 } 455 456 query := `TRUNCATE ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) 457 if _, err := tx.Exec(query); err != nil { 458 if errRollback := tx.Rollback(); errRollback != nil { 459 err = multierror.Append(err, errRollback) 460 } 461 return &database.Error{OrigErr: err, Query: []byte(query)} 462 } 463 464 // Also re-write the schema version for nil dirty versions to prevent 465 // empty schema version for failed down migration on the first migration 466 // See: https://github.com/golang-migrate/migrate/issues/330 467 if version >= 0 || (version == database.NilVersion && dirty) { 468 query = `INSERT INTO ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version, dirty) VALUES ($1, $2)` 469 if _, err := tx.Exec(query, version, dirty); err != nil { 470 if errRollback := tx.Rollback(); errRollback != nil { 471 err = multierror.Append(err, errRollback) 472 } 473 return &database.Error{OrigErr: err, Query: []byte(query)} 474 } 475 } 476 477 if err := tx.Commit(); err != nil { 478 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 479 } 480 481 return nil 482 } 483 484 func (p *Postgres) Version() (version int, dirty bool, err error) { 485 query := `SELECT version, dirty FROM ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` LIMIT 1` 486 err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 487 switch { 488 case err == sql.ErrNoRows: 489 return database.NilVersion, false, nil 490 491 case err != nil: 492 if e, ok := err.(*pgconn.PgError); ok { 493 if e.SQLState() == pgerrcode.UndefinedTable { 494 return database.NilVersion, false, nil 495 } 496 } 497 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 498 499 default: 500 return version, dirty, nil 501 } 502 } 503 504 func (p *Postgres) Drop() (err error) { 505 // select all tables in current schema 506 query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'` 507 tables, err := p.conn.QueryContext(context.Background(), query) 508 if err != nil { 509 return &database.Error{OrigErr: err, Query: []byte(query)} 510 } 511 defer func() { 512 if errClose := tables.Close(); errClose != nil { 513 err = multierror.Append(err, errClose) 514 } 515 }() 516 517 // delete one table after another 518 tableNames := make([]string, 0) 519 for tables.Next() { 520 var tableName string 521 if err := tables.Scan(&tableName); err != nil { 522 return err 523 } 524 525 // do not drop lock table 526 if tableName == p.config.LockTable && p.config.LockStrategy == LockStrategyTable { 527 continue 528 } 529 530 if len(tableName) > 0 { 531 tableNames = append(tableNames, tableName) 532 } 533 } 534 if err := tables.Err(); err != nil { 535 return &database.Error{OrigErr: err, Query: []byte(query)} 536 } 537 538 if len(tableNames) > 0 { 539 // delete one by one ... 540 for _, t := range tableNames { 541 query = `DROP TABLE IF EXISTS ` + quoteIdentifier(t) + ` CASCADE` 542 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 543 return &database.Error{OrigErr: err, Query: []byte(query)} 544 } 545 } 546 } 547 548 return nil 549 } 550 551 // ensureVersionTable checks if versions table exists and, if not, creates it. 552 // Note that this function locks the database, which deviates from the usual 553 // convention of "caller locks" in the Postgres type. 554 func (p *Postgres) ensureVersionTable() (err error) { 555 if err = p.Lock(); err != nil { 556 return err 557 } 558 559 defer func() { 560 if e := p.Unlock(); e != nil { 561 if err == nil { 562 err = e 563 } else { 564 err = multierror.Append(err, e) 565 } 566 } 567 }() 568 569 // This block checks whether the `MigrationsTable` already exists. This is useful because it allows read only postgres 570 // users to also check the current version of the schema. Previously, even if `MigrationsTable` existed, the 571 // `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission. 572 // Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258 573 query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2 LIMIT 1` 574 row := p.conn.QueryRowContext(context.Background(), query, p.config.migrationsSchemaName, p.config.migrationsTableName) 575 576 var count int 577 err = row.Scan(&count) 578 if err != nil { 579 return &database.Error{OrigErr: err, Query: []byte(query)} 580 } 581 582 if count == 1 { 583 return nil 584 } 585 586 query = `CREATE TABLE IF NOT EXISTS ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version bigint not null primary key, dirty boolean not null)` 587 if _, err = p.conn.ExecContext(context.Background(), query); err != nil { 588 return &database.Error{OrigErr: err, Query: []byte(query)} 589 } 590 591 return nil 592 } 593 594 func (p *Postgres) ensureLockTable() error { 595 if p.config.LockStrategy != LockStrategyTable { 596 return nil 597 } 598 599 var count int 600 query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1` 601 if err := p.db.QueryRow(query, p.config.LockTable).Scan(&count); err != nil { 602 return &database.Error{OrigErr: err, Query: []byte(query)} 603 } 604 if count == 1 { 605 return nil 606 } 607 608 query = `CREATE TABLE ` + pq.QuoteIdentifier(p.config.LockTable) + ` (lock_id BIGINT NOT NULL PRIMARY KEY)` 609 if _, err := p.db.Exec(query); err != nil { 610 return &database.Error{OrigErr: err, Query: []byte(query)} 611 } 612 613 return nil 614 } 615 616 // Copied from lib/pq implementation: https://github.com/lib/pq/blob/v1.9.0/conn.go#L1611 617 func quoteIdentifier(name string) string { 618 end := strings.IndexRune(name, 0) 619 if end > -1 { 620 name = name[:end] 621 } 622 return `"` + strings.Replace(name, `"`, `""`, -1) + `"` 623 }