github.com/Elate-DevOps/migrate/v4@v4.0.12/database/pgx/v5/pgx.go (about) 1 //go:build go1.9 2 // +build go1.9 3 4 package pgx 5 6 import ( 7 "context" 8 "database/sql" 9 "fmt" 10 "io" 11 nurl "net/url" 12 "regexp" 13 "strconv" 14 "strings" 15 "time" 16 17 "github.com/Elate-DevOps/migrate/v4" 18 "github.com/Elate-DevOps/migrate/v4/database" 19 "github.com/Elate-DevOps/migrate/v4/database/multistmt" 20 "github.com/hashicorp/go-multierror" 21 "github.com/jackc/pgerrcode" 22 "github.com/jackc/pgx/v5/pgconn" 23 _ "github.com/jackc/pgx/v5/stdlib" 24 ) 25 26 func init() { 27 db := Postgres{} 28 database.Register("pgx5", &db) 29 } 30 31 var ( 32 multiStmtDelimiter = []byte(";") 33 34 DefaultMigrationsTable = "schema_migrations" 35 DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB 36 ) 37 38 var ( 39 ErrNilConfig = fmt.Errorf("no config") 40 ErrNoDatabaseName = fmt.Errorf("no database name") 41 ErrNoSchema = fmt.Errorf("no schema") 42 ) 43 44 type Config struct { 45 MigrationsTable string 46 DatabaseName string 47 migrationsTableName string 48 StatementTimeout time.Duration 49 MigrationsTableQuoted bool 50 MultiStatementEnabled bool 51 MultiStatementMaxSize int 52 } 53 54 type Postgres struct { 55 // Locking and unlocking need to use the same connection 56 conn *sql.Conn 57 db *sql.DB 58 // isLocked atomic.Bool 59 60 // Open and WithInstance need to guarantee that config is never nil 61 config *Config 62 } 63 64 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 65 if config == nil { 66 return nil, ErrNilConfig 67 } 68 69 if err := instance.Ping(); err != nil { 70 return nil, err 71 } 72 73 if config.DatabaseName == "" { 74 query := `SELECT CURRENT_DATABASE()` 75 var databaseName string 76 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 77 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 78 } 79 80 if len(databaseName) == 0 { 81 return nil, ErrNoDatabaseName 82 } 83 84 config.DatabaseName = databaseName 85 } 86 87 if len(config.MigrationsTable) == 0 { 88 config.MigrationsTable = DefaultMigrationsTable 89 } 90 91 config.migrationsTableName = config.MigrationsTable 92 if config.MigrationsTableQuoted { 93 re := regexp.MustCompile(`"(.*?)"`) 94 result := re.FindAllStringSubmatch(config.MigrationsTable, -1) 95 config.migrationsTableName = result[len(result)-1][1] 96 if len(result) > 1 { 97 return nil, fmt.Errorf("\"%s\" MigrationsTable contains too many dot characters", config.MigrationsTable) 98 } 99 } 100 101 conn, err := instance.Conn(context.Background()) 102 if err != nil { 103 return nil, err 104 } 105 106 px := &Postgres{ 107 conn: conn, 108 db: instance, 109 config: config, 110 } 111 112 if err := px.ensureVersionTable(); err != nil { 113 return nil, err 114 } 115 116 return px, nil 117 } 118 119 func (p *Postgres) Open(url string) (database.Driver, error) { 120 purl, err := nurl.Parse(url) 121 if err != nil { 122 return nil, err 123 } 124 125 // Driver is registered as pgx, but connection string must use postgres schema 126 // when making actual connection 127 // i.e. pgx://user:password@host:port/db => postgres://user:password@host:port/db 128 purl.Scheme = "postgres" 129 130 db, err := sql.Open("pgx/v5", migrate.FilterCustomQuery(purl).String()) 131 if err != nil { 132 return nil, err 133 } 134 135 migrationsTable := purl.Query().Get("x-migrations-table") 136 migrationsTableQuoted := false 137 if s := purl.Query().Get("x-migrations-table-quoted"); len(s) > 0 { 138 migrationsTableQuoted, err = strconv.ParseBool(s) 139 if err != nil { 140 return nil, fmt.Errorf("Unable to parse option x-migrations-table-quoted: %w", err) 141 } 142 } 143 if (len(migrationsTable) > 0) && (migrationsTableQuoted) && ((migrationsTable[0] != '"') || (migrationsTable[len(migrationsTable)-1] != '"')) { 144 return nil, fmt.Errorf("x-migrations-table must be quoted (for instance '\"migrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: %s", migrationsTable) 145 } 146 147 statementTimeoutString := purl.Query().Get("x-statement-timeout") 148 statementTimeout := 0 149 if statementTimeoutString != "" { 150 statementTimeout, err = strconv.Atoi(statementTimeoutString) 151 if err != nil { 152 return nil, err 153 } 154 } 155 156 multiStatementMaxSize := DefaultMultiStatementMaxSize 157 if s := purl.Query().Get("x-multi-statement-max-size"); len(s) > 0 { 158 multiStatementMaxSize, err = strconv.Atoi(s) 159 if err != nil { 160 return nil, err 161 } 162 if multiStatementMaxSize <= 0 { 163 multiStatementMaxSize = DefaultMultiStatementMaxSize 164 } 165 } 166 167 multiStatementEnabled := false 168 if s := purl.Query().Get("x-multi-statement"); len(s) > 0 { 169 multiStatementEnabled, err = strconv.ParseBool(s) 170 if err != nil { 171 return nil, fmt.Errorf("Unable to parse option x-multi-statement: %w", err) 172 } 173 } 174 175 px, err := WithInstance(db, &Config{ 176 DatabaseName: purl.Path, 177 MigrationsTable: migrationsTable, 178 MigrationsTableQuoted: migrationsTableQuoted, 179 StatementTimeout: time.Duration(statementTimeout) * time.Millisecond, 180 MultiStatementEnabled: multiStatementEnabled, 181 MultiStatementMaxSize: multiStatementMaxSize, 182 }) 183 if err != nil { 184 return nil, err 185 } 186 187 return px, nil 188 } 189 190 func (p *Postgres) Close() error { 191 connErr := p.conn.Close() 192 dbErr := p.db.Close() 193 if connErr != nil || dbErr != nil { 194 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 195 } 196 return nil 197 } 198 199 // https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS 200 func (p *Postgres) Lock() error { 201 // return database.CasRestoreOnErr(&p.isLocked, false, true, database.ErrLocked, func() error { 202 // aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName) 203 // if err != nil { 204 // return err 205 // } 206 207 // // This will wait indefinitely until the lock can be acquired. 208 // query := `SELECT pg_advisory_lock($1)` 209 // if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { 210 // return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} 211 // } 212 // return nil 213 // }) 214 return nil 215 } 216 217 func (p *Postgres) Unlock() error { 218 // return database.CasRestoreOnErr(&p.isLocked, true, false, database.ErrNotLocked, func() error { 219 // aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName) 220 // if err != nil { 221 // return err 222 // } 223 224 // query := `SELECT pg_advisory_unlock($1)` 225 // if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { 226 // return &database.Error{OrigErr: err, Query: []byte(query)} 227 // } 228 // return nil 229 // }) 230 return nil 231 } 232 233 func (p *Postgres) Run(migration io.Reader) error { 234 if p.config.MultiStatementEnabled { 235 var err error 236 if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool { 237 if err = p.runStatement(m); err != nil { 238 return false 239 } 240 return true 241 }); e != nil { 242 return e 243 } 244 return err 245 } 246 migr, err := io.ReadAll(migration) 247 if err != nil { 248 return err 249 } 250 return p.runStatement(migr) 251 } 252 253 func (p *Postgres) runStatement(statement []byte) error { 254 ctx := context.Background() 255 if p.config.StatementTimeout != 0 { 256 var cancel context.CancelFunc 257 ctx, cancel = context.WithTimeout(ctx, p.config.StatementTimeout) 258 defer cancel() 259 } 260 query := string(statement) 261 if strings.TrimSpace(query) == "" { 262 return nil 263 } 264 if _, err := p.conn.ExecContext(ctx, query); err != nil { 265 266 if pgErr, ok := err.(*pgconn.PgError); ok { 267 var line uint 268 var col uint 269 var lineColOK bool 270 line, col, lineColOK = computeLineFromPos(query, int(pgErr.Position)) 271 message := fmt.Sprintf("migration failed: %s", pgErr.Message) 272 if lineColOK { 273 message = fmt.Sprintf("%s (column %d)", message, col) 274 } 275 if pgErr.Detail != "" { 276 message = fmt.Sprintf("%s, %s", message, pgErr.Detail) 277 } 278 return database.Error{OrigErr: err, Err: message, Query: statement, Line: line} 279 } 280 return database.Error{OrigErr: err, Err: "migration failed", Query: statement} 281 } 282 return nil 283 } 284 285 func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) { 286 // replace crlf with lf 287 s = strings.Replace(s, "\r\n", "\n", -1) 288 // pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes 289 runes := []rune(s) 290 if pos > len(runes) { 291 return 0, 0, false 292 } 293 sel := runes[:pos] 294 line = uint(runesCount(sel, newLine) + 1) 295 col = uint(pos - 1 - runesLastIndex(sel, newLine)) 296 return line, col, true 297 } 298 299 const newLine = '\n' 300 301 func runesCount(input []rune, target rune) int { 302 var count int 303 for _, r := range input { 304 if r == target { 305 count++ 306 } 307 } 308 return count 309 } 310 311 func runesLastIndex(input []rune, target rune) int { 312 for i := len(input) - 1; i >= 0; i-- { 313 if input[i] == target { 314 return i 315 } 316 } 317 return -1 318 } 319 320 func (p *Postgres) SetVersion(version int, dirty bool) error { 321 tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{}) 322 if err != nil { 323 return &database.Error{OrigErr: err, Err: "transaction start failed"} 324 } 325 326 query := `TRUNCATE ` + quoteIdentifier(p.config.migrationsTableName) 327 if _, err := tx.Exec(query); err != nil { 328 if errRollback := tx.Rollback(); errRollback != nil { 329 err = multierror.Append(err, errRollback) 330 } 331 return &database.Error{OrigErr: err, Query: []byte(query)} 332 } 333 334 // Also re-write the schema version for nil dirty versions to prevent 335 // empty schema version for failed down migration on the first migration 336 // See: https://github.com/golang-migrate/migrate/issues/330 337 if version >= 0 || (version == database.NilVersion && dirty) { 338 query = `INSERT INTO ` + quoteIdentifier(p.config.migrationsTableName) + ` (version, dirty) VALUES ($1, $2)` 339 if _, err := tx.Exec(query, version, dirty); err != nil { 340 if errRollback := tx.Rollback(); errRollback != nil { 341 err = multierror.Append(err, errRollback) 342 } 343 return &database.Error{OrigErr: err, Query: []byte(query)} 344 } 345 } 346 347 if err := tx.Commit(); err != nil { 348 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 349 } 350 351 return nil 352 } 353 354 func (p *Postgres) Version() (version int, dirty bool, err error) { 355 query := `SELECT version, dirty FROM ` + quoteIdentifier(p.config.migrationsTableName) + ` LIMIT 1` 356 err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 357 switch { 358 case err == sql.ErrNoRows: 359 return database.NilVersion, false, nil 360 361 case err != nil: 362 if e, ok := err.(*pgconn.PgError); ok { 363 if e.SQLState() == pgerrcode.UndefinedTable { 364 return database.NilVersion, false, nil 365 } 366 } 367 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 368 369 default: 370 return version, dirty, nil 371 } 372 } 373 374 func (p *Postgres) Drop() (err error) { 375 // select all tables in current schema 376 query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'` 377 tables, err := p.conn.QueryContext(context.Background(), query) 378 if err != nil { 379 return &database.Error{OrigErr: err, Query: []byte(query)} 380 } 381 defer func() { 382 if errClose := tables.Close(); errClose != nil { 383 err = multierror.Append(err, errClose) 384 } 385 }() 386 387 // delete one table after another 388 tableNames := make([]string, 0) 389 for tables.Next() { 390 var tableName string 391 if err := tables.Scan(&tableName); err != nil { 392 return err 393 } 394 if len(tableName) > 0 { 395 tableNames = append(tableNames, tableName) 396 } 397 } 398 if err := tables.Err(); err != nil { 399 return &database.Error{OrigErr: err, Query: []byte(query)} 400 } 401 402 if len(tableNames) > 0 { 403 // delete one by one ... 404 for _, t := range tableNames { 405 query = `DROP TABLE IF EXISTS ` + quoteIdentifier(t) + ` CASCADE` 406 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 407 return &database.Error{OrigErr: err, Query: []byte(query)} 408 } 409 } 410 } 411 412 return nil 413 } 414 415 // ensureVersionTable checks if versions table exists and, if not, creates it. 416 // Note that this function locks the database, which deviates from the usual 417 // convention of "caller locks" in the Postgres type. 418 func (p *Postgres) ensureVersionTable() (err error) { 419 if err = p.Lock(); err != nil { 420 return err 421 } 422 423 defer func() { 424 if e := p.Unlock(); e != nil { 425 if err == nil { 426 err = e 427 } else { 428 err = multierror.Append(err, e) 429 } 430 } 431 }() 432 433 // This block checks whether the `MigrationsTable` already exists. This is useful because it allows read only postgres 434 // users to also check the current version of the schema. Previously, even if `MigrationsTable` existed, the 435 // `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission. 436 // Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258 437 query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 LIMIT 1` 438 row := p.conn.QueryRowContext(context.Background(), query, p.config.migrationsTableName) 439 440 var count int 441 err = row.Scan(&count) 442 if err != nil { 443 return &database.Error{OrigErr: err, Query: []byte(query)} 444 } 445 446 if count == 1 { 447 return nil 448 } 449 450 query = `CREATE TABLE IF NOT EXISTS ` + quoteIdentifier(p.config.migrationsTableName) + ` (version bigint not null primary key, dirty boolean not null)` 451 if _, err = p.conn.ExecContext(context.Background(), query); err != nil { 452 return &database.Error{OrigErr: err, Query: []byte(query)} 453 } 454 455 return nil 456 } 457 458 // Copied from lib/pq implementation: https://github.com/lib/pq/blob/v1.9.0/conn.go#L1611 459 func quoteIdentifier(name string) string { 460 end := strings.IndexRune(name, 0) 461 if end > -1 { 462 name = name[:end] 463 } 464 return `"` + strings.Replace(name, `"`, `""`, -1) + `"` 465 }