github.com/bdollma-te/migrate/v4@v4.17.0-clickv2/database/pgx/v5/pgx.go (about) 1 //go:build go1.9 2 // +build go1.9 3 4 package pgx 5 6 import ( 7 "context" 8 "database/sql" 9 "fmt" 10 "io" 11 nurl "net/url" 12 "regexp" 13 "strconv" 14 "strings" 15 "time" 16 17 "go.uber.org/atomic" 18 19 "github.com/bdollma-te/migrate/v4" 20 "github.com/bdollma-te/migrate/v4/database" 21 "github.com/bdollma-te/migrate/v4/database/multistmt" 22 "github.com/hashicorp/go-multierror" 23 "github.com/jackc/pgerrcode" 24 "github.com/jackc/pgx/v5/pgconn" 25 _ "github.com/jackc/pgx/v5/stdlib" 26 ) 27 28 func init() { 29 db := Postgres{} 30 database.Register("pgx5", &db) 31 } 32 33 var ( 34 multiStmtDelimiter = []byte(";") 35 36 DefaultMigrationsTable = "schema_migrations" 37 DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB 38 ) 39 40 var ( 41 ErrNilConfig = fmt.Errorf("no config") 42 ErrNoDatabaseName = fmt.Errorf("no database name") 43 ErrNoSchema = fmt.Errorf("no schema") 44 ) 45 46 type Config struct { 47 MigrationsTable string 48 DatabaseName string 49 SchemaName string 50 migrationsSchemaName string 51 migrationsTableName string 52 StatementTimeout time.Duration 53 MigrationsTableQuoted bool 54 MultiStatementEnabled bool 55 MultiStatementMaxSize int 56 } 57 58 type Postgres struct { 59 // Locking and unlocking need to use the same connection 60 conn *sql.Conn 61 db *sql.DB 62 isLocked atomic.Bool 63 64 // Open and WithInstance need to guarantee that config is never nil 65 config *Config 66 } 67 68 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 69 if config == nil { 70 return nil, ErrNilConfig 71 } 72 73 if err := instance.Ping(); err != nil { 74 return nil, err 75 } 76 77 if config.DatabaseName == "" { 78 query := `SELECT CURRENT_DATABASE()` 79 var databaseName string 80 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 81 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 82 } 83 84 if len(databaseName) == 0 { 85 return nil, ErrNoDatabaseName 86 } 87 88 config.DatabaseName = databaseName 89 } 90 91 if config.SchemaName == "" { 92 query := `SELECT CURRENT_SCHEMA()` 93 var schemaName string 94 if err := instance.QueryRow(query).Scan(&schemaName); err != nil { 95 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 96 } 97 98 if len(schemaName) == 0 { 99 return nil, ErrNoSchema 100 } 101 102 config.SchemaName = schemaName 103 } 104 105 if len(config.MigrationsTable) == 0 { 106 config.MigrationsTable = DefaultMigrationsTable 107 } 108 109 config.migrationsSchemaName = config.SchemaName 110 config.migrationsTableName = config.MigrationsTable 111 if config.MigrationsTableQuoted { 112 re := regexp.MustCompile(`"(.*?)"`) 113 result := re.FindAllStringSubmatch(config.MigrationsTable, -1) 114 config.migrationsTableName = result[len(result)-1][1] 115 if len(result) == 2 { 116 config.migrationsSchemaName = result[0][1] 117 } else if len(result) > 2 { 118 return nil, fmt.Errorf("\"%s\" MigrationsTable contains too many dot characters", config.MigrationsTable) 119 } 120 } 121 122 conn, err := instance.Conn(context.Background()) 123 124 if err != nil { 125 return nil, err 126 } 127 128 px := &Postgres{ 129 conn: conn, 130 db: instance, 131 config: config, 132 } 133 134 if err := px.ensureVersionTable(); err != nil { 135 return nil, err 136 } 137 138 return px, nil 139 } 140 141 func (p *Postgres) Open(url string) (database.Driver, error) { 142 purl, err := nurl.Parse(url) 143 if err != nil { 144 return nil, err 145 } 146 147 // Driver is registered as pgx, but connection string must use postgres schema 148 // when making actual connection 149 // i.e. pgx://user:password@host:port/db => postgres://user:password@host:port/db 150 purl.Scheme = "postgres" 151 152 db, err := sql.Open("pgx/v5", migrate.FilterCustomQuery(purl).String()) 153 if err != nil { 154 return nil, err 155 } 156 157 migrationsTable := purl.Query().Get("x-migrations-table") 158 migrationsTableQuoted := false 159 if s := purl.Query().Get("x-migrations-table-quoted"); len(s) > 0 { 160 migrationsTableQuoted, err = strconv.ParseBool(s) 161 if err != nil { 162 return nil, fmt.Errorf("Unable to parse option x-migrations-table-quoted: %w", err) 163 } 164 } 165 if (len(migrationsTable) > 0) && (migrationsTableQuoted) && ((migrationsTable[0] != '"') || (migrationsTable[len(migrationsTable)-1] != '"')) { 166 return nil, fmt.Errorf("x-migrations-table must be quoted (for instance '\"migrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: %s", migrationsTable) 167 } 168 169 statementTimeoutString := purl.Query().Get("x-statement-timeout") 170 statementTimeout := 0 171 if statementTimeoutString != "" { 172 statementTimeout, err = strconv.Atoi(statementTimeoutString) 173 if err != nil { 174 return nil, err 175 } 176 } 177 178 multiStatementMaxSize := DefaultMultiStatementMaxSize 179 if s := purl.Query().Get("x-multi-statement-max-size"); len(s) > 0 { 180 multiStatementMaxSize, err = strconv.Atoi(s) 181 if err != nil { 182 return nil, err 183 } 184 if multiStatementMaxSize <= 0 { 185 multiStatementMaxSize = DefaultMultiStatementMaxSize 186 } 187 } 188 189 multiStatementEnabled := false 190 if s := purl.Query().Get("x-multi-statement"); len(s) > 0 { 191 multiStatementEnabled, err = strconv.ParseBool(s) 192 if err != nil { 193 return nil, fmt.Errorf("Unable to parse option x-multi-statement: %w", err) 194 } 195 } 196 197 px, err := WithInstance(db, &Config{ 198 DatabaseName: purl.Path, 199 MigrationsTable: migrationsTable, 200 MigrationsTableQuoted: migrationsTableQuoted, 201 StatementTimeout: time.Duration(statementTimeout) * time.Millisecond, 202 MultiStatementEnabled: multiStatementEnabled, 203 MultiStatementMaxSize: multiStatementMaxSize, 204 }) 205 206 if err != nil { 207 return nil, err 208 } 209 210 return px, nil 211 } 212 213 func (p *Postgres) Close() error { 214 connErr := p.conn.Close() 215 dbErr := p.db.Close() 216 if connErr != nil || dbErr != nil { 217 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 218 } 219 return nil 220 } 221 222 // https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS 223 func (p *Postgres) Lock() error { 224 return database.CasRestoreOnErr(&p.isLocked, false, true, database.ErrLocked, func() error { 225 aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName) 226 if err != nil { 227 return err 228 } 229 230 // This will wait indefinitely until the lock can be acquired. 231 query := `SELECT pg_advisory_lock($1)` 232 if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { 233 return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} 234 } 235 return nil 236 }) 237 } 238 239 func (p *Postgres) Unlock() error { 240 return database.CasRestoreOnErr(&p.isLocked, true, false, database.ErrNotLocked, func() error { 241 aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName) 242 if err != nil { 243 return err 244 } 245 246 query := `SELECT pg_advisory_unlock($1)` 247 if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { 248 return &database.Error{OrigErr: err, Query: []byte(query)} 249 } 250 return nil 251 }) 252 } 253 254 func (p *Postgres) Run(migration io.Reader) error { 255 if p.config.MultiStatementEnabled { 256 var err error 257 if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool { 258 if err = p.runStatement(m); err != nil { 259 return false 260 } 261 return true 262 }); e != nil { 263 return e 264 } 265 return err 266 } 267 migr, err := io.ReadAll(migration) 268 if err != nil { 269 return err 270 } 271 return p.runStatement(migr) 272 } 273 274 func (p *Postgres) runStatement(statement []byte) error { 275 ctx := context.Background() 276 if p.config.StatementTimeout != 0 { 277 var cancel context.CancelFunc 278 ctx, cancel = context.WithTimeout(ctx, p.config.StatementTimeout) 279 defer cancel() 280 } 281 query := string(statement) 282 if strings.TrimSpace(query) == "" { 283 return nil 284 } 285 if _, err := p.conn.ExecContext(ctx, query); err != nil { 286 287 if pgErr, ok := err.(*pgconn.PgError); ok { 288 var line uint 289 var col uint 290 var lineColOK bool 291 line, col, lineColOK = computeLineFromPos(query, int(pgErr.Position)) 292 message := fmt.Sprintf("migration failed: %s", pgErr.Message) 293 if lineColOK { 294 message = fmt.Sprintf("%s (column %d)", message, col) 295 } 296 if pgErr.Detail != "" { 297 message = fmt.Sprintf("%s, %s", message, pgErr.Detail) 298 } 299 return database.Error{OrigErr: err, Err: message, Query: statement, Line: line} 300 } 301 return database.Error{OrigErr: err, Err: "migration failed", Query: statement} 302 } 303 return nil 304 } 305 306 func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) { 307 // replace crlf with lf 308 s = strings.Replace(s, "\r\n", "\n", -1) 309 // pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes 310 runes := []rune(s) 311 if pos > len(runes) { 312 return 0, 0, false 313 } 314 sel := runes[:pos] 315 line = uint(runesCount(sel, newLine) + 1) 316 col = uint(pos - 1 - runesLastIndex(sel, newLine)) 317 return line, col, true 318 } 319 320 const newLine = '\n' 321 322 func runesCount(input []rune, target rune) int { 323 var count int 324 for _, r := range input { 325 if r == target { 326 count++ 327 } 328 } 329 return count 330 } 331 332 func runesLastIndex(input []rune, target rune) int { 333 for i := len(input) - 1; i >= 0; i-- { 334 if input[i] == target { 335 return i 336 } 337 } 338 return -1 339 } 340 341 func (p *Postgres) SetVersion(version int, dirty bool) error { 342 tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{}) 343 if err != nil { 344 return &database.Error{OrigErr: err, Err: "transaction start failed"} 345 } 346 347 query := `TRUNCATE ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) 348 if _, err := tx.Exec(query); err != nil { 349 if errRollback := tx.Rollback(); errRollback != nil { 350 err = multierror.Append(err, errRollback) 351 } 352 return &database.Error{OrigErr: err, Query: []byte(query)} 353 } 354 355 // Also re-write the schema version for nil dirty versions to prevent 356 // empty schema version for failed down migration on the first migration 357 // See: https://github.com/golang-migrate/migrate/issues/330 358 if version >= 0 || (version == database.NilVersion && dirty) { 359 query = `INSERT INTO ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version, dirty) VALUES ($1, $2)` 360 if _, err := tx.Exec(query, version, dirty); err != nil { 361 if errRollback := tx.Rollback(); errRollback != nil { 362 err = multierror.Append(err, errRollback) 363 } 364 return &database.Error{OrigErr: err, Query: []byte(query)} 365 } 366 } 367 368 if err := tx.Commit(); err != nil { 369 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 370 } 371 372 return nil 373 } 374 375 func (p *Postgres) Version() (version int, dirty bool, err error) { 376 query := `SELECT version, dirty FROM ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` LIMIT 1` 377 err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 378 switch { 379 case err == sql.ErrNoRows: 380 return database.NilVersion, false, nil 381 382 case err != nil: 383 if e, ok := err.(*pgconn.PgError); ok { 384 if e.SQLState() == pgerrcode.UndefinedTable { 385 return database.NilVersion, false, nil 386 } 387 } 388 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 389 390 default: 391 return version, dirty, nil 392 } 393 } 394 395 func (p *Postgres) Drop() (err error) { 396 // select all tables in current schema 397 query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'` 398 tables, err := p.conn.QueryContext(context.Background(), query) 399 if err != nil { 400 return &database.Error{OrigErr: err, Query: []byte(query)} 401 } 402 defer func() { 403 if errClose := tables.Close(); errClose != nil { 404 err = multierror.Append(err, errClose) 405 } 406 }() 407 408 // delete one table after another 409 tableNames := make([]string, 0) 410 for tables.Next() { 411 var tableName string 412 if err := tables.Scan(&tableName); err != nil { 413 return err 414 } 415 if len(tableName) > 0 { 416 tableNames = append(tableNames, tableName) 417 } 418 } 419 if err := tables.Err(); err != nil { 420 return &database.Error{OrigErr: err, Query: []byte(query)} 421 } 422 423 if len(tableNames) > 0 { 424 // delete one by one ... 425 for _, t := range tableNames { 426 query = `DROP TABLE IF EXISTS ` + quoteIdentifier(t) + ` CASCADE` 427 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 428 return &database.Error{OrigErr: err, Query: []byte(query)} 429 } 430 } 431 } 432 433 return nil 434 } 435 436 // ensureVersionTable checks if versions table exists and, if not, creates it. 437 // Note that this function locks the database, which deviates from the usual 438 // convention of "caller locks" in the Postgres type. 439 func (p *Postgres) ensureVersionTable() (err error) { 440 if err = p.Lock(); err != nil { 441 return err 442 } 443 444 defer func() { 445 if e := p.Unlock(); e != nil { 446 if err == nil { 447 err = e 448 } else { 449 err = multierror.Append(err, e) 450 } 451 } 452 }() 453 454 // This block checks whether the `MigrationsTable` already exists. This is useful because it allows read only postgres 455 // users to also check the current version of the schema. Previously, even if `MigrationsTable` existed, the 456 // `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission. 457 // Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258 458 query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2 LIMIT 1` 459 row := p.conn.QueryRowContext(context.Background(), query, p.config.migrationsSchemaName, p.config.migrationsTableName) 460 461 var count int 462 err = row.Scan(&count) 463 if err != nil { 464 return &database.Error{OrigErr: err, Query: []byte(query)} 465 } 466 467 if count == 1 { 468 return nil 469 } 470 471 query = `CREATE TABLE IF NOT EXISTS ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version bigint not null primary key, dirty boolean not null)` 472 if _, err = p.conn.ExecContext(context.Background(), query); err != nil { 473 return &database.Error{OrigErr: err, Query: []byte(query)} 474 } 475 476 return nil 477 } 478 479 // Copied from lib/pq implementation: https://github.com/lib/pq/blob/v1.9.0/conn.go#L1611 480 func quoteIdentifier(name string) string { 481 end := strings.IndexRune(name, 0) 482 if end > -1 { 483 name = name[:end] 484 } 485 return `"` + strings.Replace(name, `"`, `""`, -1) + `"` 486 }