github.com/nagyist/migrate/v4@v4.14.6/database/redshift/redshift.go (about) 1 // +build go1.9 2 3 package redshift 4 5 import ( 6 "context" 7 "database/sql" 8 "fmt" 9 "io" 10 "io/ioutil" 11 nurl "net/url" 12 "strconv" 13 "strings" 14 15 "github.com/golang-migrate/migrate/v4" 16 "github.com/golang-migrate/migrate/v4/database" 17 "github.com/hashicorp/go-multierror" 18 "github.com/lib/pq" 19 ) 20 21 func init() { 22 db := Redshift{} 23 database.Register("redshift", &db) 24 } 25 26 var DefaultMigrationsTable = "schema_migrations" 27 28 var ( 29 ErrNilConfig = fmt.Errorf("no config") 30 ErrNoDatabaseName = fmt.Errorf("no database name") 31 ) 32 33 type Config struct { 34 MigrationsTable string 35 DatabaseName string 36 } 37 38 type Redshift struct { 39 isLocked bool 40 conn *sql.Conn 41 db *sql.DB 42 43 // Open and WithInstance need to guarantee that config is never nil 44 config *Config 45 } 46 47 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 48 if config == nil { 49 return nil, ErrNilConfig 50 } 51 52 if err := instance.Ping(); err != nil { 53 return nil, err 54 } 55 56 if config.DatabaseName == "" { 57 query := `SELECT CURRENT_DATABASE()` 58 var databaseName string 59 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 60 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 61 } 62 63 if len(databaseName) == 0 { 64 return nil, ErrNoDatabaseName 65 } 66 67 config.DatabaseName = databaseName 68 } 69 70 if len(config.MigrationsTable) == 0 { 71 config.MigrationsTable = DefaultMigrationsTable 72 } 73 74 conn, err := instance.Conn(context.Background()) 75 76 if err != nil { 77 return nil, err 78 } 79 80 px := &Redshift{ 81 conn: conn, 82 db: instance, 83 config: config, 84 } 85 86 if err := px.ensureVersionTable(); err != nil { 87 return nil, err 88 } 89 90 return px, nil 91 } 92 93 func (p *Redshift) Open(url string) (database.Driver, error) { 94 purl, err := nurl.Parse(url) 95 if err != nil { 96 return nil, err 97 } 98 purl.Scheme = "postgres" 99 100 db, err := sql.Open("postgres", migrate.FilterCustomQuery(purl).String()) 101 if err != nil { 102 return nil, err 103 } 104 105 migrationsTable := purl.Query().Get("x-migrations-table") 106 107 px, err := WithInstance(db, &Config{ 108 DatabaseName: purl.Path, 109 MigrationsTable: migrationsTable, 110 }) 111 if err != nil { 112 return nil, err 113 } 114 115 return px, nil 116 } 117 118 func (p *Redshift) Close() error { 119 connErr := p.conn.Close() 120 dbErr := p.db.Close() 121 if connErr != nil || dbErr != nil { 122 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 123 } 124 return nil 125 } 126 127 // Redshift does not support advisory lock functions: https://docs.aws.amazon.com/redshift/latest/dg/c_unsupported-postgresql-functions.html 128 func (p *Redshift) Lock() error { 129 if p.isLocked { 130 return database.ErrLocked 131 } 132 p.isLocked = true 133 return nil 134 } 135 136 func (p *Redshift) Unlock() error { 137 p.isLocked = false 138 return nil 139 } 140 141 func (p *Redshift) Run(migration io.Reader) error { 142 migr, err := ioutil.ReadAll(migration) 143 if err != nil { 144 return err 145 } 146 147 // run migration 148 query := string(migr[:]) 149 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 150 if pgErr, ok := err.(*pq.Error); ok { 151 var line uint 152 var col uint 153 var lineColOK bool 154 if pgErr.Position != "" { 155 if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil { 156 line, col, lineColOK = computeLineFromPos(query, int(pos)) 157 } 158 } 159 message := fmt.Sprintf("migration failed: %s", pgErr.Message) 160 if lineColOK { 161 message = fmt.Sprintf("%s (column %d)", message, col) 162 } 163 if pgErr.Detail != "" { 164 message = fmt.Sprintf("%s, %s", message, pgErr.Detail) 165 } 166 return database.Error{OrigErr: err, Err: message, Query: migr, Line: line} 167 } 168 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 169 } 170 171 return nil 172 } 173 174 func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) { 175 // replace crlf with lf 176 s = strings.Replace(s, "\r\n", "\n", -1) 177 // pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes 178 runes := []rune(s) 179 if pos > len(runes) { 180 return 0, 0, false 181 } 182 sel := runes[:pos] 183 line = uint(runesCount(sel, newLine) + 1) 184 col = uint(pos - 1 - runesLastIndex(sel, newLine)) 185 return line, col, true 186 } 187 188 const newLine = '\n' 189 190 func runesCount(input []rune, target rune) int { 191 var count int 192 for _, r := range input { 193 if r == target { 194 count++ 195 } 196 } 197 return count 198 } 199 200 func runesLastIndex(input []rune, target rune) int { 201 for i := len(input) - 1; i >= 0; i-- { 202 if input[i] == target { 203 return i 204 } 205 } 206 return -1 207 } 208 209 func (p *Redshift) SetVersion(version int, dirty bool) error { 210 tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{}) 211 if err != nil { 212 return &database.Error{OrigErr: err, Err: "transaction start failed"} 213 } 214 215 query := `DELETE FROM "` + p.config.MigrationsTable + `"` 216 if _, err := tx.Exec(query); err != nil { 217 if errRollback := tx.Rollback(); errRollback != nil { 218 err = multierror.Append(err, errRollback) 219 } 220 return &database.Error{OrigErr: err, Query: []byte(query)} 221 } 222 223 // Also re-write the schema version for nil dirty versions to prevent 224 // empty schema version for failed down migration on the first migration 225 // See: https://github.com/golang-migrate/migrate/issues/330 226 if version >= 0 || (version == database.NilVersion && dirty) { 227 query = `INSERT INTO "` + p.config.MigrationsTable + `" (version, dirty) VALUES ($1, $2)` 228 if _, err := tx.Exec(query, version, dirty); err != nil { 229 if errRollback := tx.Rollback(); errRollback != nil { 230 err = multierror.Append(err, errRollback) 231 } 232 return &database.Error{OrigErr: err, Query: []byte(query)} 233 } 234 } 235 236 if err := tx.Commit(); err != nil { 237 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 238 } 239 240 return nil 241 } 242 243 func (p *Redshift) Version() (version int, dirty bool, err error) { 244 query := `SELECT version, dirty FROM "` + p.config.MigrationsTable + `" LIMIT 1` 245 err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 246 switch { 247 case err == sql.ErrNoRows: 248 return database.NilVersion, false, nil 249 250 case err != nil: 251 if e, ok := err.(*pq.Error); ok { 252 if e.Code.Name() == "undefined_table" { 253 return database.NilVersion, false, nil 254 } 255 } 256 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 257 258 default: 259 return version, dirty, nil 260 } 261 } 262 263 func (p *Redshift) Drop() (err error) { 264 // select all tables in current schema 265 query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'` 266 tables, err := p.conn.QueryContext(context.Background(), query) 267 if err != nil { 268 return &database.Error{OrigErr: err, Query: []byte(query)} 269 } 270 defer func() { 271 if errClose := tables.Close(); errClose != nil { 272 err = multierror.Append(err, errClose) 273 } 274 }() 275 276 // delete one table after another 277 tableNames := make([]string, 0) 278 for tables.Next() { 279 var tableName string 280 if err := tables.Scan(&tableName); err != nil { 281 return err 282 } 283 if len(tableName) > 0 { 284 tableNames = append(tableNames, tableName) 285 } 286 } 287 if err := tables.Err(); err != nil { 288 return &database.Error{OrigErr: err, Query: []byte(query)} 289 } 290 291 if len(tableNames) > 0 { 292 // delete one by one ... 293 for _, t := range tableNames { 294 query = `DROP TABLE IF EXISTS ` + t + ` CASCADE` 295 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 296 return &database.Error{OrigErr: err, Query: []byte(query)} 297 } 298 } 299 } 300 301 return nil 302 } 303 304 // ensureVersionTable checks if versions table exists and, if not, creates it. 305 // Note that this function locks the database, which deviates from the usual 306 // convention of "caller locks" in the Redshift type. 307 func (p *Redshift) ensureVersionTable() (err error) { 308 if err = p.Lock(); err != nil { 309 return err 310 } 311 312 defer func() { 313 if e := p.Unlock(); e != nil { 314 if err == nil { 315 err = e 316 } else { 317 err = multierror.Append(err, e) 318 } 319 } 320 }() 321 322 // check if migration table exists 323 var count int 324 query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1` 325 if err := p.conn.QueryRowContext(context.Background(), query, p.config.MigrationsTable).Scan(&count); err != nil { 326 return &database.Error{OrigErr: err, Query: []byte(query)} 327 } 328 if count == 1 { 329 return nil 330 } 331 332 // if not, create the empty migration table 333 query = `CREATE TABLE "` + p.config.MigrationsTable + `" (version bigint not null primary key, dirty boolean not null)` 334 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 335 return &database.Error{OrigErr: err, Query: []byte(query)} 336 } 337 return nil 338 }