github.com/qdentity/migrate@v3.3.0+incompatible/database/postgres/postgres.go (about) 1 // +build go1.9 2 3 package postgres 4 5 import ( 6 "database/sql" 7 "fmt" 8 "io" 9 "io/ioutil" 10 nurl "net/url" 11 12 "context" 13 "github.com/golang-migrate/migrate" 14 "github.com/golang-migrate/migrate/database" 15 "github.com/lib/pq" 16 ) 17 18 func init() { 19 db := Postgres{} 20 database.Register("postgres", &db) 21 database.Register("postgresql", &db) 22 } 23 24 var DefaultMigrationsTable = "schema_migrations" 25 26 var ( 27 ErrNilConfig = fmt.Errorf("no config") 28 ErrNoDatabaseName = fmt.Errorf("no database name") 29 ErrNoSchema = fmt.Errorf("no schema") 30 ErrDatabaseDirty = fmt.Errorf("database is dirty") 31 ) 32 33 type Config struct { 34 MigrationsTable string 35 DatabaseName string 36 } 37 38 type Postgres struct { 39 // Locking and unlocking need to use the same connection 40 conn *sql.Conn 41 isLocked bool 42 43 // Open and WithInstance need to garantuee that config is never nil 44 config *Config 45 } 46 47 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 48 if config == nil { 49 return nil, ErrNilConfig 50 } 51 52 if err := instance.Ping(); err != nil { 53 return nil, err 54 } 55 56 query := `SELECT CURRENT_DATABASE()` 57 var databaseName string 58 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 59 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 60 } 61 62 if len(databaseName) == 0 { 63 return nil, ErrNoDatabaseName 64 } 65 66 config.DatabaseName = databaseName 67 68 if len(config.MigrationsTable) == 0 { 69 config.MigrationsTable = DefaultMigrationsTable 70 } 71 72 conn, err := instance.Conn(context.Background()) 73 74 if err != nil { 75 return nil, err 76 } 77 78 px := &Postgres{ 79 conn: conn, 80 config: config, 81 } 82 83 if err := px.ensureVersionTable(); err != nil { 84 return nil, err 85 } 86 87 return px, nil 88 } 89 90 func (p *Postgres) Open(url string) (database.Driver, error) { 91 purl, err := nurl.Parse(url) 92 if err != nil { 93 return nil, err 94 } 95 96 db, err := sql.Open("postgres", migrate.FilterCustomQuery(purl).String()) 97 if err != nil { 98 return nil, err 99 } 100 101 migrationsTable := purl.Query().Get("x-migrations-table") 102 if len(migrationsTable) == 0 { 103 migrationsTable = DefaultMigrationsTable 104 } 105 106 px, err := WithInstance(db, &Config{ 107 DatabaseName: purl.Path, 108 MigrationsTable: migrationsTable, 109 }) 110 if err != nil { 111 return nil, err 112 } 113 114 return px, nil 115 } 116 117 func (p *Postgres) Close() error { 118 return p.conn.Close() 119 } 120 121 // https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS 122 func (p *Postgres) Lock() error { 123 if p.isLocked { 124 return database.ErrLocked 125 } 126 127 aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName) 128 if err != nil { 129 return err 130 } 131 132 // This will either obtain the lock immediately and return true, 133 // or return false if the lock cannot be acquired immediately. 134 query := `SELECT pg_advisory_lock($1)` 135 if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { 136 return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} 137 } 138 139 p.isLocked = true 140 return nil 141 } 142 143 func (p *Postgres) Unlock() error { 144 if !p.isLocked { 145 return nil 146 } 147 148 aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName) 149 if err != nil { 150 return err 151 } 152 153 query := `SELECT pg_advisory_unlock($1)` 154 if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { 155 return &database.Error{OrigErr: err, Query: []byte(query)} 156 } 157 p.isLocked = false 158 return nil 159 } 160 161 func (p *Postgres) Run(migration io.Reader) error { 162 migr, err := ioutil.ReadAll(migration) 163 if err != nil { 164 return err 165 } 166 167 // run migration 168 query := string(migr[:]) 169 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 170 // TODO: cast to postgress error and get line number 171 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 172 } 173 174 return nil 175 } 176 177 func (p *Postgres) SetVersion(version int, dirty bool) error { 178 tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{}) 179 if err != nil { 180 return &database.Error{OrigErr: err, Err: "transaction start failed"} 181 } 182 183 query := `TRUNCATE "` + p.config.MigrationsTable + `"` 184 if _, err := tx.Exec(query); err != nil { 185 tx.Rollback() 186 return &database.Error{OrigErr: err, Query: []byte(query)} 187 } 188 189 if version >= 0 { 190 query = `INSERT INTO "` + p.config.MigrationsTable + `" (version, dirty) VALUES ($1, $2)` 191 if _, err := tx.Exec(query, version, dirty); err != nil { 192 tx.Rollback() 193 return &database.Error{OrigErr: err, Query: []byte(query)} 194 } 195 } 196 197 if err := tx.Commit(); err != nil { 198 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 199 } 200 201 return nil 202 } 203 204 func (p *Postgres) Version() (version int, dirty bool, err error) { 205 query := `SELECT version, dirty FROM "` + p.config.MigrationsTable + `" LIMIT 1` 206 err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 207 switch { 208 case err == sql.ErrNoRows: 209 return database.NilVersion, false, nil 210 211 case err != nil: 212 if e, ok := err.(*pq.Error); ok { 213 if e.Code.Name() == "undefined_table" { 214 return database.NilVersion, false, nil 215 } 216 } 217 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 218 219 default: 220 return version, dirty, nil 221 } 222 } 223 224 func (p *Postgres) Drop() error { 225 // select all tables in current schema 226 query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema())` 227 tables, err := p.conn.QueryContext(context.Background(), query) 228 if err != nil { 229 return &database.Error{OrigErr: err, Query: []byte(query)} 230 } 231 defer tables.Close() 232 233 // delete one table after another 234 tableNames := make([]string, 0) 235 for tables.Next() { 236 var tableName string 237 if err := tables.Scan(&tableName); err != nil { 238 return err 239 } 240 if len(tableName) > 0 { 241 tableNames = append(tableNames, tableName) 242 } 243 } 244 245 if len(tableNames) > 0 { 246 // delete one by one ... 247 for _, t := range tableNames { 248 query = `DROP TABLE IF EXISTS ` + t + ` CASCADE` 249 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 250 return &database.Error{OrigErr: err, Query: []byte(query)} 251 } 252 } 253 if err := p.ensureVersionTable(); err != nil { 254 return err 255 } 256 } 257 258 return nil 259 } 260 261 func (p *Postgres) ensureVersionTable() error { 262 // check if migration table exists 263 var count int 264 query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1` 265 if err := p.conn.QueryRowContext(context.Background(), query, p.config.MigrationsTable).Scan(&count); err != nil { 266 return &database.Error{OrigErr: err, Query: []byte(query)} 267 } 268 if count == 1 { 269 return nil 270 } 271 272 // if not, create the empty migration table 273 query = `CREATE TABLE "` + p.config.MigrationsTable + `" (version bigint not null primary key, dirty boolean not null)` 274 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 275 return &database.Error{OrigErr: err, Query: []byte(query)} 276 } 277 return nil 278 }