github.com/mrqzzz/migrate@v5.1.7+incompatible/database/spanner/spanner.go (about) 1 package spanner 2 3 import ( 4 "fmt" 5 "io" 6 "io/ioutil" 7 "log" 8 nurl "net/url" 9 "regexp" 10 "strings" 11 12 "golang.org/x/net/context" 13 14 "cloud.google.com/go/spanner" 15 sdb "cloud.google.com/go/spanner/admin/database/apiv1" 16 17 "github.com/golang-migrate/migrate/v4" 18 "github.com/golang-migrate/migrate/v4/database" 19 20 "google.golang.org/api/iterator" 21 adminpb "google.golang.org/genproto/googleapis/spanner/admin/database/v1" 22 ) 23 24 func init() { 25 db := Spanner{} 26 database.Register("spanner", &db) 27 } 28 29 // DefaultMigrationsTable is used if no custom table is specified 30 const DefaultMigrationsTable = "SchemaMigrations" 31 32 // Driver errors 33 var ( 34 ErrNilConfig = fmt.Errorf("no config") 35 ErrNoDatabaseName = fmt.Errorf("no database name") 36 ErrNoSchema = fmt.Errorf("no schema") 37 ErrDatabaseDirty = fmt.Errorf("database is dirty") 38 ) 39 40 // Config used for a Spanner instance 41 type Config struct { 42 MigrationsTable string 43 DatabaseName string 44 } 45 46 // Spanner implements database.Driver for Google Cloud Spanner 47 type Spanner struct { 48 db *DB 49 50 config *Config 51 } 52 53 type DB struct { 54 admin *sdb.DatabaseAdminClient 55 data *spanner.Client 56 } 57 58 func NewDB(admin sdb.DatabaseAdminClient, data spanner.Client) *DB { 59 return &DB{ 60 admin: &admin, 61 data: &data, 62 } 63 } 64 65 // WithInstance implements database.Driver 66 func WithInstance(instance *DB, config *Config) (database.Driver, error) { 67 if config == nil { 68 return nil, ErrNilConfig 69 } 70 71 if len(config.DatabaseName) == 0 { 72 return nil, ErrNoDatabaseName 73 } 74 75 if len(config.MigrationsTable) == 0 { 76 config.MigrationsTable = DefaultMigrationsTable 77 } 78 79 sx := &Spanner{ 80 db: instance, 81 config: config, 82 } 83 84 if err := sx.ensureVersionTable(); err != nil { 85 return nil, err 86 } 87 88 return sx, nil 89 } 90 91 // Open implements database.Driver 92 func (s *Spanner) Open(url string) (database.Driver, error) { 93 purl, err := nurl.Parse(url) 94 if err != nil { 95 return nil, err 96 } 97 98 ctx := context.Background() 99 100 adminClient, err := sdb.NewDatabaseAdminClient(ctx) 101 if err != nil { 102 return nil, err 103 } 104 dbname := strings.Replace(migrate.FilterCustomQuery(purl).String(), "spanner://", "", 1) 105 dataClient, err := spanner.NewClient(ctx, dbname) 106 if err != nil { 107 log.Fatal(err) 108 } 109 110 migrationsTable := purl.Query().Get("x-migrations-table") 111 if len(migrationsTable) == 0 { 112 migrationsTable = DefaultMigrationsTable 113 } 114 115 db := &DB{admin: adminClient, data: dataClient} 116 return WithInstance(db, &Config{ 117 DatabaseName: dbname, 118 MigrationsTable: migrationsTable, 119 }) 120 } 121 122 // Close implements database.Driver 123 func (s *Spanner) Close() error { 124 s.db.data.Close() 125 return s.db.admin.Close() 126 } 127 128 // Lock implements database.Driver but doesn't do anything because Spanner only 129 // enqueues the UpdateDatabaseDdlRequest. 130 func (s *Spanner) Lock() error { 131 return nil 132 } 133 134 // Unlock implements database.Driver but no action required, see Lock. 135 func (s *Spanner) Unlock() error { 136 return nil 137 } 138 139 // Run implements database.Driver 140 func (s *Spanner) Run(migration io.Reader) error { 141 migr, err := ioutil.ReadAll(migration) 142 if err != nil { 143 return err 144 } 145 146 // run migration 147 stmts := migrationStatements(migr) 148 ctx := context.Background() 149 150 op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{ 151 Database: s.config.DatabaseName, 152 Statements: stmts, 153 }) 154 155 if err != nil { 156 return &database.Error{OrigErr: err, Err: "migration failed", Query: migr} 157 } 158 159 if err := op.Wait(ctx); err != nil { 160 return &database.Error{OrigErr: err, Err: "migration failed", Query: migr} 161 } 162 163 return nil 164 } 165 166 // SetVersion implements database.Driver 167 func (s *Spanner) SetVersion(version int, dirty bool) error { 168 ctx := context.Background() 169 170 _, err := s.db.data.ReadWriteTransaction(ctx, 171 func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { 172 m := []*spanner.Mutation{ 173 spanner.Delete(s.config.MigrationsTable, spanner.AllKeys()), 174 spanner.Insert(s.config.MigrationsTable, 175 []string{"Version", "Dirty"}, 176 []interface{}{version, dirty}, 177 )} 178 return txn.BufferWrite(m) 179 }) 180 if err != nil { 181 return &database.Error{OrigErr: err} 182 } 183 184 return nil 185 } 186 187 // Version implements database.Driver 188 func (s *Spanner) Version() (version int, dirty bool, err error) { 189 ctx := context.Background() 190 191 stmt := spanner.Statement{ 192 SQL: `SELECT Version, Dirty FROM ` + s.config.MigrationsTable + ` LIMIT 1`, 193 } 194 iter := s.db.data.Single().Query(ctx, stmt) 195 defer iter.Stop() 196 197 row, err := iter.Next() 198 switch err { 199 case iterator.Done: 200 return database.NilVersion, false, nil 201 case nil: 202 var v int64 203 if err = row.Columns(&v, &dirty); err != nil { 204 return 0, false, &database.Error{OrigErr: err, Query: []byte(stmt.SQL)} 205 } 206 version = int(v) 207 default: 208 return 0, false, &database.Error{OrigErr: err, Query: []byte(stmt.SQL)} 209 } 210 211 return version, dirty, nil 212 } 213 214 // Drop implements database.Driver. Retrieves the database schema first and 215 // creates statements to drop the indexes and tables accordingly. 216 // Note: The drop statements are created in reverse order to how they're 217 // provided in the schema. Assuming the schema describes how the database can 218 // be "build up", it seems logical to "unbuild" the database simply by going the 219 // opposite direction. More testing 220 func (s *Spanner) Drop() error { 221 ctx := context.Background() 222 res, err := s.db.admin.GetDatabaseDdl(ctx, &adminpb.GetDatabaseDdlRequest{ 223 Database: s.config.DatabaseName, 224 }) 225 if err != nil { 226 return &database.Error{OrigErr: err, Err: "drop failed"} 227 } 228 if len(res.Statements) == 0 { 229 return nil 230 } 231 232 r := regexp.MustCompile(`(CREATE TABLE\s(\S+)\s)|(CREATE.+INDEX\s(\S+)\s)`) 233 stmts := make([]string, 0) 234 for i := len(res.Statements) - 1; i >= 0; i-- { 235 s := res.Statements[i] 236 m := r.FindSubmatch([]byte(s)) 237 238 if len(m) == 0 { 239 continue 240 } else if tbl := m[2]; len(tbl) > 0 { 241 stmts = append(stmts, fmt.Sprintf(`DROP TABLE %s`, tbl)) 242 } else if idx := m[4]; len(idx) > 0 { 243 stmts = append(stmts, fmt.Sprintf(`DROP INDEX %s`, idx)) 244 } 245 } 246 247 op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{ 248 Database: s.config.DatabaseName, 249 Statements: stmts, 250 }) 251 if err != nil { 252 return &database.Error{OrigErr: err, Query: []byte(strings.Join(stmts, "; "))} 253 } 254 if err := op.Wait(ctx); err != nil { 255 return &database.Error{OrigErr: err, Query: []byte(strings.Join(stmts, "; "))} 256 } 257 258 if err := s.ensureVersionTable(); err != nil { 259 return err 260 } 261 262 return nil 263 } 264 265 func (s *Spanner) ensureVersionTable() error { 266 ctx := context.Background() 267 tbl := s.config.MigrationsTable 268 iter := s.db.data.Single().Read(ctx, tbl, spanner.AllKeys(), []string{"Version"}) 269 if err := iter.Do(func(r *spanner.Row) error { return nil }); err == nil { 270 return nil 271 } 272 273 stmt := fmt.Sprintf(`CREATE TABLE %s ( 274 Version INT64 NOT NULL, 275 Dirty BOOL NOT NULL 276 ) PRIMARY KEY(Version)`, tbl) 277 278 op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{ 279 Database: s.config.DatabaseName, 280 Statements: []string{stmt}, 281 }) 282 283 if err != nil { 284 return &database.Error{OrigErr: err, Query: []byte(stmt)} 285 } 286 if err := op.Wait(ctx); err != nil { 287 return &database.Error{OrigErr: err, Query: []byte(stmt)} 288 } 289 290 return nil 291 } 292 293 func migrationStatements(migration []byte) []string { 294 regex := regexp.MustCompile(";$") 295 migrationString := string(migration[:]) 296 migrationString = strings.TrimSpace(migrationString) 297 migrationString = regex.ReplaceAllString(migrationString, "") 298 299 statements := strings.Split(migrationString, ";") 300 return statements 301 }