github.com/Elate-DevOps/migrate/v4@v4.0.12/database/snowflake/snowflake.go (about) 1 package snowflake 2 3 import ( 4 "context" 5 "database/sql" 6 "fmt" 7 "io" 8 nurl "net/url" 9 "strconv" 10 "strings" 11 12 "go.uber.org/atomic" 13 14 "github.com/Elate-DevOps/migrate/v4/database" 15 "github.com/hashicorp/go-multierror" 16 "github.com/lib/pq" 17 sf "github.com/snowflakedb/gosnowflake" 18 ) 19 20 func init() { 21 db := Snowflake{} 22 database.Register("snowflake", &db) 23 } 24 25 var DefaultMigrationsTable = "schema_migrations" 26 27 var ( 28 ErrNilConfig = fmt.Errorf("no config") 29 ErrNoDatabaseName = fmt.Errorf("no database name") 30 ErrNoPassword = fmt.Errorf("no password") 31 ErrNoSchema = fmt.Errorf("no schema") 32 ErrNoSchemaOrDatabase = fmt.Errorf("no schema/database name") 33 ) 34 35 type Config struct { 36 MigrationsTable string 37 DatabaseName string 38 } 39 40 type Snowflake struct { 41 isLocked atomic.Bool 42 conn *sql.Conn 43 db *sql.DB 44 45 // Open and WithInstance need to guarantee that config is never nil 46 config *Config 47 } 48 49 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 50 if config == nil { 51 return nil, ErrNilConfig 52 } 53 54 if err := instance.Ping(); err != nil { 55 return nil, err 56 } 57 58 if config.DatabaseName == "" { 59 query := `SELECT CURRENT_DATABASE()` 60 var databaseName string 61 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 62 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 63 } 64 65 if len(databaseName) == 0 { 66 return nil, ErrNoDatabaseName 67 } 68 69 config.DatabaseName = databaseName 70 } 71 72 if len(config.MigrationsTable) == 0 { 73 config.MigrationsTable = DefaultMigrationsTable 74 } 75 76 conn, err := instance.Conn(context.Background()) 77 if err != nil { 78 return nil, err 79 } 80 81 px := &Snowflake{ 82 conn: conn, 83 db: instance, 84 config: config, 85 } 86 87 if err := px.ensureVersionTable(); err != nil { 88 return nil, err 89 } 90 91 return px, nil 92 } 93 94 func (p *Snowflake) Open(url string) (database.Driver, error) { 95 purl, err := nurl.Parse(url) 96 if err != nil { 97 return nil, err 98 } 99 100 password, isPasswordSet := purl.User.Password() 101 if !isPasswordSet { 102 return nil, ErrNoPassword 103 } 104 105 splitPath := strings.Split(purl.Path, "/") 106 if len(splitPath) < 3 { 107 return nil, ErrNoSchemaOrDatabase 108 } 109 110 database := splitPath[2] 111 if len(database) == 0 { 112 return nil, ErrNoDatabaseName 113 } 114 115 schema := splitPath[1] 116 if len(schema) == 0 { 117 return nil, ErrNoSchema 118 } 119 120 cfg := &sf.Config{ 121 Account: purl.Host, 122 User: purl.User.Username(), 123 Password: password, 124 Database: database, 125 Schema: schema, 126 } 127 128 dsn, err := sf.DSN(cfg) 129 if err != nil { 130 return nil, err 131 } 132 133 db, err := sql.Open("snowflake", dsn) 134 if err != nil { 135 return nil, err 136 } 137 138 migrationsTable := purl.Query().Get("x-migrations-table") 139 140 px, err := WithInstance(db, &Config{ 141 DatabaseName: database, 142 MigrationsTable: migrationsTable, 143 }) 144 if err != nil { 145 return nil, err 146 } 147 148 return px, nil 149 } 150 151 func (p *Snowflake) Close() error { 152 connErr := p.conn.Close() 153 dbErr := p.db.Close() 154 if connErr != nil || dbErr != nil { 155 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 156 } 157 return nil 158 } 159 160 func (p *Snowflake) Lock() error { 161 if !p.isLocked.CAS(false, true) { 162 return database.ErrLocked 163 } 164 return nil 165 } 166 167 func (p *Snowflake) Unlock() error { 168 if !p.isLocked.CAS(true, false) { 169 return database.ErrNotLocked 170 } 171 return nil 172 } 173 174 func (p *Snowflake) Run(migration io.Reader) error { 175 migr, err := io.ReadAll(migration) 176 if err != nil { 177 return err 178 } 179 180 // run migration 181 query := string(migr[:]) 182 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 183 if pgErr, ok := err.(*pq.Error); ok { 184 var line uint 185 var col uint 186 var lineColOK bool 187 if pgErr.Position != "" { 188 if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil { 189 line, col, lineColOK = computeLineFromPos(query, int(pos)) 190 } 191 } 192 message := fmt.Sprintf("migration failed: %s", pgErr.Message) 193 if lineColOK { 194 message = fmt.Sprintf("%s (column %d)", message, col) 195 } 196 if pgErr.Detail != "" { 197 message = fmt.Sprintf("%s, %s", message, pgErr.Detail) 198 } 199 return database.Error{OrigErr: err, Err: message, Query: migr, Line: line} 200 } 201 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 202 } 203 204 return nil 205 } 206 207 func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) { 208 // replace crlf with lf 209 s = strings.Replace(s, "\r\n", "\n", -1) 210 // pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes 211 runes := []rune(s) 212 if pos > len(runes) { 213 return 0, 0, false 214 } 215 sel := runes[:pos] 216 line = uint(runesCount(sel, newLine) + 1) 217 col = uint(pos - 1 - runesLastIndex(sel, newLine)) 218 return line, col, true 219 } 220 221 const newLine = '\n' 222 223 func runesCount(input []rune, target rune) int { 224 var count int 225 for _, r := range input { 226 if r == target { 227 count++ 228 } 229 } 230 return count 231 } 232 233 func runesLastIndex(input []rune, target rune) int { 234 for i := len(input) - 1; i >= 0; i-- { 235 if input[i] == target { 236 return i 237 } 238 } 239 return -1 240 } 241 242 func (p *Snowflake) SetVersion(version int, dirty bool) error { 243 tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{}) 244 if err != nil { 245 return &database.Error{OrigErr: err, Err: "transaction start failed"} 246 } 247 248 query := `DELETE FROM "` + p.config.MigrationsTable + `"` 249 if _, err := tx.Exec(query); err != nil { 250 if errRollback := tx.Rollback(); errRollback != nil { 251 err = multierror.Append(err, errRollback) 252 } 253 return &database.Error{OrigErr: err, Query: []byte(query)} 254 } 255 256 // Also re-write the schema version for nil dirty versions to prevent 257 // empty schema version for failed down migration on the first migration 258 // See: https://github.com/golang-migrate/migrate/issues/330 259 if version >= 0 || (version == database.NilVersion && dirty) { 260 query = `INSERT INTO "` + p.config.MigrationsTable + `" (version, 261 dirty) VALUES (` + strconv.FormatInt(int64(version), 10) + `, 262 ` + strconv.FormatBool(dirty) + `)` 263 if _, err := tx.Exec(query); err != nil { 264 if errRollback := tx.Rollback(); errRollback != nil { 265 err = multierror.Append(err, errRollback) 266 } 267 return &database.Error{OrigErr: err, Query: []byte(query)} 268 } 269 } 270 271 if err := tx.Commit(); err != nil { 272 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 273 } 274 275 return nil 276 } 277 278 func (p *Snowflake) Version() (version int, dirty bool, err error) { 279 query := `SELECT version, dirty FROM "` + p.config.MigrationsTable + `" LIMIT 1` 280 err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 281 switch { 282 case err == sql.ErrNoRows: 283 return database.NilVersion, false, nil 284 285 case err != nil: 286 if e, ok := err.(*pq.Error); ok { 287 if e.Code.Name() == "undefined_table" { 288 return database.NilVersion, false, nil 289 } 290 } 291 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 292 293 default: 294 return version, dirty, nil 295 } 296 } 297 298 func (p *Snowflake) Drop() (err error) { 299 // select all tables in current schema 300 query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'` 301 tables, err := p.conn.QueryContext(context.Background(), query) 302 if err != nil { 303 return &database.Error{OrigErr: err, Query: []byte(query)} 304 } 305 defer func() { 306 if errClose := tables.Close(); errClose != nil { 307 err = multierror.Append(err, errClose) 308 } 309 }() 310 311 // delete one table after another 312 tableNames := make([]string, 0) 313 for tables.Next() { 314 var tableName string 315 if err := tables.Scan(&tableName); err != nil { 316 return err 317 } 318 if len(tableName) > 0 { 319 tableNames = append(tableNames, tableName) 320 } 321 } 322 if err := tables.Err(); err != nil { 323 return &database.Error{OrigErr: err, Query: []byte(query)} 324 } 325 326 if len(tableNames) > 0 { 327 // delete one by one ... 328 for _, t := range tableNames { 329 query = `DROP TABLE IF EXISTS ` + t + ` CASCADE` 330 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 331 return &database.Error{OrigErr: err, Query: []byte(query)} 332 } 333 } 334 } 335 336 return nil 337 } 338 339 // ensureVersionTable checks if versions table exists and, if not, creates it. 340 // Note that this function locks the database, which deviates from the usual 341 // convention of "caller locks" in the Snowflake type. 342 func (p *Snowflake) ensureVersionTable() (err error) { 343 if err = p.Lock(); err != nil { 344 return err 345 } 346 347 defer func() { 348 if e := p.Unlock(); e != nil { 349 if err == nil { 350 err = e 351 } else { 352 err = multierror.Append(err, e) 353 } 354 } 355 }() 356 357 // check if migration table exists 358 var count int 359 query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1` 360 if err := p.conn.QueryRowContext(context.Background(), query, p.config.MigrationsTable).Scan(&count); err != nil { 361 return &database.Error{OrigErr: err, Query: []byte(query)} 362 } 363 if count == 1 { 364 return nil 365 } 366 367 // if not, create the empty migration table 368 query = `CREATE TABLE if not exists "` + p.config.MigrationsTable + `" ( 369 version bigint not null primary key, dirty boolean not null)` 370 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 371 return &database.Error{OrigErr: err, Query: []byte(query)} 372 } 373 374 return nil 375 }