github.com/scraniel/migrate@v0.0.0-20230320185700-339088f36cee/database/snowflake/snowflake.go (about) 1 package snowflake 2 3 import ( 4 "context" 5 "database/sql" 6 "fmt" 7 "io" 8 nurl "net/url" 9 "strconv" 10 "strings" 11 12 "go.uber.org/atomic" 13 14 "github.com/golang-migrate/migrate/v4/database" 15 "github.com/hashicorp/go-multierror" 16 "github.com/lib/pq" 17 sf "github.com/snowflakedb/gosnowflake" 18 ) 19 20 func init() { 21 db := Snowflake{} 22 database.Register("snowflake", &db) 23 } 24 25 var DefaultMigrationsTable = "schema_migrations" 26 27 var ( 28 ErrNilConfig = fmt.Errorf("no config") 29 ErrNoDatabaseName = fmt.Errorf("no database name") 30 ErrNoPassword = fmt.Errorf("no password") 31 ErrNoSchema = fmt.Errorf("no schema") 32 ErrNoSchemaOrDatabase = fmt.Errorf("no schema/database name") 33 ) 34 35 type Config struct { 36 MigrationsTable string 37 DatabaseName string 38 } 39 40 type Snowflake struct { 41 isLocked atomic.Bool 42 conn *sql.Conn 43 db *sql.DB 44 45 // Open and WithInstance need to guarantee that config is never nil 46 config *Config 47 } 48 49 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 50 if config == nil { 51 return nil, ErrNilConfig 52 } 53 54 if err := instance.Ping(); err != nil { 55 return nil, err 56 } 57 58 if config.DatabaseName == "" { 59 query := `SELECT CURRENT_DATABASE()` 60 var databaseName string 61 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 62 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 63 } 64 65 if len(databaseName) == 0 { 66 return nil, ErrNoDatabaseName 67 } 68 69 config.DatabaseName = databaseName 70 } 71 72 if len(config.MigrationsTable) == 0 { 73 config.MigrationsTable = DefaultMigrationsTable 74 } 75 76 conn, err := instance.Conn(context.Background()) 77 78 if err != nil { 79 return nil, err 80 } 81 82 px := &Snowflake{ 83 conn: conn, 84 db: instance, 85 config: config, 86 } 87 88 if err := px.ensureVersionTable(); err != nil { 89 return nil, err 90 } 91 92 return px, nil 93 } 94 95 func (p *Snowflake) Open(url string) (database.Driver, error) { 96 purl, err := nurl.Parse(url) 97 if err != nil { 98 return nil, err 99 } 100 101 password, isPasswordSet := purl.User.Password() 102 if !isPasswordSet { 103 return nil, ErrNoPassword 104 } 105 106 splitPath := strings.Split(purl.Path, "/") 107 if len(splitPath) < 3 { 108 return nil, ErrNoSchemaOrDatabase 109 } 110 111 database := splitPath[2] 112 if len(database) == 0 { 113 return nil, ErrNoDatabaseName 114 } 115 116 schema := splitPath[1] 117 if len(schema) == 0 { 118 return nil, ErrNoSchema 119 } 120 121 cfg := &sf.Config{ 122 Account: purl.Host, 123 User: purl.User.Username(), 124 Password: password, 125 Database: database, 126 Schema: schema, 127 } 128 129 dsn, err := sf.DSN(cfg) 130 if err != nil { 131 return nil, err 132 } 133 134 db, err := sql.Open("snowflake", dsn) 135 if err != nil { 136 return nil, err 137 } 138 139 migrationsTable := purl.Query().Get("x-migrations-table") 140 141 px, err := WithInstance(db, &Config{ 142 DatabaseName: database, 143 MigrationsTable: migrationsTable, 144 }) 145 if err != nil { 146 return nil, err 147 } 148 149 return px, nil 150 } 151 152 func (p *Snowflake) Close() error { 153 connErr := p.conn.Close() 154 dbErr := p.db.Close() 155 if connErr != nil || dbErr != nil { 156 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 157 } 158 return nil 159 } 160 161 func (p *Snowflake) Lock() error { 162 if !p.isLocked.CAS(false, true) { 163 return database.ErrLocked 164 } 165 return nil 166 } 167 168 func (p *Snowflake) Unlock() error { 169 if !p.isLocked.CAS(true, false) { 170 return database.ErrNotLocked 171 } 172 return nil 173 } 174 175 func (p *Snowflake) Run(migration io.Reader) error { 176 migr, err := io.ReadAll(migration) 177 if err != nil { 178 return err 179 } 180 181 // run migration 182 query := string(migr[:]) 183 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 184 if pgErr, ok := err.(*pq.Error); ok { 185 var line uint 186 var col uint 187 var lineColOK bool 188 if pgErr.Position != "" { 189 if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil { 190 line, col, lineColOK = computeLineFromPos(query, int(pos)) 191 } 192 } 193 message := fmt.Sprintf("migration failed: %s", pgErr.Message) 194 if lineColOK { 195 message = fmt.Sprintf("%s (column %d)", message, col) 196 } 197 if pgErr.Detail != "" { 198 message = fmt.Sprintf("%s, %s", message, pgErr.Detail) 199 } 200 return database.Error{OrigErr: err, Err: message, Query: migr, Line: line} 201 } 202 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 203 } 204 205 return nil 206 } 207 208 func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) { 209 // replace crlf with lf 210 s = strings.Replace(s, "\r\n", "\n", -1) 211 // pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes 212 runes := []rune(s) 213 if pos > len(runes) { 214 return 0, 0, false 215 } 216 sel := runes[:pos] 217 line = uint(runesCount(sel, newLine) + 1) 218 col = uint(pos - 1 - runesLastIndex(sel, newLine)) 219 return line, col, true 220 } 221 222 const newLine = '\n' 223 224 func runesCount(input []rune, target rune) int { 225 var count int 226 for _, r := range input { 227 if r == target { 228 count++ 229 } 230 } 231 return count 232 } 233 234 func runesLastIndex(input []rune, target rune) int { 235 for i := len(input) - 1; i >= 0; i-- { 236 if input[i] == target { 237 return i 238 } 239 } 240 return -1 241 } 242 243 func (p *Snowflake) SetVersion(version int, dirty bool) error { 244 tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{}) 245 if err != nil { 246 return &database.Error{OrigErr: err, Err: "transaction start failed"} 247 } 248 249 query := `DELETE FROM "` + p.config.MigrationsTable + `"` 250 if _, err := tx.Exec(query); err != nil { 251 if errRollback := tx.Rollback(); errRollback != nil { 252 err = multierror.Append(err, errRollback) 253 } 254 return &database.Error{OrigErr: err, Query: []byte(query)} 255 } 256 257 // Also re-write the schema version for nil dirty versions to prevent 258 // empty schema version for failed down migration on the first migration 259 // See: https://github.com/golang-migrate/migrate/issues/330 260 if version >= 0 || (version == database.NilVersion && dirty) { 261 query = `INSERT INTO "` + p.config.MigrationsTable + `" (version, 262 dirty) VALUES (` + strconv.FormatInt(int64(version), 10) + `, 263 ` + strconv.FormatBool(dirty) + `)` 264 if _, err := tx.Exec(query); err != nil { 265 if errRollback := tx.Rollback(); errRollback != nil { 266 err = multierror.Append(err, errRollback) 267 } 268 return &database.Error{OrigErr: err, Query: []byte(query)} 269 } 270 } 271 272 if err := tx.Commit(); err != nil { 273 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 274 } 275 276 return nil 277 } 278 279 func (p *Snowflake) Version() (version int, dirty bool, err error) { 280 query := `SELECT version, dirty FROM "` + p.config.MigrationsTable + `" LIMIT 1` 281 err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 282 switch { 283 case err == sql.ErrNoRows: 284 return database.NilVersion, false, nil 285 286 case err != nil: 287 if e, ok := err.(*pq.Error); ok { 288 if e.Code.Name() == "undefined_table" { 289 return database.NilVersion, false, nil 290 } 291 } 292 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 293 294 default: 295 return version, dirty, nil 296 } 297 } 298 299 func (p *Snowflake) Drop() (err error) { 300 // select all tables in current schema 301 query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'` 302 tables, err := p.conn.QueryContext(context.Background(), query) 303 if err != nil { 304 return &database.Error{OrigErr: err, Query: []byte(query)} 305 } 306 defer func() { 307 if errClose := tables.Close(); errClose != nil { 308 err = multierror.Append(err, errClose) 309 } 310 }() 311 312 // delete one table after another 313 tableNames := make([]string, 0) 314 for tables.Next() { 315 var tableName string 316 if err := tables.Scan(&tableName); err != nil { 317 return err 318 } 319 if len(tableName) > 0 { 320 tableNames = append(tableNames, tableName) 321 } 322 } 323 if err := tables.Err(); err != nil { 324 return &database.Error{OrigErr: err, Query: []byte(query)} 325 } 326 327 if len(tableNames) > 0 { 328 // delete one by one ... 329 for _, t := range tableNames { 330 query = `DROP TABLE IF EXISTS ` + t + ` CASCADE` 331 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 332 return &database.Error{OrigErr: err, Query: []byte(query)} 333 } 334 } 335 } 336 337 return nil 338 } 339 340 // ensureVersionTable checks if versions table exists and, if not, creates it. 341 // Note that this function locks the database, which deviates from the usual 342 // convention of "caller locks" in the Snowflake type. 343 func (p *Snowflake) ensureVersionTable() (err error) { 344 if err = p.Lock(); err != nil { 345 return err 346 } 347 348 defer func() { 349 if e := p.Unlock(); e != nil { 350 if err == nil { 351 err = e 352 } else { 353 err = multierror.Append(err, e) 354 } 355 } 356 }() 357 358 // check if migration table exists 359 var count int 360 query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1` 361 if err := p.conn.QueryRowContext(context.Background(), query, p.config.MigrationsTable).Scan(&count); err != nil { 362 return &database.Error{OrigErr: err, Query: []byte(query)} 363 } 364 if count == 1 { 365 return nil 366 } 367 368 // if not, create the empty migration table 369 query = `CREATE TABLE if not exists "` + p.config.MigrationsTable + `" ( 370 version bigint not null primary key, dirty boolean not null)` 371 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 372 return &database.Error{OrigErr: err, Query: []byte(query)} 373 } 374 375 return nil 376 }