github.com/olitvin/migrate/v4@v4.14.3-0.20210330111251-992b37ee04c8/database/postgres/postgres.go (about) 1 // +build go1.9 2 3 package postgres 4 5 import ( 6 "context" 7 "database/sql" 8 "fmt" 9 "io" 10 "io/ioutil" 11 nurl "net/url" 12 "strconv" 13 "strings" 14 "time" 15 16 multierror "github.com/hashicorp/go-multierror" 17 "github.com/lib/pq" 18 "github.com/olitvin/migrate/v4" 19 "github.com/olitvin/migrate/v4/database" 20 "github.com/olitvin/migrate/v4/database/multistmt" 21 ) 22 23 func init() { 24 db := Postgres{} 25 database.Register("postgres", &db) 26 database.Register("postgresql", &db) 27 } 28 29 var ( 30 multiStmtDelimiter = []byte(";") 31 32 DefaultMigrationsTable = "schema_migrations" 33 DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB 34 ) 35 36 var ( 37 ErrNilConfig = fmt.Errorf("no config") 38 ErrNoDatabaseName = fmt.Errorf("no database name") 39 ErrNoSchema = fmt.Errorf("no schema") 40 ErrDatabaseDirty = fmt.Errorf("database is dirty") 41 ) 42 43 type Config struct { 44 MigrationsTable string 45 DatabaseName string 46 SchemaName string 47 StatementTimeout time.Duration 48 MultiStatementEnabled bool 49 MultiStatementMaxSize int 50 } 51 52 type Postgres struct { 53 // Locking and unlocking need to use the same connection 54 conn *sql.Conn 55 db *sql.DB 56 isLocked bool 57 58 // Open and WithInstance need to guarantee that config is never nil 59 config *Config 60 } 61 62 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 63 if config == nil { 64 return nil, ErrNilConfig 65 } 66 67 if err := instance.Ping(); err != nil { 68 return nil, err 69 } 70 71 if config.DatabaseName == "" { 72 query := `SELECT CURRENT_DATABASE()` 73 var databaseName string 74 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 75 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 76 } 77 78 if len(databaseName) == 0 { 79 return nil, ErrNoDatabaseName 80 } 81 82 config.DatabaseName = databaseName 83 } 84 85 if config.SchemaName == "" { 86 query := `SELECT CURRENT_SCHEMA()` 87 var schemaName string 88 if err := instance.QueryRow(query).Scan(&schemaName); err != nil { 89 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 90 } 91 92 if len(schemaName) == 0 { 93 return nil, ErrNoSchema 94 } 95 96 config.SchemaName = schemaName 97 } 98 99 if len(config.MigrationsTable) == 0 { 100 config.MigrationsTable = DefaultMigrationsTable 101 } 102 103 conn, err := instance.Conn(context.Background()) 104 105 if err != nil { 106 return nil, err 107 } 108 109 px := &Postgres{ 110 conn: conn, 111 db: instance, 112 config: config, 113 } 114 115 if err := px.ensureVersionTable(); err != nil { 116 return nil, err 117 } 118 119 return px, nil 120 } 121 122 func (p *Postgres) Open(url string) (database.Driver, error) { 123 purl, err := nurl.Parse(url) 124 if err != nil { 125 return nil, err 126 } 127 128 db, err := sql.Open("postgres", migrate.FilterCustomQuery(purl).String()) 129 if err != nil { 130 return nil, err 131 } 132 133 migrationsTable := purl.Query().Get("x-migrations-table") 134 statementTimeoutString := purl.Query().Get("x-statement-timeout") 135 statementTimeout := 0 136 if statementTimeoutString != "" { 137 statementTimeout, err = strconv.Atoi(statementTimeoutString) 138 if err != nil { 139 return nil, err 140 } 141 } 142 143 multiStatementMaxSize := DefaultMultiStatementMaxSize 144 if s := purl.Query().Get("x-multi-statement-max-size"); len(s) > 0 { 145 multiStatementMaxSize, err = strconv.Atoi(s) 146 if err != nil { 147 return nil, err 148 } 149 if multiStatementMaxSize <= 0 { 150 multiStatementMaxSize = DefaultMultiStatementMaxSize 151 } 152 } 153 154 multiStatementEnabled := false 155 if s := purl.Query().Get("x-multi-statement"); len(s) > 0 { 156 multiStatementEnabled, err = strconv.ParseBool(s) 157 if err != nil { 158 return nil, fmt.Errorf("Unable to parse option x-multi-statement: %w", err) 159 } 160 } 161 162 px, err := WithInstance(db, &Config{ 163 DatabaseName: purl.Path, 164 MigrationsTable: migrationsTable, 165 StatementTimeout: time.Duration(statementTimeout) * time.Millisecond, 166 MultiStatementEnabled: multiStatementEnabled, 167 MultiStatementMaxSize: multiStatementMaxSize, 168 }) 169 170 if err != nil { 171 return nil, err 172 } 173 174 return px, nil 175 } 176 177 func (p *Postgres) Close() error { 178 connErr := p.conn.Close() 179 dbErr := p.db.Close() 180 if connErr != nil || dbErr != nil { 181 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 182 } 183 return nil 184 } 185 186 // https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS 187 func (p *Postgres) Lock() error { 188 if p.isLocked { 189 return database.ErrLocked 190 } 191 192 aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.SchemaName) 193 if err != nil { 194 return err 195 } 196 197 // This will wait indefinitely until the lock can be acquired. 198 query := `SELECT pg_advisory_lock($1)` 199 if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { 200 return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} 201 } 202 203 p.isLocked = true 204 return nil 205 } 206 207 func (p *Postgres) Unlock() error { 208 if !p.isLocked { 209 return nil 210 } 211 212 aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.SchemaName) 213 if err != nil { 214 return err 215 } 216 217 query := `SELECT pg_advisory_unlock($1)` 218 if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { 219 return &database.Error{OrigErr: err, Query: []byte(query)} 220 } 221 p.isLocked = false 222 return nil 223 } 224 225 func (p *Postgres) Run(migration io.Reader) error { 226 if p.config.MultiStatementEnabled { 227 var err error 228 if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool { 229 if err = p.runStatement(m); err != nil { 230 return false 231 } 232 return true 233 }); e != nil { 234 return e 235 } 236 return err 237 } 238 migr, err := ioutil.ReadAll(migration) 239 if err != nil { 240 return err 241 } 242 return p.runStatement(migr) 243 } 244 245 func (p *Postgres) runStatement(statement []byte) error { 246 ctx := context.Background() 247 if p.config.StatementTimeout != 0 { 248 var cancel context.CancelFunc 249 ctx, cancel = context.WithTimeout(ctx, p.config.StatementTimeout) 250 defer cancel() 251 } 252 query := string(statement) 253 if strings.TrimSpace(query) == "" { 254 return nil 255 } 256 if _, err := p.conn.ExecContext(ctx, query); err != nil { 257 if pgErr, ok := err.(*pq.Error); ok { 258 var line uint 259 var col uint 260 var lineColOK bool 261 if pgErr.Position != "" { 262 if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil { 263 line, col, lineColOK = computeLineFromPos(query, int(pos)) 264 } 265 } 266 message := fmt.Sprintf("migration failed: %s", pgErr.Message) 267 if lineColOK { 268 message = fmt.Sprintf("%s (column %d)", message, col) 269 } 270 if pgErr.Detail != "" { 271 message = fmt.Sprintf("%s, %s", message, pgErr.Detail) 272 } 273 return database.Error{OrigErr: err, Err: message, Query: statement, Line: line} 274 } 275 return database.Error{OrigErr: err, Err: "migration failed", Query: statement} 276 } 277 return nil 278 } 279 280 func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) { 281 // replace crlf with lf 282 s = strings.Replace(s, "\r\n", "\n", -1) 283 // pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes 284 runes := []rune(s) 285 if pos > len(runes) { 286 return 0, 0, false 287 } 288 sel := runes[:pos] 289 line = uint(runesCount(sel, newLine) + 1) 290 col = uint(pos - 1 - runesLastIndex(sel, newLine)) 291 return line, col, true 292 } 293 294 const newLine = '\n' 295 296 func runesCount(input []rune, target rune) int { 297 var count int 298 for _, r := range input { 299 if r == target { 300 count++ 301 } 302 } 303 return count 304 } 305 306 func runesLastIndex(input []rune, target rune) int { 307 for i := len(input) - 1; i >= 0; i-- { 308 if input[i] == target { 309 return i 310 } 311 } 312 return -1 313 } 314 315 func (p *Postgres) SetVersion(version int, dirty bool) error { 316 tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{}) 317 if err != nil { 318 return &database.Error{OrigErr: err, Err: "transaction start failed"} 319 } 320 321 query := `TRUNCATE ` + pq.QuoteIdentifier(p.config.MigrationsTable) 322 if _, err := tx.Exec(query); err != nil { 323 if errRollback := tx.Rollback(); errRollback != nil { 324 err = multierror.Append(err, errRollback) 325 } 326 return &database.Error{OrigErr: err, Query: []byte(query)} 327 } 328 329 // Also re-write the schema version for nil dirty versions to prevent 330 // empty schema version for failed down migration on the first migration 331 // See: https://github.com/olitvin/migrate/issues/330 332 if version >= 0 || (version == database.NilVersion && dirty) { 333 query = `INSERT INTO ` + pq.QuoteIdentifier(p.config.MigrationsTable) + 334 ` (version, dirty) VALUES ($1, $2)` 335 if _, err := tx.Exec(query, version, dirty); err != nil { 336 if errRollback := tx.Rollback(); errRollback != nil { 337 err = multierror.Append(err, errRollback) 338 } 339 return &database.Error{OrigErr: err, Query: []byte(query)} 340 } 341 } 342 343 if err := tx.Commit(); err != nil { 344 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 345 } 346 347 return nil 348 } 349 350 func (p *Postgres) Version() (version int, dirty bool, err error) { 351 query := `SELECT version, dirty FROM ` + pq.QuoteIdentifier(p.config.MigrationsTable) + ` LIMIT 1` 352 err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 353 switch { 354 case err == sql.ErrNoRows: 355 return database.NilVersion, false, nil 356 357 case err != nil: 358 if e, ok := err.(*pq.Error); ok { 359 if e.Code.Name() == "undefined_table" { 360 return database.NilVersion, false, nil 361 } 362 } 363 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 364 365 default: 366 return version, dirty, nil 367 } 368 } 369 370 func (p *Postgres) Drop() (err error) { 371 // select all tables in current schema 372 query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'` 373 tables, err := p.conn.QueryContext(context.Background(), query) 374 if err != nil { 375 return &database.Error{OrigErr: err, Query: []byte(query)} 376 } 377 defer func() { 378 if errClose := tables.Close(); errClose != nil { 379 err = multierror.Append(err, errClose) 380 } 381 }() 382 383 // delete one table after another 384 tableNames := make([]string, 0) 385 for tables.Next() { 386 var tableName string 387 if err := tables.Scan(&tableName); err != nil { 388 return err 389 } 390 if len(tableName) > 0 { 391 tableNames = append(tableNames, tableName) 392 } 393 } 394 if err := tables.Err(); err != nil { 395 return &database.Error{OrigErr: err, Query: []byte(query)} 396 } 397 398 if len(tableNames) > 0 { 399 // delete one by one ... 400 for _, t := range tableNames { 401 query = `DROP TABLE IF EXISTS ` + pq.QuoteIdentifier(t) + ` CASCADE` 402 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 403 return &database.Error{OrigErr: err, Query: []byte(query)} 404 } 405 } 406 } 407 408 return nil 409 } 410 411 // ensureVersionTable checks if versions table exists and, if not, creates it. 412 // Note that this function locks the database, which deviates from the usual 413 // convention of "caller locks" in the Postgres type. 414 func (p *Postgres) ensureVersionTable() (err error) { 415 if err = p.Lock(); err != nil { 416 return err 417 } 418 419 defer func() { 420 if e := p.Unlock(); e != nil { 421 if err == nil { 422 err = e 423 } else { 424 err = multierror.Append(err, e) 425 } 426 } 427 }() 428 429 // This block checks whether the `MigrationsTable` already exists. This is useful because it allows read only postgres 430 // users to also check the current version of the schema. Previously, even if `MigrationsTable` existed, the 431 // `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission. 432 // Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258 433 var count int 434 query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1` 435 row := p.conn.QueryRowContext(context.Background(), query, p.config.MigrationsTable) 436 437 err = row.Scan(&count) 438 if err != nil { 439 return &database.Error{OrigErr: err, Query: []byte(query)} 440 } 441 442 if count == 1 { 443 return nil 444 } 445 446 query = `CREATE TABLE IF NOT EXISTS ` + pq.QuoteIdentifier(p.config.MigrationsTable) + ` (version bigint not null primary key, dirty boolean not null)` 447 if _, err = p.conn.ExecContext(context.Background(), query); err != nil { 448 return &database.Error{OrigErr: err, Query: []byte(query)} 449 } 450 451 return nil 452 }