github.com/Elate-DevOps/migrate/v4@v4.0.12/database/redshift/redshift.go (about) 1 //go:build go1.9 2 // +build go1.9 3 4 package redshift 5 6 import ( 7 "context" 8 "database/sql" 9 "fmt" 10 "io" 11 nurl "net/url" 12 "strconv" 13 "strings" 14 15 "go.uber.org/atomic" 16 17 "github.com/Elate-DevOps/migrate/v4" 18 "github.com/Elate-DevOps/migrate/v4/database" 19 "github.com/hashicorp/go-multierror" 20 "github.com/lib/pq" 21 ) 22 23 func init() { 24 db := Redshift{} 25 database.Register("redshift", &db) 26 } 27 28 var DefaultMigrationsTable = "schema_migrations" 29 30 var ( 31 ErrNilConfig = fmt.Errorf("no config") 32 ErrNoDatabaseName = fmt.Errorf("no database name") 33 ) 34 35 type Config struct { 36 MigrationsTable string 37 DatabaseName string 38 } 39 40 type Redshift struct { 41 isLocked atomic.Bool 42 conn *sql.Conn 43 db *sql.DB 44 45 // Open and WithInstance need to guarantee that config is never nil 46 config *Config 47 } 48 49 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 50 if config == nil { 51 return nil, ErrNilConfig 52 } 53 54 if err := instance.Ping(); err != nil { 55 return nil, err 56 } 57 58 if config.DatabaseName == "" { 59 query := `SELECT CURRENT_DATABASE()` 60 var databaseName string 61 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 62 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 63 } 64 65 if len(databaseName) == 0 { 66 return nil, ErrNoDatabaseName 67 } 68 69 config.DatabaseName = databaseName 70 } 71 72 if len(config.MigrationsTable) == 0 { 73 config.MigrationsTable = DefaultMigrationsTable 74 } 75 76 conn, err := instance.Conn(context.Background()) 77 if err != nil { 78 return nil, err 79 } 80 81 px := &Redshift{ 82 conn: conn, 83 db: instance, 84 config: config, 85 } 86 87 if err := px.ensureVersionTable(); err != nil { 88 return nil, err 89 } 90 91 return px, nil 92 } 93 94 func (p *Redshift) Open(url string) (database.Driver, error) { 95 purl, err := nurl.Parse(url) 96 if err != nil { 97 return nil, err 98 } 99 purl.Scheme = "postgres" 100 101 db, err := sql.Open("postgres", migrate.FilterCustomQuery(purl).String()) 102 if err != nil { 103 return nil, err 104 } 105 106 migrationsTable := purl.Query().Get("x-migrations-table") 107 108 px, err := WithInstance(db, &Config{ 109 DatabaseName: purl.Path, 110 MigrationsTable: migrationsTable, 111 }) 112 if err != nil { 113 return nil, err 114 } 115 116 return px, nil 117 } 118 119 func (p *Redshift) Close() error { 120 connErr := p.conn.Close() 121 dbErr := p.db.Close() 122 if connErr != nil || dbErr != nil { 123 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 124 } 125 return nil 126 } 127 128 // Redshift does not support advisory lock functions: https://docs.aws.amazon.com/redshift/latest/dg/c_unsupported-postgresql-functions.html 129 func (p *Redshift) Lock() error { 130 if !p.isLocked.CAS(false, true) { 131 return database.ErrLocked 132 } 133 return nil 134 } 135 136 func (p *Redshift) Unlock() error { 137 if !p.isLocked.CAS(true, false) { 138 return database.ErrNotLocked 139 } 140 return nil 141 } 142 143 func (p *Redshift) Run(migration io.Reader) error { 144 migr, err := io.ReadAll(migration) 145 if err != nil { 146 return err 147 } 148 149 // run migration 150 query := string(migr[:]) 151 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 152 if pgErr, ok := err.(*pq.Error); ok { 153 var line uint 154 var col uint 155 var lineColOK bool 156 if pgErr.Position != "" { 157 if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil { 158 line, col, lineColOK = computeLineFromPos(query, int(pos)) 159 } 160 } 161 message := fmt.Sprintf("migration failed: %s", pgErr.Message) 162 if lineColOK { 163 message = fmt.Sprintf("%s (column %d)", message, col) 164 } 165 if pgErr.Detail != "" { 166 message = fmt.Sprintf("%s, %s", message, pgErr.Detail) 167 } 168 return database.Error{OrigErr: err, Err: message, Query: migr, Line: line} 169 } 170 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 171 } 172 173 return nil 174 } 175 176 func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) { 177 // replace crlf with lf 178 s = strings.Replace(s, "\r\n", "\n", -1) 179 // pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes 180 runes := []rune(s) 181 if pos > len(runes) { 182 return 0, 0, false 183 } 184 sel := runes[:pos] 185 line = uint(runesCount(sel, newLine) + 1) 186 col = uint(pos - 1 - runesLastIndex(sel, newLine)) 187 return line, col, true 188 } 189 190 const newLine = '\n' 191 192 func runesCount(input []rune, target rune) int { 193 var count int 194 for _, r := range input { 195 if r == target { 196 count++ 197 } 198 } 199 return count 200 } 201 202 func runesLastIndex(input []rune, target rune) int { 203 for i := len(input) - 1; i >= 0; i-- { 204 if input[i] == target { 205 return i 206 } 207 } 208 return -1 209 } 210 211 func (p *Redshift) SetVersion(version int, dirty bool) error { 212 tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{}) 213 if err != nil { 214 return &database.Error{OrigErr: err, Err: "transaction start failed"} 215 } 216 217 query := `DELETE FROM "` + p.config.MigrationsTable + `"` 218 if _, err := tx.Exec(query); err != nil { 219 if errRollback := tx.Rollback(); errRollback != nil { 220 err = multierror.Append(err, errRollback) 221 } 222 return &database.Error{OrigErr: err, Query: []byte(query)} 223 } 224 225 // Also re-write the schema version for nil dirty versions to prevent 226 // empty schema version for failed down migration on the first migration 227 // See: https://github.com/golang-migrate/migrate/issues/330 228 if version >= 0 || (version == database.NilVersion && dirty) { 229 query = `INSERT INTO "` + p.config.MigrationsTable + `" (version, dirty) VALUES ($1, $2)` 230 if _, err := tx.Exec(query, version, dirty); err != nil { 231 if errRollback := tx.Rollback(); errRollback != nil { 232 err = multierror.Append(err, errRollback) 233 } 234 return &database.Error{OrigErr: err, Query: []byte(query)} 235 } 236 } 237 238 if err := tx.Commit(); err != nil { 239 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 240 } 241 242 return nil 243 } 244 245 func (p *Redshift) Version() (version int, dirty bool, err error) { 246 query := `SELECT version, dirty FROM "` + p.config.MigrationsTable + `" LIMIT 1` 247 err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 248 switch { 249 case err == sql.ErrNoRows: 250 return database.NilVersion, false, nil 251 252 case err != nil: 253 if e, ok := err.(*pq.Error); ok { 254 if e.Code.Name() == "undefined_table" { 255 return database.NilVersion, false, nil 256 } 257 } 258 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 259 260 default: 261 return version, dirty, nil 262 } 263 } 264 265 func (p *Redshift) Drop() (err error) { 266 // select all tables in current schema 267 query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'` 268 tables, err := p.conn.QueryContext(context.Background(), query) 269 if err != nil { 270 return &database.Error{OrigErr: err, Query: []byte(query)} 271 } 272 defer func() { 273 if errClose := tables.Close(); errClose != nil { 274 err = multierror.Append(err, errClose) 275 } 276 }() 277 278 // delete one table after another 279 tableNames := make([]string, 0) 280 for tables.Next() { 281 var tableName string 282 if err := tables.Scan(&tableName); err != nil { 283 return err 284 } 285 if len(tableName) > 0 { 286 tableNames = append(tableNames, tableName) 287 } 288 } 289 if err := tables.Err(); err != nil { 290 return &database.Error{OrigErr: err, Query: []byte(query)} 291 } 292 293 if len(tableNames) > 0 { 294 // delete one by one ... 295 for _, t := range tableNames { 296 query = `DROP TABLE IF EXISTS ` + t + ` CASCADE` 297 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 298 return &database.Error{OrigErr: err, Query: []byte(query)} 299 } 300 } 301 } 302 303 return nil 304 } 305 306 // ensureVersionTable checks if versions table exists and, if not, creates it. 307 // Note that this function locks the database, which deviates from the usual 308 // convention of "caller locks" in the Redshift type. 309 func (p *Redshift) ensureVersionTable() (err error) { 310 if err = p.Lock(); err != nil { 311 return err 312 } 313 314 defer func() { 315 if e := p.Unlock(); e != nil { 316 if err == nil { 317 err = e 318 } else { 319 err = multierror.Append(err, e) 320 } 321 } 322 }() 323 324 // check if migration table exists 325 var count int 326 query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1` 327 if err := p.conn.QueryRowContext(context.Background(), query, p.config.MigrationsTable).Scan(&count); err != nil { 328 return &database.Error{OrigErr: err, Query: []byte(query)} 329 } 330 if count == 1 { 331 return nil 332 } 333 334 // if not, create the empty migration table 335 query = `CREATE TABLE "` + p.config.MigrationsTable + `" (version bigint not null primary key, dirty boolean not null)` 336 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 337 return &database.Error{OrigErr: err, Query: []byte(query)} 338 } 339 return nil 340 }