github.com/brandonmartin/migrate/v4@v4.14.2/database/snowflake/snowflake.go (about) 1 package snowflake 2 3 import ( 4 "context" 5 "database/sql" 6 "fmt" 7 "io" 8 "io/ioutil" 9 nurl "net/url" 10 "strconv" 11 "strings" 12 13 "github.com/golang-migrate/migrate/v4/database" 14 "github.com/hashicorp/go-multierror" 15 "github.com/lib/pq" 16 sf "github.com/snowflakedb/gosnowflake" 17 ) 18 19 func init() { 20 db := Snowflake{} 21 database.Register("snowflake", &db) 22 } 23 24 var DefaultMigrationsTable = "schema_migrations" 25 26 var ( 27 ErrNilConfig = fmt.Errorf("no config") 28 ErrNoDatabaseName = fmt.Errorf("no database name") 29 ErrNoPassword = fmt.Errorf("no password") 30 ErrNoSchema = fmt.Errorf("no schema") 31 ErrNoSchemaOrDatabase = fmt.Errorf("no schema/database name") 32 ) 33 34 type Config struct { 35 MigrationsTable string 36 DatabaseName string 37 } 38 39 type Snowflake struct { 40 isLocked bool 41 conn *sql.Conn 42 db *sql.DB 43 44 // Open and WithInstance need to guarantee that config is never nil 45 config *Config 46 } 47 48 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 49 if config == nil { 50 return nil, ErrNilConfig 51 } 52 53 if err := instance.Ping(); err != nil { 54 return nil, err 55 } 56 57 if config.DatabaseName == "" { 58 query := `SELECT CURRENT_DATABASE()` 59 var databaseName string 60 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 61 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 62 } 63 64 if len(databaseName) == 0 { 65 return nil, ErrNoDatabaseName 66 } 67 68 config.DatabaseName = databaseName 69 } 70 71 if len(config.MigrationsTable) == 0 { 72 config.MigrationsTable = DefaultMigrationsTable 73 } 74 75 conn, err := instance.Conn(context.Background()) 76 77 if err != nil { 78 return nil, err 79 } 80 81 px := &Snowflake{ 82 conn: conn, 83 db: instance, 84 config: config, 85 } 86 87 if err := px.ensureVersionTable(); err != nil { 88 return nil, err 89 } 90 91 return px, nil 92 } 93 94 func (p *Snowflake) Open(url string) (database.Driver, error) { 95 purl, err := nurl.Parse(url) 96 if err != nil { 97 return nil, err 98 } 99 100 password, isPasswordSet := purl.User.Password() 101 if !isPasswordSet { 102 return nil, ErrNoPassword 103 } 104 105 splitPath := strings.Split(purl.Path, "/") 106 if len(splitPath) < 3 { 107 return nil, ErrNoSchemaOrDatabase 108 } 109 110 database := splitPath[2] 111 if len(database) == 0 { 112 return nil, ErrNoDatabaseName 113 } 114 115 schema := splitPath[1] 116 if len(schema) == 0 { 117 return nil, ErrNoSchema 118 } 119 120 cfg := &sf.Config{ 121 Account: purl.Host, 122 User: purl.User.Username(), 123 Password: password, 124 Database: database, 125 Schema: schema, 126 } 127 128 dsn, err := sf.DSN(cfg) 129 if err != nil { 130 return nil, err 131 } 132 133 db, err := sql.Open("snowflake", dsn) 134 if err != nil { 135 return nil, err 136 } 137 138 migrationsTable := purl.Query().Get("x-migrations-table") 139 140 px, err := WithInstance(db, &Config{ 141 DatabaseName: database, 142 MigrationsTable: migrationsTable, 143 }) 144 if err != nil { 145 return nil, err 146 } 147 148 return px, nil 149 } 150 151 func (p *Snowflake) Close() error { 152 connErr := p.conn.Close() 153 dbErr := p.db.Close() 154 if connErr != nil || dbErr != nil { 155 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 156 } 157 return nil 158 } 159 160 func (p *Snowflake) Lock() error { 161 if p.isLocked { 162 return database.ErrLocked 163 } 164 p.isLocked = true 165 return nil 166 } 167 168 func (p *Snowflake) Unlock() error { 169 p.isLocked = false 170 return nil 171 } 172 173 func (p *Snowflake) Run(migration io.Reader) error { 174 migr, err := ioutil.ReadAll(migration) 175 if err != nil { 176 return err 177 } 178 179 // run migration 180 query := string(migr[:]) 181 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 182 if pgErr, ok := err.(*pq.Error); ok { 183 var line uint 184 var col uint 185 var lineColOK bool 186 if pgErr.Position != "" { 187 if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil { 188 line, col, lineColOK = computeLineFromPos(query, int(pos)) 189 } 190 } 191 message := fmt.Sprintf("migration failed: %s", pgErr.Message) 192 if lineColOK { 193 message = fmt.Sprintf("%s (column %d)", message, col) 194 } 195 if pgErr.Detail != "" { 196 message = fmt.Sprintf("%s, %s", message, pgErr.Detail) 197 } 198 return database.Error{OrigErr: err, Err: message, Query: migr, Line: line} 199 } 200 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 201 } 202 203 return nil 204 } 205 206 func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) { 207 // replace crlf with lf 208 s = strings.Replace(s, "\r\n", "\n", -1) 209 // pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes 210 runes := []rune(s) 211 if pos > len(runes) { 212 return 0, 0, false 213 } 214 sel := runes[:pos] 215 line = uint(runesCount(sel, newLine) + 1) 216 col = uint(pos - 1 - runesLastIndex(sel, newLine)) 217 return line, col, true 218 } 219 220 const newLine = '\n' 221 222 func runesCount(input []rune, target rune) int { 223 var count int 224 for _, r := range input { 225 if r == target { 226 count++ 227 } 228 } 229 return count 230 } 231 232 func runesLastIndex(input []rune, target rune) int { 233 for i := len(input) - 1; i >= 0; i-- { 234 if input[i] == target { 235 return i 236 } 237 } 238 return -1 239 } 240 241 func (p *Snowflake) SetVersion(version int, dirty bool) error { 242 tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{}) 243 if err != nil { 244 return &database.Error{OrigErr: err, Err: "transaction start failed"} 245 } 246 247 query := `DELETE FROM "` + p.config.MigrationsTable + `"` 248 if _, err := tx.Exec(query); err != nil { 249 if errRollback := tx.Rollback(); errRollback != nil { 250 err = multierror.Append(err, errRollback) 251 } 252 return &database.Error{OrigErr: err, Query: []byte(query)} 253 } 254 255 // Also re-write the schema version for nil dirty versions to prevent 256 // empty schema version for failed down migration on the first migration 257 // See: https://github.com/golang-migrate/migrate/issues/330 258 if version >= 0 || (version == database.NilVersion && dirty) { 259 query = `INSERT INTO "` + p.config.MigrationsTable + `" (version, 260 dirty) VALUES (` + strconv.FormatInt(int64(version), 10) + `, 261 ` + strconv.FormatBool(dirty) + `)` 262 if _, err := tx.Exec(query); err != nil { 263 if errRollback := tx.Rollback(); errRollback != nil { 264 err = multierror.Append(err, errRollback) 265 } 266 return &database.Error{OrigErr: err, Query: []byte(query)} 267 } 268 } 269 270 if err := tx.Commit(); err != nil { 271 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 272 } 273 274 return nil 275 } 276 277 func (p *Snowflake) Version() (version int, dirty bool, err error) { 278 query := `SELECT version, dirty FROM "` + p.config.MigrationsTable + `" LIMIT 1` 279 err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 280 switch { 281 case err == sql.ErrNoRows: 282 return database.NilVersion, false, nil 283 284 case err != nil: 285 if e, ok := err.(*pq.Error); ok { 286 if e.Code.Name() == "undefined_table" { 287 return database.NilVersion, false, nil 288 } 289 } 290 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 291 292 default: 293 return version, dirty, nil 294 } 295 } 296 297 func (p *Snowflake) Drop() (err error) { 298 // select all tables in current schema 299 query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'` 300 tables, err := p.conn.QueryContext(context.Background(), query) 301 if err != nil { 302 return &database.Error{OrigErr: err, Query: []byte(query)} 303 } 304 defer func() { 305 if errClose := tables.Close(); errClose != nil { 306 err = multierror.Append(err, errClose) 307 } 308 }() 309 310 // delete one table after another 311 tableNames := make([]string, 0) 312 for tables.Next() { 313 var tableName string 314 if err := tables.Scan(&tableName); err != nil { 315 return err 316 } 317 if len(tableName) > 0 { 318 tableNames = append(tableNames, tableName) 319 } 320 } 321 if err := tables.Err(); err != nil { 322 return &database.Error{OrigErr: err, Query: []byte(query)} 323 } 324 325 if len(tableNames) > 0 { 326 // delete one by one ... 327 for _, t := range tableNames { 328 query = `DROP TABLE IF EXISTS ` + t + ` CASCADE` 329 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 330 return &database.Error{OrigErr: err, Query: []byte(query)} 331 } 332 } 333 } 334 335 return nil 336 } 337 338 // ensureVersionTable checks if versions table exists and, if not, creates it. 339 // Note that this function locks the database, which deviates from the usual 340 // convention of "caller locks" in the Snowflake type. 341 func (p *Snowflake) ensureVersionTable() (err error) { 342 if err = p.Lock(); err != nil { 343 return err 344 } 345 346 defer func() { 347 if e := p.Unlock(); e != nil { 348 if err == nil { 349 err = e 350 } else { 351 err = multierror.Append(err, e) 352 } 353 } 354 }() 355 356 // check if migration table exists 357 var count int 358 query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1` 359 if err := p.conn.QueryRowContext(context.Background(), query, p.config.MigrationsTable).Scan(&count); err != nil { 360 return &database.Error{OrigErr: err, Query: []byte(query)} 361 } 362 if count == 1 { 363 return nil 364 } 365 366 // if not, create the empty migration table 367 query = `CREATE TABLE if not exists "` + p.config.MigrationsTable + `" ( 368 version bigint not null primary key, dirty boolean not null)` 369 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 370 return &database.Error{OrigErr: err, Query: []byte(query)} 371 } 372 373 return nil 374 }