github.com/nokia/migrate/v4@v4.16.0/database/postgres/postgres.go (about) 1 //go:build go1.9 2 // +build go1.9 3 4 package postgres 5 6 import ( 7 "context" 8 "database/sql" 9 "fmt" 10 "io" 11 "io/ioutil" 12 nurl "net/url" 13 "regexp" 14 "strconv" 15 "strings" 16 "time" 17 18 "go.uber.org/atomic" 19 20 multierror "github.com/hashicorp/go-multierror" 21 "github.com/lib/pq" 22 "github.com/nokia/migrate/v4" 23 "github.com/nokia/migrate/v4/database" 24 "github.com/nokia/migrate/v4/database/multistmt" 25 "github.com/nokia/migrate/v4/source" 26 ) 27 28 func init() { 29 db := Postgres{} 30 database.Register("postgres", &db) 31 database.Register("postgresql", &db) 32 } 33 34 var ( 35 multiStmtDelimiter = []byte(";") 36 37 DefaultMigrationsTable = "schema_migrations" 38 DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB 39 ) 40 41 var ( 42 ErrNilConfig = fmt.Errorf("no config") 43 ErrNoDatabaseName = fmt.Errorf("no database name") 44 ErrNoSchema = fmt.Errorf("no schema") 45 ErrDatabaseDirty = fmt.Errorf("database is dirty") 46 ) 47 48 type Config struct { 49 MigrationsTable string 50 MigrationsTableQuoted bool 51 MultiStatementEnabled bool 52 DatabaseName string 53 SchemaName string 54 migrationsSchemaName string 55 migrationsTableName string 56 StatementTimeout time.Duration 57 MultiStatementMaxSize int 58 } 59 60 type Postgres struct { 61 // Locking and unlocking need to use the same connection 62 conn *sql.Conn 63 db *sql.DB 64 isLocked atomic.Bool 65 66 // Open and WithInstance need to guarantee that config is never nil 67 config *Config 68 } 69 70 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 71 if config == nil { 72 return nil, ErrNilConfig 73 } 74 75 if err := instance.Ping(); err != nil { 76 return nil, err 77 } 78 79 if config.DatabaseName == "" { 80 query := `SELECT CURRENT_DATABASE()` 81 var databaseName string 82 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 83 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 84 } 85 86 if len(databaseName) == 0 { 87 return nil, ErrNoDatabaseName 88 } 89 90 config.DatabaseName = databaseName 91 } 92 93 if config.SchemaName == "" { 94 query := `SELECT CURRENT_SCHEMA()` 95 var schemaName string 96 if err := instance.QueryRow(query).Scan(&schemaName); err != nil { 97 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 98 } 99 100 if len(schemaName) == 0 { 101 return nil, ErrNoSchema 102 } 103 104 config.SchemaName = schemaName 105 } 106 107 if len(config.MigrationsTable) == 0 { 108 config.MigrationsTable = DefaultMigrationsTable 109 } 110 111 config.migrationsSchemaName = config.SchemaName 112 config.migrationsTableName = config.MigrationsTable 113 if config.MigrationsTableQuoted { 114 re := regexp.MustCompile(`"(.*?)"`) 115 result := re.FindAllStringSubmatch(config.MigrationsTable, -1) 116 config.migrationsTableName = result[len(result)-1][1] 117 if len(result) == 2 { 118 config.migrationsSchemaName = result[0][1] 119 } else if len(result) > 2 { 120 return nil, fmt.Errorf("\"%s\" MigrationsTable contains too many dot characters", config.MigrationsTable) 121 } 122 } 123 124 conn, err := instance.Conn(context.Background()) 125 if err != nil { 126 return nil, err 127 } 128 129 px := &Postgres{ 130 conn: conn, 131 db: instance, 132 config: config, 133 } 134 135 if err := px.ensureVersionTable(); err != nil { 136 return nil, err 137 } 138 139 return px, nil 140 } 141 142 func (p *Postgres) Open(url string) (database.Driver, error) { 143 purl, err := nurl.Parse(url) 144 if err != nil { 145 return nil, err 146 } 147 148 db, err := sql.Open("postgres", migrate.FilterCustomQuery(purl).String()) 149 if err != nil { 150 return nil, err 151 } 152 153 migrationsTable := purl.Query().Get("x-migrations-table") 154 migrationsTableQuoted := false 155 if s := purl.Query().Get("x-migrations-table-quoted"); len(s) > 0 { 156 migrationsTableQuoted, err = strconv.ParseBool(s) 157 if err != nil { 158 return nil, fmt.Errorf("Unable to parse option x-migrations-table-quoted: %w", err) 159 } 160 } 161 if (len(migrationsTable) > 0) && (migrationsTableQuoted) && ((migrationsTable[0] != '"') || (migrationsTable[len(migrationsTable)-1] != '"')) { 162 return nil, fmt.Errorf("x-migrations-table must be quoted (for instance '\"migrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: %s", migrationsTable) 163 } 164 165 statementTimeoutString := purl.Query().Get("x-statement-timeout") 166 statementTimeout := 0 167 if statementTimeoutString != "" { 168 statementTimeout, err = strconv.Atoi(statementTimeoutString) 169 if err != nil { 170 return nil, err 171 } 172 } 173 174 multiStatementMaxSize := DefaultMultiStatementMaxSize 175 if s := purl.Query().Get("x-multi-statement-max-size"); len(s) > 0 { 176 multiStatementMaxSize, err = strconv.Atoi(s) 177 if err != nil { 178 return nil, err 179 } 180 if multiStatementMaxSize <= 0 { 181 multiStatementMaxSize = DefaultMultiStatementMaxSize 182 } 183 } 184 185 multiStatementEnabled := false 186 if s := purl.Query().Get("x-multi-statement"); len(s) > 0 { 187 multiStatementEnabled, err = strconv.ParseBool(s) 188 if err != nil { 189 return nil, fmt.Errorf("Unable to parse option x-multi-statement: %w", err) 190 } 191 } 192 193 px, err := WithInstance(db, &Config{ 194 DatabaseName: purl.Path, 195 MigrationsTable: migrationsTable, 196 MigrationsTableQuoted: migrationsTableQuoted, 197 StatementTimeout: time.Duration(statementTimeout) * time.Millisecond, 198 MultiStatementEnabled: multiStatementEnabled, 199 MultiStatementMaxSize: multiStatementMaxSize, 200 }) 201 if err != nil { 202 return nil, err 203 } 204 205 return px, nil 206 } 207 208 func (p *Postgres) Close() error { 209 connErr := p.conn.Close() 210 dbErr := p.db.Close() 211 if connErr != nil || dbErr != nil { 212 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 213 } 214 return nil 215 } 216 217 // https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS 218 func (p *Postgres) Lock() error { 219 return database.CasRestoreOnErr(&p.isLocked, false, true, database.ErrLocked, func() error { 220 aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName) 221 if err != nil { 222 return err 223 } 224 225 // This will wait indefinitely until the lock can be acquired. 226 query := `SELECT pg_advisory_lock($1)` 227 if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { 228 return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} 229 } 230 231 return nil 232 }) 233 } 234 235 func (p *Postgres) Unlock() error { 236 return database.CasRestoreOnErr(&p.isLocked, true, false, database.ErrNotLocked, func() error { 237 aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName) 238 if err != nil { 239 return err 240 } 241 242 query := `SELECT pg_advisory_unlock($1)` 243 if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { 244 return &database.Error{OrigErr: err, Query: []byte(query)} 245 } 246 return nil 247 }) 248 } 249 250 func (p *Postgres) Run(migration io.Reader) error { 251 if p.config.MultiStatementEnabled { 252 var err error 253 if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool { 254 if err = p.runStatement(m); err != nil { 255 return false 256 } 257 return true 258 }); e != nil { 259 return e 260 } 261 return err 262 } 263 migr, err := ioutil.ReadAll(migration) 264 if err != nil { 265 return err 266 } 267 return p.runStatement(migr) 268 } 269 270 func (p *Postgres) RunFunctionMigration(fn source.MigrationFunc) error { 271 return database.ErrNotImpl 272 } 273 274 func (p *Postgres) runStatement(statement []byte) error { 275 ctx := context.Background() 276 if p.config.StatementTimeout != 0 { 277 var cancel context.CancelFunc 278 ctx, cancel = context.WithTimeout(ctx, p.config.StatementTimeout) 279 defer cancel() 280 } 281 query := string(statement) 282 if strings.TrimSpace(query) == "" { 283 return nil 284 } 285 if _, err := p.conn.ExecContext(ctx, query); err != nil { 286 if pgErr, ok := err.(*pq.Error); ok { 287 var line uint 288 var col uint 289 var lineColOK bool 290 if pgErr.Position != "" { 291 if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil { 292 line, col, lineColOK = computeLineFromPos(query, int(pos)) 293 } 294 } 295 message := fmt.Sprintf("migration failed: %s", pgErr.Message) 296 if lineColOK { 297 message = fmt.Sprintf("%s (column %d)", message, col) 298 } 299 if pgErr.Detail != "" { 300 message = fmt.Sprintf("%s, %s", message, pgErr.Detail) 301 } 302 return database.Error{OrigErr: err, Err: message, Query: statement, Line: line} 303 } 304 return database.Error{OrigErr: err, Err: "migration failed", Query: statement} 305 } 306 return nil 307 } 308 309 func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) { 310 // replace crlf with lf 311 s = strings.Replace(s, "\r\n", "\n", -1) 312 // pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes 313 runes := []rune(s) 314 if pos > len(runes) { 315 return 0, 0, false 316 } 317 sel := runes[:pos] 318 line = uint(runesCount(sel, newLine) + 1) 319 col = uint(pos - 1 - runesLastIndex(sel, newLine)) 320 return line, col, true 321 } 322 323 const newLine = '\n' 324 325 func runesCount(input []rune, target rune) int { 326 var count int 327 for _, r := range input { 328 if r == target { 329 count++ 330 } 331 } 332 return count 333 } 334 335 func runesLastIndex(input []rune, target rune) int { 336 for i := len(input) - 1; i >= 0; i-- { 337 if input[i] == target { 338 return i 339 } 340 } 341 return -1 342 } 343 344 func (p *Postgres) SetVersion(version int, dirty bool) error { 345 tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{}) 346 if err != nil { 347 return &database.Error{OrigErr: err, Err: "transaction start failed"} 348 } 349 350 query := `TRUNCATE ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName) 351 if _, err := tx.Exec(query); err != nil { 352 if errRollback := tx.Rollback(); errRollback != nil { 353 err = multierror.Append(err, errRollback) 354 } 355 return &database.Error{OrigErr: err, Query: []byte(query)} 356 } 357 358 // Also re-write the schema version for nil dirty versions to prevent 359 // empty schema version for failed down migration on the first migration 360 // See: https://github.com/nokia/migrate/issues/330 361 if version >= 0 || (version == database.NilVersion && dirty) { 362 query = `INSERT INTO ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName) + ` (version, dirty) VALUES ($1, $2)` 363 if _, err := tx.Exec(query, version, dirty); err != nil { 364 if errRollback := tx.Rollback(); errRollback != nil { 365 err = multierror.Append(err, errRollback) 366 } 367 return &database.Error{OrigErr: err, Query: []byte(query)} 368 } 369 } 370 371 if err := tx.Commit(); err != nil { 372 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 373 } 374 375 return nil 376 } 377 378 func (p *Postgres) Version() (version int, dirty bool, err error) { 379 query := `SELECT version, dirty FROM ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName) + ` LIMIT 1` 380 err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 381 switch { 382 case err == sql.ErrNoRows: 383 return database.NilVersion, false, nil 384 385 case err != nil: 386 if e, ok := err.(*pq.Error); ok { 387 if e.Code.Name() == "undefined_table" { 388 return database.NilVersion, false, nil 389 } 390 } 391 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 392 393 default: 394 return version, dirty, nil 395 } 396 } 397 398 func (p *Postgres) Drop() (err error) { 399 // select all tables in current schema 400 query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'` 401 tables, err := p.conn.QueryContext(context.Background(), query) 402 if err != nil { 403 return &database.Error{OrigErr: err, Query: []byte(query)} 404 } 405 defer func() { 406 if errClose := tables.Close(); errClose != nil { 407 err = multierror.Append(err, errClose) 408 } 409 }() 410 411 // delete one table after another 412 tableNames := make([]string, 0) 413 for tables.Next() { 414 var tableName string 415 if err := tables.Scan(&tableName); err != nil { 416 return err 417 } 418 if len(tableName) > 0 { 419 tableNames = append(tableNames, tableName) 420 } 421 } 422 if err := tables.Err(); err != nil { 423 return &database.Error{OrigErr: err, Query: []byte(query)} 424 } 425 426 if len(tableNames) > 0 { 427 // delete one by one ... 428 for _, t := range tableNames { 429 query = `DROP TABLE IF EXISTS ` + pq.QuoteIdentifier(t) + ` CASCADE` 430 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 431 return &database.Error{OrigErr: err, Query: []byte(query)} 432 } 433 } 434 } 435 436 return nil 437 } 438 439 // ensureVersionTable checks if versions table exists and, if not, creates it. 440 // Note that this function locks the database, which deviates from the usual 441 // convention of "caller locks" in the Postgres type. 442 func (p *Postgres) ensureVersionTable() (err error) { 443 if err = p.Lock(); err != nil { 444 return err 445 } 446 447 defer func() { 448 if e := p.Unlock(); e != nil { 449 if err == nil { 450 err = e 451 } else { 452 err = multierror.Append(err, e) 453 } 454 } 455 }() 456 457 // This block checks whether the `MigrationsTable` already exists. This is useful because it allows read only postgres 458 // users to also check the current version of the schema. Previously, even if `MigrationsTable` existed, the 459 // `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission. 460 // Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258 461 query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2 LIMIT 1` 462 row := p.conn.QueryRowContext(context.Background(), query, p.config.migrationsSchemaName, p.config.migrationsTableName) 463 464 var count int 465 err = row.Scan(&count) 466 if err != nil { 467 return &database.Error{OrigErr: err, Query: []byte(query)} 468 } 469 470 if count == 1 { 471 return nil 472 } 473 474 query = `CREATE TABLE IF NOT EXISTS ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName) + ` (version bigint not null primary key, dirty boolean not null)` 475 if _, err = p.conn.ExecContext(context.Background(), query); err != nil { 476 return &database.Error{OrigErr: err, Query: []byte(query)} 477 } 478 479 return nil 480 }