github.com/seashell-org/golang-migrate/v4@v4.15.3-0.20220722221203-6ab6c6c062d1/database/spanner/spanner.go (about)

     1  package spanner
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"io"
     7  	"io/ioutil"
     8  	"log"
     9  	nurl "net/url"
    10  	"regexp"
    11  	"strconv"
    12  	"strings"
    13  
    14  	"context"
    15  
    16  	"cloud.google.com/go/spanner"
    17  	sdb "cloud.google.com/go/spanner/admin/database/apiv1"
    18  
    19  	migrate "github.com/seashell-org/golang-migrate/v4"
    20  	"github.com/seashell-org/golang-migrate/v4/database"
    21  	"github.com/seashell-org/golang-migrate/v4/database/spanner/spansql"
    22  
    23  	"github.com/hashicorp/go-multierror"
    24  	uatomic "go.uber.org/atomic"
    25  	"google.golang.org/api/iterator"
    26  	adminpb "google.golang.org/genproto/googleapis/spanner/admin/database/v1"
    27  )
    28  
    29  func init() {
    30  	db := Spanner{}
    31  	database.Register("spanner", &db)
    32  }
    33  
    34  // DefaultMigrationsTable is used if no custom table is specified
    35  const DefaultMigrationsTable = "SchemaMigrations"
    36  
    37  const (
    38  	unlockedVal = 0
    39  	lockedVal   = 1
    40  )
    41  
    42  // Driver errors
    43  var (
    44  	ErrNilConfig       = errors.New("no config")
    45  	ErrNoDatabaseName  = errors.New("no database name")
    46  	ErrNoSchema        = errors.New("no schema")
    47  	ErrDatabaseDirty   = errors.New("database is dirty")
    48  	ErrLockHeld        = errors.New("unable to obtain lock")
    49  	ErrLockNotHeld     = errors.New("unable to release already released lock")
    50  	ErrInvalidStmtType = errors.New("Invalid statement type")
    51  )
    52  
    53  // Config used for a Spanner instance
    54  type Config struct {
    55  	MigrationsTable string
    56  	DatabaseName    string
    57  	// Whether to parse the migration DDL with spansql before
    58  	// running them towards Spanner.
    59  	// Parsing outputs clean DDL statements such as reformatted
    60  	// and void of comments.
    61  	CleanStatements bool
    62  }
    63  
    64  // Spanner implements database.Driver for Google Cloud Spanner
    65  type Spanner struct {
    66  	db *DB
    67  
    68  	config *Config
    69  
    70  	lock *uatomic.Uint32
    71  }
    72  
    73  type DB struct {
    74  	admin *sdb.DatabaseAdminClient
    75  	data  *spanner.Client
    76  }
    77  
    78  func NewDB(admin sdb.DatabaseAdminClient, data spanner.Client) *DB {
    79  	return &DB{
    80  		admin: &admin,
    81  		data:  &data,
    82  	}
    83  }
    84  
    85  // WithInstance implements database.Driver
    86  func WithInstance(instance *DB, config *Config) (database.Driver, error) {
    87  	if config == nil {
    88  		return nil, ErrNilConfig
    89  	}
    90  
    91  	if len(config.DatabaseName) == 0 {
    92  		return nil, ErrNoDatabaseName
    93  	}
    94  
    95  	if len(config.MigrationsTable) == 0 {
    96  		config.MigrationsTable = DefaultMigrationsTable
    97  	}
    98  
    99  	sx := &Spanner{
   100  		db:     instance,
   101  		config: config,
   102  		lock:   uatomic.NewUint32(unlockedVal),
   103  	}
   104  
   105  	if err := sx.ensureVersionTable(); err != nil {
   106  		return nil, err
   107  	}
   108  
   109  	return sx, nil
   110  }
   111  
   112  // Open implements database.Driver
   113  func (s *Spanner) Open(url string) (database.Driver, error) {
   114  	purl, err := nurl.Parse(url)
   115  	if err != nil {
   116  		return nil, err
   117  	}
   118  
   119  	ctx := context.Background()
   120  
   121  	adminClient, err := sdb.NewDatabaseAdminClient(ctx)
   122  	if err != nil {
   123  		return nil, err
   124  	}
   125  	dbname := strings.Replace(migrate.FilterCustomQuery(purl).String(), "spanner://", "", 1)
   126  	dataClient, err := spanner.NewClient(ctx, dbname)
   127  	if err != nil {
   128  		log.Fatal(err)
   129  	}
   130  
   131  	migrationsTable := purl.Query().Get("x-migrations-table")
   132  
   133  	cleanQuery := purl.Query().Get("x-clean-statements")
   134  	clean := false
   135  	if cleanQuery != "" {
   136  		clean, err = strconv.ParseBool(cleanQuery)
   137  		if err != nil {
   138  			return nil, err
   139  		}
   140  	}
   141  
   142  	db := &DB{admin: adminClient, data: dataClient}
   143  	return WithInstance(db, &Config{
   144  		DatabaseName:    dbname,
   145  		MigrationsTable: migrationsTable,
   146  		CleanStatements: clean,
   147  	})
   148  }
   149  
   150  // Close implements database.Driver
   151  func (s *Spanner) Close() error {
   152  	s.db.data.Close()
   153  	return s.db.admin.Close()
   154  }
   155  
   156  // Lock implements database.Driver but doesn't do anything because Spanner only
   157  // enqueues the UpdateDatabaseDdlRequest.
   158  func (s *Spanner) Lock() error {
   159  	if swapped := s.lock.CAS(unlockedVal, lockedVal); swapped {
   160  		return nil
   161  	}
   162  	return ErrLockHeld
   163  }
   164  
   165  // Unlock implements database.Driver but no action required, see Lock.
   166  func (s *Spanner) Unlock() error {
   167  	if swapped := s.lock.CAS(lockedVal, unlockedVal); swapped {
   168  		return nil
   169  	}
   170  	return ErrLockNotHeld
   171  }
   172  
   173  // Run implements database.Driver
   174  func (s *Spanner) Run(migration io.Reader) error {
   175  	migr, err := ioutil.ReadAll(migration)
   176  	if err != nil {
   177  		return err
   178  	}
   179  
   180  	stmts, err := cleanStatements(migr)
   181  	if err != nil {
   182  		return err
   183  	}
   184  
   185  	ctx := context.Background()
   186  
   187  	for _, stmt := range stmts {
   188  		if stmt.t == "ddl" {
   189  			op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{
   190  				Database:   s.config.DatabaseName,
   191  				Statements: stmt.sqls,
   192  			})
   193  
   194  			if err != nil {
   195  				return &database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   196  			}
   197  
   198  			if err := op.Wait(ctx); err != nil {
   199  				return &database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   200  			}
   201  		} else if stmt.t == "dml" {
   202  			spannerStmts := []spanner.Statement{}
   203  			for _, sql := range stmt.sqls {
   204  				spannerStmts = append(spannerStmts, spanner.Statement{
   205  					SQL: sql,
   206  				})
   207  			}
   208  
   209  			_, err = s.db.data.ReadWriteTransaction(ctx, func(ctx context.Context, rwt *spanner.ReadWriteTransaction) error {
   210  				_, err := rwt.BatchUpdate(ctx, spannerStmts)
   211  				if err != nil {
   212  					return err
   213  				}
   214  
   215  				return nil
   216  			})
   217  			if err != nil {
   218  				return &database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   219  			}
   220  		} else {
   221  			return ErrInvalidStmtType
   222  		}
   223  	}
   224  
   225  	return nil
   226  }
   227  
   228  // SetVersion implements database.Driver
   229  func (s *Spanner) SetVersion(version int, dirty bool) error {
   230  	ctx := context.Background()
   231  
   232  	_, err := s.db.data.ReadWriteTransaction(ctx,
   233  		func(ctx context.Context, txn *spanner.ReadWriteTransaction) error {
   234  			m := []*spanner.Mutation{
   235  				spanner.Delete(s.config.MigrationsTable, spanner.AllKeys()),
   236  				spanner.Insert(s.config.MigrationsTable,
   237  					[]string{"Version", "Dirty"},
   238  					[]interface{}{version, dirty},
   239  				)}
   240  			return txn.BufferWrite(m)
   241  		})
   242  	if err != nil {
   243  		return &database.Error{OrigErr: err}
   244  	}
   245  
   246  	return nil
   247  }
   248  
   249  // Version implements database.Driver
   250  func (s *Spanner) Version() (version int, dirty bool, err error) {
   251  	ctx := context.Background()
   252  
   253  	stmt := spanner.Statement{
   254  		SQL: `SELECT Version, Dirty FROM ` + s.config.MigrationsTable + ` LIMIT 1`,
   255  	}
   256  	iter := s.db.data.Single().Query(ctx, stmt)
   257  	defer iter.Stop()
   258  
   259  	row, err := iter.Next()
   260  	switch err {
   261  	case iterator.Done:
   262  		return database.NilVersion, false, nil
   263  	case nil:
   264  		var v int64
   265  		if err = row.Columns(&v, &dirty); err != nil {
   266  			return 0, false, &database.Error{OrigErr: err, Query: []byte(stmt.SQL)}
   267  		}
   268  		version = int(v)
   269  	default:
   270  		return 0, false, &database.Error{OrigErr: err, Query: []byte(stmt.SQL)}
   271  	}
   272  
   273  	return version, dirty, nil
   274  }
   275  
   276  var nameMatcher = regexp.MustCompile(`(CREATE TABLE\s(\S+)\s)|(CREATE.+INDEX\s(\S+)\s)`)
   277  
   278  // Drop implements database.Driver. Retrieves the database schema first and
   279  // creates statements to drop the indexes and tables accordingly.
   280  // Note: The drop statements are created in reverse order to how they're
   281  // provided in the schema. Assuming the schema describes how the database can
   282  // be "build up", it seems logical to "unbuild" the database simply by going the
   283  // opposite direction. More testing
   284  func (s *Spanner) Drop() error {
   285  	ctx := context.Background()
   286  	res, err := s.db.admin.GetDatabaseDdl(ctx, &adminpb.GetDatabaseDdlRequest{
   287  		Database: s.config.DatabaseName,
   288  	})
   289  	if err != nil {
   290  		return &database.Error{OrigErr: err, Err: "drop failed"}
   291  	}
   292  	if len(res.Statements) == 0 {
   293  		return nil
   294  	}
   295  
   296  	stmts := make([]string, 0)
   297  	for i := len(res.Statements) - 1; i >= 0; i-- {
   298  		s := res.Statements[i]
   299  		m := nameMatcher.FindSubmatch([]byte(s))
   300  
   301  		if len(m) == 0 {
   302  			continue
   303  		} else if tbl := m[2]; len(tbl) > 0 {
   304  			stmts = append(stmts, fmt.Sprintf(`DROP TABLE %s`, tbl))
   305  		} else if idx := m[4]; len(idx) > 0 {
   306  			stmts = append(stmts, fmt.Sprintf(`DROP INDEX %s`, idx))
   307  		}
   308  	}
   309  
   310  	op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{
   311  		Database:   s.config.DatabaseName,
   312  		Statements: stmts,
   313  	})
   314  	if err != nil {
   315  		return &database.Error{OrigErr: err, Query: []byte(strings.Join(stmts, "; "))}
   316  	}
   317  	if err := op.Wait(ctx); err != nil {
   318  		return &database.Error{OrigErr: err, Query: []byte(strings.Join(stmts, "; "))}
   319  	}
   320  
   321  	return nil
   322  }
   323  
   324  // ensureVersionTable checks if versions table exists and, if not, creates it.
   325  // Note that this function locks the database, which deviates from the usual
   326  // convention of "caller locks" in the Spanner type.
   327  func (s *Spanner) ensureVersionTable() (err error) {
   328  	if err = s.Lock(); err != nil {
   329  		return err
   330  	}
   331  
   332  	defer func() {
   333  		if e := s.Unlock(); e != nil {
   334  			if err == nil {
   335  				err = e
   336  			} else {
   337  				err = multierror.Append(err, e)
   338  			}
   339  		}
   340  	}()
   341  
   342  	ctx := context.Background()
   343  	tbl := s.config.MigrationsTable
   344  	iter := s.db.data.Single().Read(ctx, tbl, spanner.AllKeys(), []string{"Version"})
   345  	if err := iter.Do(func(r *spanner.Row) error { return nil }); err == nil {
   346  		return nil
   347  	}
   348  
   349  	stmt := fmt.Sprintf(`CREATE TABLE %s (
   350      Version INT64 NOT NULL,
   351      Dirty    BOOL NOT NULL
   352  	) PRIMARY KEY(Version)`, tbl)
   353  
   354  	op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{
   355  		Database:   s.config.DatabaseName,
   356  		Statements: []string{stmt},
   357  	})
   358  
   359  	if err != nil {
   360  		return &database.Error{OrigErr: err, Query: []byte(stmt)}
   361  	}
   362  	if err := op.Wait(ctx); err != nil {
   363  		return &database.Error{OrigErr: err, Query: []byte(stmt)}
   364  	}
   365  
   366  	return nil
   367  }
   368  
   369  type statement struct {
   370  	t    string
   371  	sqls []string
   372  }
   373  
   374  func cleanStatements(migration []byte) ([]statement, error) {
   375  	// The Spanner GCP backend does not yet support comments for the UpdateDatabaseDdl RPC
   376  	// (see https://issuetracker.google.com/issues/159730604) we use
   377  	// spansql to parse the DDL and output valid stamements without comments
   378  	ddlAndDml, err := spansql.ParseDDLAndDML("", string(migration))
   379  	if err != nil {
   380  		return nil, err
   381  	}
   382  	stmts := make([]statement, 0, len(ddlAndDml.List))
   383  	lastType := ""
   384  	stmtsBuffer := []string(nil)
   385  	for _, stmt := range ddlAndDml.List {
   386  		if lastType != stmt.T && lastType != "" {
   387  			stmts = append(stmts, statement{
   388  				sqls: stmtsBuffer,
   389  				t:    lastType,
   390  			})
   391  			stmtsBuffer = []string(nil)
   392  		}
   393  		if stmt.T == "ddl" {
   394  			stmtsBuffer = append(stmtsBuffer, stmt.Ddl.SQL())
   395  		} else if stmt.T == "dml" {
   396  			stmtsBuffer = append(stmtsBuffer, stmt.Dml.SQL())
   397  		} else {
   398  			return nil, ErrInvalidStmtType
   399  		}
   400  		lastType = stmt.T
   401  	}
   402  	if len(stmtsBuffer) > 0 && lastType != "" {
   403  		stmts = append(stmts, statement{
   404  			sqls: stmtsBuffer,
   405  			t:    lastType,
   406  		})
   407  	}
   408  
   409  	return stmts, nil
   410  }