github.com/seashell-org/golang-migrate/v4@v4.15.3-0.20220722221203-6ab6c6c062d1/database/spanner/spanner.go (about) 1 package spanner 2 3 import ( 4 "errors" 5 "fmt" 6 "io" 7 "io/ioutil" 8 "log" 9 nurl "net/url" 10 "regexp" 11 "strconv" 12 "strings" 13 14 "context" 15 16 "cloud.google.com/go/spanner" 17 sdb "cloud.google.com/go/spanner/admin/database/apiv1" 18 19 migrate "github.com/seashell-org/golang-migrate/v4" 20 "github.com/seashell-org/golang-migrate/v4/database" 21 "github.com/seashell-org/golang-migrate/v4/database/spanner/spansql" 22 23 "github.com/hashicorp/go-multierror" 24 uatomic "go.uber.org/atomic" 25 "google.golang.org/api/iterator" 26 adminpb "google.golang.org/genproto/googleapis/spanner/admin/database/v1" 27 ) 28 29 func init() { 30 db := Spanner{} 31 database.Register("spanner", &db) 32 } 33 34 // DefaultMigrationsTable is used if no custom table is specified 35 const DefaultMigrationsTable = "SchemaMigrations" 36 37 const ( 38 unlockedVal = 0 39 lockedVal = 1 40 ) 41 42 // Driver errors 43 var ( 44 ErrNilConfig = errors.New("no config") 45 ErrNoDatabaseName = errors.New("no database name") 46 ErrNoSchema = errors.New("no schema") 47 ErrDatabaseDirty = errors.New("database is dirty") 48 ErrLockHeld = errors.New("unable to obtain lock") 49 ErrLockNotHeld = errors.New("unable to release already released lock") 50 ErrInvalidStmtType = errors.New("Invalid statement type") 51 ) 52 53 // Config used for a Spanner instance 54 type Config struct { 55 MigrationsTable string 56 DatabaseName string 57 // Whether to parse the migration DDL with spansql before 58 // running them towards Spanner. 59 // Parsing outputs clean DDL statements such as reformatted 60 // and void of comments. 61 CleanStatements bool 62 } 63 64 // Spanner implements database.Driver for Google Cloud Spanner 65 type Spanner struct { 66 db *DB 67 68 config *Config 69 70 lock *uatomic.Uint32 71 } 72 73 type DB struct { 74 admin *sdb.DatabaseAdminClient 75 data *spanner.Client 76 } 77 78 func NewDB(admin sdb.DatabaseAdminClient, data spanner.Client) *DB { 79 return &DB{ 80 admin: &admin, 81 data: &data, 82 } 83 } 84 85 // WithInstance implements database.Driver 86 func WithInstance(instance *DB, config *Config) (database.Driver, error) { 87 if config == nil { 88 return nil, ErrNilConfig 89 } 90 91 if len(config.DatabaseName) == 0 { 92 return nil, ErrNoDatabaseName 93 } 94 95 if len(config.MigrationsTable) == 0 { 96 config.MigrationsTable = DefaultMigrationsTable 97 } 98 99 sx := &Spanner{ 100 db: instance, 101 config: config, 102 lock: uatomic.NewUint32(unlockedVal), 103 } 104 105 if err := sx.ensureVersionTable(); err != nil { 106 return nil, err 107 } 108 109 return sx, nil 110 } 111 112 // Open implements database.Driver 113 func (s *Spanner) Open(url string) (database.Driver, error) { 114 purl, err := nurl.Parse(url) 115 if err != nil { 116 return nil, err 117 } 118 119 ctx := context.Background() 120 121 adminClient, err := sdb.NewDatabaseAdminClient(ctx) 122 if err != nil { 123 return nil, err 124 } 125 dbname := strings.Replace(migrate.FilterCustomQuery(purl).String(), "spanner://", "", 1) 126 dataClient, err := spanner.NewClient(ctx, dbname) 127 if err != nil { 128 log.Fatal(err) 129 } 130 131 migrationsTable := purl.Query().Get("x-migrations-table") 132 133 cleanQuery := purl.Query().Get("x-clean-statements") 134 clean := false 135 if cleanQuery != "" { 136 clean, err = strconv.ParseBool(cleanQuery) 137 if err != nil { 138 return nil, err 139 } 140 } 141 142 db := &DB{admin: adminClient, data: dataClient} 143 return WithInstance(db, &Config{ 144 DatabaseName: dbname, 145 MigrationsTable: migrationsTable, 146 CleanStatements: clean, 147 }) 148 } 149 150 // Close implements database.Driver 151 func (s *Spanner) Close() error { 152 s.db.data.Close() 153 return s.db.admin.Close() 154 } 155 156 // Lock implements database.Driver but doesn't do anything because Spanner only 157 // enqueues the UpdateDatabaseDdlRequest. 158 func (s *Spanner) Lock() error { 159 if swapped := s.lock.CAS(unlockedVal, lockedVal); swapped { 160 return nil 161 } 162 return ErrLockHeld 163 } 164 165 // Unlock implements database.Driver but no action required, see Lock. 166 func (s *Spanner) Unlock() error { 167 if swapped := s.lock.CAS(lockedVal, unlockedVal); swapped { 168 return nil 169 } 170 return ErrLockNotHeld 171 } 172 173 // Run implements database.Driver 174 func (s *Spanner) Run(migration io.Reader) error { 175 migr, err := ioutil.ReadAll(migration) 176 if err != nil { 177 return err 178 } 179 180 stmts, err := cleanStatements(migr) 181 if err != nil { 182 return err 183 } 184 185 ctx := context.Background() 186 187 for _, stmt := range stmts { 188 if stmt.t == "ddl" { 189 op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{ 190 Database: s.config.DatabaseName, 191 Statements: stmt.sqls, 192 }) 193 194 if err != nil { 195 return &database.Error{OrigErr: err, Err: "migration failed", Query: migr} 196 } 197 198 if err := op.Wait(ctx); err != nil { 199 return &database.Error{OrigErr: err, Err: "migration failed", Query: migr} 200 } 201 } else if stmt.t == "dml" { 202 spannerStmts := []spanner.Statement{} 203 for _, sql := range stmt.sqls { 204 spannerStmts = append(spannerStmts, spanner.Statement{ 205 SQL: sql, 206 }) 207 } 208 209 _, err = s.db.data.ReadWriteTransaction(ctx, func(ctx context.Context, rwt *spanner.ReadWriteTransaction) error { 210 _, err := rwt.BatchUpdate(ctx, spannerStmts) 211 if err != nil { 212 return err 213 } 214 215 return nil 216 }) 217 if err != nil { 218 return &database.Error{OrigErr: err, Err: "migration failed", Query: migr} 219 } 220 } else { 221 return ErrInvalidStmtType 222 } 223 } 224 225 return nil 226 } 227 228 // SetVersion implements database.Driver 229 func (s *Spanner) SetVersion(version int, dirty bool) error { 230 ctx := context.Background() 231 232 _, err := s.db.data.ReadWriteTransaction(ctx, 233 func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { 234 m := []*spanner.Mutation{ 235 spanner.Delete(s.config.MigrationsTable, spanner.AllKeys()), 236 spanner.Insert(s.config.MigrationsTable, 237 []string{"Version", "Dirty"}, 238 []interface{}{version, dirty}, 239 )} 240 return txn.BufferWrite(m) 241 }) 242 if err != nil { 243 return &database.Error{OrigErr: err} 244 } 245 246 return nil 247 } 248 249 // Version implements database.Driver 250 func (s *Spanner) Version() (version int, dirty bool, err error) { 251 ctx := context.Background() 252 253 stmt := spanner.Statement{ 254 SQL: `SELECT Version, Dirty FROM ` + s.config.MigrationsTable + ` LIMIT 1`, 255 } 256 iter := s.db.data.Single().Query(ctx, stmt) 257 defer iter.Stop() 258 259 row, err := iter.Next() 260 switch err { 261 case iterator.Done: 262 return database.NilVersion, false, nil 263 case nil: 264 var v int64 265 if err = row.Columns(&v, &dirty); err != nil { 266 return 0, false, &database.Error{OrigErr: err, Query: []byte(stmt.SQL)} 267 } 268 version = int(v) 269 default: 270 return 0, false, &database.Error{OrigErr: err, Query: []byte(stmt.SQL)} 271 } 272 273 return version, dirty, nil 274 } 275 276 var nameMatcher = regexp.MustCompile(`(CREATE TABLE\s(\S+)\s)|(CREATE.+INDEX\s(\S+)\s)`) 277 278 // Drop implements database.Driver. Retrieves the database schema first and 279 // creates statements to drop the indexes and tables accordingly. 280 // Note: The drop statements are created in reverse order to how they're 281 // provided in the schema. Assuming the schema describes how the database can 282 // be "build up", it seems logical to "unbuild" the database simply by going the 283 // opposite direction. More testing 284 func (s *Spanner) Drop() error { 285 ctx := context.Background() 286 res, err := s.db.admin.GetDatabaseDdl(ctx, &adminpb.GetDatabaseDdlRequest{ 287 Database: s.config.DatabaseName, 288 }) 289 if err != nil { 290 return &database.Error{OrigErr: err, Err: "drop failed"} 291 } 292 if len(res.Statements) == 0 { 293 return nil 294 } 295 296 stmts := make([]string, 0) 297 for i := len(res.Statements) - 1; i >= 0; i-- { 298 s := res.Statements[i] 299 m := nameMatcher.FindSubmatch([]byte(s)) 300 301 if len(m) == 0 { 302 continue 303 } else if tbl := m[2]; len(tbl) > 0 { 304 stmts = append(stmts, fmt.Sprintf(`DROP TABLE %s`, tbl)) 305 } else if idx := m[4]; len(idx) > 0 { 306 stmts = append(stmts, fmt.Sprintf(`DROP INDEX %s`, idx)) 307 } 308 } 309 310 op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{ 311 Database: s.config.DatabaseName, 312 Statements: stmts, 313 }) 314 if err != nil { 315 return &database.Error{OrigErr: err, Query: []byte(strings.Join(stmts, "; "))} 316 } 317 if err := op.Wait(ctx); err != nil { 318 return &database.Error{OrigErr: err, Query: []byte(strings.Join(stmts, "; "))} 319 } 320 321 return nil 322 } 323 324 // ensureVersionTable checks if versions table exists and, if not, creates it. 325 // Note that this function locks the database, which deviates from the usual 326 // convention of "caller locks" in the Spanner type. 327 func (s *Spanner) ensureVersionTable() (err error) { 328 if err = s.Lock(); err != nil { 329 return err 330 } 331 332 defer func() { 333 if e := s.Unlock(); e != nil { 334 if err == nil { 335 err = e 336 } else { 337 err = multierror.Append(err, e) 338 } 339 } 340 }() 341 342 ctx := context.Background() 343 tbl := s.config.MigrationsTable 344 iter := s.db.data.Single().Read(ctx, tbl, spanner.AllKeys(), []string{"Version"}) 345 if err := iter.Do(func(r *spanner.Row) error { return nil }); err == nil { 346 return nil 347 } 348 349 stmt := fmt.Sprintf(`CREATE TABLE %s ( 350 Version INT64 NOT NULL, 351 Dirty BOOL NOT NULL 352 ) PRIMARY KEY(Version)`, tbl) 353 354 op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{ 355 Database: s.config.DatabaseName, 356 Statements: []string{stmt}, 357 }) 358 359 if err != nil { 360 return &database.Error{OrigErr: err, Query: []byte(stmt)} 361 } 362 if err := op.Wait(ctx); err != nil { 363 return &database.Error{OrigErr: err, Query: []byte(stmt)} 364 } 365 366 return nil 367 } 368 369 type statement struct { 370 t string 371 sqls []string 372 } 373 374 func cleanStatements(migration []byte) ([]statement, error) { 375 // The Spanner GCP backend does not yet support comments for the UpdateDatabaseDdl RPC 376 // (see https://issuetracker.google.com/issues/159730604) we use 377 // spansql to parse the DDL and output valid stamements without comments 378 ddlAndDml, err := spansql.ParseDDLAndDML("", string(migration)) 379 if err != nil { 380 return nil, err 381 } 382 stmts := make([]statement, 0, len(ddlAndDml.List)) 383 lastType := "" 384 stmtsBuffer := []string(nil) 385 for _, stmt := range ddlAndDml.List { 386 if lastType != stmt.T && lastType != "" { 387 stmts = append(stmts, statement{ 388 sqls: stmtsBuffer, 389 t: lastType, 390 }) 391 stmtsBuffer = []string(nil) 392 } 393 if stmt.T == "ddl" { 394 stmtsBuffer = append(stmtsBuffer, stmt.Ddl.SQL()) 395 } else if stmt.T == "dml" { 396 stmtsBuffer = append(stmtsBuffer, stmt.Dml.SQL()) 397 } else { 398 return nil, ErrInvalidStmtType 399 } 400 lastType = stmt.T 401 } 402 if len(stmtsBuffer) > 0 && lastType != "" { 403 stmts = append(stmts, statement{ 404 sqls: stmtsBuffer, 405 t: lastType, 406 }) 407 } 408 409 return stmts, nil 410 }