github.com/eatigo/migrate@v3.0.2-0.20210729130915-7610befb1b6b+incompatible/database/postgres/postgres.go (about) 1 package postgres 2 3 import ( 4 "database/sql" 5 "fmt" 6 "io" 7 "io/ioutil" 8 nurl "net/url" 9 10 "github.com/lib/pq" 11 "github.com/eatigo/migrate" 12 "github.com/eatigo/migrate/database" 13 ) 14 15 func init() { 16 db := Postgres{} 17 database.Register("postgres", &db) 18 database.Register("postgresql", &db) 19 } 20 21 var DefaultMigrationsTable = "schema_migrations" 22 23 var ( 24 ErrNilConfig = fmt.Errorf("no config") 25 ErrNoDatabaseName = fmt.Errorf("no database name") 26 ErrNoSchema = fmt.Errorf("no schema") 27 ErrDatabaseDirty = fmt.Errorf("database is dirty") 28 ) 29 30 type Config struct { 31 MigrationsTable string 32 DatabaseName string 33 } 34 35 type Postgres struct { 36 db *sql.DB 37 isLocked bool 38 39 // Open and WithInstance need to garantuee that config is never nil 40 config *Config 41 } 42 43 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 44 if config == nil { 45 return nil, ErrNilConfig 46 } 47 48 if err := instance.Ping(); err != nil { 49 return nil, err 50 } 51 52 query := `SELECT CURRENT_DATABASE()` 53 var databaseName string 54 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 55 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 56 } 57 58 if len(databaseName) == 0 { 59 return nil, ErrNoDatabaseName 60 } 61 62 config.DatabaseName = databaseName 63 64 if len(config.MigrationsTable) == 0 { 65 config.MigrationsTable = DefaultMigrationsTable 66 } 67 68 px := &Postgres{ 69 db: instance, 70 config: config, 71 } 72 73 if err := px.ensureVersionTable(); err != nil { 74 return nil, err 75 } 76 77 return px, nil 78 } 79 80 func (p *Postgres) Open(url string) (database.Driver, error) { 81 purl, err := nurl.Parse(url) 82 if err != nil { 83 return nil, err 84 } 85 86 db, err := sql.Open("postgres", migrate.FilterCustomQuery(purl).String()) 87 if err != nil { 88 return nil, err 89 } 90 91 migrationsTable := purl.Query().Get("x-migrations-table") 92 if len(migrationsTable) == 0 { 93 migrationsTable = DefaultMigrationsTable 94 } 95 96 px, err := WithInstance(db, &Config{ 97 DatabaseName: purl.Path, 98 MigrationsTable: migrationsTable, 99 }) 100 if err != nil { 101 return nil, err 102 } 103 104 return px, nil 105 } 106 107 func (p *Postgres) Close() error { 108 return p.db.Close() 109 } 110 111 // https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS 112 func (p *Postgres) Lock() error { 113 if p.isLocked { 114 return database.ErrLocked 115 } 116 117 aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName) 118 if err != nil { 119 return err 120 } 121 122 // This will either obtain the lock immediately and return true, 123 // or return false if the lock cannot be acquired immediately. 124 query := `SELECT pg_try_advisory_lock($1)` 125 var success bool 126 if err := p.db.QueryRow(query, aid).Scan(&success); err != nil { 127 return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} 128 } 129 130 if success { 131 p.isLocked = true 132 return nil 133 } 134 135 return database.ErrLocked 136 } 137 138 func (p *Postgres) Unlock() error { 139 if !p.isLocked { 140 return nil 141 } 142 143 aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName) 144 if err != nil { 145 return err 146 } 147 148 query := `SELECT pg_advisory_unlock($1)` 149 if _, err := p.db.Exec(query, aid); err != nil { 150 return &database.Error{OrigErr: err, Query: []byte(query)} 151 } 152 p.isLocked = false 153 return nil 154 } 155 156 func (p *Postgres) Run(migration io.Reader) error { 157 migr, err := ioutil.ReadAll(migration) 158 if err != nil { 159 return err 160 } 161 162 // run migration 163 query := string(migr[:]) 164 if _, err := p.db.Exec(query); err != nil { 165 // TODO: cast to postgress error and get line number 166 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 167 } 168 169 return nil 170 } 171 172 func (p *Postgres) SetVersion(version int, dirty bool) error { 173 tx, err := p.db.Begin() 174 if err != nil { 175 return &database.Error{OrigErr: err, Err: "transaction start failed"} 176 } 177 178 query := `TRUNCATE "` + p.config.MigrationsTable + `"` 179 if _, err := tx.Exec(query); err != nil { 180 tx.Rollback() 181 return &database.Error{OrigErr: err, Query: []byte(query)} 182 } 183 184 if version >= 0 { 185 query = `INSERT INTO "` + p.config.MigrationsTable + `" (version, dirty) VALUES ($1, $2)` 186 if _, err := tx.Exec(query, version, dirty); err != nil { 187 tx.Rollback() 188 return &database.Error{OrigErr: err, Query: []byte(query)} 189 } 190 } 191 192 if err := tx.Commit(); err != nil { 193 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 194 } 195 196 return nil 197 } 198 199 func (p *Postgres) Version() (version int, dirty bool, err error) { 200 query := `SELECT version, dirty FROM "` + p.config.MigrationsTable + `" LIMIT 1` 201 err = p.db.QueryRow(query).Scan(&version, &dirty) 202 switch { 203 case err == sql.ErrNoRows: 204 return database.NilVersion, false, nil 205 206 case err != nil: 207 if e, ok := err.(*pq.Error); ok { 208 if e.Code.Name() == "undefined_table" { 209 return database.NilVersion, false, nil 210 } 211 } 212 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 213 214 default: 215 return version, dirty, nil 216 } 217 } 218 219 func (p *Postgres) Drop() error { 220 // select all tables in current schema 221 query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema())` 222 tables, err := p.db.Query(query) 223 if err != nil { 224 return &database.Error{OrigErr: err, Query: []byte(query)} 225 } 226 defer tables.Close() 227 228 // delete one table after another 229 tableNames := make([]string, 0) 230 for tables.Next() { 231 var tableName string 232 if err := tables.Scan(&tableName); err != nil { 233 return err 234 } 235 if len(tableName) > 0 { 236 tableNames = append(tableNames, tableName) 237 } 238 } 239 240 if len(tableNames) > 0 { 241 // delete one by one ... 242 for _, t := range tableNames { 243 query = `DROP TABLE IF EXISTS ` + t + ` CASCADE` 244 if _, err := p.db.Exec(query); err != nil { 245 return &database.Error{OrigErr: err, Query: []byte(query)} 246 } 247 } 248 if err := p.ensureVersionTable(); err != nil { 249 return err 250 } 251 } 252 253 return nil 254 } 255 256 func (p *Postgres) ensureVersionTable() error { 257 // check if migration table exists 258 var count int 259 query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1` 260 if err := p.db.QueryRow(query, p.config.MigrationsTable).Scan(&count); err != nil { 261 return &database.Error{OrigErr: err, Query: []byte(query)} 262 } 263 if count == 1 { 264 return nil 265 } 266 267 // if not, create the empty migration table 268 query = `CREATE TABLE "` + p.config.MigrationsTable + `" (version bigint not null primary key, dirty boolean not null)` 269 if _, err := p.db.Exec(query); err != nil { 270 return &database.Error{OrigErr: err, Query: []byte(query)} 271 } 272 return nil 273 }