github.com/getsynq/migrate/v4@v4.15.3-0.20220615182648-8e72daaa5ed9/database/snowflake/snowflake.go (about) 1 package snowflake 2 3 import ( 4 "context" 5 "database/sql" 6 "fmt" 7 "io" 8 "io/ioutil" 9 nurl "net/url" 10 "strconv" 11 "strings" 12 13 "go.uber.org/atomic" 14 15 "github.com/getsynq/migrate/v4/database" 16 "github.com/hashicorp/go-multierror" 17 "github.com/lib/pq" 18 sf "github.com/snowflakedb/gosnowflake" 19 ) 20 21 func init() { 22 db := Snowflake{} 23 database.Register("snowflake", &db) 24 } 25 26 var DefaultMigrationsTable = "schema_migrations" 27 28 var ( 29 ErrNilConfig = fmt.Errorf("no config") 30 ErrNoDatabaseName = fmt.Errorf("no database name") 31 ErrNoPassword = fmt.Errorf("no password") 32 ErrNoSchema = fmt.Errorf("no schema") 33 ErrNoSchemaOrDatabase = fmt.Errorf("no schema/database name") 34 ) 35 36 type Config struct { 37 MigrationsTable string 38 DatabaseName string 39 } 40 41 type Snowflake struct { 42 isLocked atomic.Bool 43 conn *sql.Conn 44 db *sql.DB 45 46 // Open and WithInstance need to guarantee that config is never nil 47 config *Config 48 } 49 50 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 51 if config == nil { 52 return nil, ErrNilConfig 53 } 54 55 if err := instance.Ping(); err != nil { 56 return nil, err 57 } 58 59 if config.DatabaseName == "" { 60 query := `SELECT CURRENT_DATABASE()` 61 var databaseName string 62 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 63 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 64 } 65 66 if len(databaseName) == 0 { 67 return nil, ErrNoDatabaseName 68 } 69 70 config.DatabaseName = databaseName 71 } 72 73 if len(config.MigrationsTable) == 0 { 74 config.MigrationsTable = DefaultMigrationsTable 75 } 76 77 conn, err := instance.Conn(context.Background()) 78 79 if err != nil { 80 return nil, err 81 } 82 83 px := &Snowflake{ 84 conn: conn, 85 db: instance, 86 config: config, 87 } 88 89 if err := px.ensureVersionTable(); err != nil { 90 return nil, err 91 } 92 93 return px, nil 94 } 95 96 func (p *Snowflake) Open(url string) (database.Driver, error) { 97 purl, err := nurl.Parse(url) 98 if err != nil { 99 return nil, err 100 } 101 102 password, isPasswordSet := purl.User.Password() 103 if !isPasswordSet { 104 return nil, ErrNoPassword 105 } 106 107 splitPath := strings.Split(purl.Path, "/") 108 if len(splitPath) < 3 { 109 return nil, ErrNoSchemaOrDatabase 110 } 111 112 database := splitPath[2] 113 if len(database) == 0 { 114 return nil, ErrNoDatabaseName 115 } 116 117 schema := splitPath[1] 118 if len(schema) == 0 { 119 return nil, ErrNoSchema 120 } 121 122 cfg := &sf.Config{ 123 Account: purl.Host, 124 User: purl.User.Username(), 125 Password: password, 126 Database: database, 127 Schema: schema, 128 } 129 130 dsn, err := sf.DSN(cfg) 131 if err != nil { 132 return nil, err 133 } 134 135 db, err := sql.Open("snowflake", dsn) 136 if err != nil { 137 return nil, err 138 } 139 140 migrationsTable := purl.Query().Get("x-migrations-table") 141 142 px, err := WithInstance(db, &Config{ 143 DatabaseName: database, 144 MigrationsTable: migrationsTable, 145 }) 146 if err != nil { 147 return nil, err 148 } 149 150 return px, nil 151 } 152 153 func (p *Snowflake) Close() error { 154 connErr := p.conn.Close() 155 dbErr := p.db.Close() 156 if connErr != nil || dbErr != nil { 157 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 158 } 159 return nil 160 } 161 162 func (p *Snowflake) Lock() error { 163 if !p.isLocked.CAS(false, true) { 164 return database.ErrLocked 165 } 166 return nil 167 } 168 169 func (p *Snowflake) Unlock() error { 170 if !p.isLocked.CAS(true, false) { 171 return database.ErrNotLocked 172 } 173 return nil 174 } 175 176 func (p *Snowflake) Run(migration io.Reader) error { 177 migr, err := ioutil.ReadAll(migration) 178 if err != nil { 179 return err 180 } 181 182 // run migration 183 query := string(migr[:]) 184 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 185 if pgErr, ok := err.(*pq.Error); ok { 186 var line uint 187 var col uint 188 var lineColOK bool 189 if pgErr.Position != "" { 190 if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil { 191 line, col, lineColOK = computeLineFromPos(query, int(pos)) 192 } 193 } 194 message := fmt.Sprintf("migration failed: %s", pgErr.Message) 195 if lineColOK { 196 message = fmt.Sprintf("%s (column %d)", message, col) 197 } 198 if pgErr.Detail != "" { 199 message = fmt.Sprintf("%s, %s", message, pgErr.Detail) 200 } 201 return database.Error{OrigErr: err, Err: message, Query: migr, Line: line} 202 } 203 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 204 } 205 206 return nil 207 } 208 209 func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) { 210 // replace crlf with lf 211 s = strings.Replace(s, "\r\n", "\n", -1) 212 // pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes 213 runes := []rune(s) 214 if pos > len(runes) { 215 return 0, 0, false 216 } 217 sel := runes[:pos] 218 line = uint(runesCount(sel, newLine) + 1) 219 col = uint(pos - 1 - runesLastIndex(sel, newLine)) 220 return line, col, true 221 } 222 223 const newLine = '\n' 224 225 func runesCount(input []rune, target rune) int { 226 var count int 227 for _, r := range input { 228 if r == target { 229 count++ 230 } 231 } 232 return count 233 } 234 235 func runesLastIndex(input []rune, target rune) int { 236 for i := len(input) - 1; i >= 0; i-- { 237 if input[i] == target { 238 return i 239 } 240 } 241 return -1 242 } 243 244 func (p *Snowflake) SetVersion(version int, dirty bool) error { 245 tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{}) 246 if err != nil { 247 return &database.Error{OrigErr: err, Err: "transaction start failed"} 248 } 249 250 query := `DELETE FROM "` + p.config.MigrationsTable + `"` 251 if _, err := tx.Exec(query); err != nil { 252 if errRollback := tx.Rollback(); errRollback != nil { 253 err = multierror.Append(err, errRollback) 254 } 255 return &database.Error{OrigErr: err, Query: []byte(query)} 256 } 257 258 // Also re-write the schema version for nil dirty versions to prevent 259 // empty schema version for failed down migration on the first migration 260 // See: https://github.com/getsynq/migrate/issues/330 261 if version >= 0 || (version == database.NilVersion && dirty) { 262 query = `INSERT INTO "` + p.config.MigrationsTable + `" (version, 263 dirty) VALUES (` + strconv.FormatInt(int64(version), 10) + `, 264 ` + strconv.FormatBool(dirty) + `)` 265 if _, err := tx.Exec(query); err != nil { 266 if errRollback := tx.Rollback(); errRollback != nil { 267 err = multierror.Append(err, errRollback) 268 } 269 return &database.Error{OrigErr: err, Query: []byte(query)} 270 } 271 } 272 273 if err := tx.Commit(); err != nil { 274 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 275 } 276 277 return nil 278 } 279 280 func (p *Snowflake) Version() (version int, dirty bool, err error) { 281 query := `SELECT version, dirty FROM "` + p.config.MigrationsTable + `" LIMIT 1` 282 err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 283 switch { 284 case err == sql.ErrNoRows: 285 return database.NilVersion, false, nil 286 287 case err != nil: 288 if e, ok := err.(*pq.Error); ok { 289 if e.Code.Name() == "undefined_table" { 290 return database.NilVersion, false, nil 291 } 292 } 293 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 294 295 default: 296 return version, dirty, nil 297 } 298 } 299 300 func (p *Snowflake) Drop() (err error) { 301 // select all tables in current schema 302 query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'` 303 tables, err := p.conn.QueryContext(context.Background(), query) 304 if err != nil { 305 return &database.Error{OrigErr: err, Query: []byte(query)} 306 } 307 defer func() { 308 if errClose := tables.Close(); errClose != nil { 309 err = multierror.Append(err, errClose) 310 } 311 }() 312 313 // delete one table after another 314 tableNames := make([]string, 0) 315 for tables.Next() { 316 var tableName string 317 if err := tables.Scan(&tableName); err != nil { 318 return err 319 } 320 if len(tableName) > 0 { 321 tableNames = append(tableNames, tableName) 322 } 323 } 324 if err := tables.Err(); err != nil { 325 return &database.Error{OrigErr: err, Query: []byte(query)} 326 } 327 328 if len(tableNames) > 0 { 329 // delete one by one ... 330 for _, t := range tableNames { 331 query = `DROP TABLE IF EXISTS ` + t + ` CASCADE` 332 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 333 return &database.Error{OrigErr: err, Query: []byte(query)} 334 } 335 } 336 } 337 338 return nil 339 } 340 341 // ensureVersionTable checks if versions table exists and, if not, creates it. 342 // Note that this function locks the database, which deviates from the usual 343 // convention of "caller locks" in the Snowflake type. 344 func (p *Snowflake) ensureVersionTable() (err error) { 345 if err = p.Lock(); err != nil { 346 return err 347 } 348 349 defer func() { 350 if e := p.Unlock(); e != nil { 351 if err == nil { 352 err = e 353 } else { 354 err = multierror.Append(err, e) 355 } 356 } 357 }() 358 359 // check if migration table exists 360 var count int 361 query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1` 362 if err := p.conn.QueryRowContext(context.Background(), query, p.config.MigrationsTable).Scan(&count); err != nil { 363 return &database.Error{OrigErr: err, Query: []byte(query)} 364 } 365 if count == 1 { 366 return nil 367 } 368 369 // if not, create the empty migration table 370 query = `CREATE TABLE if not exists "` + p.config.MigrationsTable + `" ( 371 version bigint not null primary key, dirty boolean not null)` 372 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 373 return &database.Error{OrigErr: err, Query: []byte(query)} 374 } 375 376 return nil 377 }