github.com/seashell-org/golang-migrate/v4@v4.15.3-0.20220722221203-6ab6c6c062d1/database/postgres/postgres.go (about) 1 //go:build go1.9 2 // +build go1.9 3 4 package postgres 5 6 import ( 7 "context" 8 "database/sql" 9 "fmt" 10 "io" 11 "io/ioutil" 12 nurl "net/url" 13 "regexp" 14 "strconv" 15 "strings" 16 "time" 17 18 "go.uber.org/atomic" 19 20 "github.com/hashicorp/go-multierror" 21 "github.com/lib/pq" 22 migrate "github.com/seashell-org/golang-migrate/v4" 23 "github.com/seashell-org/golang-migrate/v4/database" 24 "github.com/seashell-org/golang-migrate/v4/database/multistmt" 25 ) 26 27 func init() { 28 db := Postgres{} 29 database.Register("postgres", &db) 30 database.Register("postgresql", &db) 31 } 32 33 var ( 34 multiStmtDelimiter = []byte(";") 35 36 DefaultMigrationsTable = "schema_migrations" 37 DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB 38 ) 39 40 var ( 41 ErrNilConfig = fmt.Errorf("no config") 42 ErrNoDatabaseName = fmt.Errorf("no database name") 43 ErrNoSchema = fmt.Errorf("no schema") 44 ErrDatabaseDirty = fmt.Errorf("database is dirty") 45 ) 46 47 type Config struct { 48 MigrationsTable string 49 MigrationsTableQuoted bool 50 MultiStatementEnabled bool 51 DatabaseName string 52 SchemaName string 53 migrationsSchemaName string 54 migrationsTableName string 55 StatementTimeout time.Duration 56 MultiStatementMaxSize int 57 } 58 59 type Postgres struct { 60 // Locking and unlocking need to use the same connection 61 conn *sql.Conn 62 db *sql.DB 63 isLocked atomic.Bool 64 65 // Open and WithInstance need to guarantee that config is never nil 66 config *Config 67 } 68 69 func WithConnection(ctx context.Context, conn *sql.Conn, config *Config) (*Postgres, error) { 70 if config == nil { 71 return nil, ErrNilConfig 72 } 73 74 if err := conn.PingContext(ctx); err != nil { 75 return nil, err 76 } 77 78 if config.DatabaseName == "" { 79 query := `SELECT CURRENT_DATABASE()` 80 var databaseName string 81 if err := conn.QueryRowContext(ctx, query).Scan(&databaseName); err != nil { 82 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 83 } 84 85 if len(databaseName) == 0 { 86 return nil, ErrNoDatabaseName 87 } 88 89 config.DatabaseName = databaseName 90 } 91 92 if config.SchemaName == "" { 93 query := `SELECT CURRENT_SCHEMA()` 94 var schemaName string 95 if err := conn.QueryRowContext(ctx, query).Scan(&schemaName); err != nil { 96 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 97 } 98 99 if len(schemaName) == 0 { 100 return nil, ErrNoSchema 101 } 102 103 config.SchemaName = schemaName 104 } 105 106 if len(config.MigrationsTable) == 0 { 107 config.MigrationsTable = DefaultMigrationsTable 108 } 109 110 config.migrationsSchemaName = config.SchemaName 111 config.migrationsTableName = config.MigrationsTable 112 if config.MigrationsTableQuoted { 113 re := regexp.MustCompile(`"(.*?)"`) 114 result := re.FindAllStringSubmatch(config.MigrationsTable, -1) 115 config.migrationsTableName = result[len(result)-1][1] 116 if len(result) == 2 { 117 config.migrationsSchemaName = result[0][1] 118 } else if len(result) > 2 { 119 return nil, fmt.Errorf("\"%s\" MigrationsTable contains too many dot characters", config.MigrationsTable) 120 } 121 } 122 123 px := &Postgres{ 124 conn: conn, 125 config: config, 126 } 127 128 if err := px.ensureVersionTable(); err != nil { 129 return nil, err 130 } 131 132 return px, nil 133 } 134 135 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 136 ctx := context.Background() 137 138 if err := instance.Ping(); err != nil { 139 return nil, err 140 } 141 142 conn, err := instance.Conn(ctx) 143 if err != nil { 144 return nil, err 145 } 146 147 px, err := WithConnection(ctx, conn, config) 148 if err != nil { 149 return nil, err 150 } 151 px.db = instance 152 return px, nil 153 } 154 155 func (p *Postgres) Open(url string) (database.Driver, error) { 156 purl, err := nurl.Parse(url) 157 if err != nil { 158 return nil, err 159 } 160 161 db, err := sql.Open("postgres", migrate.FilterCustomQuery(purl).String()) 162 if err != nil { 163 return nil, err 164 } 165 166 migrationsTable := purl.Query().Get("x-migrations-table") 167 migrationsTableQuoted := false 168 if s := purl.Query().Get("x-migrations-table-quoted"); len(s) > 0 { 169 migrationsTableQuoted, err = strconv.ParseBool(s) 170 if err != nil { 171 return nil, fmt.Errorf("Unable to parse option x-migrations-table-quoted: %w", err) 172 } 173 } 174 if (len(migrationsTable) > 0) && (migrationsTableQuoted) && ((migrationsTable[0] != '"') || (migrationsTable[len(migrationsTable)-1] != '"')) { 175 return nil, fmt.Errorf("x-migrations-table must be quoted (for instance '\"migrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: %s", migrationsTable) 176 } 177 178 statementTimeoutString := purl.Query().Get("x-statement-timeout") 179 statementTimeout := 0 180 if statementTimeoutString != "" { 181 statementTimeout, err = strconv.Atoi(statementTimeoutString) 182 if err != nil { 183 return nil, err 184 } 185 } 186 187 multiStatementMaxSize := DefaultMultiStatementMaxSize 188 if s := purl.Query().Get("x-multi-statement-max-size"); len(s) > 0 { 189 multiStatementMaxSize, err = strconv.Atoi(s) 190 if err != nil { 191 return nil, err 192 } 193 if multiStatementMaxSize <= 0 { 194 multiStatementMaxSize = DefaultMultiStatementMaxSize 195 } 196 } 197 198 multiStatementEnabled := false 199 if s := purl.Query().Get("x-multi-statement"); len(s) > 0 { 200 multiStatementEnabled, err = strconv.ParseBool(s) 201 if err != nil { 202 return nil, fmt.Errorf("Unable to parse option x-multi-statement: %w", err) 203 } 204 } 205 206 px, err := WithInstance(db, &Config{ 207 DatabaseName: purl.Path, 208 MigrationsTable: migrationsTable, 209 MigrationsTableQuoted: migrationsTableQuoted, 210 StatementTimeout: time.Duration(statementTimeout) * time.Millisecond, 211 MultiStatementEnabled: multiStatementEnabled, 212 MultiStatementMaxSize: multiStatementMaxSize, 213 }) 214 215 if err != nil { 216 return nil, err 217 } 218 219 return px, nil 220 } 221 222 func (p *Postgres) Close() error { 223 connErr := p.conn.Close() 224 var dbErr error 225 if p.db != nil { 226 dbErr = p.db.Close() 227 } 228 229 if connErr != nil || dbErr != nil { 230 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 231 } 232 return nil 233 } 234 235 // https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS 236 func (p *Postgres) Lock() error { 237 return database.CasRestoreOnErr(&p.isLocked, false, true, database.ErrLocked, func() error { 238 aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName) 239 if err != nil { 240 return err 241 } 242 243 // This will wait indefinitely until the lock can be acquired. 244 query := `SELECT pg_advisory_lock($1)` 245 if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { 246 return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} 247 } 248 249 return nil 250 }) 251 } 252 253 func (p *Postgres) Unlock() error { 254 return database.CasRestoreOnErr(&p.isLocked, true, false, database.ErrNotLocked, func() error { 255 aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName) 256 if err != nil { 257 return err 258 } 259 260 query := `SELECT pg_advisory_unlock($1)` 261 if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { 262 return &database.Error{OrigErr: err, Query: []byte(query)} 263 } 264 return nil 265 }) 266 } 267 268 func (p *Postgres) Run(migration io.Reader) error { 269 if p.config.MultiStatementEnabled { 270 var err error 271 if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool { 272 if err = p.runStatement(m); err != nil { 273 return false 274 } 275 return true 276 }); e != nil { 277 return e 278 } 279 return err 280 } 281 migr, err := ioutil.ReadAll(migration) 282 if err != nil { 283 return err 284 } 285 return p.runStatement(migr) 286 } 287 288 func (p *Postgres) runStatement(statement []byte) error { 289 ctx := context.Background() 290 if p.config.StatementTimeout != 0 { 291 var cancel context.CancelFunc 292 ctx, cancel = context.WithTimeout(ctx, p.config.StatementTimeout) 293 defer cancel() 294 } 295 query := string(statement) 296 if strings.TrimSpace(query) == "" { 297 return nil 298 } 299 if _, err := p.conn.ExecContext(ctx, query); err != nil { 300 if pgErr, ok := err.(*pq.Error); ok { 301 var line uint 302 var col uint 303 var lineColOK bool 304 if pgErr.Position != "" { 305 if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil { 306 line, col, lineColOK = computeLineFromPos(query, int(pos)) 307 } 308 } 309 message := fmt.Sprintf("migration failed: %s", pgErr.Message) 310 if lineColOK { 311 message = fmt.Sprintf("%s (column %d)", message, col) 312 } 313 if pgErr.Detail != "" { 314 message = fmt.Sprintf("%s, %s", message, pgErr.Detail) 315 } 316 return database.Error{OrigErr: err, Err: message, Query: statement, Line: line} 317 } 318 return database.Error{OrigErr: err, Err: "migration failed", Query: statement} 319 } 320 return nil 321 } 322 323 func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) { 324 // replace crlf with lf 325 s = strings.Replace(s, "\r\n", "\n", -1) 326 // pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes 327 runes := []rune(s) 328 if pos > len(runes) { 329 return 0, 0, false 330 } 331 sel := runes[:pos] 332 line = uint(runesCount(sel, newLine) + 1) 333 col = uint(pos - 1 - runesLastIndex(sel, newLine)) 334 return line, col, true 335 } 336 337 const newLine = '\n' 338 339 func runesCount(input []rune, target rune) int { 340 var count int 341 for _, r := range input { 342 if r == target { 343 count++ 344 } 345 } 346 return count 347 } 348 349 func runesLastIndex(input []rune, target rune) int { 350 for i := len(input) - 1; i >= 0; i-- { 351 if input[i] == target { 352 return i 353 } 354 } 355 return -1 356 } 357 358 func (p *Postgres) SetVersion(version int, dirty bool) error { 359 tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{}) 360 if err != nil { 361 return &database.Error{OrigErr: err, Err: "transaction start failed"} 362 } 363 364 query := `TRUNCATE ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName) 365 if _, err := tx.Exec(query); err != nil { 366 if errRollback := tx.Rollback(); errRollback != nil { 367 err = multierror.Append(err, errRollback) 368 } 369 return &database.Error{OrigErr: err, Query: []byte(query)} 370 } 371 372 // Also re-write the schema version for nil dirty versions to prevent 373 // empty schema version for failed down migration on the first migration 374 // See: https://github.com/seashell-org/golang-migrate/issues/330 375 if version >= 0 || (version == database.NilVersion && dirty) { 376 query = `INSERT INTO ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName) + ` (version, dirty) VALUES ($1, $2)` 377 if _, err := tx.Exec(query, version, dirty); err != nil { 378 if errRollback := tx.Rollback(); errRollback != nil { 379 err = multierror.Append(err, errRollback) 380 } 381 return &database.Error{OrigErr: err, Query: []byte(query)} 382 } 383 } 384 385 if err := tx.Commit(); err != nil { 386 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 387 } 388 389 return nil 390 } 391 392 func (p *Postgres) Version() (version int, dirty bool, err error) { 393 query := `SELECT version, dirty FROM ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName) + ` LIMIT 1` 394 err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 395 switch { 396 case err == sql.ErrNoRows: 397 return database.NilVersion, false, nil 398 399 case err != nil: 400 if e, ok := err.(*pq.Error); ok { 401 if e.Code.Name() == "undefined_table" { 402 return database.NilVersion, false, nil 403 } 404 } 405 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 406 407 default: 408 return version, dirty, nil 409 } 410 } 411 412 func (p *Postgres) Drop() (err error) { 413 // select all tables in current schema 414 query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'` 415 tables, err := p.conn.QueryContext(context.Background(), query) 416 if err != nil { 417 return &database.Error{OrigErr: err, Query: []byte(query)} 418 } 419 defer func() { 420 if errClose := tables.Close(); errClose != nil { 421 err = multierror.Append(err, errClose) 422 } 423 }() 424 425 // delete one table after another 426 tableNames := make([]string, 0) 427 for tables.Next() { 428 var tableName string 429 if err := tables.Scan(&tableName); err != nil { 430 return err 431 } 432 if len(tableName) > 0 { 433 tableNames = append(tableNames, tableName) 434 } 435 } 436 if err := tables.Err(); err != nil { 437 return &database.Error{OrigErr: err, Query: []byte(query)} 438 } 439 440 if len(tableNames) > 0 { 441 // delete one by one ... 442 for _, t := range tableNames { 443 query = `DROP TABLE IF EXISTS ` + pq.QuoteIdentifier(t) + ` CASCADE` 444 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 445 return &database.Error{OrigErr: err, Query: []byte(query)} 446 } 447 } 448 } 449 450 return nil 451 } 452 453 // ensureVersionTable checks if versions table exists and, if not, creates it. 454 // Note that this function locks the database, which deviates from the usual 455 // convention of "caller locks" in the Postgres type. 456 func (p *Postgres) ensureVersionTable() (err error) { 457 if err = p.Lock(); err != nil { 458 return err 459 } 460 461 defer func() { 462 if e := p.Unlock(); e != nil { 463 if err == nil { 464 err = e 465 } else { 466 err = multierror.Append(err, e) 467 } 468 } 469 }() 470 471 // This block checks whether the `MigrationsTable` already exists. This is useful because it allows read only postgres 472 // users to also check the current version of the schema. Previously, even if `MigrationsTable` existed, the 473 // `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission. 474 // Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258 475 query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2 LIMIT 1` 476 row := p.conn.QueryRowContext(context.Background(), query, p.config.migrationsSchemaName, p.config.migrationsTableName) 477 478 var count int 479 err = row.Scan(&count) 480 if err != nil { 481 return &database.Error{OrigErr: err, Query: []byte(query)} 482 } 483 484 if count == 1 { 485 return nil 486 } 487 488 query = `CREATE TABLE IF NOT EXISTS ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName) + ` (version bigint not null primary key, dirty boolean not null)` 489 if _, err = p.conn.ExecContext(context.Background(), query); err != nil { 490 return &database.Error{OrigErr: err, Query: []byte(query)} 491 } 492 493 return nil 494 }