github.com/olitvin/migrate/v4@v4.14.3-0.20210330111251-992b37ee04c8/database/pgx/pgx.go (about) 1 // +build go1.9 2 3 package pgx 4 5 import ( 6 "context" 7 "database/sql" 8 "fmt" 9 "io" 10 "io/ioutil" 11 nurl "net/url" 12 "strconv" 13 "strings" 14 "time" 15 16 multierror "github.com/hashicorp/go-multierror" 17 "github.com/jackc/pgconn" 18 "github.com/jackc/pgerrcode" 19 _ "github.com/jackc/pgx/v4/stdlib" 20 "github.com/olitvin/migrate/v4" 21 "github.com/olitvin/migrate/v4/database" 22 "github.com/olitvin/migrate/v4/database/multistmt" 23 ) 24 25 func init() { 26 db := Postgres{} 27 database.Register("pgx", &db) 28 } 29 30 var ( 31 multiStmtDelimiter = []byte(";") 32 33 DefaultMigrationsTable = "schema_migrations" 34 DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB 35 ) 36 37 var ( 38 ErrNilConfig = fmt.Errorf("no config") 39 ErrNoDatabaseName = fmt.Errorf("no database name") 40 ErrNoSchema = fmt.Errorf("no schema") 41 ErrDatabaseDirty = fmt.Errorf("database is dirty") 42 ) 43 44 type Config struct { 45 MigrationsTable string 46 DatabaseName string 47 SchemaName string 48 StatementTimeout time.Duration 49 MultiStatementEnabled bool 50 MultiStatementMaxSize int 51 } 52 53 type Postgres struct { 54 // Locking and unlocking need to use the same connection 55 conn *sql.Conn 56 db *sql.DB 57 isLocked bool 58 59 // Open and WithInstance need to guarantee that config is never nil 60 config *Config 61 } 62 63 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 64 if config == nil { 65 return nil, ErrNilConfig 66 } 67 68 if err := instance.Ping(); err != nil { 69 return nil, err 70 } 71 72 if config.DatabaseName == "" { 73 query := `SELECT CURRENT_DATABASE()` 74 var databaseName string 75 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 76 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 77 } 78 79 if len(databaseName) == 0 { 80 return nil, ErrNoDatabaseName 81 } 82 83 config.DatabaseName = databaseName 84 } 85 86 if config.SchemaName == "" { 87 query := `SELECT CURRENT_SCHEMA()` 88 var schemaName string 89 if err := instance.QueryRow(query).Scan(&schemaName); err != nil { 90 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 91 } 92 93 if len(schemaName) == 0 { 94 return nil, ErrNoSchema 95 } 96 97 config.SchemaName = schemaName 98 } 99 100 if len(config.MigrationsTable) == 0 { 101 config.MigrationsTable = DefaultMigrationsTable 102 } 103 104 conn, err := instance.Conn(context.Background()) 105 106 if err != nil { 107 return nil, err 108 } 109 110 px := &Postgres{ 111 conn: conn, 112 db: instance, 113 config: config, 114 } 115 116 if err := px.ensureVersionTable(); err != nil { 117 return nil, err 118 } 119 120 return px, nil 121 } 122 123 func (p *Postgres) Open(url string) (database.Driver, error) { 124 purl, err := nurl.Parse(url) 125 if err != nil { 126 return nil, err 127 } 128 129 // Driver is registered as pgx, but connection string must use postgres schema 130 // when making actual connection 131 // i.e. pgx://user:password@host:port/db => postgres://user:password@host:port/db 132 purl.Scheme = "postgres" 133 134 db, err := sql.Open("pgx", migrate.FilterCustomQuery(purl).String()) 135 if err != nil { 136 return nil, err 137 } 138 139 migrationsTable := purl.Query().Get("x-migrations-table") 140 statementTimeoutString := purl.Query().Get("x-statement-timeout") 141 statementTimeout := 0 142 if statementTimeoutString != "" { 143 statementTimeout, err = strconv.Atoi(statementTimeoutString) 144 if err != nil { 145 return nil, err 146 } 147 } 148 149 multiStatementMaxSize := DefaultMultiStatementMaxSize 150 if s := purl.Query().Get("x-multi-statement-max-size"); len(s) > 0 { 151 multiStatementMaxSize, err = strconv.Atoi(s) 152 if err != nil { 153 return nil, err 154 } 155 if multiStatementMaxSize <= 0 { 156 multiStatementMaxSize = DefaultMultiStatementMaxSize 157 } 158 } 159 160 multiStatementEnabled := false 161 if s := purl.Query().Get("x-multi-statement"); len(s) > 0 { 162 multiStatementEnabled, err = strconv.ParseBool(s) 163 if err != nil { 164 return nil, fmt.Errorf("Unable to parse option x-multi-statement: %w", err) 165 } 166 } 167 168 px, err := WithInstance(db, &Config{ 169 DatabaseName: purl.Path, 170 MigrationsTable: migrationsTable, 171 StatementTimeout: time.Duration(statementTimeout) * time.Millisecond, 172 MultiStatementEnabled: multiStatementEnabled, 173 MultiStatementMaxSize: multiStatementMaxSize, 174 }) 175 176 if err != nil { 177 return nil, err 178 } 179 180 return px, nil 181 } 182 183 func (p *Postgres) Close() error { 184 connErr := p.conn.Close() 185 dbErr := p.db.Close() 186 if connErr != nil || dbErr != nil { 187 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 188 } 189 return nil 190 } 191 192 // https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS 193 func (p *Postgres) Lock() error { 194 if p.isLocked { 195 return database.ErrLocked 196 } 197 198 aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.SchemaName) 199 if err != nil { 200 return err 201 } 202 203 // This will wait indefinitely until the lock can be acquired. 204 query := `SELECT pg_advisory_lock($1)` 205 if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { 206 return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} 207 } 208 209 p.isLocked = true 210 return nil 211 } 212 213 func (p *Postgres) Unlock() error { 214 if !p.isLocked { 215 return nil 216 } 217 218 aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.SchemaName) 219 if err != nil { 220 return err 221 } 222 223 query := `SELECT pg_advisory_unlock($1)` 224 if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { 225 return &database.Error{OrigErr: err, Query: []byte(query)} 226 } 227 p.isLocked = false 228 return nil 229 } 230 231 func (p *Postgres) Run(migration io.Reader) error { 232 if p.config.MultiStatementEnabled { 233 var err error 234 if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool { 235 if err = p.runStatement(m); err != nil { 236 return false 237 } 238 return true 239 }); e != nil { 240 return e 241 } 242 return err 243 } 244 migr, err := ioutil.ReadAll(migration) 245 if err != nil { 246 return err 247 } 248 return p.runStatement(migr) 249 } 250 251 func (p *Postgres) runStatement(statement []byte) error { 252 ctx := context.Background() 253 if p.config.StatementTimeout != 0 { 254 var cancel context.CancelFunc 255 ctx, cancel = context.WithTimeout(ctx, p.config.StatementTimeout) 256 defer cancel() 257 } 258 query := string(statement) 259 if strings.TrimSpace(query) == "" { 260 return nil 261 } 262 if _, err := p.conn.ExecContext(ctx, query); err != nil { 263 264 if pgErr, ok := err.(*pgconn.PgError); ok { 265 var line uint 266 var col uint 267 var lineColOK bool 268 line, col, lineColOK = computeLineFromPos(query, int(pgErr.Position)) 269 message := fmt.Sprintf("migration failed: %s", pgErr.Message) 270 if lineColOK { 271 message = fmt.Sprintf("%s (column %d)", message, col) 272 } 273 if pgErr.Detail != "" { 274 message = fmt.Sprintf("%s, %s", message, pgErr.Detail) 275 } 276 return database.Error{OrigErr: err, Err: message, Query: statement, Line: line} 277 } 278 return database.Error{OrigErr: err, Err: "migration failed", Query: statement} 279 } 280 return nil 281 } 282 283 func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) { 284 // replace crlf with lf 285 s = strings.Replace(s, "\r\n", "\n", -1) 286 // pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes 287 runes := []rune(s) 288 if pos > len(runes) { 289 return 0, 0, false 290 } 291 sel := runes[:pos] 292 line = uint(runesCount(sel, newLine) + 1) 293 col = uint(pos - 1 - runesLastIndex(sel, newLine)) 294 return line, col, true 295 } 296 297 const newLine = '\n' 298 299 func runesCount(input []rune, target rune) int { 300 var count int 301 for _, r := range input { 302 if r == target { 303 count++ 304 } 305 } 306 return count 307 } 308 309 func runesLastIndex(input []rune, target rune) int { 310 for i := len(input) - 1; i >= 0; i-- { 311 if input[i] == target { 312 return i 313 } 314 } 315 return -1 316 } 317 318 func (p *Postgres) SetVersion(version int, dirty bool) error { 319 tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{}) 320 if err != nil { 321 return &database.Error{OrigErr: err, Err: "transaction start failed"} 322 } 323 324 query := `TRUNCATE ` + quoteIdentifier(p.config.MigrationsTable) 325 if _, err := tx.Exec(query); err != nil { 326 if errRollback := tx.Rollback(); errRollback != nil { 327 err = multierror.Append(err, errRollback) 328 } 329 return &database.Error{OrigErr: err, Query: []byte(query)} 330 } 331 332 // Also re-write the schema version for nil dirty versions to prevent 333 // empty schema version for failed down migration on the first migration 334 // See: https://github.com/olitvin/migrate/issues/330 335 if version >= 0 || (version == database.NilVersion && dirty) { 336 query = `INSERT INTO ` + quoteIdentifier(p.config.MigrationsTable) + 337 ` (version, dirty) VALUES ($1, $2)` 338 if _, err := tx.Exec(query, version, dirty); err != nil { 339 if errRollback := tx.Rollback(); errRollback != nil { 340 err = multierror.Append(err, errRollback) 341 } 342 return &database.Error{OrigErr: err, Query: []byte(query)} 343 } 344 } 345 346 if err := tx.Commit(); err != nil { 347 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 348 } 349 350 return nil 351 } 352 353 func (p *Postgres) Version() (version int, dirty bool, err error) { 354 query := `SELECT version, dirty FROM ` + quoteIdentifier(p.config.MigrationsTable) + ` LIMIT 1` 355 err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 356 switch { 357 case err == sql.ErrNoRows: 358 return database.NilVersion, false, nil 359 360 case err != nil: 361 if e, ok := err.(*pgconn.PgError); ok { 362 if e.SQLState() == pgerrcode.UndefinedTable { 363 return database.NilVersion, false, nil 364 } 365 } 366 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 367 368 default: 369 return version, dirty, nil 370 } 371 } 372 373 func (p *Postgres) Drop() (err error) { 374 // select all tables in current schema 375 query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'` 376 tables, err := p.conn.QueryContext(context.Background(), query) 377 if err != nil { 378 return &database.Error{OrigErr: err, Query: []byte(query)} 379 } 380 defer func() { 381 if errClose := tables.Close(); errClose != nil { 382 err = multierror.Append(err, errClose) 383 } 384 }() 385 386 // delete one table after another 387 tableNames := make([]string, 0) 388 for tables.Next() { 389 var tableName string 390 if err := tables.Scan(&tableName); err != nil { 391 return err 392 } 393 if len(tableName) > 0 { 394 tableNames = append(tableNames, tableName) 395 } 396 } 397 if err := tables.Err(); err != nil { 398 return &database.Error{OrigErr: err, Query: []byte(query)} 399 } 400 401 if len(tableNames) > 0 { 402 // delete one by one ... 403 for _, t := range tableNames { 404 query = `DROP TABLE IF EXISTS ` + quoteIdentifier(t) + ` CASCADE` 405 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 406 return &database.Error{OrigErr: err, Query: []byte(query)} 407 } 408 } 409 } 410 411 return nil 412 } 413 414 // ensureVersionTable checks if versions table exists and, if not, creates it. 415 // Note that this function locks the database, which deviates from the usual 416 // convention of "caller locks" in the Postgres type. 417 func (p *Postgres) ensureVersionTable() (err error) { 418 if err = p.Lock(); err != nil { 419 return err 420 } 421 422 defer func() { 423 if e := p.Unlock(); e != nil { 424 if err == nil { 425 err = e 426 } else { 427 err = multierror.Append(err, e) 428 } 429 } 430 }() 431 432 // This block checks whether the `MigrationsTable` already exists. This is useful because it allows read only postgres 433 // users to also check the current version of the schema. Previously, even if `MigrationsTable` existed, the 434 // `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission. 435 // Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258 436 var count int 437 query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1` 438 row := p.conn.QueryRowContext(context.Background(), query, p.config.MigrationsTable) 439 440 err = row.Scan(&count) 441 if err != nil { 442 return &database.Error{OrigErr: err, Query: []byte(query)} 443 } 444 445 if count == 1 { 446 return nil 447 } 448 449 query = `CREATE TABLE IF NOT EXISTS ` + quoteIdentifier(p.config.MigrationsTable) + ` (version bigint not null primary key, dirty boolean not null)` 450 if _, err = p.conn.ExecContext(context.Background(), query); err != nil { 451 return &database.Error{OrigErr: err, Query: []byte(query)} 452 } 453 454 return nil 455 } 456 457 // Copied from lib/pq implementation: https://github.com/lib/pq/blob/v1.9.0/conn.go#L1611 458 func quoteIdentifier(name string) string { 459 end := strings.IndexRune(name, 0) 460 if end > -1 { 461 name = name[:end] 462 } 463 return `"` + strings.Replace(name, `"`, `""`, -1) + `"` 464 }