github.com/scraniel/migrate@v0.0.0-20230320185700-339088f36cee/database/redshift/redshift.go (about) 1 //go:build go1.9 2 // +build go1.9 3 4 package redshift 5 6 import ( 7 "context" 8 "database/sql" 9 "fmt" 10 "io" 11 nurl "net/url" 12 "strconv" 13 "strings" 14 15 "go.uber.org/atomic" 16 17 "github.com/golang-migrate/migrate/v4" 18 "github.com/golang-migrate/migrate/v4/database" 19 "github.com/hashicorp/go-multierror" 20 "github.com/lib/pq" 21 ) 22 23 func init() { 24 db := Redshift{} 25 database.Register("redshift", &db) 26 } 27 28 var DefaultMigrationsTable = "schema_migrations" 29 30 var ( 31 ErrNilConfig = fmt.Errorf("no config") 32 ErrNoDatabaseName = fmt.Errorf("no database name") 33 ) 34 35 type Config struct { 36 MigrationsTable string 37 DatabaseName string 38 } 39 40 type Redshift struct { 41 isLocked atomic.Bool 42 conn *sql.Conn 43 db *sql.DB 44 45 // Open and WithInstance need to guarantee that config is never nil 46 config *Config 47 } 48 49 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 50 if config == nil { 51 return nil, ErrNilConfig 52 } 53 54 if err := instance.Ping(); err != nil { 55 return nil, err 56 } 57 58 if config.DatabaseName == "" { 59 query := `SELECT CURRENT_DATABASE()` 60 var databaseName string 61 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 62 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 63 } 64 65 if len(databaseName) == 0 { 66 return nil, ErrNoDatabaseName 67 } 68 69 config.DatabaseName = databaseName 70 } 71 72 if len(config.MigrationsTable) == 0 { 73 config.MigrationsTable = DefaultMigrationsTable 74 } 75 76 conn, err := instance.Conn(context.Background()) 77 78 if err != nil { 79 return nil, err 80 } 81 82 px := &Redshift{ 83 conn: conn, 84 db: instance, 85 config: config, 86 } 87 88 if err := px.ensureVersionTable(); err != nil { 89 return nil, err 90 } 91 92 return px, nil 93 } 94 95 func (p *Redshift) Open(url string) (database.Driver, error) { 96 purl, err := nurl.Parse(url) 97 if err != nil { 98 return nil, err 99 } 100 purl.Scheme = "postgres" 101 102 db, err := sql.Open("postgres", migrate.FilterCustomQuery(purl).String()) 103 if err != nil { 104 return nil, err 105 } 106 107 migrationsTable := purl.Query().Get("x-migrations-table") 108 109 px, err := WithInstance(db, &Config{ 110 DatabaseName: purl.Path, 111 MigrationsTable: migrationsTable, 112 }) 113 if err != nil { 114 return nil, err 115 } 116 117 return px, nil 118 } 119 120 func (p *Redshift) Close() error { 121 connErr := p.conn.Close() 122 dbErr := p.db.Close() 123 if connErr != nil || dbErr != nil { 124 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 125 } 126 return nil 127 } 128 129 // Redshift does not support advisory lock functions: https://docs.aws.amazon.com/redshift/latest/dg/c_unsupported-postgresql-functions.html 130 func (p *Redshift) Lock() error { 131 if !p.isLocked.CAS(false, true) { 132 return database.ErrLocked 133 } 134 return nil 135 } 136 137 func (p *Redshift) Unlock() error { 138 if !p.isLocked.CAS(true, false) { 139 return database.ErrNotLocked 140 } 141 return nil 142 } 143 144 func (p *Redshift) Run(migration io.Reader) error { 145 migr, err := io.ReadAll(migration) 146 if err != nil { 147 return err 148 } 149 150 // run migration 151 query := string(migr[:]) 152 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 153 if pgErr, ok := err.(*pq.Error); ok { 154 var line uint 155 var col uint 156 var lineColOK bool 157 if pgErr.Position != "" { 158 if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil { 159 line, col, lineColOK = computeLineFromPos(query, int(pos)) 160 } 161 } 162 message := fmt.Sprintf("migration failed: %s", pgErr.Message) 163 if lineColOK { 164 message = fmt.Sprintf("%s (column %d)", message, col) 165 } 166 if pgErr.Detail != "" { 167 message = fmt.Sprintf("%s, %s", message, pgErr.Detail) 168 } 169 return database.Error{OrigErr: err, Err: message, Query: migr, Line: line} 170 } 171 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 172 } 173 174 return nil 175 } 176 177 func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) { 178 // replace crlf with lf 179 s = strings.Replace(s, "\r\n", "\n", -1) 180 // pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes 181 runes := []rune(s) 182 if pos > len(runes) { 183 return 0, 0, false 184 } 185 sel := runes[:pos] 186 line = uint(runesCount(sel, newLine) + 1) 187 col = uint(pos - 1 - runesLastIndex(sel, newLine)) 188 return line, col, true 189 } 190 191 const newLine = '\n' 192 193 func runesCount(input []rune, target rune) int { 194 var count int 195 for _, r := range input { 196 if r == target { 197 count++ 198 } 199 } 200 return count 201 } 202 203 func runesLastIndex(input []rune, target rune) int { 204 for i := len(input) - 1; i >= 0; i-- { 205 if input[i] == target { 206 return i 207 } 208 } 209 return -1 210 } 211 212 func (p *Redshift) SetVersion(version int, dirty bool) error { 213 tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{}) 214 if err != nil { 215 return &database.Error{OrigErr: err, Err: "transaction start failed"} 216 } 217 218 query := `DELETE FROM "` + p.config.MigrationsTable + `"` 219 if _, err := tx.Exec(query); err != nil { 220 if errRollback := tx.Rollback(); errRollback != nil { 221 err = multierror.Append(err, errRollback) 222 } 223 return &database.Error{OrigErr: err, Query: []byte(query)} 224 } 225 226 // Also re-write the schema version for nil dirty versions to prevent 227 // empty schema version for failed down migration on the first migration 228 // See: https://github.com/golang-migrate/migrate/issues/330 229 if version >= 0 || (version == database.NilVersion && dirty) { 230 query = `INSERT INTO "` + p.config.MigrationsTable + `" (version, dirty) VALUES ($1, $2)` 231 if _, err := tx.Exec(query, version, dirty); err != nil { 232 if errRollback := tx.Rollback(); errRollback != nil { 233 err = multierror.Append(err, errRollback) 234 } 235 return &database.Error{OrigErr: err, Query: []byte(query)} 236 } 237 } 238 239 if err := tx.Commit(); err != nil { 240 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 241 } 242 243 return nil 244 } 245 246 func (p *Redshift) Version() (version int, dirty bool, err error) { 247 query := `SELECT version, dirty FROM "` + p.config.MigrationsTable + `" LIMIT 1` 248 err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 249 switch { 250 case err == sql.ErrNoRows: 251 return database.NilVersion, false, nil 252 253 case err != nil: 254 if e, ok := err.(*pq.Error); ok { 255 if e.Code.Name() == "undefined_table" { 256 return database.NilVersion, false, nil 257 } 258 } 259 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 260 261 default: 262 return version, dirty, nil 263 } 264 } 265 266 func (p *Redshift) Drop() (err error) { 267 // select all tables in current schema 268 query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'` 269 tables, err := p.conn.QueryContext(context.Background(), query) 270 if err != nil { 271 return &database.Error{OrigErr: err, Query: []byte(query)} 272 } 273 defer func() { 274 if errClose := tables.Close(); errClose != nil { 275 err = multierror.Append(err, errClose) 276 } 277 }() 278 279 // delete one table after another 280 tableNames := make([]string, 0) 281 for tables.Next() { 282 var tableName string 283 if err := tables.Scan(&tableName); err != nil { 284 return err 285 } 286 if len(tableName) > 0 { 287 tableNames = append(tableNames, tableName) 288 } 289 } 290 if err := tables.Err(); err != nil { 291 return &database.Error{OrigErr: err, Query: []byte(query)} 292 } 293 294 if len(tableNames) > 0 { 295 // delete one by one ... 296 for _, t := range tableNames { 297 query = `DROP TABLE IF EXISTS ` + t + ` CASCADE` 298 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 299 return &database.Error{OrigErr: err, Query: []byte(query)} 300 } 301 } 302 } 303 304 return nil 305 } 306 307 // ensureVersionTable checks if versions table exists and, if not, creates it. 308 // Note that this function locks the database, which deviates from the usual 309 // convention of "caller locks" in the Redshift type. 310 func (p *Redshift) ensureVersionTable() (err error) { 311 if err = p.Lock(); err != nil { 312 return err 313 } 314 315 defer func() { 316 if e := p.Unlock(); e != nil { 317 if err == nil { 318 err = e 319 } else { 320 err = multierror.Append(err, e) 321 } 322 } 323 }() 324 325 // check if migration table exists 326 var count int 327 query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1` 328 if err := p.conn.QueryRowContext(context.Background(), query, p.config.MigrationsTable).Scan(&count); err != nil { 329 return &database.Error{OrigErr: err, Query: []byte(query)} 330 } 331 if count == 1 { 332 return nil 333 } 334 335 // if not, create the empty migration table 336 query = `CREATE TABLE "` + p.config.MigrationsTable + `" (version bigint not null primary key, dirty boolean not null)` 337 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 338 return &database.Error{OrigErr: err, Query: []byte(query)} 339 } 340 return nil 341 }