github.com/dynastymasra/migrate/v4@v4.11.0/database/spanner/spanner.go (about) 1 package spanner 2 3 import ( 4 "fmt" 5 "io" 6 "io/ioutil" 7 "log" 8 nurl "net/url" 9 "regexp" 10 "strings" 11 12 "golang.org/x/net/context" 13 14 "cloud.google.com/go/spanner" 15 sdb "cloud.google.com/go/spanner/admin/database/apiv1" 16 17 "github.com/golang-migrate/migrate/v4" 18 "github.com/golang-migrate/migrate/v4/database" 19 20 "github.com/hashicorp/go-multierror" 21 "google.golang.org/api/iterator" 22 adminpb "google.golang.org/genproto/googleapis/spanner/admin/database/v1" 23 ) 24 25 func init() { 26 db := Spanner{} 27 database.Register("spanner", &db) 28 } 29 30 // DefaultMigrationsTable is used if no custom table is specified 31 const DefaultMigrationsTable = "SchemaMigrations" 32 33 // Driver errors 34 var ( 35 ErrNilConfig = fmt.Errorf("no config") 36 ErrNoDatabaseName = fmt.Errorf("no database name") 37 ErrNoSchema = fmt.Errorf("no schema") 38 ErrDatabaseDirty = fmt.Errorf("database is dirty") 39 ) 40 41 // Config used for a Spanner instance 42 type Config struct { 43 MigrationsTable string 44 DatabaseName string 45 } 46 47 // Spanner implements database.Driver for Google Cloud Spanner 48 type Spanner struct { 49 db *DB 50 51 config *Config 52 } 53 54 type DB struct { 55 admin *sdb.DatabaseAdminClient 56 data *spanner.Client 57 } 58 59 func NewDB(admin sdb.DatabaseAdminClient, data spanner.Client) *DB { 60 return &DB{ 61 admin: &admin, 62 data: &data, 63 } 64 } 65 66 // WithInstance implements database.Driver 67 func WithInstance(instance *DB, config *Config) (database.Driver, error) { 68 if config == nil { 69 return nil, ErrNilConfig 70 } 71 72 if len(config.DatabaseName) == 0 { 73 return nil, ErrNoDatabaseName 74 } 75 76 if len(config.MigrationsTable) == 0 { 77 config.MigrationsTable = DefaultMigrationsTable 78 } 79 80 sx := &Spanner{ 81 db: instance, 82 config: config, 83 } 84 85 if err := sx.ensureVersionTable(); err != nil { 86 return nil, err 87 } 88 89 return sx, nil 90 } 91 92 // Open implements database.Driver 93 func (s *Spanner) Open(url string) (database.Driver, error) { 94 purl, err := nurl.Parse(url) 95 if err != nil { 96 return nil, err 97 } 98 99 ctx := context.Background() 100 101 adminClient, err := sdb.NewDatabaseAdminClient(ctx) 102 if err != nil { 103 return nil, err 104 } 105 dbname := strings.Replace(migrate.FilterCustomQuery(purl).String(), "spanner://", "", 1) 106 dataClient, err := spanner.NewClient(ctx, dbname) 107 if err != nil { 108 log.Fatal(err) 109 } 110 111 migrationsTable := purl.Query().Get("x-migrations-table") 112 113 db := &DB{admin: adminClient, data: dataClient} 114 return WithInstance(db, &Config{ 115 DatabaseName: dbname, 116 MigrationsTable: migrationsTable, 117 }) 118 } 119 120 // Close implements database.Driver 121 func (s *Spanner) Close() error { 122 s.db.data.Close() 123 return s.db.admin.Close() 124 } 125 126 // Lock implements database.Driver but doesn't do anything because Spanner only 127 // enqueues the UpdateDatabaseDdlRequest. 128 func (s *Spanner) Lock() error { 129 return nil 130 } 131 132 // Unlock implements database.Driver but no action required, see Lock. 133 func (s *Spanner) Unlock() error { 134 return nil 135 } 136 137 // Run implements database.Driver 138 func (s *Spanner) Run(migration io.Reader) error { 139 migr, err := ioutil.ReadAll(migration) 140 if err != nil { 141 return err 142 } 143 144 // run migration 145 stmts := migrationStatements(migr) 146 ctx := context.Background() 147 148 op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{ 149 Database: s.config.DatabaseName, 150 Statements: stmts, 151 }) 152 153 if err != nil { 154 return &database.Error{OrigErr: err, Err: "migration failed", Query: migr} 155 } 156 157 if err := op.Wait(ctx); err != nil { 158 return &database.Error{OrigErr: err, Err: "migration failed", Query: migr} 159 } 160 161 return nil 162 } 163 164 // SetVersion implements database.Driver 165 func (s *Spanner) SetVersion(version int, dirty bool) error { 166 ctx := context.Background() 167 168 _, err := s.db.data.ReadWriteTransaction(ctx, 169 func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { 170 m := []*spanner.Mutation{ 171 spanner.Delete(s.config.MigrationsTable, spanner.AllKeys()), 172 spanner.Insert(s.config.MigrationsTable, 173 []string{"Version", "Dirty"}, 174 []interface{}{version, dirty}, 175 )} 176 return txn.BufferWrite(m) 177 }) 178 if err != nil { 179 return &database.Error{OrigErr: err} 180 } 181 182 return nil 183 } 184 185 // Version implements database.Driver 186 func (s *Spanner) Version() (version int, dirty bool, err error) { 187 ctx := context.Background() 188 189 stmt := spanner.Statement{ 190 SQL: `SELECT Version, Dirty FROM ` + s.config.MigrationsTable + ` LIMIT 1`, 191 } 192 iter := s.db.data.Single().Query(ctx, stmt) 193 defer iter.Stop() 194 195 row, err := iter.Next() 196 switch err { 197 case iterator.Done: 198 return database.NilVersion, false, nil 199 case nil: 200 var v int64 201 if err = row.Columns(&v, &dirty); err != nil { 202 return 0, false, &database.Error{OrigErr: err, Query: []byte(stmt.SQL)} 203 } 204 version = int(v) 205 default: 206 return 0, false, &database.Error{OrigErr: err, Query: []byte(stmt.SQL)} 207 } 208 209 return version, dirty, nil 210 } 211 212 var nameMatcher = regexp.MustCompile(`(CREATE TABLE\s(\S+)\s)|(CREATE.+INDEX\s(\S+)\s)`) 213 214 // Drop implements database.Driver. Retrieves the database schema first and 215 // creates statements to drop the indexes and tables accordingly. 216 // Note: The drop statements are created in reverse order to how they're 217 // provided in the schema. Assuming the schema describes how the database can 218 // be "build up", it seems logical to "unbuild" the database simply by going the 219 // opposite direction. More testing 220 func (s *Spanner) Drop() error { 221 ctx := context.Background() 222 res, err := s.db.admin.GetDatabaseDdl(ctx, &adminpb.GetDatabaseDdlRequest{ 223 Database: s.config.DatabaseName, 224 }) 225 if err != nil { 226 return &database.Error{OrigErr: err, Err: "drop failed"} 227 } 228 if len(res.Statements) == 0 { 229 return nil 230 } 231 232 stmts := make([]string, 0) 233 for i := len(res.Statements) - 1; i >= 0; i-- { 234 s := res.Statements[i] 235 m := nameMatcher.FindSubmatch([]byte(s)) 236 237 if len(m) == 0 { 238 continue 239 } else if tbl := m[2]; len(tbl) > 0 { 240 stmts = append(stmts, fmt.Sprintf(`DROP TABLE %s`, tbl)) 241 } else if idx := m[4]; len(idx) > 0 { 242 stmts = append(stmts, fmt.Sprintf(`DROP INDEX %s`, idx)) 243 } 244 } 245 246 op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{ 247 Database: s.config.DatabaseName, 248 Statements: stmts, 249 }) 250 if err != nil { 251 return &database.Error{OrigErr: err, Query: []byte(strings.Join(stmts, "; "))} 252 } 253 if err := op.Wait(ctx); err != nil { 254 return &database.Error{OrigErr: err, Query: []byte(strings.Join(stmts, "; "))} 255 } 256 257 return nil 258 } 259 260 // ensureVersionTable checks if versions table exists and, if not, creates it. 261 // Note that this function locks the database, which deviates from the usual 262 // convention of "caller locks" in the Spanner type. 263 func (s *Spanner) ensureVersionTable() (err error) { 264 if err = s.Lock(); err != nil { 265 return err 266 } 267 268 defer func() { 269 if e := s.Unlock(); e != nil { 270 if err == nil { 271 err = e 272 } else { 273 err = multierror.Append(err, e) 274 } 275 } 276 }() 277 278 ctx := context.Background() 279 tbl := s.config.MigrationsTable 280 iter := s.db.data.Single().Read(ctx, tbl, spanner.AllKeys(), []string{"Version"}) 281 if err := iter.Do(func(r *spanner.Row) error { return nil }); err == nil { 282 return nil 283 } 284 285 stmt := fmt.Sprintf(`CREATE TABLE %s ( 286 Version INT64 NOT NULL, 287 Dirty BOOL NOT NULL 288 ) PRIMARY KEY(Version)`, tbl) 289 290 op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{ 291 Database: s.config.DatabaseName, 292 Statements: []string{stmt}, 293 }) 294 295 if err != nil { 296 return &database.Error{OrigErr: err, Query: []byte(stmt)} 297 } 298 if err := op.Wait(ctx); err != nil { 299 return &database.Error{OrigErr: err, Query: []byte(stmt)} 300 } 301 302 return nil 303 } 304 305 func migrationStatements(migration []byte) []string { 306 migrationString := string(migration[:]) 307 migrationString = strings.TrimSpace(migrationString) 308 309 allStatements := strings.Split(migrationString, ";") 310 nonEmptyStatements := allStatements[:0] 311 for _, s := range allStatements { 312 s = strings.TrimSpace(s) 313 if s != "" { 314 nonEmptyStatements = append(nonEmptyStatements, s) 315 } 316 } 317 return nonEmptyStatements 318 }