github.com/fr-nvriep/migrate/v4@v4.3.2/database/redshift/redshift.go (about) 1 // +build go1.9 2 3 package redshift 4 5 import ( 6 "context" 7 "database/sql" 8 "fmt" 9 "io" 10 "io/ioutil" 11 nurl "net/url" 12 "strconv" 13 "strings" 14 15 "github.com/fr-nvriep/migrate/v4" 16 "github.com/fr-nvriep/migrate/v4/database" 17 "github.com/hashicorp/go-multierror" 18 "github.com/lib/pq" 19 ) 20 21 func init() { 22 db := Redshift{} 23 database.Register("redshift", &db) 24 } 25 26 var DefaultMigrationsTable = "schema_migrations" 27 28 var ( 29 ErrNilConfig = fmt.Errorf("no config") 30 ErrNoDatabaseName = fmt.Errorf("no database name") 31 ) 32 33 type Config struct { 34 MigrationsTable string 35 DatabaseName string 36 } 37 38 type Redshift struct { 39 isLocked bool 40 conn *sql.Conn 41 db *sql.DB 42 43 // Open and WithInstance need to guarantee that config is never nil 44 config *Config 45 } 46 47 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 48 if config == nil { 49 return nil, ErrNilConfig 50 } 51 52 if err := instance.Ping(); err != nil { 53 return nil, err 54 } 55 56 query := `SELECT CURRENT_DATABASE()` 57 var databaseName string 58 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 59 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 60 } 61 62 if len(databaseName) == 0 { 63 return nil, ErrNoDatabaseName 64 } 65 66 config.DatabaseName = databaseName 67 68 if len(config.MigrationsTable) == 0 { 69 config.MigrationsTable = DefaultMigrationsTable 70 } 71 72 conn, err := instance.Conn(context.Background()) 73 74 if err != nil { 75 return nil, err 76 } 77 78 px := &Redshift{ 79 conn: conn, 80 db: instance, 81 config: config, 82 } 83 84 if err := px.ensureVersionTable(); err != nil { 85 return nil, err 86 } 87 88 return px, nil 89 } 90 91 func (p *Redshift) Open(url string) (database.Driver, error) { 92 purl, err := nurl.Parse(url) 93 if err != nil { 94 return nil, err 95 } 96 purl.Scheme = "postgres" 97 98 db, err := sql.Open("postgres", migrate.FilterCustomQuery(purl).String()) 99 if err != nil { 100 return nil, err 101 } 102 103 migrationsTable := purl.Query().Get("x-migrations-table") 104 105 px, err := WithInstance(db, &Config{ 106 DatabaseName: purl.Path, 107 MigrationsTable: migrationsTable, 108 }) 109 if err != nil { 110 return nil, err 111 } 112 113 return px, nil 114 } 115 116 func (p *Redshift) Close() error { 117 connErr := p.conn.Close() 118 dbErr := p.db.Close() 119 if connErr != nil || dbErr != nil { 120 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 121 } 122 return nil 123 } 124 125 // Redshift does not support advisory lock functions: https://docs.aws.amazon.com/redshift/latest/dg/c_unsupported-postgresql-functions.html 126 func (p *Redshift) Lock() error { 127 if p.isLocked { 128 return database.ErrLocked 129 } 130 p.isLocked = true 131 return nil 132 } 133 134 func (p *Redshift) Unlock() error { 135 p.isLocked = false 136 return nil 137 } 138 139 func (p *Redshift) Run(migration io.Reader) error { 140 migr, err := ioutil.ReadAll(migration) 141 if err != nil { 142 return err 143 } 144 145 // run migration 146 query := string(migr[:]) 147 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 148 if pgErr, ok := err.(*pq.Error); ok { 149 var line uint 150 var col uint 151 var lineColOK bool 152 if pgErr.Position != "" { 153 if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil { 154 line, col, lineColOK = computeLineFromPos(query, int(pos)) 155 } 156 } 157 message := fmt.Sprintf("migration failed: %s", pgErr.Message) 158 if lineColOK { 159 message = fmt.Sprintf("%s (column %d)", message, col) 160 } 161 if pgErr.Detail != "" { 162 message = fmt.Sprintf("%s, %s", message, pgErr.Detail) 163 } 164 return database.Error{OrigErr: err, Err: message, Query: migr, Line: line} 165 } 166 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 167 } 168 169 return nil 170 } 171 172 func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) { 173 // replace crlf with lf 174 s = strings.Replace(s, "\r\n", "\n", -1) 175 // pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes 176 runes := []rune(s) 177 if pos > len(runes) { 178 return 0, 0, false 179 } 180 sel := runes[:pos] 181 line = uint(runesCount(sel, newLine) + 1) 182 col = uint(pos - 1 - runesLastIndex(sel, newLine)) 183 return line, col, true 184 } 185 186 const newLine = '\n' 187 188 func runesCount(input []rune, target rune) int { 189 var count int 190 for _, r := range input { 191 if r == target { 192 count++ 193 } 194 } 195 return count 196 } 197 198 func runesLastIndex(input []rune, target rune) int { 199 for i := len(input) - 1; i >= 0; i-- { 200 if input[i] == target { 201 return i 202 } 203 } 204 return -1 205 } 206 207 func (p *Redshift) SetVersion(version int, dirty bool) error { 208 tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{}) 209 if err != nil { 210 return &database.Error{OrigErr: err, Err: "transaction start failed"} 211 } 212 213 query := `DELETE FROM "` + p.config.MigrationsTable + `"` 214 if _, err := tx.Exec(query); err != nil { 215 if errRollback := tx.Rollback(); errRollback != nil { 216 err = multierror.Append(err, errRollback) 217 } 218 return &database.Error{OrigErr: err, Query: []byte(query)} 219 } 220 221 if version >= 0 { 222 query = `INSERT INTO "` + p.config.MigrationsTable + `" (version, dirty) VALUES ($1, $2)` 223 if _, err := tx.Exec(query, version, dirty); err != nil { 224 if errRollback := tx.Rollback(); errRollback != nil { 225 err = multierror.Append(err, errRollback) 226 } 227 return &database.Error{OrigErr: err, Query: []byte(query)} 228 } 229 } 230 231 if err := tx.Commit(); err != nil { 232 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 233 } 234 235 return nil 236 } 237 238 func (p *Redshift) Version() (version int, dirty bool, err error) { 239 query := `SELECT version, dirty FROM "` + p.config.MigrationsTable + `" LIMIT 1` 240 err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 241 switch { 242 case err == sql.ErrNoRows: 243 return database.NilVersion, false, nil 244 245 case err != nil: 246 if e, ok := err.(*pq.Error); ok { 247 if e.Code.Name() == "undefined_table" { 248 return database.NilVersion, false, nil 249 } 250 } 251 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 252 253 default: 254 return version, dirty, nil 255 } 256 } 257 258 func (p *Redshift) Drop() (err error) { 259 // select all tables in current schema 260 query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'` 261 tables, err := p.conn.QueryContext(context.Background(), query) 262 if err != nil { 263 return &database.Error{OrigErr: err, Query: []byte(query)} 264 } 265 defer func() { 266 if errClose := tables.Close(); errClose != nil { 267 err = multierror.Append(err, errClose) 268 } 269 }() 270 271 // delete one table after another 272 tableNames := make([]string, 0) 273 for tables.Next() { 274 var tableName string 275 if err := tables.Scan(&tableName); err != nil { 276 return err 277 } 278 if len(tableName) > 0 { 279 tableNames = append(tableNames, tableName) 280 } 281 } 282 283 if len(tableNames) > 0 { 284 // delete one by one ... 285 for _, t := range tableNames { 286 query = `DROP TABLE IF EXISTS ` + t + ` CASCADE` 287 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 288 return &database.Error{OrigErr: err, Query: []byte(query)} 289 } 290 } 291 } 292 293 return nil 294 } 295 296 // ensureVersionTable checks if versions table exists and, if not, creates it. 297 // Note that this function locks the database, which deviates from the usual 298 // convention of "caller locks" in the Redshift type. 299 func (p *Redshift) ensureVersionTable() (err error) { 300 if err = p.Lock(); err != nil { 301 return err 302 } 303 304 defer func() { 305 if e := p.Unlock(); e != nil { 306 if err == nil { 307 err = e 308 } else { 309 err = multierror.Append(err, e) 310 } 311 } 312 }() 313 314 // check if migration table exists 315 var count int 316 query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1` 317 if err := p.conn.QueryRowContext(context.Background(), query, p.config.MigrationsTable).Scan(&count); err != nil { 318 return &database.Error{OrigErr: err, Query: []byte(query)} 319 } 320 if count == 1 { 321 return nil 322 } 323 324 // if not, create the empty migration table 325 query = `CREATE TABLE "` + p.config.MigrationsTable + `" (version bigint not null primary key, dirty boolean not null)` 326 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 327 return &database.Error{OrigErr: err, Query: []byte(query)} 328 } 329 return nil 330 }