github.com/mrqzzz/migrate@v5.1.7+incompatible/database/spanner/spanner.go (about)

     1  package spanner
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"io/ioutil"
     7  	"log"
     8  	nurl "net/url"
     9  	"regexp"
    10  	"strings"
    11  
    12  	"golang.org/x/net/context"
    13  
    14  	"cloud.google.com/go/spanner"
    15  	sdb "cloud.google.com/go/spanner/admin/database/apiv1"
    16  
    17  	"github.com/golang-migrate/migrate/v4"
    18  	"github.com/golang-migrate/migrate/v4/database"
    19  
    20  	"google.golang.org/api/iterator"
    21  	adminpb "google.golang.org/genproto/googleapis/spanner/admin/database/v1"
    22  )
    23  
    24  func init() {
    25  	db := Spanner{}
    26  	database.Register("spanner", &db)
    27  }
    28  
    29  // DefaultMigrationsTable is used if no custom table is specified
    30  const DefaultMigrationsTable = "SchemaMigrations"
    31  
    32  // Driver errors
    33  var (
    34  	ErrNilConfig      = fmt.Errorf("no config")
    35  	ErrNoDatabaseName = fmt.Errorf("no database name")
    36  	ErrNoSchema       = fmt.Errorf("no schema")
    37  	ErrDatabaseDirty  = fmt.Errorf("database is dirty")
    38  )
    39  
    40  // Config used for a Spanner instance
    41  type Config struct {
    42  	MigrationsTable string
    43  	DatabaseName    string
    44  }
    45  
    46  // Spanner implements database.Driver for Google Cloud Spanner
    47  type Spanner struct {
    48  	db *DB
    49  
    50  	config *Config
    51  }
    52  
    53  type DB struct {
    54  	admin *sdb.DatabaseAdminClient
    55  	data  *spanner.Client
    56  }
    57  
    58  func NewDB(admin sdb.DatabaseAdminClient, data spanner.Client) *DB {
    59  	return &DB{
    60  		admin: &admin,
    61  		data:  &data,
    62  	}
    63  }
    64  
    65  // WithInstance implements database.Driver
    66  func WithInstance(instance *DB, config *Config) (database.Driver, error) {
    67  	if config == nil {
    68  		return nil, ErrNilConfig
    69  	}
    70  
    71  	if len(config.DatabaseName) == 0 {
    72  		return nil, ErrNoDatabaseName
    73  	}
    74  
    75  	if len(config.MigrationsTable) == 0 {
    76  		config.MigrationsTable = DefaultMigrationsTable
    77  	}
    78  
    79  	sx := &Spanner{
    80  		db:     instance,
    81  		config: config,
    82  	}
    83  
    84  	if err := sx.ensureVersionTable(); err != nil {
    85  		return nil, err
    86  	}
    87  
    88  	return sx, nil
    89  }
    90  
    91  // Open implements database.Driver
    92  func (s *Spanner) Open(url string) (database.Driver, error) {
    93  	purl, err := nurl.Parse(url)
    94  	if err != nil {
    95  		return nil, err
    96  	}
    97  
    98  	ctx := context.Background()
    99  
   100  	adminClient, err := sdb.NewDatabaseAdminClient(ctx)
   101  	if err != nil {
   102  		return nil, err
   103  	}
   104  	dbname := strings.Replace(migrate.FilterCustomQuery(purl).String(), "spanner://", "", 1)
   105  	dataClient, err := spanner.NewClient(ctx, dbname)
   106  	if err != nil {
   107  		log.Fatal(err)
   108  	}
   109  
   110  	migrationsTable := purl.Query().Get("x-migrations-table")
   111  	if len(migrationsTable) == 0 {
   112  		migrationsTable = DefaultMigrationsTable
   113  	}
   114  
   115  	db := &DB{admin: adminClient, data: dataClient}
   116  	return WithInstance(db, &Config{
   117  		DatabaseName:    dbname,
   118  		MigrationsTable: migrationsTable,
   119  	})
   120  }
   121  
   122  // Close implements database.Driver
   123  func (s *Spanner) Close() error {
   124  	s.db.data.Close()
   125  	return s.db.admin.Close()
   126  }
   127  
   128  // Lock implements database.Driver but doesn't do anything because Spanner only
   129  // enqueues the UpdateDatabaseDdlRequest.
   130  func (s *Spanner) Lock() error {
   131  	return nil
   132  }
   133  
   134  // Unlock implements database.Driver but no action required, see Lock.
   135  func (s *Spanner) Unlock() error {
   136  	return nil
   137  }
   138  
   139  // Run implements database.Driver
   140  func (s *Spanner) Run(migration io.Reader) error {
   141  	migr, err := ioutil.ReadAll(migration)
   142  	if err != nil {
   143  		return err
   144  	}
   145  
   146  	// run migration
   147  	stmts := migrationStatements(migr)
   148  	ctx := context.Background()
   149  
   150  	op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{
   151  		Database:   s.config.DatabaseName,
   152  		Statements: stmts,
   153  	})
   154  
   155  	if err != nil {
   156  		return &database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   157  	}
   158  
   159  	if err := op.Wait(ctx); err != nil {
   160  		return &database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   161  	}
   162  
   163  	return nil
   164  }
   165  
   166  // SetVersion implements database.Driver
   167  func (s *Spanner) SetVersion(version int, dirty bool) error {
   168  	ctx := context.Background()
   169  
   170  	_, err := s.db.data.ReadWriteTransaction(ctx,
   171  		func(ctx context.Context, txn *spanner.ReadWriteTransaction) error {
   172  			m := []*spanner.Mutation{
   173  				spanner.Delete(s.config.MigrationsTable, spanner.AllKeys()),
   174  				spanner.Insert(s.config.MigrationsTable,
   175  					[]string{"Version", "Dirty"},
   176  					[]interface{}{version, dirty},
   177  				)}
   178  			return txn.BufferWrite(m)
   179  		})
   180  	if err != nil {
   181  		return &database.Error{OrigErr: err}
   182  	}
   183  
   184  	return nil
   185  }
   186  
   187  // Version implements database.Driver
   188  func (s *Spanner) Version() (version int, dirty bool, err error) {
   189  	ctx := context.Background()
   190  
   191  	stmt := spanner.Statement{
   192  		SQL: `SELECT Version, Dirty FROM ` + s.config.MigrationsTable + ` LIMIT 1`,
   193  	}
   194  	iter := s.db.data.Single().Query(ctx, stmt)
   195  	defer iter.Stop()
   196  
   197  	row, err := iter.Next()
   198  	switch err {
   199  	case iterator.Done:
   200  		return database.NilVersion, false, nil
   201  	case nil:
   202  		var v int64
   203  		if err = row.Columns(&v, &dirty); err != nil {
   204  			return 0, false, &database.Error{OrigErr: err, Query: []byte(stmt.SQL)}
   205  		}
   206  		version = int(v)
   207  	default:
   208  		return 0, false, &database.Error{OrigErr: err, Query: []byte(stmt.SQL)}
   209  	}
   210  
   211  	return version, dirty, nil
   212  }
   213  
   214  // Drop implements database.Driver. Retrieves the database schema first and
   215  // creates statements to drop the indexes and tables accordingly.
   216  // Note: The drop statements are created in reverse order to how they're
   217  // provided in the schema. Assuming the schema describes how the database can
   218  // be "build up", it seems logical to "unbuild" the database simply by going the
   219  // opposite direction. More testing
   220  func (s *Spanner) Drop() error {
   221  	ctx := context.Background()
   222  	res, err := s.db.admin.GetDatabaseDdl(ctx, &adminpb.GetDatabaseDdlRequest{
   223  		Database: s.config.DatabaseName,
   224  	})
   225  	if err != nil {
   226  		return &database.Error{OrigErr: err, Err: "drop failed"}
   227  	}
   228  	if len(res.Statements) == 0 {
   229  		return nil
   230  	}
   231  
   232  	r := regexp.MustCompile(`(CREATE TABLE\s(\S+)\s)|(CREATE.+INDEX\s(\S+)\s)`)
   233  	stmts := make([]string, 0)
   234  	for i := len(res.Statements) - 1; i >= 0; i-- {
   235  		s := res.Statements[i]
   236  		m := r.FindSubmatch([]byte(s))
   237  
   238  		if len(m) == 0 {
   239  			continue
   240  		} else if tbl := m[2]; len(tbl) > 0 {
   241  			stmts = append(stmts, fmt.Sprintf(`DROP TABLE %s`, tbl))
   242  		} else if idx := m[4]; len(idx) > 0 {
   243  			stmts = append(stmts, fmt.Sprintf(`DROP INDEX %s`, idx))
   244  		}
   245  	}
   246  
   247  	op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{
   248  		Database:   s.config.DatabaseName,
   249  		Statements: stmts,
   250  	})
   251  	if err != nil {
   252  		return &database.Error{OrigErr: err, Query: []byte(strings.Join(stmts, "; "))}
   253  	}
   254  	if err := op.Wait(ctx); err != nil {
   255  		return &database.Error{OrigErr: err, Query: []byte(strings.Join(stmts, "; "))}
   256  	}
   257  
   258  	if err := s.ensureVersionTable(); err != nil {
   259  		return err
   260  	}
   261  
   262  	return nil
   263  }
   264  
   265  func (s *Spanner) ensureVersionTable() error {
   266  	ctx := context.Background()
   267  	tbl := s.config.MigrationsTable
   268  	iter := s.db.data.Single().Read(ctx, tbl, spanner.AllKeys(), []string{"Version"})
   269  	if err := iter.Do(func(r *spanner.Row) error { return nil }); err == nil {
   270  		return nil
   271  	}
   272  
   273  	stmt := fmt.Sprintf(`CREATE TABLE %s (
   274      Version INT64 NOT NULL,
   275      Dirty    BOOL NOT NULL
   276  	) PRIMARY KEY(Version)`, tbl)
   277  
   278  	op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{
   279  		Database:   s.config.DatabaseName,
   280  		Statements: []string{stmt},
   281  	})
   282  
   283  	if err != nil {
   284  		return &database.Error{OrigErr: err, Query: []byte(stmt)}
   285  	}
   286  	if err := op.Wait(ctx); err != nil {
   287  		return &database.Error{OrigErr: err, Query: []byte(stmt)}
   288  	}
   289  
   290  	return nil
   291  }
   292  
   293  func migrationStatements(migration []byte) []string {
   294  	regex := regexp.MustCompile(";$")
   295  	migrationString := string(migration[:])
   296  	migrationString = strings.TrimSpace(migrationString)
   297  	migrationString = regex.ReplaceAllString(migrationString, "")
   298  
   299  	statements := strings.Split(migrationString, ";")
   300  	return statements
   301  }