github.com/scraniel/migrate@v0.0.0-20230320185700-339088f36cee/database/postgres/postgres.go (about) 1 //go:build go1.9 2 // +build go1.9 3 4 package postgres 5 6 import ( 7 "context" 8 "database/sql" 9 "fmt" 10 "io" 11 nurl "net/url" 12 "regexp" 13 "strconv" 14 "strings" 15 "time" 16 17 "go.uber.org/atomic" 18 19 "github.com/golang-migrate/migrate/v4" 20 "github.com/golang-migrate/migrate/v4/database" 21 "github.com/golang-migrate/migrate/v4/database/multistmt" 22 "github.com/hashicorp/go-multierror" 23 "github.com/lib/pq" 24 ) 25 26 func init() { 27 db := Postgres{} 28 database.Register("postgres", &db) 29 database.Register("postgresql", &db) 30 } 31 32 var ( 33 multiStmtDelimiter = []byte(";") 34 35 DefaultMigrationsTable = "schema_migrations" 36 DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB 37 ) 38 39 var ( 40 ErrNilConfig = fmt.Errorf("no config") 41 ErrNoDatabaseName = fmt.Errorf("no database name") 42 ErrNoSchema = fmt.Errorf("no schema") 43 ErrDatabaseDirty = fmt.Errorf("database is dirty") 44 ) 45 46 type Config struct { 47 MigrationsTable string 48 MigrationsTableQuoted bool 49 MultiStatementEnabled bool 50 DatabaseName string 51 SchemaName string 52 migrationsSchemaName string 53 migrationsTableName string 54 StatementTimeout time.Duration 55 MultiStatementMaxSize int 56 } 57 58 type Postgres struct { 59 // Locking and unlocking need to use the same connection 60 conn *sql.Conn 61 db *sql.DB 62 isLocked atomic.Bool 63 64 // Open and WithInstance need to guarantee that config is never nil 65 config *Config 66 } 67 68 func WithConnection(ctx context.Context, conn *sql.Conn, config *Config) (*Postgres, error) { 69 if config == nil { 70 return nil, ErrNilConfig 71 } 72 73 if err := conn.PingContext(ctx); err != nil { 74 return nil, err 75 } 76 77 if config.DatabaseName == "" { 78 query := `SELECT CURRENT_DATABASE()` 79 var databaseName string 80 if err := conn.QueryRowContext(ctx, query).Scan(&databaseName); err != nil { 81 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 82 } 83 84 if len(databaseName) == 0 { 85 return nil, ErrNoDatabaseName 86 } 87 88 config.DatabaseName = databaseName 89 } 90 91 if config.SchemaName == "" { 92 query := `SELECT CURRENT_SCHEMA()` 93 var schemaName string 94 if err := conn.QueryRowContext(ctx, query).Scan(&schemaName); err != nil { 95 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 96 } 97 98 if len(schemaName) == 0 { 99 return nil, ErrNoSchema 100 } 101 102 config.SchemaName = schemaName 103 } 104 105 if len(config.MigrationsTable) == 0 { 106 config.MigrationsTable = DefaultMigrationsTable 107 } 108 109 config.migrationsSchemaName = config.SchemaName 110 config.migrationsTableName = config.MigrationsTable 111 if config.MigrationsTableQuoted { 112 re := regexp.MustCompile(`"(.*?)"`) 113 result := re.FindAllStringSubmatch(config.MigrationsTable, -1) 114 config.migrationsTableName = result[len(result)-1][1] 115 if len(result) == 2 { 116 config.migrationsSchemaName = result[0][1] 117 } else if len(result) > 2 { 118 return nil, fmt.Errorf("\"%s\" MigrationsTable contains too many dot characters", config.MigrationsTable) 119 } 120 } 121 122 px := &Postgres{ 123 conn: conn, 124 config: config, 125 } 126 127 if err := px.ensureVersionTable(); err != nil { 128 return nil, err 129 } 130 131 return px, nil 132 } 133 134 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 135 ctx := context.Background() 136 137 if err := instance.Ping(); err != nil { 138 return nil, err 139 } 140 141 conn, err := instance.Conn(ctx) 142 if err != nil { 143 return nil, err 144 } 145 146 px, err := WithConnection(ctx, conn, config) 147 if err != nil { 148 return nil, err 149 } 150 px.db = instance 151 return px, nil 152 } 153 154 func (p *Postgres) Open(url string) (database.Driver, error) { 155 purl, err := nurl.Parse(url) 156 if err != nil { 157 return nil, err 158 } 159 160 db, err := sql.Open("postgres", migrate.FilterCustomQuery(purl).String()) 161 if err != nil { 162 return nil, err 163 } 164 165 migrationsTable := purl.Query().Get("x-migrations-table") 166 migrationsTableQuoted := false 167 if s := purl.Query().Get("x-migrations-table-quoted"); len(s) > 0 { 168 migrationsTableQuoted, err = strconv.ParseBool(s) 169 if err != nil { 170 return nil, fmt.Errorf("Unable to parse option x-migrations-table-quoted: %w", err) 171 } 172 } 173 if (len(migrationsTable) > 0) && (migrationsTableQuoted) && ((migrationsTable[0] != '"') || (migrationsTable[len(migrationsTable)-1] != '"')) { 174 return nil, fmt.Errorf("x-migrations-table must be quoted (for instance '\"migrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: %s", migrationsTable) 175 } 176 177 statementTimeoutString := purl.Query().Get("x-statement-timeout") 178 statementTimeout := 0 179 if statementTimeoutString != "" { 180 statementTimeout, err = strconv.Atoi(statementTimeoutString) 181 if err != nil { 182 return nil, err 183 } 184 } 185 186 multiStatementMaxSize := DefaultMultiStatementMaxSize 187 if s := purl.Query().Get("x-multi-statement-max-size"); len(s) > 0 { 188 multiStatementMaxSize, err = strconv.Atoi(s) 189 if err != nil { 190 return nil, err 191 } 192 if multiStatementMaxSize <= 0 { 193 multiStatementMaxSize = DefaultMultiStatementMaxSize 194 } 195 } 196 197 multiStatementEnabled := false 198 if s := purl.Query().Get("x-multi-statement"); len(s) > 0 { 199 multiStatementEnabled, err = strconv.ParseBool(s) 200 if err != nil { 201 return nil, fmt.Errorf("Unable to parse option x-multi-statement: %w", err) 202 } 203 } 204 205 px, err := WithInstance(db, &Config{ 206 DatabaseName: purl.Path, 207 MigrationsTable: migrationsTable, 208 MigrationsTableQuoted: migrationsTableQuoted, 209 StatementTimeout: time.Duration(statementTimeout) * time.Millisecond, 210 MultiStatementEnabled: multiStatementEnabled, 211 MultiStatementMaxSize: multiStatementMaxSize, 212 }) 213 214 if err != nil { 215 return nil, err 216 } 217 218 return px, nil 219 } 220 221 func (p *Postgres) Close() error { 222 connErr := p.conn.Close() 223 var dbErr error 224 if p.db != nil { 225 dbErr = p.db.Close() 226 } 227 228 if connErr != nil || dbErr != nil { 229 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 230 } 231 return nil 232 } 233 234 // https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS 235 func (p *Postgres) Lock() error { 236 return database.CasRestoreOnErr(&p.isLocked, false, true, database.ErrLocked, func() error { 237 aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName) 238 if err != nil { 239 return err 240 } 241 242 // This will wait indefinitely until the lock can be acquired. 243 query := `SELECT pg_advisory_lock($1)` 244 if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { 245 return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} 246 } 247 248 return nil 249 }) 250 } 251 252 func (p *Postgres) Unlock() error { 253 return database.CasRestoreOnErr(&p.isLocked, true, false, database.ErrNotLocked, func() error { 254 aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName) 255 if err != nil { 256 return err 257 } 258 259 query := `SELECT pg_advisory_unlock($1)` 260 if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { 261 return &database.Error{OrigErr: err, Query: []byte(query)} 262 } 263 return nil 264 }) 265 } 266 267 func (p *Postgres) Run(migration io.Reader) error { 268 if p.config.MultiStatementEnabled { 269 var err error 270 if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool { 271 if err = p.runStatement(m); err != nil { 272 return false 273 } 274 return true 275 }); e != nil { 276 return e 277 } 278 return err 279 } 280 migr, err := io.ReadAll(migration) 281 if err != nil { 282 return err 283 } 284 return p.runStatement(migr) 285 } 286 287 func (p *Postgres) runStatement(statement []byte) error { 288 ctx := context.Background() 289 if p.config.StatementTimeout != 0 { 290 var cancel context.CancelFunc 291 ctx, cancel = context.WithTimeout(ctx, p.config.StatementTimeout) 292 defer cancel() 293 } 294 query := string(statement) 295 if strings.TrimSpace(query) == "" { 296 return nil 297 } 298 if _, err := p.conn.ExecContext(ctx, query); err != nil { 299 if pgErr, ok := err.(*pq.Error); ok { 300 var line uint 301 var col uint 302 var lineColOK bool 303 if pgErr.Position != "" { 304 if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil { 305 line, col, lineColOK = computeLineFromPos(query, int(pos)) 306 } 307 } 308 message := fmt.Sprintf("migration failed: %s", pgErr.Message) 309 if lineColOK { 310 message = fmt.Sprintf("%s (column %d)", message, col) 311 } 312 if pgErr.Detail != "" { 313 message = fmt.Sprintf("%s, %s", message, pgErr.Detail) 314 } 315 return database.Error{OrigErr: err, Err: message, Query: statement, Line: line} 316 } 317 return database.Error{OrigErr: err, Err: "migration failed", Query: statement} 318 } 319 return nil 320 } 321 322 func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) { 323 // replace crlf with lf 324 s = strings.Replace(s, "\r\n", "\n", -1) 325 // pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes 326 runes := []rune(s) 327 if pos > len(runes) { 328 return 0, 0, false 329 } 330 sel := runes[:pos] 331 line = uint(runesCount(sel, newLine) + 1) 332 col = uint(pos - 1 - runesLastIndex(sel, newLine)) 333 return line, col, true 334 } 335 336 const newLine = '\n' 337 338 func runesCount(input []rune, target rune) int { 339 var count int 340 for _, r := range input { 341 if r == target { 342 count++ 343 } 344 } 345 return count 346 } 347 348 func runesLastIndex(input []rune, target rune) int { 349 for i := len(input) - 1; i >= 0; i-- { 350 if input[i] == target { 351 return i 352 } 353 } 354 return -1 355 } 356 357 func (p *Postgres) SetVersion(version int, dirty bool) error { 358 tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{}) 359 if err != nil { 360 return &database.Error{OrigErr: err, Err: "transaction start failed"} 361 } 362 363 query := `TRUNCATE ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName) 364 if _, err := tx.Exec(query); err != nil { 365 if errRollback := tx.Rollback(); errRollback != nil { 366 err = multierror.Append(err, errRollback) 367 } 368 return &database.Error{OrigErr: err, Query: []byte(query)} 369 } 370 371 // Also re-write the schema version for nil dirty versions to prevent 372 // empty schema version for failed down migration on the first migration 373 // See: https://github.com/golang-migrate/migrate/issues/330 374 if version >= 0 || (version == database.NilVersion && dirty) { 375 query = `INSERT INTO ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName) + ` (version, dirty) VALUES ($1, $2)` 376 if _, err := tx.Exec(query, version, dirty); err != nil { 377 if errRollback := tx.Rollback(); errRollback != nil { 378 err = multierror.Append(err, errRollback) 379 } 380 return &database.Error{OrigErr: err, Query: []byte(query)} 381 } 382 } 383 384 if err := tx.Commit(); err != nil { 385 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 386 } 387 388 return nil 389 } 390 391 func (p *Postgres) Version() (version int, dirty bool, err error) { 392 query := `SELECT version, dirty FROM ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName) + ` LIMIT 1` 393 err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 394 switch { 395 case err == sql.ErrNoRows: 396 return database.NilVersion, false, nil 397 398 case err != nil: 399 if e, ok := err.(*pq.Error); ok { 400 if e.Code.Name() == "undefined_table" { 401 return database.NilVersion, false, nil 402 } 403 } 404 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 405 406 default: 407 return version, dirty, nil 408 } 409 } 410 411 func (p *Postgres) Drop() (err error) { 412 // select all tables in current schema 413 query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'` 414 tables, err := p.conn.QueryContext(context.Background(), query) 415 if err != nil { 416 return &database.Error{OrigErr: err, Query: []byte(query)} 417 } 418 defer func() { 419 if errClose := tables.Close(); errClose != nil { 420 err = multierror.Append(err, errClose) 421 } 422 }() 423 424 // delete one table after another 425 tableNames := make([]string, 0) 426 for tables.Next() { 427 var tableName string 428 if err := tables.Scan(&tableName); err != nil { 429 return err 430 } 431 if len(tableName) > 0 { 432 tableNames = append(tableNames, tableName) 433 } 434 } 435 if err := tables.Err(); err != nil { 436 return &database.Error{OrigErr: err, Query: []byte(query)} 437 } 438 439 if len(tableNames) > 0 { 440 // delete one by one ... 441 for _, t := range tableNames { 442 query = `DROP TABLE IF EXISTS ` + pq.QuoteIdentifier(t) + ` CASCADE` 443 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 444 return &database.Error{OrigErr: err, Query: []byte(query)} 445 } 446 } 447 } 448 449 return nil 450 } 451 452 // ensureVersionTable checks if versions table exists and, if not, creates it. 453 // Note that this function locks the database, which deviates from the usual 454 // convention of "caller locks" in the Postgres type. 455 func (p *Postgres) ensureVersionTable() (err error) { 456 if err = p.Lock(); err != nil { 457 return err 458 } 459 460 defer func() { 461 if e := p.Unlock(); e != nil { 462 if err == nil { 463 err = e 464 } else { 465 err = multierror.Append(err, e) 466 } 467 } 468 }() 469 470 // This block checks whether the `MigrationsTable` already exists. This is useful because it allows read only postgres 471 // users to also check the current version of the schema. Previously, even if `MigrationsTable` existed, the 472 // `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission. 473 // Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258 474 query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2 LIMIT 1` 475 row := p.conn.QueryRowContext(context.Background(), query, p.config.migrationsSchemaName, p.config.migrationsTableName) 476 477 var count int 478 err = row.Scan(&count) 479 if err != nil { 480 return &database.Error{OrigErr: err, Query: []byte(query)} 481 } 482 483 if count == 1 { 484 return nil 485 } 486 487 query = `CREATE TABLE IF NOT EXISTS ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName) + ` (version bigint not null primary key, dirty boolean not null)` 488 if _, err = p.conn.ExecContext(context.Background(), query); err != nil { 489 return &database.Error{OrigErr: err, Query: []byte(query)} 490 } 491 492 return nil 493 }