github.com/matcornic/migrate@v3.3.2-0.20180717234201-feea45c20506+incompatible/database/postgres/postgres.go (about) 1 // +build go1.9 2 3 package postgres 4 5 import ( 6 "context" 7 "database/sql" 8 "fmt" 9 "io" 10 "io/ioutil" 11 nurl "net/url" 12 "strconv" 13 "strings" 14 15 "github.com/golang-migrate/migrate" 16 "github.com/golang-migrate/migrate/database" 17 "github.com/lib/pq" 18 ) 19 20 func init() { 21 db := Postgres{} 22 database.Register("postgres", &db) 23 database.Register("postgresql", &db) 24 } 25 26 var DefaultMigrationsTable = "schema_migrations" 27 28 var ( 29 ErrNilConfig = fmt.Errorf("no config") 30 ErrNoDatabaseName = fmt.Errorf("no database name") 31 ErrNoSchema = fmt.Errorf("no schema") 32 ErrDatabaseDirty = fmt.Errorf("database is dirty") 33 ) 34 35 type Config struct { 36 MigrationsTable string 37 DatabaseName string 38 } 39 40 type Postgres struct { 41 // Locking and unlocking need to use the same connection 42 conn *sql.Conn 43 isLocked bool 44 45 // Open and WithInstance need to garantuee that config is never nil 46 config *Config 47 } 48 49 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 50 if config == nil { 51 return nil, ErrNilConfig 52 } 53 54 if err := instance.Ping(); err != nil { 55 return nil, err 56 } 57 58 query := `SELECT CURRENT_DATABASE()` 59 var databaseName string 60 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 61 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 62 } 63 64 if len(databaseName) == 0 { 65 return nil, ErrNoDatabaseName 66 } 67 68 config.DatabaseName = databaseName 69 70 if len(config.MigrationsTable) == 0 { 71 config.MigrationsTable = DefaultMigrationsTable 72 } 73 74 conn, err := instance.Conn(context.Background()) 75 76 if err != nil { 77 return nil, err 78 } 79 80 px := &Postgres{ 81 conn: conn, 82 config: config, 83 } 84 85 if err := px.ensureVersionTable(); err != nil { 86 return nil, err 87 } 88 89 return px, nil 90 } 91 92 func (p *Postgres) Open(url string) (database.Driver, error) { 93 purl, err := nurl.Parse(url) 94 if err != nil { 95 return nil, err 96 } 97 98 db, err := sql.Open("postgres", migrate.FilterCustomQuery(purl).String()) 99 if err != nil { 100 return nil, err 101 } 102 103 migrationsTable := purl.Query().Get("x-migrations-table") 104 if len(migrationsTable) == 0 { 105 migrationsTable = DefaultMigrationsTable 106 } 107 108 px, err := WithInstance(db, &Config{ 109 DatabaseName: purl.Path, 110 MigrationsTable: migrationsTable, 111 }) 112 if err != nil { 113 return nil, err 114 } 115 116 return px, nil 117 } 118 119 func (p *Postgres) Close() error { 120 return p.conn.Close() 121 } 122 123 // https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS 124 func (p *Postgres) Lock() error { 125 if p.isLocked { 126 return database.ErrLocked 127 } 128 129 aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName) 130 if err != nil { 131 return err 132 } 133 134 // This will either obtain the lock immediately and return true, 135 // or return false if the lock cannot be acquired immediately. 136 query := `SELECT pg_advisory_lock($1)` 137 if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { 138 return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} 139 } 140 141 p.isLocked = true 142 return nil 143 } 144 145 func (p *Postgres) Unlock() error { 146 if !p.isLocked { 147 return nil 148 } 149 150 aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName) 151 if err != nil { 152 return err 153 } 154 155 query := `SELECT pg_advisory_unlock($1)` 156 if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { 157 return &database.Error{OrigErr: err, Query: []byte(query)} 158 } 159 p.isLocked = false 160 return nil 161 } 162 163 func (p *Postgres) Run(migration io.Reader) error { 164 migr, err := ioutil.ReadAll(migration) 165 if err != nil { 166 return err 167 } 168 169 // run migration 170 query := string(migr[:]) 171 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 172 if pgErr, ok := err.(*pq.Error); ok { 173 var line uint 174 var col uint 175 var lineColOK bool 176 if pgErr.Position != "" { 177 if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil { 178 line, col, lineColOK = computeLineFromPos(query, int(pos)) 179 } 180 } 181 message := fmt.Sprintf("migration failed: %s", pgErr.Message) 182 if lineColOK { 183 message = fmt.Sprintf("%s (column %d)", message, col) 184 } 185 if pgErr.Detail != "" { 186 message = fmt.Sprintf("%s, %s", message, pgErr.Detail) 187 } 188 return database.Error{OrigErr: err, Err: message, Query: migr, Line: line} 189 } 190 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 191 } 192 193 return nil 194 } 195 196 func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) { 197 // replace crlf with lf 198 s = strings.Replace(s, "\r\n", "\n", -1) 199 // pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes 200 runes := []rune(s) 201 if pos > len(runes) { 202 return 0, 0, false 203 } 204 sel := runes[:pos] 205 line = uint(runesCount(sel, newLine) + 1) 206 col = uint(pos - 1 - runesLastIndex(sel, newLine)) 207 return line, col, true 208 } 209 210 const newLine = '\n' 211 212 func runesCount(input []rune, target rune) int { 213 var count int 214 for _, r := range input { 215 if r == target { 216 count++ 217 } 218 } 219 return count 220 } 221 222 func runesLastIndex(input []rune, target rune) int { 223 for i := len(input) - 1; i >= 0; i-- { 224 if input[i] == target { 225 return i 226 } 227 } 228 return -1 229 } 230 231 func (p *Postgres) SetVersion(version int, dirty bool) error { 232 tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{}) 233 if err != nil { 234 return &database.Error{OrigErr: err, Err: "transaction start failed"} 235 } 236 237 query := `TRUNCATE "` + p.config.MigrationsTable + `"` 238 if _, err := tx.Exec(query); err != nil { 239 tx.Rollback() 240 return &database.Error{OrigErr: err, Query: []byte(query)} 241 } 242 243 if version >= 0 { 244 query = `INSERT INTO "` + p.config.MigrationsTable + `" (version, dirty) VALUES ($1, $2)` 245 if _, err := tx.Exec(query, version, dirty); err != nil { 246 tx.Rollback() 247 return &database.Error{OrigErr: err, Query: []byte(query)} 248 } 249 } 250 251 if err := tx.Commit(); err != nil { 252 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 253 } 254 255 return nil 256 } 257 258 func (p *Postgres) Version() (version int, dirty bool, err error) { 259 query := `SELECT version, dirty FROM "` + p.config.MigrationsTable + `" LIMIT 1` 260 err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 261 switch { 262 case err == sql.ErrNoRows: 263 return database.NilVersion, false, nil 264 265 case err != nil: 266 if e, ok := err.(*pq.Error); ok { 267 if e.Code.Name() == "undefined_table" { 268 return database.NilVersion, false, nil 269 } 270 } 271 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 272 273 default: 274 return version, dirty, nil 275 } 276 } 277 278 func (p *Postgres) Drop() error { 279 // select all tables in current schema 280 query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema())` 281 tables, err := p.conn.QueryContext(context.Background(), query) 282 if err != nil { 283 return &database.Error{OrigErr: err, Query: []byte(query)} 284 } 285 defer tables.Close() 286 287 // delete one table after another 288 tableNames := make([]string, 0) 289 for tables.Next() { 290 var tableName string 291 if err := tables.Scan(&tableName); err != nil { 292 return err 293 } 294 if len(tableName) > 0 { 295 tableNames = append(tableNames, tableName) 296 } 297 } 298 299 if len(tableNames) > 0 { 300 // delete one by one ... 301 for _, t := range tableNames { 302 query = `DROP TABLE IF EXISTS ` + t + ` CASCADE` 303 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 304 return &database.Error{OrigErr: err, Query: []byte(query)} 305 } 306 } 307 if err := p.ensureVersionTable(); err != nil { 308 return err 309 } 310 } 311 312 return nil 313 } 314 315 func (p *Postgres) ensureVersionTable() error { 316 // check if migration table exists 317 var count int 318 query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1` 319 if err := p.conn.QueryRowContext(context.Background(), query, p.config.MigrationsTable).Scan(&count); err != nil { 320 return &database.Error{OrigErr: err, Query: []byte(query)} 321 } 322 if count == 1 { 323 return nil 324 } 325 326 // if not, create the empty migration table 327 query = `CREATE TABLE "` + p.config.MigrationsTable + `" (version bigint not null primary key, dirty boolean not null)` 328 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 329 return &database.Error{OrigErr: err, Query: []byte(query)} 330 } 331 return nil 332 }