github.com/solvedata/migrate/v4@v4.8.7-0.20201127053940-c9fba4ce569f/database/postgres/postgres.go (about) 1 // +build go1.9 2 3 package postgres 4 5 import ( 6 "context" 7 "database/sql" 8 "fmt" 9 "io" 10 "io/ioutil" 11 nurl "net/url" 12 "strconv" 13 "strings" 14 15 "github.com/solvedata/migrate/v4" 16 "github.com/solvedata/migrate/v4/database" 17 multierror "github.com/hashicorp/go-multierror" 18 "github.com/lib/pq" 19 ) 20 21 func init() { 22 db := Postgres{} 23 database.Register("postgres", &db) 24 database.Register("postgresql", &db) 25 } 26 27 var DefaultMigrationsTable = "schema_migrations" 28 29 var ( 30 ErrNilConfig = fmt.Errorf("no config") 31 ErrNoDatabaseName = fmt.Errorf("no database name") 32 ErrNoSchema = fmt.Errorf("no schema") 33 ErrDatabaseDirty = fmt.Errorf("database is dirty") 34 ) 35 36 type Config struct { 37 MigrationsTable string 38 DatabaseName string 39 SchemaName string 40 } 41 42 type Postgres struct { 43 // Locking and unlocking need to use the same connection 44 conn *sql.Conn 45 db *sql.DB 46 isLocked bool 47 48 // Open and WithInstance need to guarantee that config is never nil 49 config *Config 50 } 51 52 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 53 if config == nil { 54 return nil, ErrNilConfig 55 } 56 57 if err := instance.Ping(); err != nil { 58 return nil, err 59 } 60 61 if config.DatabaseName == "" { 62 query := `SELECT CURRENT_DATABASE()` 63 var databaseName string 64 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 65 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 66 } 67 68 if len(databaseName) == 0 { 69 return nil, ErrNoDatabaseName 70 } 71 72 config.DatabaseName = databaseName 73 } 74 75 if config.SchemaName == "" { 76 query := `SELECT CURRENT_SCHEMA()` 77 var schemaName string 78 if err := instance.QueryRow(query).Scan(&schemaName); err != nil { 79 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 80 } 81 82 if len(schemaName) == 0 { 83 return nil, ErrNoSchema 84 } 85 86 config.SchemaName = schemaName 87 } 88 89 if len(config.MigrationsTable) == 0 { 90 config.MigrationsTable = DefaultMigrationsTable 91 } 92 93 conn, err := instance.Conn(context.Background()) 94 95 if err != nil { 96 return nil, err 97 } 98 99 px := &Postgres{ 100 conn: conn, 101 db: instance, 102 config: config, 103 } 104 105 if err := px.ensureVersionTable(); err != nil { 106 return nil, err 107 } 108 109 return px, nil 110 } 111 112 func (p *Postgres) Open(url string) (database.Driver, error) { 113 purl, err := nurl.Parse(url) 114 if err != nil { 115 return nil, err 116 } 117 118 db, err := sql.Open("postgres", migrate.FilterCustomQuery(purl).String()) 119 if err != nil { 120 return nil, err 121 } 122 123 migrationsTable := purl.Query().Get("x-migrations-table") 124 125 px, err := WithInstance(db, &Config{ 126 DatabaseName: purl.Path, 127 MigrationsTable: migrationsTable, 128 }) 129 130 if err != nil { 131 return nil, err 132 } 133 134 return px, nil 135 } 136 137 func (p *Postgres) Close() error { 138 connErr := p.conn.Close() 139 dbErr := p.db.Close() 140 if connErr != nil || dbErr != nil { 141 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 142 } 143 return nil 144 } 145 146 // https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS 147 func (p *Postgres) Lock() error { 148 if p.isLocked { 149 return database.ErrLocked 150 } 151 152 aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.SchemaName) 153 if err != nil { 154 return err 155 } 156 157 // This will wait indefinitely until the lock can be acquired. 158 query := `SELECT pg_advisory_lock($1)` 159 if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { 160 return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} 161 } 162 163 p.isLocked = true 164 return nil 165 } 166 167 func (p *Postgres) Unlock() error { 168 if !p.isLocked { 169 return nil 170 } 171 172 aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.SchemaName) 173 if err != nil { 174 return err 175 } 176 177 query := `SELECT pg_advisory_unlock($1)` 178 if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { 179 return &database.Error{OrigErr: err, Query: []byte(query)} 180 } 181 p.isLocked = false 182 return nil 183 } 184 185 func (p *Postgres) Run(migration io.Reader) error { 186 migr, err := ioutil.ReadAll(migration) 187 if err != nil { 188 return err 189 } 190 191 // run migration 192 query := string(migr[:]) 193 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 194 if pgErr, ok := err.(*pq.Error); ok { 195 var line uint 196 var col uint 197 var lineColOK bool 198 if pgErr.Position != "" { 199 if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil { 200 line, col, lineColOK = computeLineFromPos(query, int(pos)) 201 } 202 } 203 message := fmt.Sprintf("migration failed: %s", pgErr.Message) 204 if lineColOK { 205 message = fmt.Sprintf("%s (column %d)", message, col) 206 } 207 if pgErr.Detail != "" { 208 message = fmt.Sprintf("%s, %s", message, pgErr.Detail) 209 } 210 return database.Error{OrigErr: err, Err: message, Query: migr, Line: line} 211 } 212 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 213 } 214 215 return nil 216 } 217 218 func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) { 219 // replace crlf with lf 220 s = strings.Replace(s, "\r\n", "\n", -1) 221 // pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes 222 runes := []rune(s) 223 if pos > len(runes) { 224 return 0, 0, false 225 } 226 sel := runes[:pos] 227 line = uint(runesCount(sel, newLine) + 1) 228 col = uint(pos - 1 - runesLastIndex(sel, newLine)) 229 return line, col, true 230 } 231 232 const newLine = '\n' 233 234 func runesCount(input []rune, target rune) int { 235 var count int 236 for _, r := range input { 237 if r == target { 238 count++ 239 } 240 } 241 return count 242 } 243 244 func runesLastIndex(input []rune, target rune) int { 245 for i := len(input) - 1; i >= 0; i-- { 246 if input[i] == target { 247 return i 248 } 249 } 250 return -1 251 } 252 253 func (p *Postgres) SetVersion(version int, dirty bool) error { 254 tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{}) 255 if err != nil { 256 return &database.Error{OrigErr: err, Err: "transaction start failed"} 257 } 258 259 query := `TRUNCATE ` + pq.QuoteIdentifier(p.config.MigrationsTable) 260 if _, err := tx.Exec(query); err != nil { 261 if errRollback := tx.Rollback(); errRollback != nil { 262 err = multierror.Append(err, errRollback) 263 } 264 return &database.Error{OrigErr: err, Query: []byte(query)} 265 } 266 267 if version >= 0 { 268 query = `INSERT INTO ` + pq.QuoteIdentifier(p.config.MigrationsTable) + ` (version, dirty) VALUES ($1, $2)` 269 if _, err := tx.Exec(query, version, dirty); err != nil { 270 if errRollback := tx.Rollback(); errRollback != nil { 271 err = multierror.Append(err, errRollback) 272 } 273 return &database.Error{OrigErr: err, Query: []byte(query)} 274 } 275 } 276 277 if err := tx.Commit(); err != nil { 278 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 279 } 280 281 return nil 282 } 283 284 func (p *Postgres) Version() (version int, dirty bool, err error) { 285 query := `SELECT version, dirty FROM ` + pq.QuoteIdentifier(p.config.MigrationsTable) + ` LIMIT 1` 286 err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 287 switch { 288 case err == sql.ErrNoRows: 289 return database.NilVersion, false, nil 290 291 case err != nil: 292 if e, ok := err.(*pq.Error); ok { 293 if e.Code.Name() == "undefined_table" { 294 return database.NilVersion, false, nil 295 } 296 } 297 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 298 299 default: 300 return version, dirty, nil 301 } 302 } 303 304 func (p *Postgres) Drop() (err error) { 305 // select all tables in current schema 306 query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'` 307 tables, err := p.conn.QueryContext(context.Background(), query) 308 if err != nil { 309 return &database.Error{OrigErr: err, Query: []byte(query)} 310 } 311 defer func() { 312 if errClose := tables.Close(); errClose != nil { 313 err = multierror.Append(err, errClose) 314 } 315 }() 316 317 // delete one table after another 318 tableNames := make([]string, 0) 319 for tables.Next() { 320 var tableName string 321 if err := tables.Scan(&tableName); err != nil { 322 return err 323 } 324 if len(tableName) > 0 { 325 tableNames = append(tableNames, tableName) 326 } 327 } 328 329 if len(tableNames) > 0 { 330 // delete one by one ... 331 for _, t := range tableNames { 332 query = `DROP TABLE IF EXISTS ` + pq.QuoteIdentifier(t) + ` CASCADE` 333 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 334 return &database.Error{OrigErr: err, Query: []byte(query)} 335 } 336 } 337 } 338 339 return nil 340 } 341 342 // ensureVersionTable checks if versions table exists and, if not, creates it. 343 // Note that this function locks the database, which deviates from the usual 344 // convention of "caller locks" in the Postgres type. 345 func (p *Postgres) ensureVersionTable() (err error) { 346 if err = p.Lock(); err != nil { 347 return err 348 } 349 350 defer func() { 351 if e := p.Unlock(); e != nil { 352 if err == nil { 353 err = e 354 } else { 355 err = multierror.Append(err, e) 356 } 357 } 358 }() 359 360 query := `CREATE TABLE IF NOT EXISTS ` + pq.QuoteIdentifier(p.config.MigrationsTable) + ` (version bigint not null primary key, dirty boolean not null)` 361 if _, err = p.conn.ExecContext(context.Background(), query); err != nil { 362 return &database.Error{OrigErr: err, Query: []byte(query)} 363 } 364 365 return nil 366 }