github.com/nokia/migrate/v4@v4.16.0/database/pgx/pgx.go (about) 1 //go:build go1.9 2 // +build go1.9 3 4 package pgx 5 6 import ( 7 "context" 8 "database/sql" 9 "fmt" 10 "io" 11 "io/ioutil" 12 nurl "net/url" 13 "regexp" 14 "strconv" 15 "strings" 16 "time" 17 18 "go.uber.org/atomic" 19 20 multierror "github.com/hashicorp/go-multierror" 21 "github.com/jackc/pgconn" 22 "github.com/jackc/pgerrcode" 23 _ "github.com/jackc/pgx/v4/stdlib" 24 "github.com/nokia/migrate/v4" 25 "github.com/nokia/migrate/v4/database" 26 "github.com/nokia/migrate/v4/database/multistmt" 27 "github.com/nokia/migrate/v4/source" 28 ) 29 30 func init() { 31 db := Postgres{} 32 database.Register("pgx", &db) 33 } 34 35 var ( 36 multiStmtDelimiter = []byte(";") 37 38 DefaultMigrationsTable = "schema_migrations" 39 DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB 40 ) 41 42 var ( 43 ErrNilConfig = fmt.Errorf("no config") 44 ErrNoDatabaseName = fmt.Errorf("no database name") 45 ErrNoSchema = fmt.Errorf("no schema") 46 ErrDatabaseDirty = fmt.Errorf("database is dirty") 47 ) 48 49 type Config struct { 50 MigrationsTable string 51 DatabaseName string 52 SchemaName string 53 migrationsSchemaName string 54 migrationsTableName string 55 StatementTimeout time.Duration 56 MigrationsTableQuoted bool 57 MultiStatementEnabled bool 58 MultiStatementMaxSize int 59 } 60 61 type Postgres struct { 62 // Locking and unlocking need to use the same connection 63 conn *sql.Conn 64 db *sql.DB 65 isLocked atomic.Bool 66 67 // Open and WithInstance need to guarantee that config is never nil 68 config *Config 69 } 70 71 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 72 if config == nil { 73 return nil, ErrNilConfig 74 } 75 76 if err := instance.Ping(); err != nil { 77 return nil, err 78 } 79 80 if config.DatabaseName == "" { 81 query := `SELECT CURRENT_DATABASE()` 82 var databaseName string 83 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 84 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 85 } 86 87 if len(databaseName) == 0 { 88 return nil, ErrNoDatabaseName 89 } 90 91 config.DatabaseName = databaseName 92 } 93 94 if config.SchemaName == "" { 95 query := `SELECT CURRENT_SCHEMA()` 96 var schemaName string 97 if err := instance.QueryRow(query).Scan(&schemaName); err != nil { 98 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 99 } 100 101 if len(schemaName) == 0 { 102 return nil, ErrNoSchema 103 } 104 105 config.SchemaName = schemaName 106 } 107 108 if len(config.MigrationsTable) == 0 { 109 config.MigrationsTable = DefaultMigrationsTable 110 } 111 112 config.migrationsSchemaName = config.SchemaName 113 config.migrationsTableName = config.MigrationsTable 114 if config.MigrationsTableQuoted { 115 re := regexp.MustCompile(`"(.*?)"`) 116 result := re.FindAllStringSubmatch(config.MigrationsTable, -1) 117 config.migrationsTableName = result[len(result)-1][1] 118 if len(result) == 2 { 119 config.migrationsSchemaName = result[0][1] 120 } else if len(result) > 2 { 121 return nil, fmt.Errorf("\"%s\" MigrationsTable contains too many dot characters", config.MigrationsTable) 122 } 123 } 124 125 conn, err := instance.Conn(context.Background()) 126 if err != nil { 127 return nil, err 128 } 129 130 px := &Postgres{ 131 conn: conn, 132 db: instance, 133 config: config, 134 } 135 136 if err := px.ensureVersionTable(); err != nil { 137 return nil, err 138 } 139 140 return px, nil 141 } 142 143 func (p *Postgres) Open(url string) (database.Driver, error) { 144 purl, err := nurl.Parse(url) 145 if err != nil { 146 return nil, err 147 } 148 149 // Driver is registered as pgx, but connection string must use postgres schema 150 // when making actual connection 151 // i.e. pgx://user:password@host:port/db => postgres://user:password@host:port/db 152 purl.Scheme = "postgres" 153 154 db, err := sql.Open("pgx", migrate.FilterCustomQuery(purl).String()) 155 if err != nil { 156 return nil, err 157 } 158 159 migrationsTable := purl.Query().Get("x-migrations-table") 160 migrationsTableQuoted := false 161 if s := purl.Query().Get("x-migrations-table-quoted"); len(s) > 0 { 162 migrationsTableQuoted, err = strconv.ParseBool(s) 163 if err != nil { 164 return nil, fmt.Errorf("Unable to parse option x-migrations-table-quoted: %w", err) 165 } 166 } 167 if (len(migrationsTable) > 0) && (migrationsTableQuoted) && ((migrationsTable[0] != '"') || (migrationsTable[len(migrationsTable)-1] != '"')) { 168 return nil, fmt.Errorf("x-migrations-table must be quoted (for instance '\"migrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: %s", migrationsTable) 169 } 170 171 statementTimeoutString := purl.Query().Get("x-statement-timeout") 172 statementTimeout := 0 173 if statementTimeoutString != "" { 174 statementTimeout, err = strconv.Atoi(statementTimeoutString) 175 if err != nil { 176 return nil, err 177 } 178 } 179 180 multiStatementMaxSize := DefaultMultiStatementMaxSize 181 if s := purl.Query().Get("x-multi-statement-max-size"); len(s) > 0 { 182 multiStatementMaxSize, err = strconv.Atoi(s) 183 if err != nil { 184 return nil, err 185 } 186 if multiStatementMaxSize <= 0 { 187 multiStatementMaxSize = DefaultMultiStatementMaxSize 188 } 189 } 190 191 multiStatementEnabled := false 192 if s := purl.Query().Get("x-multi-statement"); len(s) > 0 { 193 multiStatementEnabled, err = strconv.ParseBool(s) 194 if err != nil { 195 return nil, fmt.Errorf("Unable to parse option x-multi-statement: %w", err) 196 } 197 } 198 199 px, err := WithInstance(db, &Config{ 200 DatabaseName: purl.Path, 201 MigrationsTable: migrationsTable, 202 MigrationsTableQuoted: migrationsTableQuoted, 203 StatementTimeout: time.Duration(statementTimeout) * time.Millisecond, 204 MultiStatementEnabled: multiStatementEnabled, 205 MultiStatementMaxSize: multiStatementMaxSize, 206 }) 207 if err != nil { 208 return nil, err 209 } 210 211 return px, nil 212 } 213 214 func (p *Postgres) Close() error { 215 connErr := p.conn.Close() 216 dbErr := p.db.Close() 217 if connErr != nil || dbErr != nil { 218 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 219 } 220 return nil 221 } 222 223 // https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS 224 func (p *Postgres) Lock() error { 225 return database.CasRestoreOnErr(&p.isLocked, false, true, database.ErrLocked, func() error { 226 aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName) 227 if err != nil { 228 return err 229 } 230 231 // This will wait indefinitely until the lock can be acquired. 232 query := `SELECT pg_advisory_lock($1)` 233 if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { 234 return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} 235 } 236 return nil 237 }) 238 } 239 240 func (p *Postgres) Unlock() error { 241 return database.CasRestoreOnErr(&p.isLocked, true, false, database.ErrNotLocked, func() error { 242 aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName) 243 if err != nil { 244 return err 245 } 246 247 query := `SELECT pg_advisory_unlock($1)` 248 if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { 249 return &database.Error{OrigErr: err, Query: []byte(query)} 250 } 251 return nil 252 }) 253 } 254 255 func (p *Postgres) Run(migration io.Reader) error { 256 if p.config.MultiStatementEnabled { 257 var err error 258 if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool { 259 if err = p.runStatement(m); err != nil { 260 return false 261 } 262 return true 263 }); e != nil { 264 return e 265 } 266 return err 267 } 268 migr, err := ioutil.ReadAll(migration) 269 if err != nil { 270 return err 271 } 272 return p.runStatement(migr) 273 } 274 275 func (p *Postgres) RunFunctionMigration(fn source.MigrationFunc) error { 276 return database.ErrNotImpl 277 } 278 279 func (p *Postgres) runStatement(statement []byte) error { 280 ctx := context.Background() 281 if p.config.StatementTimeout != 0 { 282 var cancel context.CancelFunc 283 ctx, cancel = context.WithTimeout(ctx, p.config.StatementTimeout) 284 defer cancel() 285 } 286 query := string(statement) 287 if strings.TrimSpace(query) == "" { 288 return nil 289 } 290 if _, err := p.conn.ExecContext(ctx, query); err != nil { 291 292 if pgErr, ok := err.(*pgconn.PgError); ok { 293 var line uint 294 var col uint 295 var lineColOK bool 296 line, col, lineColOK = computeLineFromPos(query, int(pgErr.Position)) 297 message := fmt.Sprintf("migration failed: %s", pgErr.Message) 298 if lineColOK { 299 message = fmt.Sprintf("%s (column %d)", message, col) 300 } 301 if pgErr.Detail != "" { 302 message = fmt.Sprintf("%s, %s", message, pgErr.Detail) 303 } 304 return database.Error{OrigErr: err, Err: message, Query: statement, Line: line} 305 } 306 return database.Error{OrigErr: err, Err: "migration failed", Query: statement} 307 } 308 return nil 309 } 310 311 func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) { 312 // replace crlf with lf 313 s = strings.Replace(s, "\r\n", "\n", -1) 314 // pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes 315 runes := []rune(s) 316 if pos > len(runes) { 317 return 0, 0, false 318 } 319 sel := runes[:pos] 320 line = uint(runesCount(sel, newLine) + 1) 321 col = uint(pos - 1 - runesLastIndex(sel, newLine)) 322 return line, col, true 323 } 324 325 const newLine = '\n' 326 327 func runesCount(input []rune, target rune) int { 328 var count int 329 for _, r := range input { 330 if r == target { 331 count++ 332 } 333 } 334 return count 335 } 336 337 func runesLastIndex(input []rune, target rune) int { 338 for i := len(input) - 1; i >= 0; i-- { 339 if input[i] == target { 340 return i 341 } 342 } 343 return -1 344 } 345 346 func (p *Postgres) SetVersion(version int, dirty bool) error { 347 tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{}) 348 if err != nil { 349 return &database.Error{OrigErr: err, Err: "transaction start failed"} 350 } 351 352 query := `TRUNCATE ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) 353 if _, err := tx.Exec(query); err != nil { 354 if errRollback := tx.Rollback(); errRollback != nil { 355 err = multierror.Append(err, errRollback) 356 } 357 return &database.Error{OrigErr: err, Query: []byte(query)} 358 } 359 360 // Also re-write the schema version for nil dirty versions to prevent 361 // empty schema version for failed down migration on the first migration 362 // See: https://github.com/nokia/migrate/issues/330 363 if version >= 0 || (version == database.NilVersion && dirty) { 364 query = `INSERT INTO ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version, dirty) VALUES ($1, $2)` 365 if _, err := tx.Exec(query, version, dirty); err != nil { 366 if errRollback := tx.Rollback(); errRollback != nil { 367 err = multierror.Append(err, errRollback) 368 } 369 return &database.Error{OrigErr: err, Query: []byte(query)} 370 } 371 } 372 373 if err := tx.Commit(); err != nil { 374 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 375 } 376 377 return nil 378 } 379 380 func (p *Postgres) Version() (version int, dirty bool, err error) { 381 query := `SELECT version, dirty FROM ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` LIMIT 1` 382 err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 383 switch { 384 case err == sql.ErrNoRows: 385 return database.NilVersion, false, nil 386 387 case err != nil: 388 if e, ok := err.(*pgconn.PgError); ok { 389 if e.SQLState() == pgerrcode.UndefinedTable { 390 return database.NilVersion, false, nil 391 } 392 } 393 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 394 395 default: 396 return version, dirty, nil 397 } 398 } 399 400 func (p *Postgres) Drop() (err error) { 401 // select all tables in current schema 402 query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'` 403 tables, err := p.conn.QueryContext(context.Background(), query) 404 if err != nil { 405 return &database.Error{OrigErr: err, Query: []byte(query)} 406 } 407 defer func() { 408 if errClose := tables.Close(); errClose != nil { 409 err = multierror.Append(err, errClose) 410 } 411 }() 412 413 // delete one table after another 414 tableNames := make([]string, 0) 415 for tables.Next() { 416 var tableName string 417 if err := tables.Scan(&tableName); err != nil { 418 return err 419 } 420 if len(tableName) > 0 { 421 tableNames = append(tableNames, tableName) 422 } 423 } 424 if err := tables.Err(); err != nil { 425 return &database.Error{OrigErr: err, Query: []byte(query)} 426 } 427 428 if len(tableNames) > 0 { 429 // delete one by one ... 430 for _, t := range tableNames { 431 query = `DROP TABLE IF EXISTS ` + quoteIdentifier(t) + ` CASCADE` 432 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 433 return &database.Error{OrigErr: err, Query: []byte(query)} 434 } 435 } 436 } 437 438 return nil 439 } 440 441 // ensureVersionTable checks if versions table exists and, if not, creates it. 442 // Note that this function locks the database, which deviates from the usual 443 // convention of "caller locks" in the Postgres type. 444 func (p *Postgres) ensureVersionTable() (err error) { 445 if err = p.Lock(); err != nil { 446 return err 447 } 448 449 defer func() { 450 if e := p.Unlock(); e != nil { 451 if err == nil { 452 err = e 453 } else { 454 err = multierror.Append(err, e) 455 } 456 } 457 }() 458 459 // This block checks whether the `MigrationsTable` already exists. This is useful because it allows read only postgres 460 // users to also check the current version of the schema. Previously, even if `MigrationsTable` existed, the 461 // `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission. 462 // Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258 463 query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2 LIMIT 1` 464 row := p.conn.QueryRowContext(context.Background(), query, p.config.migrationsSchemaName, p.config.migrationsTableName) 465 466 var count int 467 err = row.Scan(&count) 468 if err != nil { 469 return &database.Error{OrigErr: err, Query: []byte(query)} 470 } 471 472 if count == 1 { 473 return nil 474 } 475 476 query = `CREATE TABLE IF NOT EXISTS ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version bigint not null primary key, dirty boolean not null)` 477 if _, err = p.conn.ExecContext(context.Background(), query); err != nil { 478 return &database.Error{OrigErr: err, Query: []byte(query)} 479 } 480 481 return nil 482 } 483 484 // Copied from lib/pq implementation: https://github.com/lib/pq/blob/v1.9.0/conn.go#L1611 485 func quoteIdentifier(name string) string { 486 end := strings.IndexRune(name, 0) 487 if end > -1 { 488 name = name[:end] 489 } 490 return `"` + strings.Replace(name, `"`, `""`, -1) + `"` 491 }