github.com/nokia/migrate/v4@v4.16.0/database/snowflake/snowflake.go (about) 1 package snowflake 2 3 import ( 4 "context" 5 "database/sql" 6 "fmt" 7 "io" 8 "io/ioutil" 9 nurl "net/url" 10 "strconv" 11 "strings" 12 13 "go.uber.org/atomic" 14 15 "github.com/hashicorp/go-multierror" 16 "github.com/lib/pq" 17 "github.com/nokia/migrate/v4/database" 18 "github.com/nokia/migrate/v4/source" 19 sf "github.com/snowflakedb/gosnowflake" 20 ) 21 22 func init() { 23 db := Snowflake{} 24 database.Register("snowflake", &db) 25 } 26 27 var DefaultMigrationsTable = "schema_migrations" 28 29 var ( 30 ErrNilConfig = fmt.Errorf("no config") 31 ErrNoDatabaseName = fmt.Errorf("no database name") 32 ErrNoPassword = fmt.Errorf("no password") 33 ErrNoSchema = fmt.Errorf("no schema") 34 ErrNoSchemaOrDatabase = fmt.Errorf("no schema/database name") 35 ) 36 37 type Config struct { 38 MigrationsTable string 39 DatabaseName string 40 } 41 42 type Snowflake struct { 43 isLocked atomic.Bool 44 conn *sql.Conn 45 db *sql.DB 46 47 // Open and WithInstance need to guarantee that config is never nil 48 config *Config 49 } 50 51 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 52 if config == nil { 53 return nil, ErrNilConfig 54 } 55 56 if err := instance.Ping(); err != nil { 57 return nil, err 58 } 59 60 if config.DatabaseName == "" { 61 query := `SELECT CURRENT_DATABASE()` 62 var databaseName string 63 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 64 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 65 } 66 67 if len(databaseName) == 0 { 68 return nil, ErrNoDatabaseName 69 } 70 71 config.DatabaseName = databaseName 72 } 73 74 if len(config.MigrationsTable) == 0 { 75 config.MigrationsTable = DefaultMigrationsTable 76 } 77 78 conn, err := instance.Conn(context.Background()) 79 if err != nil { 80 return nil, err 81 } 82 83 px := &Snowflake{ 84 conn: conn, 85 db: instance, 86 config: config, 87 } 88 89 if err := px.ensureVersionTable(); err != nil { 90 return nil, err 91 } 92 93 return px, nil 94 } 95 96 func (p *Snowflake) Open(url string) (database.Driver, error) { 97 purl, err := nurl.Parse(url) 98 if err != nil { 99 return nil, err 100 } 101 102 password, isPasswordSet := purl.User.Password() 103 if !isPasswordSet { 104 return nil, ErrNoPassword 105 } 106 107 splitPath := strings.Split(purl.Path, "/") 108 if len(splitPath) < 3 { 109 return nil, ErrNoSchemaOrDatabase 110 } 111 112 database := splitPath[2] 113 if len(database) == 0 { 114 return nil, ErrNoDatabaseName 115 } 116 117 schema := splitPath[1] 118 if len(schema) == 0 { 119 return nil, ErrNoSchema 120 } 121 122 cfg := &sf.Config{ 123 Account: purl.Host, 124 User: purl.User.Username(), 125 Password: password, 126 Database: database, 127 Schema: schema, 128 } 129 130 dsn, err := sf.DSN(cfg) 131 if err != nil { 132 return nil, err 133 } 134 135 db, err := sql.Open("snowflake", dsn) 136 if err != nil { 137 return nil, err 138 } 139 140 migrationsTable := purl.Query().Get("x-migrations-table") 141 142 px, err := WithInstance(db, &Config{ 143 DatabaseName: database, 144 MigrationsTable: migrationsTable, 145 }) 146 if err != nil { 147 return nil, err 148 } 149 150 return px, nil 151 } 152 153 func (p *Snowflake) Close() error { 154 connErr := p.conn.Close() 155 dbErr := p.db.Close() 156 if connErr != nil || dbErr != nil { 157 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 158 } 159 return nil 160 } 161 162 func (p *Snowflake) Lock() error { 163 if !p.isLocked.CAS(false, true) { 164 return database.ErrLocked 165 } 166 return nil 167 } 168 169 func (p *Snowflake) Unlock() error { 170 if !p.isLocked.CAS(true, false) { 171 return database.ErrNotLocked 172 } 173 return nil 174 } 175 176 func (p *Snowflake) Run(migration io.Reader) error { 177 migr, err := ioutil.ReadAll(migration) 178 if err != nil { 179 return err 180 } 181 182 // run migration 183 query := string(migr[:]) 184 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 185 if pgErr, ok := err.(*pq.Error); ok { 186 var line uint 187 var col uint 188 var lineColOK bool 189 if pgErr.Position != "" { 190 if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil { 191 line, col, lineColOK = computeLineFromPos(query, int(pos)) 192 } 193 } 194 message := fmt.Sprintf("migration failed: %s", pgErr.Message) 195 if lineColOK { 196 message = fmt.Sprintf("%s (column %d)", message, col) 197 } 198 if pgErr.Detail != "" { 199 message = fmt.Sprintf("%s, %s", message, pgErr.Detail) 200 } 201 return database.Error{OrigErr: err, Err: message, Query: migr, Line: line} 202 } 203 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 204 } 205 206 return nil 207 } 208 209 func (p *Snowflake) RunFunctionMigration(fn source.MigrationFunc) error { 210 return database.ErrNotImpl 211 } 212 213 func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) { 214 // replace crlf with lf 215 s = strings.Replace(s, "\r\n", "\n", -1) 216 // pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes 217 runes := []rune(s) 218 if pos > len(runes) { 219 return 0, 0, false 220 } 221 sel := runes[:pos] 222 line = uint(runesCount(sel, newLine) + 1) 223 col = uint(pos - 1 - runesLastIndex(sel, newLine)) 224 return line, col, true 225 } 226 227 const newLine = '\n' 228 229 func runesCount(input []rune, target rune) int { 230 var count int 231 for _, r := range input { 232 if r == target { 233 count++ 234 } 235 } 236 return count 237 } 238 239 func runesLastIndex(input []rune, target rune) int { 240 for i := len(input) - 1; i >= 0; i-- { 241 if input[i] == target { 242 return i 243 } 244 } 245 return -1 246 } 247 248 func (p *Snowflake) SetVersion(version int, dirty bool) error { 249 tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{}) 250 if err != nil { 251 return &database.Error{OrigErr: err, Err: "transaction start failed"} 252 } 253 254 query := `DELETE FROM "` + p.config.MigrationsTable + `"` 255 if _, err := tx.Exec(query); err != nil { 256 if errRollback := tx.Rollback(); errRollback != nil { 257 err = multierror.Append(err, errRollback) 258 } 259 return &database.Error{OrigErr: err, Query: []byte(query)} 260 } 261 262 // Also re-write the schema version for nil dirty versions to prevent 263 // empty schema version for failed down migration on the first migration 264 // See: https://github.com/nokia/migrate/issues/330 265 if version >= 0 || (version == database.NilVersion && dirty) { 266 query = `INSERT INTO "` + p.config.MigrationsTable + `" (version, 267 dirty) VALUES (` + strconv.FormatInt(int64(version), 10) + `, 268 ` + strconv.FormatBool(dirty) + `)` 269 if _, err := tx.Exec(query); err != nil { 270 if errRollback := tx.Rollback(); errRollback != nil { 271 err = multierror.Append(err, errRollback) 272 } 273 return &database.Error{OrigErr: err, Query: []byte(query)} 274 } 275 } 276 277 if err := tx.Commit(); err != nil { 278 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 279 } 280 281 return nil 282 } 283 284 func (p *Snowflake) Version() (version int, dirty bool, err error) { 285 query := `SELECT version, dirty FROM "` + p.config.MigrationsTable + `" LIMIT 1` 286 err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 287 switch { 288 case err == sql.ErrNoRows: 289 return database.NilVersion, false, nil 290 291 case err != nil: 292 if e, ok := err.(*pq.Error); ok { 293 if e.Code.Name() == "undefined_table" { 294 return database.NilVersion, false, nil 295 } 296 } 297 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 298 299 default: 300 return version, dirty, nil 301 } 302 } 303 304 func (p *Snowflake) Drop() (err error) { 305 // select all tables in current schema 306 query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'` 307 tables, err := p.conn.QueryContext(context.Background(), query) 308 if err != nil { 309 return &database.Error{OrigErr: err, Query: []byte(query)} 310 } 311 defer func() { 312 if errClose := tables.Close(); errClose != nil { 313 err = multierror.Append(err, errClose) 314 } 315 }() 316 317 // delete one table after another 318 tableNames := make([]string, 0) 319 for tables.Next() { 320 var tableName string 321 if err := tables.Scan(&tableName); err != nil { 322 return err 323 } 324 if len(tableName) > 0 { 325 tableNames = append(tableNames, tableName) 326 } 327 } 328 if err := tables.Err(); err != nil { 329 return &database.Error{OrigErr: err, Query: []byte(query)} 330 } 331 332 if len(tableNames) > 0 { 333 // delete one by one ... 334 for _, t := range tableNames { 335 query = `DROP TABLE IF EXISTS ` + t + ` CASCADE` 336 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 337 return &database.Error{OrigErr: err, Query: []byte(query)} 338 } 339 } 340 } 341 342 return nil 343 } 344 345 // ensureVersionTable checks if versions table exists and, if not, creates it. 346 // Note that this function locks the database, which deviates from the usual 347 // convention of "caller locks" in the Snowflake type. 348 func (p *Snowflake) ensureVersionTable() (err error) { 349 if err = p.Lock(); err != nil { 350 return err 351 } 352 353 defer func() { 354 if e := p.Unlock(); e != nil { 355 if err == nil { 356 err = e 357 } else { 358 err = multierror.Append(err, e) 359 } 360 } 361 }() 362 363 // check if migration table exists 364 var count int 365 query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1` 366 if err := p.conn.QueryRowContext(context.Background(), query, p.config.MigrationsTable).Scan(&count); err != nil { 367 return &database.Error{OrigErr: err, Query: []byte(query)} 368 } 369 if count == 1 { 370 return nil 371 } 372 373 // if not, create the empty migration table 374 query = `CREATE TABLE if not exists "` + p.config.MigrationsTable + `" ( 375 version bigint not null primary key, dirty boolean not null)` 376 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 377 return &database.Error{OrigErr: err, Query: []byte(query)} 378 } 379 380 return nil 381 }