github.com/tooolbox/migrate/v4@v4.6.2-0.20200325001913-461b03b92064/database/postgres/postgres.go (about) 1 // +build go1.9 2 3 package postgres 4 5 import ( 6 "context" 7 "database/sql" 8 "fmt" 9 "io" 10 "io/ioutil" 11 nurl "net/url" 12 "strconv" 13 "strings" 14 "time" 15 16 multierror "github.com/hashicorp/go-multierror" 17 "github.com/lib/pq" 18 "github.com/tooolbox/migrate/v4" 19 "github.com/tooolbox/migrate/v4/database" 20 ) 21 22 func init() { 23 db := Postgres{} 24 database.Register("postgres", &db) 25 database.Register("postgresql", &db) 26 } 27 28 var DefaultMigrationsTable = "schema_migrations" 29 30 var ( 31 ErrNilConfig = fmt.Errorf("no config") 32 ErrNoDatabaseName = fmt.Errorf("no database name") 33 ErrNoSchema = fmt.Errorf("no schema") 34 ErrDatabaseDirty = fmt.Errorf("database is dirty") 35 ) 36 37 type Config struct { 38 MigrationsTable string 39 DatabaseName string 40 SchemaName string 41 StatementTimeout time.Duration 42 } 43 44 type Postgres struct { 45 // Locking and unlocking need to use the same connection 46 conn *sql.Conn 47 db *sql.DB 48 isLocked bool 49 50 // Open and WithInstance need to guarantee that config is never nil 51 config *Config 52 } 53 54 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 55 if config == nil { 56 return nil, ErrNilConfig 57 } 58 59 if err := instance.Ping(); err != nil { 60 return nil, err 61 } 62 63 if config.DatabaseName == "" { 64 query := `SELECT CURRENT_DATABASE()` 65 var databaseName string 66 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 67 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 68 } 69 70 if len(databaseName) == 0 { 71 return nil, ErrNoDatabaseName 72 } 73 74 config.DatabaseName = databaseName 75 } 76 77 if config.SchemaName == "" { 78 query := `SELECT CURRENT_SCHEMA()` 79 var schemaName string 80 if err := instance.QueryRow(query).Scan(&schemaName); err != nil { 81 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 82 } 83 84 if len(schemaName) == 0 { 85 return nil, ErrNoSchema 86 } 87 88 config.SchemaName = schemaName 89 } 90 91 if len(config.MigrationsTable) == 0 { 92 config.MigrationsTable = DefaultMigrationsTable 93 } 94 95 conn, err := instance.Conn(context.Background()) 96 97 if err != nil { 98 return nil, err 99 } 100 101 px := &Postgres{ 102 conn: conn, 103 db: instance, 104 config: config, 105 } 106 107 if err := px.ensureVersionTable(); err != nil { 108 return nil, err 109 } 110 111 return px, nil 112 } 113 114 func (p *Postgres) Open(url string) (database.Driver, error) { 115 purl, err := nurl.Parse(url) 116 if err != nil { 117 return nil, err 118 } 119 120 db, err := sql.Open("postgres", migrate.FilterCustomQuery(purl).String()) 121 if err != nil { 122 return nil, err 123 } 124 125 migrationsTable := purl.Query().Get("x-migrations-table") 126 statementTimeoutString := purl.Query().Get("x-statement-timeout") 127 statementTimeout := 0 128 if statementTimeoutString != "" { 129 statementTimeout, err = strconv.Atoi(statementTimeoutString) 130 if err != nil { 131 return nil, err 132 } 133 } 134 135 px, err := WithInstance(db, &Config{ 136 DatabaseName: purl.Path, 137 MigrationsTable: migrationsTable, 138 StatementTimeout: time.Duration(statementTimeout) * time.Millisecond, 139 }) 140 141 if err != nil { 142 return nil, err 143 } 144 145 return px, nil 146 } 147 148 func (p *Postgres) Close() error { 149 connErr := p.conn.Close() 150 dbErr := p.db.Close() 151 if connErr != nil || dbErr != nil { 152 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 153 } 154 return nil 155 } 156 157 // https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS 158 func (p *Postgres) Lock() error { 159 if p.isLocked { 160 return database.ErrLocked 161 } 162 163 aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.SchemaName) 164 if err != nil { 165 return err 166 } 167 168 // This will wait indefinitely until the lock can be acquired. 169 query := `SELECT pg_advisory_lock($1)` 170 if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { 171 return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} 172 } 173 174 p.isLocked = true 175 return nil 176 } 177 178 func (p *Postgres) Unlock() error { 179 if !p.isLocked { 180 return nil 181 } 182 183 aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.SchemaName) 184 if err != nil { 185 return err 186 } 187 188 query := `SELECT pg_advisory_unlock($1)` 189 if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { 190 return &database.Error{OrigErr: err, Query: []byte(query)} 191 } 192 p.isLocked = false 193 return nil 194 } 195 196 func (p *Postgres) Run(migration io.Reader) error { 197 migr, err := ioutil.ReadAll(migration) 198 if err != nil { 199 return err 200 } 201 ctx := context.Background() 202 if p.config.StatementTimeout != 0 { 203 var cancel context.CancelFunc 204 ctx, cancel = context.WithTimeout(ctx, p.config.StatementTimeout) 205 defer cancel() 206 } 207 // run migration 208 query := string(migr[:]) 209 if _, err := p.conn.ExecContext(ctx, query); err != nil { 210 if pgErr, ok := err.(*pq.Error); ok { 211 var line uint 212 var col uint 213 var lineColOK bool 214 if pgErr.Position != "" { 215 if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil { 216 line, col, lineColOK = computeLineFromPos(query, int(pos)) 217 } 218 } 219 message := fmt.Sprintf("migration failed: %s", pgErr.Message) 220 if lineColOK { 221 message = fmt.Sprintf("%s (column %d)", message, col) 222 } 223 if pgErr.Detail != "" { 224 message = fmt.Sprintf("%s, %s", message, pgErr.Detail) 225 } 226 return database.Error{OrigErr: err, Err: message, Query: migr, Line: line} 227 } 228 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 229 } 230 231 return nil 232 } 233 234 func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) { 235 // replace crlf with lf 236 s = strings.Replace(s, "\r\n", "\n", -1) 237 // pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes 238 runes := []rune(s) 239 if pos > len(runes) { 240 return 0, 0, false 241 } 242 sel := runes[:pos] 243 line = uint(runesCount(sel, newLine) + 1) 244 col = uint(pos - 1 - runesLastIndex(sel, newLine)) 245 return line, col, true 246 } 247 248 const newLine = '\n' 249 250 func runesCount(input []rune, target rune) int { 251 var count int 252 for _, r := range input { 253 if r == target { 254 count++ 255 } 256 } 257 return count 258 } 259 260 func runesLastIndex(input []rune, target rune) int { 261 for i := len(input) - 1; i >= 0; i-- { 262 if input[i] == target { 263 return i 264 } 265 } 266 return -1 267 } 268 269 func (p *Postgres) SetVersion(version int, dirty bool) error { 270 tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{}) 271 if err != nil { 272 return &database.Error{OrigErr: err, Err: "transaction start failed"} 273 } 274 275 query := `TRUNCATE ` + pq.QuoteIdentifier(p.config.MigrationsTable) 276 if _, err := tx.Exec(query); err != nil { 277 if errRollback := tx.Rollback(); errRollback != nil { 278 err = multierror.Append(err, errRollback) 279 } 280 return &database.Error{OrigErr: err, Query: []byte(query)} 281 } 282 283 if version >= 0 { 284 query = `INSERT INTO ` + pq.QuoteIdentifier(p.config.MigrationsTable) + ` (version, dirty) VALUES ($1, $2)` 285 if _, err := tx.Exec(query, version, dirty); err != nil { 286 if errRollback := tx.Rollback(); errRollback != nil { 287 err = multierror.Append(err, errRollback) 288 } 289 return &database.Error{OrigErr: err, Query: []byte(query)} 290 } 291 } 292 293 if err := tx.Commit(); err != nil { 294 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 295 } 296 297 return nil 298 } 299 300 func (p *Postgres) Version() (version int, dirty bool, err error) { 301 query := `SELECT version, dirty FROM ` + pq.QuoteIdentifier(p.config.MigrationsTable) + ` LIMIT 1` 302 err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 303 switch { 304 case err == sql.ErrNoRows: 305 return database.NilVersion, false, nil 306 307 case err != nil: 308 if e, ok := err.(*pq.Error); ok { 309 if e.Code.Name() == "undefined_table" { 310 return database.NilVersion, false, nil 311 } 312 } 313 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 314 315 default: 316 return version, dirty, nil 317 } 318 } 319 320 func (p *Postgres) Drop() (err error) { 321 // select all tables in current schema 322 query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'` 323 tables, err := p.conn.QueryContext(context.Background(), query) 324 if err != nil { 325 return &database.Error{OrigErr: err, Query: []byte(query)} 326 } 327 defer func() { 328 if errClose := tables.Close(); errClose != nil { 329 err = multierror.Append(err, errClose) 330 } 331 }() 332 333 // delete one table after another 334 tableNames := make([]string, 0) 335 for tables.Next() { 336 var tableName string 337 if err := tables.Scan(&tableName); err != nil { 338 return err 339 } 340 if len(tableName) > 0 { 341 tableNames = append(tableNames, tableName) 342 } 343 } 344 345 if len(tableNames) > 0 { 346 // delete one by one ... 347 for _, t := range tableNames { 348 query = `DROP TABLE IF EXISTS ` + pq.QuoteIdentifier(t) + ` CASCADE` 349 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 350 return &database.Error{OrigErr: err, Query: []byte(query)} 351 } 352 } 353 } 354 355 return nil 356 } 357 358 // ensureVersionTable checks if versions table exists and, if not, creates it. 359 // Note that this function locks the database, which deviates from the usual 360 // convention of "caller locks" in the Postgres type. 361 func (p *Postgres) ensureVersionTable() (err error) { 362 if err = p.Lock(); err != nil { 363 return err 364 } 365 366 defer func() { 367 if e := p.Unlock(); e != nil { 368 if err == nil { 369 err = e 370 } else { 371 err = multierror.Append(err, e) 372 } 373 } 374 }() 375 376 query := `CREATE TABLE IF NOT EXISTS ` + pq.QuoteIdentifier(p.config.MigrationsTable) + ` (version bigint not null primary key, dirty boolean not null)` 377 if _, err = p.conn.ExecContext(context.Background(), query); err != nil { 378 return &database.Error{OrigErr: err, Query: []byte(query)} 379 } 380 381 return nil 382 }