github.com/scraniel/migrate@v0.0.0-20230320185700-339088f36cee/database/spanner/spanner.go (about) 1 package spanner 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "io" 8 "log" 9 nurl "net/url" 10 "regexp" 11 "strconv" 12 "strings" 13 14 "cloud.google.com/go/spanner" 15 sdb "cloud.google.com/go/spanner/admin/database/apiv1" 16 "cloud.google.com/go/spanner/spansql" 17 18 "github.com/golang-migrate/migrate/v4" 19 "github.com/golang-migrate/migrate/v4/database" 20 21 adminpb "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" 22 "github.com/hashicorp/go-multierror" 23 uatomic "go.uber.org/atomic" 24 "google.golang.org/api/iterator" 25 ) 26 27 func init() { 28 db := Spanner{} 29 database.Register("spanner", &db) 30 } 31 32 // DefaultMigrationsTable is used if no custom table is specified 33 const DefaultMigrationsTable = "SchemaMigrations" 34 35 const ( 36 unlockedVal = 0 37 lockedVal = 1 38 ) 39 40 // Driver errors 41 var ( 42 ErrNilConfig = errors.New("no config") 43 ErrNoDatabaseName = errors.New("no database name") 44 ErrNoSchema = errors.New("no schema") 45 ErrDatabaseDirty = errors.New("database is dirty") 46 ErrLockHeld = errors.New("unable to obtain lock") 47 ErrLockNotHeld = errors.New("unable to release already released lock") 48 ) 49 50 // Config used for a Spanner instance 51 type Config struct { 52 MigrationsTable string 53 DatabaseName string 54 // Whether to parse the migration DDL with spansql before 55 // running them towards Spanner. 56 // Parsing outputs clean DDL statements such as reformatted 57 // and void of comments. 58 CleanStatements bool 59 } 60 61 // Spanner implements database.Driver for Google Cloud Spanner 62 type Spanner struct { 63 db *DB 64 65 config *Config 66 67 lock *uatomic.Uint32 68 } 69 70 type DB struct { 71 admin *sdb.DatabaseAdminClient 72 data *spanner.Client 73 } 74 75 func NewDB(admin sdb.DatabaseAdminClient, data spanner.Client) *DB { 76 return &DB{ 77 admin: &admin, 78 data: &data, 79 } 80 } 81 82 // WithInstance implements database.Driver 83 func WithInstance(instance *DB, config *Config) (database.Driver, error) { 84 if config == nil { 85 return nil, ErrNilConfig 86 } 87 88 if len(config.DatabaseName) == 0 { 89 return nil, ErrNoDatabaseName 90 } 91 92 if len(config.MigrationsTable) == 0 { 93 config.MigrationsTable = DefaultMigrationsTable 94 } 95 96 sx := &Spanner{ 97 db: instance, 98 config: config, 99 lock: uatomic.NewUint32(unlockedVal), 100 } 101 102 if err := sx.ensureVersionTable(); err != nil { 103 return nil, err 104 } 105 106 return sx, nil 107 } 108 109 // Open implements database.Driver 110 func (s *Spanner) Open(url string) (database.Driver, error) { 111 purl, err := nurl.Parse(url) 112 if err != nil { 113 return nil, err 114 } 115 116 ctx := context.Background() 117 118 adminClient, err := sdb.NewDatabaseAdminClient(ctx) 119 if err != nil { 120 return nil, err 121 } 122 dbname := strings.Replace(migrate.FilterCustomQuery(purl).String(), "spanner://", "", 1) 123 dataClient, err := spanner.NewClient(ctx, dbname) 124 if err != nil { 125 log.Fatal(err) 126 } 127 128 migrationsTable := purl.Query().Get("x-migrations-table") 129 130 cleanQuery := purl.Query().Get("x-clean-statements") 131 clean := false 132 if cleanQuery != "" { 133 clean, err = strconv.ParseBool(cleanQuery) 134 if err != nil { 135 return nil, err 136 } 137 } 138 139 db := &DB{admin: adminClient, data: dataClient} 140 return WithInstance(db, &Config{ 141 DatabaseName: dbname, 142 MigrationsTable: migrationsTable, 143 CleanStatements: clean, 144 }) 145 } 146 147 // Close implements database.Driver 148 func (s *Spanner) Close() error { 149 s.db.data.Close() 150 return s.db.admin.Close() 151 } 152 153 // Lock implements database.Driver but doesn't do anything because Spanner only 154 // enqueues the UpdateDatabaseDdlRequest. 155 func (s *Spanner) Lock() error { 156 if swapped := s.lock.CAS(unlockedVal, lockedVal); swapped { 157 return nil 158 } 159 return ErrLockHeld 160 } 161 162 // Unlock implements database.Driver but no action required, see Lock. 163 func (s *Spanner) Unlock() error { 164 if swapped := s.lock.CAS(lockedVal, unlockedVal); swapped { 165 return nil 166 } 167 return ErrLockNotHeld 168 } 169 170 // Run implements database.Driver 171 func (s *Spanner) Run(migration io.Reader) error { 172 migr, err := io.ReadAll(migration) 173 if err != nil { 174 return err 175 } 176 177 stmts := []string{string(migr)} 178 if s.config.CleanStatements { 179 stmts, err = cleanStatements(migr) 180 if err != nil { 181 return err 182 } 183 } 184 185 ctx := context.Background() 186 op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{ 187 Database: s.config.DatabaseName, 188 Statements: stmts, 189 }) 190 191 if err != nil { 192 return &database.Error{OrigErr: err, Err: "migration failed", Query: migr} 193 } 194 195 if err := op.Wait(ctx); err != nil { 196 return &database.Error{OrigErr: err, Err: "migration failed", Query: migr} 197 } 198 199 return nil 200 } 201 202 // SetVersion implements database.Driver 203 func (s *Spanner) SetVersion(version int, dirty bool) error { 204 ctx := context.Background() 205 206 _, err := s.db.data.ReadWriteTransaction(ctx, 207 func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { 208 m := []*spanner.Mutation{ 209 spanner.Delete(s.config.MigrationsTable, spanner.AllKeys()), 210 spanner.Insert(s.config.MigrationsTable, 211 []string{"Version", "Dirty"}, 212 []interface{}{version, dirty}, 213 )} 214 return txn.BufferWrite(m) 215 }) 216 if err != nil { 217 return &database.Error{OrigErr: err} 218 } 219 220 return nil 221 } 222 223 // Version implements database.Driver 224 func (s *Spanner) Version() (version int, dirty bool, err error) { 225 ctx := context.Background() 226 227 stmt := spanner.Statement{ 228 SQL: `SELECT Version, Dirty FROM ` + s.config.MigrationsTable + ` LIMIT 1`, 229 } 230 iter := s.db.data.Single().Query(ctx, stmt) 231 defer iter.Stop() 232 233 row, err := iter.Next() 234 switch err { 235 case iterator.Done: 236 return database.NilVersion, false, nil 237 case nil: 238 var v int64 239 if err = row.Columns(&v, &dirty); err != nil { 240 return 0, false, &database.Error{OrigErr: err, Query: []byte(stmt.SQL)} 241 } 242 version = int(v) 243 default: 244 return 0, false, &database.Error{OrigErr: err, Query: []byte(stmt.SQL)} 245 } 246 247 return version, dirty, nil 248 } 249 250 var nameMatcher = regexp.MustCompile(`(CREATE TABLE\s(\S+)\s)|(CREATE.+INDEX\s(\S+)\s)`) 251 252 // Drop implements database.Driver. Retrieves the database schema first and 253 // creates statements to drop the indexes and tables accordingly. 254 // Note: The drop statements are created in reverse order to how they're 255 // provided in the schema. Assuming the schema describes how the database can 256 // be "build up", it seems logical to "unbuild" the database simply by going the 257 // opposite direction. More testing 258 func (s *Spanner) Drop() error { 259 ctx := context.Background() 260 res, err := s.db.admin.GetDatabaseDdl(ctx, &adminpb.GetDatabaseDdlRequest{ 261 Database: s.config.DatabaseName, 262 }) 263 if err != nil { 264 return &database.Error{OrigErr: err, Err: "drop failed"} 265 } 266 if len(res.Statements) == 0 { 267 return nil 268 } 269 270 stmts := make([]string, 0) 271 for i := len(res.Statements) - 1; i >= 0; i-- { 272 s := res.Statements[i] 273 m := nameMatcher.FindSubmatch([]byte(s)) 274 275 if len(m) == 0 { 276 continue 277 } else if tbl := m[2]; len(tbl) > 0 { 278 stmts = append(stmts, fmt.Sprintf(`DROP TABLE %s`, tbl)) 279 } else if idx := m[4]; len(idx) > 0 { 280 stmts = append(stmts, fmt.Sprintf(`DROP INDEX %s`, idx)) 281 } 282 } 283 284 op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{ 285 Database: s.config.DatabaseName, 286 Statements: stmts, 287 }) 288 if err != nil { 289 return &database.Error{OrigErr: err, Query: []byte(strings.Join(stmts, "; "))} 290 } 291 if err := op.Wait(ctx); err != nil { 292 return &database.Error{OrigErr: err, Query: []byte(strings.Join(stmts, "; "))} 293 } 294 295 return nil 296 } 297 298 // ensureVersionTable checks if versions table exists and, if not, creates it. 299 // Note that this function locks the database, which deviates from the usual 300 // convention of "caller locks" in the Spanner type. 301 func (s *Spanner) ensureVersionTable() (err error) { 302 if err = s.Lock(); err != nil { 303 return err 304 } 305 306 defer func() { 307 if e := s.Unlock(); e != nil { 308 if err == nil { 309 err = e 310 } else { 311 err = multierror.Append(err, e) 312 } 313 } 314 }() 315 316 ctx := context.Background() 317 tbl := s.config.MigrationsTable 318 iter := s.db.data.Single().Read(ctx, tbl, spanner.AllKeys(), []string{"Version"}) 319 if err := iter.Do(func(r *spanner.Row) error { return nil }); err == nil { 320 return nil 321 } 322 323 stmt := fmt.Sprintf(`CREATE TABLE %s ( 324 Version INT64 NOT NULL, 325 Dirty BOOL NOT NULL 326 ) PRIMARY KEY(Version)`, tbl) 327 328 op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{ 329 Database: s.config.DatabaseName, 330 Statements: []string{stmt}, 331 }) 332 333 if err != nil { 334 return &database.Error{OrigErr: err, Query: []byte(stmt)} 335 } 336 if err := op.Wait(ctx); err != nil { 337 return &database.Error{OrigErr: err, Query: []byte(stmt)} 338 } 339 340 return nil 341 } 342 343 func cleanStatements(migration []byte) ([]string, error) { 344 // The Spanner GCP backend does not yet support comments for the UpdateDatabaseDdl RPC 345 // (see https://issuetracker.google.com/issues/159730604) we use 346 // spansql to parse the DDL and output valid stamements without comments 347 ddl, err := spansql.ParseDDL("", string(migration)) 348 if err != nil { 349 return nil, err 350 } 351 stmts := make([]string, 0, len(ddl.List)) 352 for _, stmt := range ddl.List { 353 stmts = append(stmts, stmt.SQL()) 354 } 355 return stmts, nil 356 }