github.com/seashell-org/golang-migrate/v4@v4.15.3-0.20220722221203-6ab6c6c062d1/database/redshift/redshift.go (about) 1 //go:build go1.9 2 // +build go1.9 3 4 package redshift 5 6 import ( 7 "context" 8 "database/sql" 9 "fmt" 10 "io" 11 "io/ioutil" 12 nurl "net/url" 13 "strconv" 14 "strings" 15 16 "go.uber.org/atomic" 17 18 "github.com/hashicorp/go-multierror" 19 "github.com/lib/pq" 20 migrate "github.com/seashell-org/golang-migrate/v4" 21 "github.com/seashell-org/golang-migrate/v4/database" 22 ) 23 24 func init() { 25 db := Redshift{} 26 database.Register("redshift", &db) 27 } 28 29 var DefaultMigrationsTable = "schema_migrations" 30 31 var ( 32 ErrNilConfig = fmt.Errorf("no config") 33 ErrNoDatabaseName = fmt.Errorf("no database name") 34 ) 35 36 type Config struct { 37 MigrationsTable string 38 DatabaseName string 39 } 40 41 type Redshift struct { 42 isLocked atomic.Bool 43 conn *sql.Conn 44 db *sql.DB 45 46 // Open and WithInstance need to guarantee that config is never nil 47 config *Config 48 } 49 50 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 51 if config == nil { 52 return nil, ErrNilConfig 53 } 54 55 if err := instance.Ping(); err != nil { 56 return nil, err 57 } 58 59 if config.DatabaseName == "" { 60 query := `SELECT CURRENT_DATABASE()` 61 var databaseName string 62 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 63 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 64 } 65 66 if len(databaseName) == 0 { 67 return nil, ErrNoDatabaseName 68 } 69 70 config.DatabaseName = databaseName 71 } 72 73 if len(config.MigrationsTable) == 0 { 74 config.MigrationsTable = DefaultMigrationsTable 75 } 76 77 conn, err := instance.Conn(context.Background()) 78 79 if err != nil { 80 return nil, err 81 } 82 83 px := &Redshift{ 84 conn: conn, 85 db: instance, 86 config: config, 87 } 88 89 if err := px.ensureVersionTable(); err != nil { 90 return nil, err 91 } 92 93 return px, nil 94 } 95 96 func (p *Redshift) Open(url string) (database.Driver, error) { 97 purl, err := nurl.Parse(url) 98 if err != nil { 99 return nil, err 100 } 101 purl.Scheme = "postgres" 102 103 db, err := sql.Open("postgres", migrate.FilterCustomQuery(purl).String()) 104 if err != nil { 105 return nil, err 106 } 107 108 migrationsTable := purl.Query().Get("x-migrations-table") 109 110 px, err := WithInstance(db, &Config{ 111 DatabaseName: purl.Path, 112 MigrationsTable: migrationsTable, 113 }) 114 if err != nil { 115 return nil, err 116 } 117 118 return px, nil 119 } 120 121 func (p *Redshift) Close() error { 122 connErr := p.conn.Close() 123 dbErr := p.db.Close() 124 if connErr != nil || dbErr != nil { 125 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 126 } 127 return nil 128 } 129 130 // Redshift does not support advisory lock functions: https://docs.aws.amazon.com/redshift/latest/dg/c_unsupported-postgresql-functions.html 131 func (p *Redshift) Lock() error { 132 if !p.isLocked.CAS(false, true) { 133 return database.ErrLocked 134 } 135 return nil 136 } 137 138 func (p *Redshift) Unlock() error { 139 if !p.isLocked.CAS(true, false) { 140 return database.ErrNotLocked 141 } 142 return nil 143 } 144 145 func (p *Redshift) Run(migration io.Reader) error { 146 migr, err := ioutil.ReadAll(migration) 147 if err != nil { 148 return err 149 } 150 151 // run migration 152 query := string(migr[:]) 153 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 154 if pgErr, ok := err.(*pq.Error); ok { 155 var line uint 156 var col uint 157 var lineColOK bool 158 if pgErr.Position != "" { 159 if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil { 160 line, col, lineColOK = computeLineFromPos(query, int(pos)) 161 } 162 } 163 message := fmt.Sprintf("migration failed: %s", pgErr.Message) 164 if lineColOK { 165 message = fmt.Sprintf("%s (column %d)", message, col) 166 } 167 if pgErr.Detail != "" { 168 message = fmt.Sprintf("%s, %s", message, pgErr.Detail) 169 } 170 return database.Error{OrigErr: err, Err: message, Query: migr, Line: line} 171 } 172 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 173 } 174 175 return nil 176 } 177 178 func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) { 179 // replace crlf with lf 180 s = strings.Replace(s, "\r\n", "\n", -1) 181 // pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes 182 runes := []rune(s) 183 if pos > len(runes) { 184 return 0, 0, false 185 } 186 sel := runes[:pos] 187 line = uint(runesCount(sel, newLine) + 1) 188 col = uint(pos - 1 - runesLastIndex(sel, newLine)) 189 return line, col, true 190 } 191 192 const newLine = '\n' 193 194 func runesCount(input []rune, target rune) int { 195 var count int 196 for _, r := range input { 197 if r == target { 198 count++ 199 } 200 } 201 return count 202 } 203 204 func runesLastIndex(input []rune, target rune) int { 205 for i := len(input) - 1; i >= 0; i-- { 206 if input[i] == target { 207 return i 208 } 209 } 210 return -1 211 } 212 213 func (p *Redshift) SetVersion(version int, dirty bool) error { 214 tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{}) 215 if err != nil { 216 return &database.Error{OrigErr: err, Err: "transaction start failed"} 217 } 218 219 query := `DELETE FROM "` + p.config.MigrationsTable + `"` 220 if _, err := tx.Exec(query); err != nil { 221 if errRollback := tx.Rollback(); errRollback != nil { 222 err = multierror.Append(err, errRollback) 223 } 224 return &database.Error{OrigErr: err, Query: []byte(query)} 225 } 226 227 // Also re-write the schema version for nil dirty versions to prevent 228 // empty schema version for failed down migration on the first migration 229 // See: https://github.com/seashell-org/golang-migrate/issues/330 230 if version >= 0 || (version == database.NilVersion && dirty) { 231 query = `INSERT INTO "` + p.config.MigrationsTable + `" (version, dirty) VALUES ($1, $2)` 232 if _, err := tx.Exec(query, version, dirty); err != nil { 233 if errRollback := tx.Rollback(); errRollback != nil { 234 err = multierror.Append(err, errRollback) 235 } 236 return &database.Error{OrigErr: err, Query: []byte(query)} 237 } 238 } 239 240 if err := tx.Commit(); err != nil { 241 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 242 } 243 244 return nil 245 } 246 247 func (p *Redshift) Version() (version int, dirty bool, err error) { 248 query := `SELECT version, dirty FROM "` + p.config.MigrationsTable + `" LIMIT 1` 249 err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 250 switch { 251 case err == sql.ErrNoRows: 252 return database.NilVersion, false, nil 253 254 case err != nil: 255 if e, ok := err.(*pq.Error); ok { 256 if e.Code.Name() == "undefined_table" { 257 return database.NilVersion, false, nil 258 } 259 } 260 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 261 262 default: 263 return version, dirty, nil 264 } 265 } 266 267 func (p *Redshift) Drop() (err error) { 268 // select all tables in current schema 269 query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'` 270 tables, err := p.conn.QueryContext(context.Background(), query) 271 if err != nil { 272 return &database.Error{OrigErr: err, Query: []byte(query)} 273 } 274 defer func() { 275 if errClose := tables.Close(); errClose != nil { 276 err = multierror.Append(err, errClose) 277 } 278 }() 279 280 // delete one table after another 281 tableNames := make([]string, 0) 282 for tables.Next() { 283 var tableName string 284 if err := tables.Scan(&tableName); err != nil { 285 return err 286 } 287 if len(tableName) > 0 { 288 tableNames = append(tableNames, tableName) 289 } 290 } 291 if err := tables.Err(); err != nil { 292 return &database.Error{OrigErr: err, Query: []byte(query)} 293 } 294 295 if len(tableNames) > 0 { 296 // delete one by one ... 297 for _, t := range tableNames { 298 query = `DROP TABLE IF EXISTS ` + t + ` CASCADE` 299 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 300 return &database.Error{OrigErr: err, Query: []byte(query)} 301 } 302 } 303 } 304 305 return nil 306 } 307 308 // ensureVersionTable checks if versions table exists and, if not, creates it. 309 // Note that this function locks the database, which deviates from the usual 310 // convention of "caller locks" in the Redshift type. 311 func (p *Redshift) ensureVersionTable() (err error) { 312 if err = p.Lock(); err != nil { 313 return err 314 } 315 316 defer func() { 317 if e := p.Unlock(); e != nil { 318 if err == nil { 319 err = e 320 } else { 321 err = multierror.Append(err, e) 322 } 323 } 324 }() 325 326 // check if migration table exists 327 var count int 328 query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1` 329 if err := p.conn.QueryRowContext(context.Background(), query, p.config.MigrationsTable).Scan(&count); err != nil { 330 return &database.Error{OrigErr: err, Query: []byte(query)} 331 } 332 if count == 1 { 333 return nil 334 } 335 336 // if not, create the empty migration table 337 query = `CREATE TABLE "` + p.config.MigrationsTable + `" (version bigint not null primary key, dirty boolean not null)` 338 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 339 return &database.Error{OrigErr: err, Query: []byte(query)} 340 } 341 return nil 342 }