github.com/nokia/migrate/v4@v4.16.0/database/spanner/spanner.go (about)

     1  package spanner
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"io/ioutil"
     9  	"log"
    10  	nurl "net/url"
    11  	"regexp"
    12  	"strconv"
    13  	"strings"
    14  
    15  	"cloud.google.com/go/spanner"
    16  	sdb "cloud.google.com/go/spanner/admin/database/apiv1"
    17  	"cloud.google.com/go/spanner/spansql"
    18  
    19  	"github.com/nokia/migrate/v4"
    20  	"github.com/nokia/migrate/v4/database"
    21  	"github.com/nokia/migrate/v4/source"
    22  
    23  	"github.com/hashicorp/go-multierror"
    24  	uatomic "go.uber.org/atomic"
    25  	"google.golang.org/api/iterator"
    26  	adminpb "google.golang.org/genproto/googleapis/spanner/admin/database/v1"
    27  )
    28  
    29  func init() {
    30  	db := Spanner{}
    31  	database.Register("spanner", &db)
    32  }
    33  
    34  // DefaultMigrationsTable is used if no custom table is specified
    35  const DefaultMigrationsTable = "SchemaMigrations"
    36  
    37  const (
    38  	unlockedVal = 0
    39  	lockedVal   = 1
    40  )
    41  
    42  // Driver errors
    43  var (
    44  	ErrNilConfig      = errors.New("no config")
    45  	ErrNoDatabaseName = errors.New("no database name")
    46  	ErrNoSchema       = errors.New("no schema")
    47  	ErrDatabaseDirty  = errors.New("database is dirty")
    48  	ErrLockHeld       = errors.New("unable to obtain lock")
    49  	ErrLockNotHeld    = errors.New("unable to release already released lock")
    50  )
    51  
    52  // Config used for a Spanner instance
    53  type Config struct {
    54  	MigrationsTable string
    55  	DatabaseName    string
    56  	// Whether to parse the migration DDL with spansql before
    57  	// running them towards Spanner.
    58  	// Parsing outputs clean DDL statements such as reformatted
    59  	// and void of comments.
    60  	CleanStatements bool
    61  }
    62  
    63  // Spanner implements database.Driver for Google Cloud Spanner
    64  type Spanner struct {
    65  	db *DB
    66  
    67  	config *Config
    68  
    69  	lock *uatomic.Uint32
    70  }
    71  
    72  type DB struct {
    73  	admin *sdb.DatabaseAdminClient
    74  	data  *spanner.Client
    75  }
    76  
    77  func NewDB(admin sdb.DatabaseAdminClient, data spanner.Client) *DB {
    78  	return &DB{
    79  		admin: &admin,
    80  		data:  &data,
    81  	}
    82  }
    83  
    84  // WithInstance implements database.Driver
    85  func WithInstance(instance *DB, config *Config) (database.Driver, error) {
    86  	if config == nil {
    87  		return nil, ErrNilConfig
    88  	}
    89  
    90  	if len(config.DatabaseName) == 0 {
    91  		return nil, ErrNoDatabaseName
    92  	}
    93  
    94  	if len(config.MigrationsTable) == 0 {
    95  		config.MigrationsTable = DefaultMigrationsTable
    96  	}
    97  
    98  	sx := &Spanner{
    99  		db:     instance,
   100  		config: config,
   101  		lock:   uatomic.NewUint32(unlockedVal),
   102  	}
   103  
   104  	if err := sx.ensureVersionTable(); err != nil {
   105  		return nil, err
   106  	}
   107  
   108  	return sx, nil
   109  }
   110  
   111  // Open implements database.Driver
   112  func (s *Spanner) Open(url string) (database.Driver, error) {
   113  	purl, err := nurl.Parse(url)
   114  	if err != nil {
   115  		return nil, err
   116  	}
   117  
   118  	ctx := context.Background()
   119  
   120  	adminClient, err := sdb.NewDatabaseAdminClient(ctx)
   121  	if err != nil {
   122  		return nil, err
   123  	}
   124  	dbname := strings.Replace(migrate.FilterCustomQuery(purl).String(), "spanner://", "", 1)
   125  	dataClient, err := spanner.NewClient(ctx, dbname)
   126  	if err != nil {
   127  		log.Fatal(err)
   128  	}
   129  
   130  	migrationsTable := purl.Query().Get("x-migrations-table")
   131  
   132  	cleanQuery := purl.Query().Get("x-clean-statements")
   133  	clean := false
   134  	if cleanQuery != "" {
   135  		clean, err = strconv.ParseBool(cleanQuery)
   136  		if err != nil {
   137  			return nil, err
   138  		}
   139  	}
   140  
   141  	db := &DB{admin: adminClient, data: dataClient}
   142  	return WithInstance(db, &Config{
   143  		DatabaseName:    dbname,
   144  		MigrationsTable: migrationsTable,
   145  		CleanStatements: clean,
   146  	})
   147  }
   148  
   149  // Close implements database.Driver
   150  func (s *Spanner) Close() error {
   151  	s.db.data.Close()
   152  	return s.db.admin.Close()
   153  }
   154  
   155  // Lock implements database.Driver but doesn't do anything because Spanner only
   156  // enqueues the UpdateDatabaseDdlRequest.
   157  func (s *Spanner) Lock() error {
   158  	if swapped := s.lock.CAS(unlockedVal, lockedVal); swapped {
   159  		return nil
   160  	}
   161  	return ErrLockHeld
   162  }
   163  
   164  // Unlock implements database.Driver but no action required, see Lock.
   165  func (s *Spanner) Unlock() error {
   166  	if swapped := s.lock.CAS(lockedVal, unlockedVal); swapped {
   167  		return nil
   168  	}
   169  	return ErrLockNotHeld
   170  }
   171  
   172  // Run implements database.Driver
   173  func (s *Spanner) Run(migration io.Reader) error {
   174  	migr, err := ioutil.ReadAll(migration)
   175  	if err != nil {
   176  		return err
   177  	}
   178  
   179  	stmts := []string{string(migr)}
   180  	if s.config.CleanStatements {
   181  		stmts, err = cleanStatements(migr)
   182  		if err != nil {
   183  			return err
   184  		}
   185  	}
   186  
   187  	ctx := context.Background()
   188  	op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{
   189  		Database:   s.config.DatabaseName,
   190  		Statements: stmts,
   191  	})
   192  	if err != nil {
   193  		return &database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   194  	}
   195  
   196  	if err := op.Wait(ctx); err != nil {
   197  		return &database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   198  	}
   199  
   200  	return nil
   201  }
   202  
   203  func (s *Spanner) RunFunctionMigration(fn source.MigrationFunc) error {
   204  	return database.ErrNotImpl
   205  }
   206  
   207  // SetVersion implements database.Driver
   208  func (s *Spanner) SetVersion(version int, dirty bool) error {
   209  	ctx := context.Background()
   210  
   211  	_, err := s.db.data.ReadWriteTransaction(ctx,
   212  		func(ctx context.Context, txn *spanner.ReadWriteTransaction) error {
   213  			m := []*spanner.Mutation{
   214  				spanner.Delete(s.config.MigrationsTable, spanner.AllKeys()),
   215  				spanner.Insert(s.config.MigrationsTable,
   216  					[]string{"Version", "Dirty"},
   217  					[]interface{}{version, dirty},
   218  				),
   219  			}
   220  			return txn.BufferWrite(m)
   221  		})
   222  	if err != nil {
   223  		return &database.Error{OrigErr: err}
   224  	}
   225  
   226  	return nil
   227  }
   228  
   229  // Version implements database.Driver
   230  func (s *Spanner) Version() (version int, dirty bool, err error) {
   231  	ctx := context.Background()
   232  
   233  	stmt := spanner.Statement{
   234  		SQL: `SELECT Version, Dirty FROM ` + s.config.MigrationsTable + ` LIMIT 1`,
   235  	}
   236  	iter := s.db.data.Single().Query(ctx, stmt)
   237  	defer iter.Stop()
   238  
   239  	row, err := iter.Next()
   240  	switch err {
   241  	case iterator.Done:
   242  		return database.NilVersion, false, nil
   243  	case nil:
   244  		var v int64
   245  		if err = row.Columns(&v, &dirty); err != nil {
   246  			return 0, false, &database.Error{OrigErr: err, Query: []byte(stmt.SQL)}
   247  		}
   248  		version = int(v)
   249  	default:
   250  		return 0, false, &database.Error{OrigErr: err, Query: []byte(stmt.SQL)}
   251  	}
   252  
   253  	return version, dirty, nil
   254  }
   255  
   256  var nameMatcher = regexp.MustCompile(`(CREATE TABLE\s(\S+)\s)|(CREATE.+INDEX\s(\S+)\s)`)
   257  
   258  // Drop implements database.Driver. Retrieves the database schema first and
   259  // creates statements to drop the indexes and tables accordingly.
   260  // Note: The drop statements are created in reverse order to how they're
   261  // provided in the schema. Assuming the schema describes how the database can
   262  // be "build up", it seems logical to "unbuild" the database simply by going the
   263  // opposite direction. More testing
   264  func (s *Spanner) Drop() error {
   265  	ctx := context.Background()
   266  	res, err := s.db.admin.GetDatabaseDdl(ctx, &adminpb.GetDatabaseDdlRequest{
   267  		Database: s.config.DatabaseName,
   268  	})
   269  	if err != nil {
   270  		return &database.Error{OrigErr: err, Err: "drop failed"}
   271  	}
   272  	if len(res.Statements) == 0 {
   273  		return nil
   274  	}
   275  
   276  	stmts := make([]string, 0)
   277  	for i := len(res.Statements) - 1; i >= 0; i-- {
   278  		s := res.Statements[i]
   279  		m := nameMatcher.FindSubmatch([]byte(s))
   280  
   281  		if len(m) == 0 {
   282  			continue
   283  		} else if tbl := m[2]; len(tbl) > 0 {
   284  			stmts = append(stmts, fmt.Sprintf(`DROP TABLE %s`, tbl))
   285  		} else if idx := m[4]; len(idx) > 0 {
   286  			stmts = append(stmts, fmt.Sprintf(`DROP INDEX %s`, idx))
   287  		}
   288  	}
   289  
   290  	op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{
   291  		Database:   s.config.DatabaseName,
   292  		Statements: stmts,
   293  	})
   294  	if err != nil {
   295  		return &database.Error{OrigErr: err, Query: []byte(strings.Join(stmts, "; "))}
   296  	}
   297  	if err := op.Wait(ctx); err != nil {
   298  		return &database.Error{OrigErr: err, Query: []byte(strings.Join(stmts, "; "))}
   299  	}
   300  
   301  	return nil
   302  }
   303  
   304  // ensureVersionTable checks if versions table exists and, if not, creates it.
   305  // Note that this function locks the database, which deviates from the usual
   306  // convention of "caller locks" in the Spanner type.
   307  func (s *Spanner) ensureVersionTable() (err error) {
   308  	if err = s.Lock(); err != nil {
   309  		return err
   310  	}
   311  
   312  	defer func() {
   313  		if e := s.Unlock(); e != nil {
   314  			if err == nil {
   315  				err = e
   316  			} else {
   317  				err = multierror.Append(err, e)
   318  			}
   319  		}
   320  	}()
   321  
   322  	ctx := context.Background()
   323  	tbl := s.config.MigrationsTable
   324  	iter := s.db.data.Single().Read(ctx, tbl, spanner.AllKeys(), []string{"Version"})
   325  	if err := iter.Do(func(r *spanner.Row) error { return nil }); err == nil {
   326  		return nil
   327  	}
   328  
   329  	stmt := fmt.Sprintf(`CREATE TABLE %s (
   330      Version INT64 NOT NULL,
   331      Dirty    BOOL NOT NULL
   332  	) PRIMARY KEY(Version)`, tbl)
   333  
   334  	op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{
   335  		Database:   s.config.DatabaseName,
   336  		Statements: []string{stmt},
   337  	})
   338  	if err != nil {
   339  		return &database.Error{OrigErr: err, Query: []byte(stmt)}
   340  	}
   341  	if err := op.Wait(ctx); err != nil {
   342  		return &database.Error{OrigErr: err, Query: []byte(stmt)}
   343  	}
   344  
   345  	return nil
   346  }
   347  
   348  func cleanStatements(migration []byte) ([]string, error) {
   349  	// The Spanner GCP backend does not yet support comments for the UpdateDatabaseDdl RPC
   350  	// (see https://issuetracker.google.com/issues/159730604) we use
   351  	// spansql to parse the DDL and output valid stamements without comments
   352  	ddl, err := spansql.ParseDDL("", string(migration))
   353  	if err != nil {
   354  		return nil, err
   355  	}
   356  	stmts := make([]string, 0, len(ddl.List))
   357  	for _, stmt := range ddl.List {
   358  		stmts = append(stmts, stmt.SQL())
   359  	}
   360  	return stmts, nil
   361  }