github.com/dynastymasra/migrate/v4@v4.11.0/database/spanner/spanner.go (about)

     1  package spanner
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"io/ioutil"
     7  	"log"
     8  	nurl "net/url"
     9  	"regexp"
    10  	"strings"
    11  
    12  	"golang.org/x/net/context"
    13  
    14  	"cloud.google.com/go/spanner"
    15  	sdb "cloud.google.com/go/spanner/admin/database/apiv1"
    16  
    17  	"github.com/golang-migrate/migrate/v4"
    18  	"github.com/golang-migrate/migrate/v4/database"
    19  
    20  	"github.com/hashicorp/go-multierror"
    21  	"google.golang.org/api/iterator"
    22  	adminpb "google.golang.org/genproto/googleapis/spanner/admin/database/v1"
    23  )
    24  
    25  func init() {
    26  	db := Spanner{}
    27  	database.Register("spanner", &db)
    28  }
    29  
    30  // DefaultMigrationsTable is used if no custom table is specified
    31  const DefaultMigrationsTable = "SchemaMigrations"
    32  
    33  // Driver errors
    34  var (
    35  	ErrNilConfig      = fmt.Errorf("no config")
    36  	ErrNoDatabaseName = fmt.Errorf("no database name")
    37  	ErrNoSchema       = fmt.Errorf("no schema")
    38  	ErrDatabaseDirty  = fmt.Errorf("database is dirty")
    39  )
    40  
    41  // Config used for a Spanner instance
    42  type Config struct {
    43  	MigrationsTable string
    44  	DatabaseName    string
    45  }
    46  
    47  // Spanner implements database.Driver for Google Cloud Spanner
    48  type Spanner struct {
    49  	db *DB
    50  
    51  	config *Config
    52  }
    53  
    54  type DB struct {
    55  	admin *sdb.DatabaseAdminClient
    56  	data  *spanner.Client
    57  }
    58  
    59  func NewDB(admin sdb.DatabaseAdminClient, data spanner.Client) *DB {
    60  	return &DB{
    61  		admin: &admin,
    62  		data:  &data,
    63  	}
    64  }
    65  
    66  // WithInstance implements database.Driver
    67  func WithInstance(instance *DB, config *Config) (database.Driver, error) {
    68  	if config == nil {
    69  		return nil, ErrNilConfig
    70  	}
    71  
    72  	if len(config.DatabaseName) == 0 {
    73  		return nil, ErrNoDatabaseName
    74  	}
    75  
    76  	if len(config.MigrationsTable) == 0 {
    77  		config.MigrationsTable = DefaultMigrationsTable
    78  	}
    79  
    80  	sx := &Spanner{
    81  		db:     instance,
    82  		config: config,
    83  	}
    84  
    85  	if err := sx.ensureVersionTable(); err != nil {
    86  		return nil, err
    87  	}
    88  
    89  	return sx, nil
    90  }
    91  
    92  // Open implements database.Driver
    93  func (s *Spanner) Open(url string) (database.Driver, error) {
    94  	purl, err := nurl.Parse(url)
    95  	if err != nil {
    96  		return nil, err
    97  	}
    98  
    99  	ctx := context.Background()
   100  
   101  	adminClient, err := sdb.NewDatabaseAdminClient(ctx)
   102  	if err != nil {
   103  		return nil, err
   104  	}
   105  	dbname := strings.Replace(migrate.FilterCustomQuery(purl).String(), "spanner://", "", 1)
   106  	dataClient, err := spanner.NewClient(ctx, dbname)
   107  	if err != nil {
   108  		log.Fatal(err)
   109  	}
   110  
   111  	migrationsTable := purl.Query().Get("x-migrations-table")
   112  
   113  	db := &DB{admin: adminClient, data: dataClient}
   114  	return WithInstance(db, &Config{
   115  		DatabaseName:    dbname,
   116  		MigrationsTable: migrationsTable,
   117  	})
   118  }
   119  
   120  // Close implements database.Driver
   121  func (s *Spanner) Close() error {
   122  	s.db.data.Close()
   123  	return s.db.admin.Close()
   124  }
   125  
   126  // Lock implements database.Driver but doesn't do anything because Spanner only
   127  // enqueues the UpdateDatabaseDdlRequest.
   128  func (s *Spanner) Lock() error {
   129  	return nil
   130  }
   131  
   132  // Unlock implements database.Driver but no action required, see Lock.
   133  func (s *Spanner) Unlock() error {
   134  	return nil
   135  }
   136  
   137  // Run implements database.Driver
   138  func (s *Spanner) Run(migration io.Reader) error {
   139  	migr, err := ioutil.ReadAll(migration)
   140  	if err != nil {
   141  		return err
   142  	}
   143  
   144  	// run migration
   145  	stmts := migrationStatements(migr)
   146  	ctx := context.Background()
   147  
   148  	op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{
   149  		Database:   s.config.DatabaseName,
   150  		Statements: stmts,
   151  	})
   152  
   153  	if err != nil {
   154  		return &database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   155  	}
   156  
   157  	if err := op.Wait(ctx); err != nil {
   158  		return &database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   159  	}
   160  
   161  	return nil
   162  }
   163  
   164  // SetVersion implements database.Driver
   165  func (s *Spanner) SetVersion(version int, dirty bool) error {
   166  	ctx := context.Background()
   167  
   168  	_, err := s.db.data.ReadWriteTransaction(ctx,
   169  		func(ctx context.Context, txn *spanner.ReadWriteTransaction) error {
   170  			m := []*spanner.Mutation{
   171  				spanner.Delete(s.config.MigrationsTable, spanner.AllKeys()),
   172  				spanner.Insert(s.config.MigrationsTable,
   173  					[]string{"Version", "Dirty"},
   174  					[]interface{}{version, dirty},
   175  				)}
   176  			return txn.BufferWrite(m)
   177  		})
   178  	if err != nil {
   179  		return &database.Error{OrigErr: err}
   180  	}
   181  
   182  	return nil
   183  }
   184  
   185  // Version implements database.Driver
   186  func (s *Spanner) Version() (version int, dirty bool, err error) {
   187  	ctx := context.Background()
   188  
   189  	stmt := spanner.Statement{
   190  		SQL: `SELECT Version, Dirty FROM ` + s.config.MigrationsTable + ` LIMIT 1`,
   191  	}
   192  	iter := s.db.data.Single().Query(ctx, stmt)
   193  	defer iter.Stop()
   194  
   195  	row, err := iter.Next()
   196  	switch err {
   197  	case iterator.Done:
   198  		return database.NilVersion, false, nil
   199  	case nil:
   200  		var v int64
   201  		if err = row.Columns(&v, &dirty); err != nil {
   202  			return 0, false, &database.Error{OrigErr: err, Query: []byte(stmt.SQL)}
   203  		}
   204  		version = int(v)
   205  	default:
   206  		return 0, false, &database.Error{OrigErr: err, Query: []byte(stmt.SQL)}
   207  	}
   208  
   209  	return version, dirty, nil
   210  }
   211  
   212  var nameMatcher = regexp.MustCompile(`(CREATE TABLE\s(\S+)\s)|(CREATE.+INDEX\s(\S+)\s)`)
   213  
   214  // Drop implements database.Driver. Retrieves the database schema first and
   215  // creates statements to drop the indexes and tables accordingly.
   216  // Note: The drop statements are created in reverse order to how they're
   217  // provided in the schema. Assuming the schema describes how the database can
   218  // be "build up", it seems logical to "unbuild" the database simply by going the
   219  // opposite direction. More testing
   220  func (s *Spanner) Drop() error {
   221  	ctx := context.Background()
   222  	res, err := s.db.admin.GetDatabaseDdl(ctx, &adminpb.GetDatabaseDdlRequest{
   223  		Database: s.config.DatabaseName,
   224  	})
   225  	if err != nil {
   226  		return &database.Error{OrigErr: err, Err: "drop failed"}
   227  	}
   228  	if len(res.Statements) == 0 {
   229  		return nil
   230  	}
   231  
   232  	stmts := make([]string, 0)
   233  	for i := len(res.Statements) - 1; i >= 0; i-- {
   234  		s := res.Statements[i]
   235  		m := nameMatcher.FindSubmatch([]byte(s))
   236  
   237  		if len(m) == 0 {
   238  			continue
   239  		} else if tbl := m[2]; len(tbl) > 0 {
   240  			stmts = append(stmts, fmt.Sprintf(`DROP TABLE %s`, tbl))
   241  		} else if idx := m[4]; len(idx) > 0 {
   242  			stmts = append(stmts, fmt.Sprintf(`DROP INDEX %s`, idx))
   243  		}
   244  	}
   245  
   246  	op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{
   247  		Database:   s.config.DatabaseName,
   248  		Statements: stmts,
   249  	})
   250  	if err != nil {
   251  		return &database.Error{OrigErr: err, Query: []byte(strings.Join(stmts, "; "))}
   252  	}
   253  	if err := op.Wait(ctx); err != nil {
   254  		return &database.Error{OrigErr: err, Query: []byte(strings.Join(stmts, "; "))}
   255  	}
   256  
   257  	return nil
   258  }
   259  
   260  // ensureVersionTable checks if versions table exists and, if not, creates it.
   261  // Note that this function locks the database, which deviates from the usual
   262  // convention of "caller locks" in the Spanner type.
   263  func (s *Spanner) ensureVersionTable() (err error) {
   264  	if err = s.Lock(); err != nil {
   265  		return err
   266  	}
   267  
   268  	defer func() {
   269  		if e := s.Unlock(); e != nil {
   270  			if err == nil {
   271  				err = e
   272  			} else {
   273  				err = multierror.Append(err, e)
   274  			}
   275  		}
   276  	}()
   277  
   278  	ctx := context.Background()
   279  	tbl := s.config.MigrationsTable
   280  	iter := s.db.data.Single().Read(ctx, tbl, spanner.AllKeys(), []string{"Version"})
   281  	if err := iter.Do(func(r *spanner.Row) error { return nil }); err == nil {
   282  		return nil
   283  	}
   284  
   285  	stmt := fmt.Sprintf(`CREATE TABLE %s (
   286      Version INT64 NOT NULL,
   287      Dirty    BOOL NOT NULL
   288  	) PRIMARY KEY(Version)`, tbl)
   289  
   290  	op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{
   291  		Database:   s.config.DatabaseName,
   292  		Statements: []string{stmt},
   293  	})
   294  
   295  	if err != nil {
   296  		return &database.Error{OrigErr: err, Query: []byte(stmt)}
   297  	}
   298  	if err := op.Wait(ctx); err != nil {
   299  		return &database.Error{OrigErr: err, Query: []byte(stmt)}
   300  	}
   301  
   302  	return nil
   303  }
   304  
   305  func migrationStatements(migration []byte) []string {
   306  	migrationString := string(migration[:])
   307  	migrationString = strings.TrimSpace(migrationString)
   308  
   309  	allStatements := strings.Split(migrationString, ";")
   310  	nonEmptyStatements := allStatements[:0]
   311  	for _, s := range allStatements {
   312  		s = strings.TrimSpace(s)
   313  		if s != "" {
   314  			nonEmptyStatements = append(nonEmptyStatements, s)
   315  		}
   316  	}
   317  	return nonEmptyStatements
   318  }