github.com/scraniel/migrate@v0.0.0-20230320185700-339088f36cee/database/pgx/pgx.go (about) 1 //go:build go1.9 2 // +build go1.9 3 4 package pgx 5 6 import ( 7 "context" 8 "database/sql" 9 "fmt" 10 "io" 11 nurl "net/url" 12 "regexp" 13 "strconv" 14 "strings" 15 "time" 16 17 "go.uber.org/atomic" 18 19 "github.com/golang-migrate/migrate/v4" 20 "github.com/golang-migrate/migrate/v4/database" 21 "github.com/golang-migrate/migrate/v4/database/multistmt" 22 "github.com/hashicorp/go-multierror" 23 "github.com/jackc/pgconn" 24 "github.com/jackc/pgerrcode" 25 _ "github.com/jackc/pgx/v4/stdlib" 26 ) 27 28 func init() { 29 db := Postgres{} 30 database.Register("pgx", &db) 31 } 32 33 var ( 34 multiStmtDelimiter = []byte(";") 35 36 DefaultMigrationsTable = "schema_migrations" 37 DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB 38 ) 39 40 var ( 41 ErrNilConfig = fmt.Errorf("no config") 42 ErrNoDatabaseName = fmt.Errorf("no database name") 43 ErrNoSchema = fmt.Errorf("no schema") 44 ErrDatabaseDirty = fmt.Errorf("database is dirty") 45 ) 46 47 type Config struct { 48 MigrationsTable string 49 DatabaseName string 50 SchemaName string 51 migrationsSchemaName string 52 migrationsTableName string 53 StatementTimeout time.Duration 54 MigrationsTableQuoted bool 55 MultiStatementEnabled bool 56 MultiStatementMaxSize int 57 } 58 59 type Postgres struct { 60 // Locking and unlocking need to use the same connection 61 conn *sql.Conn 62 db *sql.DB 63 isLocked atomic.Bool 64 65 // Open and WithInstance need to guarantee that config is never nil 66 config *Config 67 } 68 69 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 70 if config == nil { 71 return nil, ErrNilConfig 72 } 73 74 if err := instance.Ping(); err != nil { 75 return nil, err 76 } 77 78 if config.DatabaseName == "" { 79 query := `SELECT CURRENT_DATABASE()` 80 var databaseName string 81 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 82 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 83 } 84 85 if len(databaseName) == 0 { 86 return nil, ErrNoDatabaseName 87 } 88 89 config.DatabaseName = databaseName 90 } 91 92 if config.SchemaName == "" { 93 query := `SELECT CURRENT_SCHEMA()` 94 var schemaName string 95 if err := instance.QueryRow(query).Scan(&schemaName); err != nil { 96 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 97 } 98 99 if len(schemaName) == 0 { 100 return nil, ErrNoSchema 101 } 102 103 config.SchemaName = schemaName 104 } 105 106 if len(config.MigrationsTable) == 0 { 107 config.MigrationsTable = DefaultMigrationsTable 108 } 109 110 config.migrationsSchemaName = config.SchemaName 111 config.migrationsTableName = config.MigrationsTable 112 if config.MigrationsTableQuoted { 113 re := regexp.MustCompile(`"(.*?)"`) 114 result := re.FindAllStringSubmatch(config.MigrationsTable, -1) 115 config.migrationsTableName = result[len(result)-1][1] 116 if len(result) == 2 { 117 config.migrationsSchemaName = result[0][1] 118 } else if len(result) > 2 { 119 return nil, fmt.Errorf("\"%s\" MigrationsTable contains too many dot characters", config.MigrationsTable) 120 } 121 } 122 123 conn, err := instance.Conn(context.Background()) 124 125 if err != nil { 126 return nil, err 127 } 128 129 px := &Postgres{ 130 conn: conn, 131 db: instance, 132 config: config, 133 } 134 135 if err := px.ensureVersionTable(); err != nil { 136 return nil, err 137 } 138 139 return px, nil 140 } 141 142 func (p *Postgres) Open(url string) (database.Driver, error) { 143 purl, err := nurl.Parse(url) 144 if err != nil { 145 return nil, err 146 } 147 148 // Driver is registered as pgx, but connection string must use postgres schema 149 // when making actual connection 150 // i.e. pgx://user:password@host:port/db => postgres://user:password@host:port/db 151 purl.Scheme = "postgres" 152 153 db, err := sql.Open("pgx", migrate.FilterCustomQuery(purl).String()) 154 if err != nil { 155 return nil, err 156 } 157 158 migrationsTable := purl.Query().Get("x-migrations-table") 159 migrationsTableQuoted := false 160 if s := purl.Query().Get("x-migrations-table-quoted"); len(s) > 0 { 161 migrationsTableQuoted, err = strconv.ParseBool(s) 162 if err != nil { 163 return nil, fmt.Errorf("Unable to parse option x-migrations-table-quoted: %w", err) 164 } 165 } 166 if (len(migrationsTable) > 0) && (migrationsTableQuoted) && ((migrationsTable[0] != '"') || (migrationsTable[len(migrationsTable)-1] != '"')) { 167 return nil, fmt.Errorf("x-migrations-table must be quoted (for instance '\"migrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: %s", migrationsTable) 168 } 169 170 statementTimeoutString := purl.Query().Get("x-statement-timeout") 171 statementTimeout := 0 172 if statementTimeoutString != "" { 173 statementTimeout, err = strconv.Atoi(statementTimeoutString) 174 if err != nil { 175 return nil, err 176 } 177 } 178 179 multiStatementMaxSize := DefaultMultiStatementMaxSize 180 if s := purl.Query().Get("x-multi-statement-max-size"); len(s) > 0 { 181 multiStatementMaxSize, err = strconv.Atoi(s) 182 if err != nil { 183 return nil, err 184 } 185 if multiStatementMaxSize <= 0 { 186 multiStatementMaxSize = DefaultMultiStatementMaxSize 187 } 188 } 189 190 multiStatementEnabled := false 191 if s := purl.Query().Get("x-multi-statement"); len(s) > 0 { 192 multiStatementEnabled, err = strconv.ParseBool(s) 193 if err != nil { 194 return nil, fmt.Errorf("Unable to parse option x-multi-statement: %w", err) 195 } 196 } 197 198 px, err := WithInstance(db, &Config{ 199 DatabaseName: purl.Path, 200 MigrationsTable: migrationsTable, 201 MigrationsTableQuoted: migrationsTableQuoted, 202 StatementTimeout: time.Duration(statementTimeout) * time.Millisecond, 203 MultiStatementEnabled: multiStatementEnabled, 204 MultiStatementMaxSize: multiStatementMaxSize, 205 }) 206 207 if err != nil { 208 return nil, err 209 } 210 211 return px, nil 212 } 213 214 func (p *Postgres) Close() error { 215 connErr := p.conn.Close() 216 dbErr := p.db.Close() 217 if connErr != nil || dbErr != nil { 218 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 219 } 220 return nil 221 } 222 223 // https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS 224 func (p *Postgres) Lock() error { 225 return database.CasRestoreOnErr(&p.isLocked, false, true, database.ErrLocked, func() error { 226 aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName) 227 if err != nil { 228 return err 229 } 230 231 // This will wait indefinitely until the lock can be acquired. 232 query := `SELECT pg_advisory_lock($1)` 233 if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { 234 return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} 235 } 236 return nil 237 }) 238 } 239 240 func (p *Postgres) Unlock() error { 241 return database.CasRestoreOnErr(&p.isLocked, true, false, database.ErrNotLocked, func() error { 242 aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName) 243 if err != nil { 244 return err 245 } 246 247 query := `SELECT pg_advisory_unlock($1)` 248 if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { 249 return &database.Error{OrigErr: err, Query: []byte(query)} 250 } 251 return nil 252 }) 253 } 254 255 func (p *Postgres) Run(migration io.Reader) error { 256 if p.config.MultiStatementEnabled { 257 var err error 258 if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool { 259 if err = p.runStatement(m); err != nil { 260 return false 261 } 262 return true 263 }); e != nil { 264 return e 265 } 266 return err 267 } 268 migr, err := io.ReadAll(migration) 269 if err != nil { 270 return err 271 } 272 return p.runStatement(migr) 273 } 274 275 func (p *Postgres) runStatement(statement []byte) error { 276 ctx := context.Background() 277 if p.config.StatementTimeout != 0 { 278 var cancel context.CancelFunc 279 ctx, cancel = context.WithTimeout(ctx, p.config.StatementTimeout) 280 defer cancel() 281 } 282 query := string(statement) 283 if strings.TrimSpace(query) == "" { 284 return nil 285 } 286 if _, err := p.conn.ExecContext(ctx, query); err != nil { 287 288 if pgErr, ok := err.(*pgconn.PgError); ok { 289 var line uint 290 var col uint 291 var lineColOK bool 292 line, col, lineColOK = computeLineFromPos(query, int(pgErr.Position)) 293 message := fmt.Sprintf("migration failed: %s", pgErr.Message) 294 if lineColOK { 295 message = fmt.Sprintf("%s (column %d)", message, col) 296 } 297 if pgErr.Detail != "" { 298 message = fmt.Sprintf("%s, %s", message, pgErr.Detail) 299 } 300 return database.Error{OrigErr: err, Err: message, Query: statement, Line: line} 301 } 302 return database.Error{OrigErr: err, Err: "migration failed", Query: statement} 303 } 304 return nil 305 } 306 307 func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) { 308 // replace crlf with lf 309 s = strings.Replace(s, "\r\n", "\n", -1) 310 // pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes 311 runes := []rune(s) 312 if pos > len(runes) { 313 return 0, 0, false 314 } 315 sel := runes[:pos] 316 line = uint(runesCount(sel, newLine) + 1) 317 col = uint(pos - 1 - runesLastIndex(sel, newLine)) 318 return line, col, true 319 } 320 321 const newLine = '\n' 322 323 func runesCount(input []rune, target rune) int { 324 var count int 325 for _, r := range input { 326 if r == target { 327 count++ 328 } 329 } 330 return count 331 } 332 333 func runesLastIndex(input []rune, target rune) int { 334 for i := len(input) - 1; i >= 0; i-- { 335 if input[i] == target { 336 return i 337 } 338 } 339 return -1 340 } 341 342 func (p *Postgres) SetVersion(version int, dirty bool) error { 343 tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{}) 344 if err != nil { 345 return &database.Error{OrigErr: err, Err: "transaction start failed"} 346 } 347 348 query := `TRUNCATE ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) 349 if _, err := tx.Exec(query); err != nil { 350 if errRollback := tx.Rollback(); errRollback != nil { 351 err = multierror.Append(err, errRollback) 352 } 353 return &database.Error{OrigErr: err, Query: []byte(query)} 354 } 355 356 // Also re-write the schema version for nil dirty versions to prevent 357 // empty schema version for failed down migration on the first migration 358 // See: https://github.com/golang-migrate/migrate/issues/330 359 if version >= 0 || (version == database.NilVersion && dirty) { 360 query = `INSERT INTO ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version, dirty) VALUES ($1, $2)` 361 if _, err := tx.Exec(query, version, dirty); err != nil { 362 if errRollback := tx.Rollback(); errRollback != nil { 363 err = multierror.Append(err, errRollback) 364 } 365 return &database.Error{OrigErr: err, Query: []byte(query)} 366 } 367 } 368 369 if err := tx.Commit(); err != nil { 370 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 371 } 372 373 return nil 374 } 375 376 func (p *Postgres) Version() (version int, dirty bool, err error) { 377 query := `SELECT version, dirty FROM ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` LIMIT 1` 378 err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 379 switch { 380 case err == sql.ErrNoRows: 381 return database.NilVersion, false, nil 382 383 case err != nil: 384 if e, ok := err.(*pgconn.PgError); ok { 385 if e.SQLState() == pgerrcode.UndefinedTable { 386 return database.NilVersion, false, nil 387 } 388 } 389 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 390 391 default: 392 return version, dirty, nil 393 } 394 } 395 396 func (p *Postgres) Drop() (err error) { 397 // select all tables in current schema 398 query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'` 399 tables, err := p.conn.QueryContext(context.Background(), query) 400 if err != nil { 401 return &database.Error{OrigErr: err, Query: []byte(query)} 402 } 403 defer func() { 404 if errClose := tables.Close(); errClose != nil { 405 err = multierror.Append(err, errClose) 406 } 407 }() 408 409 // delete one table after another 410 tableNames := make([]string, 0) 411 for tables.Next() { 412 var tableName string 413 if err := tables.Scan(&tableName); err != nil { 414 return err 415 } 416 if len(tableName) > 0 { 417 tableNames = append(tableNames, tableName) 418 } 419 } 420 if err := tables.Err(); err != nil { 421 return &database.Error{OrigErr: err, Query: []byte(query)} 422 } 423 424 if len(tableNames) > 0 { 425 // delete one by one ... 426 for _, t := range tableNames { 427 query = `DROP TABLE IF EXISTS ` + quoteIdentifier(t) + ` CASCADE` 428 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 429 return &database.Error{OrigErr: err, Query: []byte(query)} 430 } 431 } 432 } 433 434 return nil 435 } 436 437 // ensureVersionTable checks if versions table exists and, if not, creates it. 438 // Note that this function locks the database, which deviates from the usual 439 // convention of "caller locks" in the Postgres type. 440 func (p *Postgres) ensureVersionTable() (err error) { 441 if err = p.Lock(); err != nil { 442 return err 443 } 444 445 defer func() { 446 if e := p.Unlock(); e != nil { 447 if err == nil { 448 err = e 449 } else { 450 err = multierror.Append(err, e) 451 } 452 } 453 }() 454 455 // This block checks whether the `MigrationsTable` already exists. This is useful because it allows read only postgres 456 // users to also check the current version of the schema. Previously, even if `MigrationsTable` existed, the 457 // `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission. 458 // Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258 459 query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2 LIMIT 1` 460 row := p.conn.QueryRowContext(context.Background(), query, p.config.migrationsSchemaName, p.config.migrationsTableName) 461 462 var count int 463 err = row.Scan(&count) 464 if err != nil { 465 return &database.Error{OrigErr: err, Query: []byte(query)} 466 } 467 468 if count == 1 { 469 return nil 470 } 471 472 query = `CREATE TABLE IF NOT EXISTS ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version bigint not null primary key, dirty boolean not null)` 473 if _, err = p.conn.ExecContext(context.Background(), query); err != nil { 474 return &database.Error{OrigErr: err, Query: []byte(query)} 475 } 476 477 return nil 478 } 479 480 // Copied from lib/pq implementation: https://github.com/lib/pq/blob/v1.9.0/conn.go#L1611 481 func quoteIdentifier(name string) string { 482 end := strings.IndexRune(name, 0) 483 if end > -1 { 484 name = name[:end] 485 } 486 return `"` + strings.Replace(name, `"`, `""`, -1) + `"` 487 }