github.com/nokia/migrate/v4@v4.16.0/database/redshift/redshift.go (about) 1 //go:build go1.9 2 // +build go1.9 3 4 package redshift 5 6 import ( 7 "context" 8 "database/sql" 9 "fmt" 10 "io" 11 "io/ioutil" 12 nurl "net/url" 13 "strconv" 14 "strings" 15 16 "go.uber.org/atomic" 17 18 "github.com/hashicorp/go-multierror" 19 "github.com/lib/pq" 20 "github.com/nokia/migrate/v4" 21 "github.com/nokia/migrate/v4/database" 22 "github.com/nokia/migrate/v4/source" 23 ) 24 25 func init() { 26 db := Redshift{} 27 database.Register("redshift", &db) 28 } 29 30 var DefaultMigrationsTable = "schema_migrations" 31 32 var ( 33 ErrNilConfig = fmt.Errorf("no config") 34 ErrNoDatabaseName = fmt.Errorf("no database name") 35 ) 36 37 type Config struct { 38 MigrationsTable string 39 DatabaseName string 40 } 41 42 type Redshift struct { 43 isLocked atomic.Bool 44 conn *sql.Conn 45 db *sql.DB 46 47 // Open and WithInstance need to guarantee that config is never nil 48 config *Config 49 } 50 51 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 52 if config == nil { 53 return nil, ErrNilConfig 54 } 55 56 if err := instance.Ping(); err != nil { 57 return nil, err 58 } 59 60 if config.DatabaseName == "" { 61 query := `SELECT CURRENT_DATABASE()` 62 var databaseName string 63 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 64 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 65 } 66 67 if len(databaseName) == 0 { 68 return nil, ErrNoDatabaseName 69 } 70 71 config.DatabaseName = databaseName 72 } 73 74 if len(config.MigrationsTable) == 0 { 75 config.MigrationsTable = DefaultMigrationsTable 76 } 77 78 conn, err := instance.Conn(context.Background()) 79 if err != nil { 80 return nil, err 81 } 82 83 px := &Redshift{ 84 conn: conn, 85 db: instance, 86 config: config, 87 } 88 89 if err := px.ensureVersionTable(); err != nil { 90 return nil, err 91 } 92 93 return px, nil 94 } 95 96 func (p *Redshift) Open(url string) (database.Driver, error) { 97 purl, err := nurl.Parse(url) 98 if err != nil { 99 return nil, err 100 } 101 purl.Scheme = "postgres" 102 103 db, err := sql.Open("postgres", migrate.FilterCustomQuery(purl).String()) 104 if err != nil { 105 return nil, err 106 } 107 108 migrationsTable := purl.Query().Get("x-migrations-table") 109 110 px, err := WithInstance(db, &Config{ 111 DatabaseName: purl.Path, 112 MigrationsTable: migrationsTable, 113 }) 114 if err != nil { 115 return nil, err 116 } 117 118 return px, nil 119 } 120 121 func (p *Redshift) Close() error { 122 connErr := p.conn.Close() 123 dbErr := p.db.Close() 124 if connErr != nil || dbErr != nil { 125 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 126 } 127 return nil 128 } 129 130 // Redshift does not support advisory lock functions: https://docs.aws.amazon.com/redshift/latest/dg/c_unsupported-postgresql-functions.html 131 func (p *Redshift) Lock() error { 132 if !p.isLocked.CAS(false, true) { 133 return database.ErrLocked 134 } 135 return nil 136 } 137 138 func (p *Redshift) Unlock() error { 139 if !p.isLocked.CAS(true, false) { 140 return database.ErrNotLocked 141 } 142 return nil 143 } 144 145 func (p *Redshift) Run(migration io.Reader) error { 146 migr, err := ioutil.ReadAll(migration) 147 if err != nil { 148 return err 149 } 150 151 // run migration 152 query := string(migr[:]) 153 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 154 if pgErr, ok := err.(*pq.Error); ok { 155 var line uint 156 var col uint 157 var lineColOK bool 158 if pgErr.Position != "" { 159 if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil { 160 line, col, lineColOK = computeLineFromPos(query, int(pos)) 161 } 162 } 163 message := fmt.Sprintf("migration failed: %s", pgErr.Message) 164 if lineColOK { 165 message = fmt.Sprintf("%s (column %d)", message, col) 166 } 167 if pgErr.Detail != "" { 168 message = fmt.Sprintf("%s, %s", message, pgErr.Detail) 169 } 170 return database.Error{OrigErr: err, Err: message, Query: migr, Line: line} 171 } 172 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 173 } 174 175 return nil 176 } 177 178 func (p *Redshift) RunFunctionMigration(fn source.MigrationFunc) error { 179 return database.ErrNotImpl 180 } 181 182 func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) { 183 // replace crlf with lf 184 s = strings.Replace(s, "\r\n", "\n", -1) 185 // pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes 186 runes := []rune(s) 187 if pos > len(runes) { 188 return 0, 0, false 189 } 190 sel := runes[:pos] 191 line = uint(runesCount(sel, newLine) + 1) 192 col = uint(pos - 1 - runesLastIndex(sel, newLine)) 193 return line, col, true 194 } 195 196 const newLine = '\n' 197 198 func runesCount(input []rune, target rune) int { 199 var count int 200 for _, r := range input { 201 if r == target { 202 count++ 203 } 204 } 205 return count 206 } 207 208 func runesLastIndex(input []rune, target rune) int { 209 for i := len(input) - 1; i >= 0; i-- { 210 if input[i] == target { 211 return i 212 } 213 } 214 return -1 215 } 216 217 func (p *Redshift) SetVersion(version int, dirty bool) error { 218 tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{}) 219 if err != nil { 220 return &database.Error{OrigErr: err, Err: "transaction start failed"} 221 } 222 223 query := `DELETE FROM "` + p.config.MigrationsTable + `"` 224 if _, err := tx.Exec(query); err != nil { 225 if errRollback := tx.Rollback(); errRollback != nil { 226 err = multierror.Append(err, errRollback) 227 } 228 return &database.Error{OrigErr: err, Query: []byte(query)} 229 } 230 231 // Also re-write the schema version for nil dirty versions to prevent 232 // empty schema version for failed down migration on the first migration 233 // See: https://github.com/nokia/migrate/issues/330 234 if version >= 0 || (version == database.NilVersion && dirty) { 235 query = `INSERT INTO "` + p.config.MigrationsTable + `" (version, dirty) VALUES ($1, $2)` 236 if _, err := tx.Exec(query, version, dirty); err != nil { 237 if errRollback := tx.Rollback(); errRollback != nil { 238 err = multierror.Append(err, errRollback) 239 } 240 return &database.Error{OrigErr: err, Query: []byte(query)} 241 } 242 } 243 244 if err := tx.Commit(); err != nil { 245 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 246 } 247 248 return nil 249 } 250 251 func (p *Redshift) Version() (version int, dirty bool, err error) { 252 query := `SELECT version, dirty FROM "` + p.config.MigrationsTable + `" LIMIT 1` 253 err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 254 switch { 255 case err == sql.ErrNoRows: 256 return database.NilVersion, false, nil 257 258 case err != nil: 259 if e, ok := err.(*pq.Error); ok { 260 if e.Code.Name() == "undefined_table" { 261 return database.NilVersion, false, nil 262 } 263 } 264 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 265 266 default: 267 return version, dirty, nil 268 } 269 } 270 271 func (p *Redshift) Drop() (err error) { 272 // select all tables in current schema 273 query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'` 274 tables, err := p.conn.QueryContext(context.Background(), query) 275 if err != nil { 276 return &database.Error{OrigErr: err, Query: []byte(query)} 277 } 278 defer func() { 279 if errClose := tables.Close(); errClose != nil { 280 err = multierror.Append(err, errClose) 281 } 282 }() 283 284 // delete one table after another 285 tableNames := make([]string, 0) 286 for tables.Next() { 287 var tableName string 288 if err := tables.Scan(&tableName); err != nil { 289 return err 290 } 291 if len(tableName) > 0 { 292 tableNames = append(tableNames, tableName) 293 } 294 } 295 if err := tables.Err(); err != nil { 296 return &database.Error{OrigErr: err, Query: []byte(query)} 297 } 298 299 if len(tableNames) > 0 { 300 // delete one by one ... 301 for _, t := range tableNames { 302 query = `DROP TABLE IF EXISTS ` + t + ` CASCADE` 303 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 304 return &database.Error{OrigErr: err, Query: []byte(query)} 305 } 306 } 307 } 308 309 return nil 310 } 311 312 // ensureVersionTable checks if versions table exists and, if not, creates it. 313 // Note that this function locks the database, which deviates from the usual 314 // convention of "caller locks" in the Redshift type. 315 func (p *Redshift) ensureVersionTable() (err error) { 316 if err = p.Lock(); err != nil { 317 return err 318 } 319 320 defer func() { 321 if e := p.Unlock(); e != nil { 322 if err == nil { 323 err = e 324 } else { 325 err = multierror.Append(err, e) 326 } 327 } 328 }() 329 330 // check if migration table exists 331 var count int 332 query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1` 333 if err := p.conn.QueryRowContext(context.Background(), query, p.config.MigrationsTable).Scan(&count); err != nil { 334 return &database.Error{OrigErr: err, Query: []byte(query)} 335 } 336 if count == 1 { 337 return nil 338 } 339 340 // if not, create the empty migration table 341 query = `CREATE TABLE "` + p.config.MigrationsTable + `" (version bigint not null primary key, dirty boolean not null)` 342 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 343 return &database.Error{OrigErr: err, Query: []byte(query)} 344 } 345 return nil 346 }