github.com/wreckingballstudiolabs/migrate@v3.5.4+incompatible/database/postgres/postgres.go (about) 1 // +build go1.9 2 3 package postgres 4 5 import ( 6 "context" 7 "database/sql" 8 "fmt" 9 "io" 10 "io/ioutil" 11 nurl "net/url" 12 "strconv" 13 "strings" 14 15 "github.com/golang-migrate/migrate" 16 "github.com/golang-migrate/migrate/database" 17 "github.com/lib/pq" 18 ) 19 20 func init() { 21 db := Postgres{} 22 database.Register("postgres", &db) 23 database.Register("postgresql", &db) 24 } 25 26 var DefaultMigrationsTable = "schema_migrations" 27 28 var ( 29 ErrNilConfig = fmt.Errorf("no config") 30 ErrNoDatabaseName = fmt.Errorf("no database name") 31 ErrNoSchema = fmt.Errorf("no schema") 32 ErrDatabaseDirty = fmt.Errorf("database is dirty") 33 ) 34 35 type Config struct { 36 MigrationsTable string 37 DatabaseName string 38 } 39 40 type Postgres struct { 41 // Locking and unlocking need to use the same connection 42 conn *sql.Conn 43 db *sql.DB 44 isLocked bool 45 46 // Open and WithInstance need to garantuee that config is never nil 47 config *Config 48 } 49 50 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 51 if config == nil { 52 return nil, ErrNilConfig 53 } 54 55 if err := instance.Ping(); err != nil { 56 return nil, err 57 } 58 59 query := `SELECT CURRENT_DATABASE()` 60 var databaseName string 61 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 62 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 63 } 64 65 if len(databaseName) == 0 { 66 return nil, ErrNoDatabaseName 67 } 68 69 config.DatabaseName = databaseName 70 71 if len(config.MigrationsTable) == 0 { 72 config.MigrationsTable = DefaultMigrationsTable 73 } 74 75 conn, err := instance.Conn(context.Background()) 76 77 if err != nil { 78 return nil, err 79 } 80 81 px := &Postgres{ 82 conn: conn, 83 db: instance, 84 config: config, 85 } 86 87 if err := px.ensureVersionTable(); err != nil { 88 return nil, err 89 } 90 91 return px, nil 92 } 93 94 func (p *Postgres) Open(url string) (database.Driver, error) { 95 purl, err := nurl.Parse(url) 96 if err != nil { 97 return nil, err 98 } 99 100 db, err := sql.Open("postgres", migrate.FilterCustomQuery(purl).String()) 101 if err != nil { 102 return nil, err 103 } 104 105 migrationsTable := purl.Query().Get("x-migrations-table") 106 if len(migrationsTable) == 0 { 107 migrationsTable = DefaultMigrationsTable 108 } 109 110 px, err := WithInstance(db, &Config{ 111 DatabaseName: purl.Path, 112 MigrationsTable: migrationsTable, 113 }) 114 if err != nil { 115 return nil, err 116 } 117 118 return px, nil 119 } 120 121 func (p *Postgres) Close() error { 122 connErr := p.conn.Close() 123 dbErr := p.db.Close() 124 if connErr != nil || dbErr != nil { 125 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 126 } 127 return nil 128 } 129 130 // https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS 131 func (p *Postgres) Lock() error { 132 if p.isLocked { 133 return database.ErrLocked 134 } 135 136 aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName) 137 if err != nil { 138 return err 139 } 140 141 // This will either obtain the lock immediately and return true, 142 // or return false if the lock cannot be acquired immediately. 143 query := `SELECT pg_advisory_lock($1)` 144 if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { 145 return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} 146 } 147 148 p.isLocked = true 149 return nil 150 } 151 152 func (p *Postgres) Unlock() error { 153 if !p.isLocked { 154 return nil 155 } 156 157 aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName) 158 if err != nil { 159 return err 160 } 161 162 query := `SELECT pg_advisory_unlock($1)` 163 if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { 164 return &database.Error{OrigErr: err, Query: []byte(query)} 165 } 166 p.isLocked = false 167 return nil 168 } 169 170 func (p *Postgres) Run(migration io.Reader) error { 171 migr, err := ioutil.ReadAll(migration) 172 if err != nil { 173 return err 174 } 175 176 // run migration 177 query := string(migr[:]) 178 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 179 if pgErr, ok := err.(*pq.Error); ok { 180 var line uint 181 var col uint 182 var lineColOK bool 183 if pgErr.Position != "" { 184 if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil { 185 line, col, lineColOK = computeLineFromPos(query, int(pos)) 186 } 187 } 188 message := fmt.Sprintf("migration failed: %s", pgErr.Message) 189 if lineColOK { 190 message = fmt.Sprintf("%s (column %d)", message, col) 191 } 192 if pgErr.Detail != "" { 193 message = fmt.Sprintf("%s, %s", message, pgErr.Detail) 194 } 195 return database.Error{OrigErr: err, Err: message, Query: migr, Line: line} 196 } 197 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 198 } 199 200 return nil 201 } 202 203 func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) { 204 // replace crlf with lf 205 s = strings.Replace(s, "\r\n", "\n", -1) 206 // pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes 207 runes := []rune(s) 208 if pos > len(runes) { 209 return 0, 0, false 210 } 211 sel := runes[:pos] 212 line = uint(runesCount(sel, newLine) + 1) 213 col = uint(pos - 1 - runesLastIndex(sel, newLine)) 214 return line, col, true 215 } 216 217 const newLine = '\n' 218 219 func runesCount(input []rune, target rune) int { 220 var count int 221 for _, r := range input { 222 if r == target { 223 count++ 224 } 225 } 226 return count 227 } 228 229 func runesLastIndex(input []rune, target rune) int { 230 for i := len(input) - 1; i >= 0; i-- { 231 if input[i] == target { 232 return i 233 } 234 } 235 return -1 236 } 237 238 func (p *Postgres) SetVersion(version int, dirty bool) error { 239 tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{}) 240 if err != nil { 241 return &database.Error{OrigErr: err, Err: "transaction start failed"} 242 } 243 244 query := `TRUNCATE "` + p.config.MigrationsTable + `"` 245 if _, err := tx.Exec(query); err != nil { 246 tx.Rollback() 247 return &database.Error{OrigErr: err, Query: []byte(query)} 248 } 249 250 if version >= 0 { 251 query = `INSERT INTO "` + p.config.MigrationsTable + `" (version, dirty) VALUES ($1, $2)` 252 if _, err := tx.Exec(query, version, dirty); err != nil { 253 tx.Rollback() 254 return &database.Error{OrigErr: err, Query: []byte(query)} 255 } 256 } 257 258 if err := tx.Commit(); err != nil { 259 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 260 } 261 262 return nil 263 } 264 265 func (p *Postgres) Version() (version int, dirty bool, err error) { 266 query := `SELECT version, dirty FROM "` + p.config.MigrationsTable + `" LIMIT 1` 267 err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 268 switch { 269 case err == sql.ErrNoRows: 270 return database.NilVersion, false, nil 271 272 case err != nil: 273 if e, ok := err.(*pq.Error); ok { 274 if e.Code.Name() == "undefined_table" { 275 return database.NilVersion, false, nil 276 } 277 } 278 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 279 280 default: 281 return version, dirty, nil 282 } 283 } 284 285 func (p *Postgres) Drop() error { 286 // select all tables in current schema 287 query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'` 288 tables, err := p.conn.QueryContext(context.Background(), query) 289 if err != nil { 290 return &database.Error{OrigErr: err, Query: []byte(query)} 291 } 292 defer tables.Close() 293 294 // delete one table after another 295 tableNames := make([]string, 0) 296 for tables.Next() { 297 var tableName string 298 if err := tables.Scan(&tableName); err != nil { 299 return err 300 } 301 if len(tableName) > 0 { 302 tableNames = append(tableNames, tableName) 303 } 304 } 305 306 if len(tableNames) > 0 { 307 // delete one by one ... 308 for _, t := range tableNames { 309 query = `DROP TABLE IF EXISTS ` + t + ` CASCADE` 310 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 311 return &database.Error{OrigErr: err, Query: []byte(query)} 312 } 313 } 314 if err := p.ensureVersionTable(); err != nil { 315 return err 316 } 317 } 318 319 return nil 320 } 321 322 func (p *Postgres) ensureVersionTable() error { 323 // check if migration table exists 324 var count int 325 query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1` 326 if err := p.conn.QueryRowContext(context.Background(), query, p.config.MigrationsTable).Scan(&count); err != nil { 327 return &database.Error{OrigErr: err, Query: []byte(query)} 328 } 329 if count == 1 { 330 return nil 331 } 332 333 // if not, create the empty migration table 334 query = `CREATE TABLE "` + p.config.MigrationsTable + `" (version bigint not null primary key, dirty boolean not null)` 335 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 336 return &database.Error{OrigErr: err, Query: []byte(query)} 337 } 338 return nil 339 }