github.com/fr-nvriep/migrate/v4@v4.3.2/database/postgres/postgres.go (about) 1 // +build go1.9 2 3 package postgres 4 5 import ( 6 "context" 7 "database/sql" 8 "fmt" 9 "io" 10 "io/ioutil" 11 nurl "net/url" 12 "strconv" 13 "strings" 14 15 "github.com/fr-nvriep/migrate/v4" 16 "github.com/fr-nvriep/migrate/v4/database" 17 multierror "github.com/hashicorp/go-multierror" 18 "github.com/lib/pq" 19 ) 20 21 func init() { 22 db := Postgres{} 23 database.Register("postgres", &db) 24 database.Register("postgresql", &db) 25 } 26 27 var DefaultMigrationsTable = "schema_migrations" 28 29 var ( 30 ErrNilConfig = fmt.Errorf("no config") 31 ErrNoDatabaseName = fmt.Errorf("no database name") 32 ErrNoSchema = fmt.Errorf("no schema") 33 ErrDatabaseDirty = fmt.Errorf("database is dirty") 34 ) 35 36 type Config struct { 37 MigrationsTable string 38 DatabaseName string 39 SchemaName string 40 } 41 42 type Postgres struct { 43 // Locking and unlocking need to use the same connection 44 conn *sql.Conn 45 db *sql.DB 46 isLocked bool 47 48 // Open and WithInstance need to guarantee that config is never nil 49 config *Config 50 } 51 52 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 53 if config == nil { 54 return nil, ErrNilConfig 55 } 56 57 if err := instance.Ping(); err != nil { 58 return nil, err 59 } 60 61 query := `SELECT CURRENT_DATABASE()` 62 var databaseName string 63 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 64 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 65 } 66 67 if len(databaseName) == 0 { 68 return nil, ErrNoDatabaseName 69 } 70 71 config.DatabaseName = databaseName 72 73 query = `SELECT CURRENT_SCHEMA()` 74 var schemaName string 75 if err := instance.QueryRow(query).Scan(&schemaName); err != nil { 76 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 77 } 78 79 if len(schemaName) == 0 { 80 return nil, ErrNoSchema 81 } 82 83 config.SchemaName = schemaName 84 85 if len(config.MigrationsTable) == 0 { 86 config.MigrationsTable = DefaultMigrationsTable 87 } 88 89 conn, err := instance.Conn(context.Background()) 90 91 if err != nil { 92 return nil, err 93 } 94 95 px := &Postgres{ 96 conn: conn, 97 db: instance, 98 config: config, 99 } 100 101 if err := px.ensureVersionTable(); err != nil { 102 return nil, err 103 } 104 105 return px, nil 106 } 107 108 func (p *Postgres) Open(url string) (database.Driver, error) { 109 purl, err := nurl.Parse(url) 110 if err != nil { 111 return nil, err 112 } 113 114 db, err := sql.Open("postgres", migrate.FilterCustomQuery(purl).String()) 115 if err != nil { 116 return nil, err 117 } 118 119 migrationsTable := purl.Query().Get("x-migrations-table") 120 121 px, err := WithInstance(db, &Config{ 122 DatabaseName: purl.Path, 123 MigrationsTable: migrationsTable, 124 }) 125 126 if err != nil { 127 return nil, err 128 } 129 130 return px, nil 131 } 132 133 func (p *Postgres) Close() error { 134 connErr := p.conn.Close() 135 dbErr := p.db.Close() 136 if connErr != nil || dbErr != nil { 137 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 138 } 139 return nil 140 } 141 142 // https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS 143 func (p *Postgres) Lock() error { 144 if p.isLocked { 145 return database.ErrLocked 146 } 147 148 aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.SchemaName) 149 if err != nil { 150 return err 151 } 152 153 // This will either obtain the lock immediately and return true, 154 // or return false if the lock cannot be acquired immediately. 155 query := `SELECT pg_advisory_lock($1)` 156 if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { 157 return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} 158 } 159 160 p.isLocked = true 161 return nil 162 } 163 164 func (p *Postgres) Unlock() error { 165 if !p.isLocked { 166 return nil 167 } 168 169 aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.SchemaName) 170 if err != nil { 171 return err 172 } 173 174 query := `SELECT pg_advisory_unlock($1)` 175 if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { 176 return &database.Error{OrigErr: err, Query: []byte(query)} 177 } 178 p.isLocked = false 179 return nil 180 } 181 182 func (p *Postgres) Run(migration io.Reader) error { 183 migr, err := ioutil.ReadAll(migration) 184 if err != nil { 185 return err 186 } 187 188 // run migration 189 query := string(migr[:]) 190 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 191 if pgErr, ok := err.(*pq.Error); ok { 192 var line uint 193 var col uint 194 var lineColOK bool 195 if pgErr.Position != "" { 196 if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil { 197 line, col, lineColOK = computeLineFromPos(query, int(pos)) 198 } 199 } 200 message := fmt.Sprintf("migration failed: %s", pgErr.Message) 201 if lineColOK { 202 message = fmt.Sprintf("%s (column %d)", message, col) 203 } 204 if pgErr.Detail != "" { 205 message = fmt.Sprintf("%s, %s", message, pgErr.Detail) 206 } 207 return database.Error{OrigErr: err, Err: message, Query: migr, Line: line} 208 } 209 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 210 } 211 212 return nil 213 } 214 215 func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) { 216 // replace crlf with lf 217 s = strings.Replace(s, "\r\n", "\n", -1) 218 // pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes 219 runes := []rune(s) 220 if pos > len(runes) { 221 return 0, 0, false 222 } 223 sel := runes[:pos] 224 line = uint(runesCount(sel, newLine) + 1) 225 col = uint(pos - 1 - runesLastIndex(sel, newLine)) 226 return line, col, true 227 } 228 229 const newLine = '\n' 230 231 func runesCount(input []rune, target rune) int { 232 var count int 233 for _, r := range input { 234 if r == target { 235 count++ 236 } 237 } 238 return count 239 } 240 241 func runesLastIndex(input []rune, target rune) int { 242 for i := len(input) - 1; i >= 0; i-- { 243 if input[i] == target { 244 return i 245 } 246 } 247 return -1 248 } 249 250 func (p *Postgres) SetVersion(version int, dirty bool) error { 251 tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{}) 252 if err != nil { 253 return &database.Error{OrigErr: err, Err: "transaction start failed"} 254 } 255 256 query := `TRUNCATE ` + pq.QuoteIdentifier(p.config.MigrationsTable) 257 if _, err := tx.Exec(query); err != nil { 258 if errRollback := tx.Rollback(); errRollback != nil { 259 err = multierror.Append(err, errRollback) 260 } 261 return &database.Error{OrigErr: err, Query: []byte(query)} 262 } 263 264 if version >= 0 { 265 query = `INSERT INTO ` + pq.QuoteIdentifier(p.config.MigrationsTable) + ` (version, dirty) VALUES ($1, $2)` 266 if _, err := tx.Exec(query, version, dirty); err != nil { 267 if errRollback := tx.Rollback(); errRollback != nil { 268 err = multierror.Append(err, errRollback) 269 } 270 return &database.Error{OrigErr: err, Query: []byte(query)} 271 } 272 } 273 274 if err := tx.Commit(); err != nil { 275 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 276 } 277 278 return nil 279 } 280 281 func (p *Postgres) Version() (version int, dirty bool, err error) { 282 query := `SELECT version, dirty FROM ` + pq.QuoteIdentifier(p.config.MigrationsTable) + ` LIMIT 1` 283 err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 284 switch { 285 case err == sql.ErrNoRows: 286 return database.NilVersion, false, nil 287 288 case err != nil: 289 if e, ok := err.(*pq.Error); ok { 290 if e.Code.Name() == "undefined_table" { 291 return database.NilVersion, false, nil 292 } 293 } 294 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 295 296 default: 297 return version, dirty, nil 298 } 299 } 300 301 func (p *Postgres) Drop() (err error) { 302 // select all tables in current schema 303 query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'` 304 tables, err := p.conn.QueryContext(context.Background(), query) 305 if err != nil { 306 return &database.Error{OrigErr: err, Query: []byte(query)} 307 } 308 defer func() { 309 if errClose := tables.Close(); errClose != nil { 310 err = multierror.Append(err, errClose) 311 } 312 }() 313 314 // delete one table after another 315 tableNames := make([]string, 0) 316 for tables.Next() { 317 var tableName string 318 if err := tables.Scan(&tableName); err != nil { 319 return err 320 } 321 if len(tableName) > 0 { 322 tableNames = append(tableNames, tableName) 323 } 324 } 325 326 if len(tableNames) > 0 { 327 // delete one by one ... 328 for _, t := range tableNames { 329 query = `DROP TABLE IF EXISTS ` + pq.QuoteIdentifier(t) + ` CASCADE` 330 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 331 return &database.Error{OrigErr: err, Query: []byte(query)} 332 } 333 } 334 } 335 336 return nil 337 } 338 339 // ensureVersionTable checks if versions table exists and, if not, creates it. 340 // Note that this function locks the database, which deviates from the usual 341 // convention of "caller locks" in the Postgres type. 342 func (p *Postgres) ensureVersionTable() (err error) { 343 if err = p.Lock(); err != nil { 344 return err 345 } 346 347 defer func() { 348 if e := p.Unlock(); e != nil { 349 if err == nil { 350 err = e 351 } else { 352 err = multierror.Append(err, e) 353 } 354 } 355 }() 356 357 // check if migration table exists 358 var count int 359 query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1` 360 if err := p.conn.QueryRowContext(context.Background(), query, p.config.MigrationsTable).Scan(&count); err != nil { 361 return &database.Error{OrigErr: err, Query: []byte(query)} 362 } 363 if count == 1 { 364 return nil 365 } 366 367 query = `CREATE TABLE ` + pq.QuoteIdentifier(p.config.MigrationsTable) + ` (version bigint not null primary key, dirty boolean not null)` 368 if _, err = p.conn.ExecContext(context.Background(), query); err != nil { 369 return &database.Error{OrigErr: err, Query: []byte(query)} 370 } 371 372 return nil 373 }