github.com/kruftik/go-migrate@v3.5.4+incompatible/database/spanner/spanner.go (about)

     1  package spanner
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"io/ioutil"
     7  	"log"
     8  	nurl "net/url"
     9  	"regexp"
    10  	"strings"
    11  
    12  	"golang.org/x/net/context"
    13  
    14  	"cloud.google.com/go/spanner"
    15  	sdb "cloud.google.com/go/spanner/admin/database/apiv1"
    16  
    17  	"github.com/golang-migrate/migrate"
    18  	"github.com/golang-migrate/migrate/database"
    19  
    20  	"google.golang.org/api/iterator"
    21  	adminpb "google.golang.org/genproto/googleapis/spanner/admin/database/v1"
    22  )
    23  
    24  func init() {
    25  	db := Spanner{}
    26  	database.Register("spanner", &db)
    27  }
    28  
    29  // DefaultMigrationsTable is used if no custom table is specified
    30  const DefaultMigrationsTable = "SchemaMigrations"
    31  
    32  // Driver errors
    33  var (
    34  	ErrNilConfig      = fmt.Errorf("no config")
    35  	ErrNoDatabaseName = fmt.Errorf("no database name")
    36  	ErrNoSchema       = fmt.Errorf("no schema")
    37  	ErrDatabaseDirty  = fmt.Errorf("database is dirty")
    38  )
    39  
    40  // Config used for a Spanner instance
    41  type Config struct {
    42  	MigrationsTable string
    43  	DatabaseName    string
    44  }
    45  
    46  // Spanner implements database.Driver for Google Cloud Spanner
    47  type Spanner struct {
    48  	db *DB
    49  
    50  	config *Config
    51  }
    52  
    53  type DB struct {
    54  	admin *sdb.DatabaseAdminClient
    55  	data  *spanner.Client
    56  }
    57  
    58  // WithInstance implements database.Driver
    59  func WithInstance(instance *DB, config *Config) (database.Driver, error) {
    60  	if config == nil {
    61  		return nil, ErrNilConfig
    62  	}
    63  
    64  	if len(config.DatabaseName) == 0 {
    65  		return nil, ErrNoDatabaseName
    66  	}
    67  
    68  	if len(config.MigrationsTable) == 0 {
    69  		config.MigrationsTable = DefaultMigrationsTable
    70  	}
    71  
    72  	sx := &Spanner{
    73  		db:     instance,
    74  		config: config,
    75  	}
    76  
    77  	if err := sx.ensureVersionTable(); err != nil {
    78  		return nil, err
    79  	}
    80  
    81  	return sx, nil
    82  }
    83  
    84  // Open implements database.Driver
    85  func (s *Spanner) Open(url string) (database.Driver, error) {
    86  	purl, err := nurl.Parse(url)
    87  	if err != nil {
    88  		return nil, err
    89  	}
    90  
    91  	ctx := context.Background()
    92  
    93  	adminClient, err := sdb.NewDatabaseAdminClient(ctx)
    94  	if err != nil {
    95  		return nil, err
    96  	}
    97  	dbname := strings.Replace(migrate.FilterCustomQuery(purl).String(), "spanner://", "", 1)
    98  	dataClient, err := spanner.NewClient(ctx, dbname)
    99  	if err != nil {
   100  		log.Fatal(err)
   101  	}
   102  
   103  	migrationsTable := purl.Query().Get("x-migrations-table")
   104  	if len(migrationsTable) == 0 {
   105  		migrationsTable = DefaultMigrationsTable
   106  	}
   107  
   108  	db := &DB{admin: adminClient, data: dataClient}
   109  	return WithInstance(db, &Config{
   110  		DatabaseName:    dbname,
   111  		MigrationsTable: migrationsTable,
   112  	})
   113  }
   114  
   115  // Close implements database.Driver
   116  func (s *Spanner) Close() error {
   117  	s.db.data.Close()
   118  	return s.db.admin.Close()
   119  }
   120  
   121  // Lock implements database.Driver but doesn't do anything because Spanner only
   122  // enqueues the UpdateDatabaseDdlRequest.
   123  func (s *Spanner) Lock() error {
   124  	return nil
   125  }
   126  
   127  // Unlock implements database.Driver but no action required, see Lock.
   128  func (s *Spanner) Unlock() error {
   129  	return nil
   130  }
   131  
   132  // Run implements database.Driver
   133  func (s *Spanner) Run(migration io.Reader) error {
   134  	migr, err := ioutil.ReadAll(migration)
   135  	if err != nil {
   136  		return err
   137  	}
   138  
   139  	// run migration
   140  	stmts := migrationStatements(migr)
   141  	ctx := context.Background()
   142  
   143  	op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{
   144  		Database:   s.config.DatabaseName,
   145  		Statements: stmts,
   146  	})
   147  
   148  	if err != nil {
   149  		return &database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   150  	}
   151  
   152  	if err := op.Wait(ctx); err != nil {
   153  		return &database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   154  	}
   155  
   156  	return nil
   157  }
   158  
   159  // SetVersion implements database.Driver
   160  func (s *Spanner) SetVersion(version int, dirty bool) error {
   161  	ctx := context.Background()
   162  
   163  	_, err := s.db.data.ReadWriteTransaction(ctx,
   164  		func(ctx context.Context, txn *spanner.ReadWriteTransaction) error {
   165  			m := []*spanner.Mutation{
   166  				spanner.Delete(s.config.MigrationsTable, spanner.AllKeys()),
   167  				spanner.Insert(s.config.MigrationsTable,
   168  					[]string{"Version", "Dirty"},
   169  					[]interface{}{version, dirty},
   170  				)}
   171  			return txn.BufferWrite(m)
   172  		})
   173  	if err != nil {
   174  		return &database.Error{OrigErr: err}
   175  	}
   176  
   177  	return nil
   178  }
   179  
   180  // Version implements database.Driver
   181  func (s *Spanner) Version() (version int, dirty bool, err error) {
   182  	ctx := context.Background()
   183  
   184  	stmt := spanner.Statement{
   185  		SQL: `SELECT Version, Dirty FROM ` + s.config.MigrationsTable + ` LIMIT 1`,
   186  	}
   187  	iter := s.db.data.Single().Query(ctx, stmt)
   188  	defer iter.Stop()
   189  
   190  	row, err := iter.Next()
   191  	switch err {
   192  	case iterator.Done:
   193  		return database.NilVersion, false, nil
   194  	case nil:
   195  		var v int64
   196  		if err = row.Columns(&v, &dirty); err != nil {
   197  			return 0, false, &database.Error{OrigErr: err, Query: []byte(stmt.SQL)}
   198  		}
   199  		version = int(v)
   200  	default:
   201  		return 0, false, &database.Error{OrigErr: err, Query: []byte(stmt.SQL)}
   202  	}
   203  
   204  	return version, dirty, nil
   205  }
   206  
   207  // Drop implements database.Driver. Retrieves the database schema first and
   208  // creates statements to drop the indexes and tables accordingly.
   209  // Note: The drop statements are created in reverse order to how they're
   210  // provided in the schema. Assuming the schema describes how the database can
   211  // be "build up", it seems logical to "unbuild" the database simply by going the
   212  // opposite direction. More testing
   213  func (s *Spanner) Drop() error {
   214  	ctx := context.Background()
   215  	res, err := s.db.admin.GetDatabaseDdl(ctx, &adminpb.GetDatabaseDdlRequest{
   216  		Database: s.config.DatabaseName,
   217  	})
   218  	if err != nil {
   219  		return &database.Error{OrigErr: err, Err: "drop failed"}
   220  	}
   221  	if len(res.Statements) == 0 {
   222  		return nil
   223  	}
   224  
   225  	r := regexp.MustCompile(`(CREATE TABLE\s(\S+)\s)|(CREATE.+INDEX\s(\S+)\s)`)
   226  	stmts := make([]string, 0)
   227  	for i := len(res.Statements) - 1; i >= 0; i-- {
   228  		s := res.Statements[i]
   229  		m := r.FindSubmatch([]byte(s))
   230  
   231  		if len(m) == 0 {
   232  			continue
   233  		} else if tbl := m[2]; len(tbl) > 0 {
   234  			stmts = append(stmts, fmt.Sprintf(`DROP TABLE %s`, tbl))
   235  		} else if idx := m[4]; len(idx) > 0 {
   236  			stmts = append(stmts, fmt.Sprintf(`DROP INDEX %s`, idx))
   237  		}
   238  	}
   239  
   240  	op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{
   241  		Database:   s.config.DatabaseName,
   242  		Statements: stmts,
   243  	})
   244  	if err != nil {
   245  		return &database.Error{OrigErr: err, Query: []byte(strings.Join(stmts, "; "))}
   246  	}
   247  	if err := op.Wait(ctx); err != nil {
   248  		return &database.Error{OrigErr: err, Query: []byte(strings.Join(stmts, "; "))}
   249  	}
   250  
   251  	if err := s.ensureVersionTable(); err != nil {
   252  		return err
   253  	}
   254  
   255  	return nil
   256  }
   257  
   258  func (s *Spanner) ensureVersionTable() error {
   259  	ctx := context.Background()
   260  	tbl := s.config.MigrationsTable
   261  	iter := s.db.data.Single().Read(ctx, tbl, spanner.AllKeys(), []string{"Version"})
   262  	if err := iter.Do(func(r *spanner.Row) error { return nil }); err == nil {
   263  		return nil
   264  	}
   265  
   266  	stmt := fmt.Sprintf(`CREATE TABLE %s (
   267      Version INT64 NOT NULL,
   268      Dirty    BOOL NOT NULL
   269  	) PRIMARY KEY(Version)`, tbl)
   270  
   271  	op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{
   272  		Database:   s.config.DatabaseName,
   273  		Statements: []string{stmt},
   274  	})
   275  
   276  	if err != nil {
   277  		return &database.Error{OrigErr: err, Query: []byte(stmt)}
   278  	}
   279  	if err := op.Wait(ctx); err != nil {
   280  		return &database.Error{OrigErr: err, Query: []byte(stmt)}
   281  	}
   282  
   283  	return nil
   284  }
   285  
   286  func migrationStatements(migration []byte) []string {
   287  	regex := regexp.MustCompile(";$")
   288  	migrationString := string(migration[:])
   289  	migrationString = strings.TrimSpace(migrationString)
   290  	migrationString = regex.ReplaceAllString(migrationString, "")
   291  
   292  	statements := strings.Split(migrationString, ";")
   293  	return statements
   294  }