github.com/nagyist/migrate/v4@v4.14.6/database/postgres/postgres.go (about) 1 // +build go1.9 2 3 package postgres 4 5 import ( 6 "context" 7 "database/sql" 8 "fmt" 9 "io" 10 "io/ioutil" 11 nurl "net/url" 12 "strconv" 13 "strings" 14 "time" 15 16 multierror "github.com/hashicorp/go-multierror" 17 "github.com/lib/pq" 18 "github.com/nagyist/migrate/v4" 19 "github.com/nagyist/migrate/v4/database" 20 ) 21 22 func init() { 23 db := Postgres{} 24 database.Register("postgres", &db) 25 database.Register("postgresql", &db) 26 } 27 28 var DefaultMigrationsTable = "schema_migrations" 29 30 var ( 31 ErrNilConfig = fmt.Errorf("no config") 32 ErrNoDatabaseName = fmt.Errorf("no database name") 33 ErrNoSchema = fmt.Errorf("no schema") 34 ErrDatabaseDirty = fmt.Errorf("database is dirty") 35 ) 36 37 type Config struct { 38 MigrationsTableHasSchema bool 39 MigrationsTable string 40 DatabaseName string 41 SchemaName string 42 StatementTimeout time.Duration 43 } 44 45 type Postgres struct { 46 // Locking and unlocking need to use the same connection 47 conn *sql.Conn 48 db *sql.DB 49 isLocked bool 50 51 // Open and WithInstance need to guarantee that config is never nil 52 config *Config 53 } 54 55 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 56 if config == nil { 57 return nil, ErrNilConfig 58 } 59 60 if err := instance.Ping(); err != nil { 61 return nil, err 62 } 63 64 if config.DatabaseName == "" { 65 query := `SELECT CURRENT_DATABASE()` 66 var databaseName string 67 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 68 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 69 } 70 71 if len(databaseName) == 0 { 72 return nil, ErrNoDatabaseName 73 } 74 75 config.DatabaseName = databaseName 76 } 77 78 if config.SchemaName == "" { 79 query := `SELECT CURRENT_SCHEMA()` 80 var schemaName string 81 if err := instance.QueryRow(query).Scan(&schemaName); err != nil { 82 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 83 } 84 85 if len(schemaName) == 0 { 86 return nil, ErrNoSchema 87 } 88 89 config.SchemaName = schemaName 90 } 91 92 if len(config.MigrationsTable) == 0 { 93 config.MigrationsTable = DefaultMigrationsTable 94 } 95 96 conn, err := instance.Conn(context.Background()) 97 98 if err != nil { 99 return nil, err 100 } 101 102 px := &Postgres{ 103 conn: conn, 104 db: instance, 105 config: config, 106 } 107 108 if err := px.ensureVersionTable(); err != nil { 109 return nil, err 110 } 111 112 return px, nil 113 } 114 115 func (p *Postgres) Open(url string) (database.Driver, error) { 116 purl, err := nurl.Parse(url) 117 if err != nil { 118 return nil, err 119 } 120 121 db, err := sql.Open("postgres", migrate.FilterCustomQuery(purl).String()) 122 if err != nil { 123 return nil, err 124 } 125 126 migrationsTable := purl.Query().Get("x-migrations-table") 127 statementTimeoutString := purl.Query().Get("x-statement-timeout") 128 migrationsTableHasSchemaString := purl.Query().Get("x-migrations-table-has-schema") 129 statementTimeout := 0 130 if statementTimeoutString != "" { 131 statementTimeout, err = strconv.Atoi(statementTimeoutString) 132 if err != nil { 133 return nil, err 134 } 135 } 136 137 migrationsTableHasSchema := false 138 if migrationsTableHasSchemaString != "" { 139 migrationsTableHasSchema, err = strconv.ParseBool(migrationsTableHasSchemaString) 140 if err != nil { 141 return nil, err 142 } 143 } 144 145 px, err := WithInstance(db, &Config{ 146 DatabaseName: purl.Path, 147 MigrationsTableHasSchema: migrationsTableHasSchema, 148 MigrationsTable: migrationsTable, 149 StatementTimeout: time.Duration(statementTimeout) * time.Millisecond, 150 }) 151 152 if err != nil { 153 return nil, err 154 } 155 156 return px, nil 157 } 158 159 func (p *Postgres) Close() error { 160 connErr := p.conn.Close() 161 dbErr := p.db.Close() 162 if connErr != nil || dbErr != nil { 163 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 164 } 165 return nil 166 } 167 168 // https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS 169 func (p *Postgres) Lock() error { 170 if p.isLocked { 171 return database.ErrLocked 172 } 173 174 aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.SchemaName) 175 if err != nil { 176 return err 177 } 178 179 // This will wait indefinitely until the lock can be acquired. 180 query := `SELECT pg_advisory_lock($1)` 181 if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { 182 return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} 183 } 184 185 p.isLocked = true 186 return nil 187 } 188 189 func (p *Postgres) Unlock() error { 190 if !p.isLocked { 191 return nil 192 } 193 194 aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.SchemaName) 195 if err != nil { 196 return err 197 } 198 199 query := `SELECT pg_advisory_unlock($1)` 200 if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { 201 return &database.Error{OrigErr: err, Query: []byte(query)} 202 } 203 p.isLocked = false 204 return nil 205 } 206 207 func (p *Postgres) Run(migration io.Reader) error { 208 migr, err := ioutil.ReadAll(migration) 209 if err != nil { 210 return err 211 } 212 ctx := context.Background() 213 if p.config.StatementTimeout != 0 { 214 var cancel context.CancelFunc 215 ctx, cancel = context.WithTimeout(ctx, p.config.StatementTimeout) 216 defer cancel() 217 } 218 // run migration 219 query := string(migr[:]) 220 if _, err := p.conn.ExecContext(ctx, query); err != nil { 221 if pgErr, ok := err.(*pq.Error); ok { 222 var line uint 223 var col uint 224 var lineColOK bool 225 if pgErr.Position != "" { 226 if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil { 227 line, col, lineColOK = computeLineFromPos(query, int(pos)) 228 } 229 } 230 message := fmt.Sprintf("migration failed: %s", pgErr.Message) 231 if lineColOK { 232 message = fmt.Sprintf("%s (column %d)", message, col) 233 } 234 if pgErr.Detail != "" { 235 message = fmt.Sprintf("%s, %s", message, pgErr.Detail) 236 } 237 return database.Error{OrigErr: err, Err: message, Query: migr, Line: line} 238 } 239 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 240 } 241 242 return nil 243 } 244 245 func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) { 246 // replace crlf with lf 247 s = strings.Replace(s, "\r\n", "\n", -1) 248 // pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes 249 runes := []rune(s) 250 if pos > len(runes) { 251 return 0, 0, false 252 } 253 sel := runes[:pos] 254 line = uint(runesCount(sel, newLine) + 1) 255 col = uint(pos - 1 - runesLastIndex(sel, newLine)) 256 return line, col, true 257 } 258 259 const newLine = '\n' 260 261 func runesCount(input []rune, target rune) int { 262 var count int 263 for _, r := range input { 264 if r == target { 265 count++ 266 } 267 } 268 return count 269 } 270 271 func runesLastIndex(input []rune, target rune) int { 272 for i := len(input) - 1; i >= 0; i-- { 273 if input[i] == target { 274 return i 275 } 276 } 277 return -1 278 } 279 280 func (p *Postgres) quoteIdentifier(name string) string { 281 if p.config.MigrationsTableHasSchema { 282 firstDotPosition := strings.Index(name, ".") 283 if firstDotPosition != -1 { 284 return pq.QuoteIdentifier(name[0:firstDotPosition]) + "." + pq.QuoteIdentifier(name[firstDotPosition+1:]) 285 } 286 } 287 return pq.QuoteIdentifier(name) 288 } 289 290 func (p *Postgres) SetVersion(version int, dirty bool) error { 291 tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{}) 292 if err != nil { 293 return &database.Error{OrigErr: err, Err: "transaction start failed"} 294 } 295 296 query := `TRUNCATE ` + p.quoteIdentifier(p.config.MigrationsTable) 297 if _, err := tx.Exec(query); err != nil { 298 if errRollback := tx.Rollback(); errRollback != nil { 299 err = multierror.Append(err, errRollback) 300 } 301 return &database.Error{OrigErr: err, Query: []byte(query)} 302 } 303 304 // Also re-write the schema version for nil dirty versions to prevent 305 // empty schema version for failed down migration on the first migration 306 // See: https://github.com/golang-migrate/migrate/issues/330 307 if version >= 0 || (version == database.NilVersion && dirty) { 308 query = `INSERT INTO ` + p.quoteIdentifier(p.config.MigrationsTable) + 309 ` (version, dirty) VALUES ($1, $2)` 310 if _, err := tx.Exec(query, version, dirty); err != nil { 311 if errRollback := tx.Rollback(); errRollback != nil { 312 err = multierror.Append(err, errRollback) 313 } 314 return &database.Error{OrigErr: err, Query: []byte(query)} 315 } 316 } 317 318 if err := tx.Commit(); err != nil { 319 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 320 } 321 322 return nil 323 } 324 325 func (p *Postgres) Version() (version int, dirty bool, err error) { 326 query := `SELECT version, dirty FROM ` + p.quoteIdentifier(p.config.MigrationsTable) + ` LIMIT 1` 327 err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 328 switch { 329 case err == sql.ErrNoRows: 330 return database.NilVersion, false, nil 331 332 case err != nil: 333 if e, ok := err.(*pq.Error); ok { 334 if e.Code.Name() == "undefined_table" { 335 return database.NilVersion, false, nil 336 } 337 } 338 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 339 340 default: 341 return version, dirty, nil 342 } 343 } 344 345 func (p *Postgres) Drop() (err error) { 346 // select all tables in current schema 347 query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'` 348 tables, err := p.conn.QueryContext(context.Background(), query) 349 if err != nil { 350 return &database.Error{OrigErr: err, Query: []byte(query)} 351 } 352 defer func() { 353 if errClose := tables.Close(); errClose != nil { 354 err = multierror.Append(err, errClose) 355 } 356 }() 357 358 // delete one table after another 359 tableNames := make([]string, 0) 360 for tables.Next() { 361 var tableName string 362 if err := tables.Scan(&tableName); err != nil { 363 return err 364 } 365 if len(tableName) > 0 { 366 tableNames = append(tableNames, tableName) 367 } 368 } 369 if err := tables.Err(); err != nil { 370 return &database.Error{OrigErr: err, Query: []byte(query)} 371 } 372 373 if len(tableNames) > 0 { 374 // delete one by one ... 375 for _, t := range tableNames { 376 query = `DROP TABLE IF EXISTS ` + pq.QuoteIdentifier(t) + ` CASCADE` 377 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 378 return &database.Error{OrigErr: err, Query: []byte(query)} 379 } 380 } 381 } 382 383 return nil 384 } 385 386 // ensureVersionTable checks if versions table exists and, if not, creates it. 387 // Note that this function locks the database, which deviates from the usual 388 // convention of "caller locks" in the Postgres type. 389 func (p *Postgres) ensureVersionTable() (err error) { 390 if err = p.Lock(); err != nil { 391 return err 392 } 393 394 defer func() { 395 if e := p.Unlock(); e != nil { 396 if err == nil { 397 err = e 398 } else { 399 err = multierror.Append(err, e) 400 } 401 } 402 }() 403 404 query := `CREATE TABLE IF NOT EXISTS ` + p.quoteIdentifier(p.config.MigrationsTable) + ` (version bigint not null primary key, dirty boolean not null)` 405 if _, err = p.conn.ExecContext(context.Background(), query); err != nil { 406 return &database.Error{OrigErr: err, Query: []byte(query)} 407 } 408 409 return nil 410 }