github.com/scraniel/migrate@v0.0.0-20230320185700-339088f36cee/database/spanner/spanner.go (about)

     1  package spanner
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"log"
     9  	nurl "net/url"
    10  	"regexp"
    11  	"strconv"
    12  	"strings"
    13  
    14  	"cloud.google.com/go/spanner"
    15  	sdb "cloud.google.com/go/spanner/admin/database/apiv1"
    16  	"cloud.google.com/go/spanner/spansql"
    17  
    18  	"github.com/golang-migrate/migrate/v4"
    19  	"github.com/golang-migrate/migrate/v4/database"
    20  
    21  	adminpb "cloud.google.com/go/spanner/admin/database/apiv1/databasepb"
    22  	"github.com/hashicorp/go-multierror"
    23  	uatomic "go.uber.org/atomic"
    24  	"google.golang.org/api/iterator"
    25  )
    26  
    27  func init() {
    28  	db := Spanner{}
    29  	database.Register("spanner", &db)
    30  }
    31  
    32  // DefaultMigrationsTable is used if no custom table is specified
    33  const DefaultMigrationsTable = "SchemaMigrations"
    34  
    35  const (
    36  	unlockedVal = 0
    37  	lockedVal   = 1
    38  )
    39  
    40  // Driver errors
    41  var (
    42  	ErrNilConfig      = errors.New("no config")
    43  	ErrNoDatabaseName = errors.New("no database name")
    44  	ErrNoSchema       = errors.New("no schema")
    45  	ErrDatabaseDirty  = errors.New("database is dirty")
    46  	ErrLockHeld       = errors.New("unable to obtain lock")
    47  	ErrLockNotHeld    = errors.New("unable to release already released lock")
    48  )
    49  
    50  // Config used for a Spanner instance
    51  type Config struct {
    52  	MigrationsTable string
    53  	DatabaseName    string
    54  	// Whether to parse the migration DDL with spansql before
    55  	// running them towards Spanner.
    56  	// Parsing outputs clean DDL statements such as reformatted
    57  	// and void of comments.
    58  	CleanStatements bool
    59  }
    60  
    61  // Spanner implements database.Driver for Google Cloud Spanner
    62  type Spanner struct {
    63  	db *DB
    64  
    65  	config *Config
    66  
    67  	lock *uatomic.Uint32
    68  }
    69  
    70  type DB struct {
    71  	admin *sdb.DatabaseAdminClient
    72  	data  *spanner.Client
    73  }
    74  
    75  func NewDB(admin sdb.DatabaseAdminClient, data spanner.Client) *DB {
    76  	return &DB{
    77  		admin: &admin,
    78  		data:  &data,
    79  	}
    80  }
    81  
    82  // WithInstance implements database.Driver
    83  func WithInstance(instance *DB, config *Config) (database.Driver, error) {
    84  	if config == nil {
    85  		return nil, ErrNilConfig
    86  	}
    87  
    88  	if len(config.DatabaseName) == 0 {
    89  		return nil, ErrNoDatabaseName
    90  	}
    91  
    92  	if len(config.MigrationsTable) == 0 {
    93  		config.MigrationsTable = DefaultMigrationsTable
    94  	}
    95  
    96  	sx := &Spanner{
    97  		db:     instance,
    98  		config: config,
    99  		lock:   uatomic.NewUint32(unlockedVal),
   100  	}
   101  
   102  	if err := sx.ensureVersionTable(); err != nil {
   103  		return nil, err
   104  	}
   105  
   106  	return sx, nil
   107  }
   108  
   109  // Open implements database.Driver
   110  func (s *Spanner) Open(url string) (database.Driver, error) {
   111  	purl, err := nurl.Parse(url)
   112  	if err != nil {
   113  		return nil, err
   114  	}
   115  
   116  	ctx := context.Background()
   117  
   118  	adminClient, err := sdb.NewDatabaseAdminClient(ctx)
   119  	if err != nil {
   120  		return nil, err
   121  	}
   122  	dbname := strings.Replace(migrate.FilterCustomQuery(purl).String(), "spanner://", "", 1)
   123  	dataClient, err := spanner.NewClient(ctx, dbname)
   124  	if err != nil {
   125  		log.Fatal(err)
   126  	}
   127  
   128  	migrationsTable := purl.Query().Get("x-migrations-table")
   129  
   130  	cleanQuery := purl.Query().Get("x-clean-statements")
   131  	clean := false
   132  	if cleanQuery != "" {
   133  		clean, err = strconv.ParseBool(cleanQuery)
   134  		if err != nil {
   135  			return nil, err
   136  		}
   137  	}
   138  
   139  	db := &DB{admin: adminClient, data: dataClient}
   140  	return WithInstance(db, &Config{
   141  		DatabaseName:    dbname,
   142  		MigrationsTable: migrationsTable,
   143  		CleanStatements: clean,
   144  	})
   145  }
   146  
   147  // Close implements database.Driver
   148  func (s *Spanner) Close() error {
   149  	s.db.data.Close()
   150  	return s.db.admin.Close()
   151  }
   152  
   153  // Lock implements database.Driver but doesn't do anything because Spanner only
   154  // enqueues the UpdateDatabaseDdlRequest.
   155  func (s *Spanner) Lock() error {
   156  	if swapped := s.lock.CAS(unlockedVal, lockedVal); swapped {
   157  		return nil
   158  	}
   159  	return ErrLockHeld
   160  }
   161  
   162  // Unlock implements database.Driver but no action required, see Lock.
   163  func (s *Spanner) Unlock() error {
   164  	if swapped := s.lock.CAS(lockedVal, unlockedVal); swapped {
   165  		return nil
   166  	}
   167  	return ErrLockNotHeld
   168  }
   169  
   170  // Run implements database.Driver
   171  func (s *Spanner) Run(migration io.Reader) error {
   172  	migr, err := io.ReadAll(migration)
   173  	if err != nil {
   174  		return err
   175  	}
   176  
   177  	stmts := []string{string(migr)}
   178  	if s.config.CleanStatements {
   179  		stmts, err = cleanStatements(migr)
   180  		if err != nil {
   181  			return err
   182  		}
   183  	}
   184  
   185  	ctx := context.Background()
   186  	op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{
   187  		Database:   s.config.DatabaseName,
   188  		Statements: stmts,
   189  	})
   190  
   191  	if err != nil {
   192  		return &database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   193  	}
   194  
   195  	if err := op.Wait(ctx); err != nil {
   196  		return &database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   197  	}
   198  
   199  	return nil
   200  }
   201  
   202  // SetVersion implements database.Driver
   203  func (s *Spanner) SetVersion(version int, dirty bool) error {
   204  	ctx := context.Background()
   205  
   206  	_, err := s.db.data.ReadWriteTransaction(ctx,
   207  		func(ctx context.Context, txn *spanner.ReadWriteTransaction) error {
   208  			m := []*spanner.Mutation{
   209  				spanner.Delete(s.config.MigrationsTable, spanner.AllKeys()),
   210  				spanner.Insert(s.config.MigrationsTable,
   211  					[]string{"Version", "Dirty"},
   212  					[]interface{}{version, dirty},
   213  				)}
   214  			return txn.BufferWrite(m)
   215  		})
   216  	if err != nil {
   217  		return &database.Error{OrigErr: err}
   218  	}
   219  
   220  	return nil
   221  }
   222  
   223  // Version implements database.Driver
   224  func (s *Spanner) Version() (version int, dirty bool, err error) {
   225  	ctx := context.Background()
   226  
   227  	stmt := spanner.Statement{
   228  		SQL: `SELECT Version, Dirty FROM ` + s.config.MigrationsTable + ` LIMIT 1`,
   229  	}
   230  	iter := s.db.data.Single().Query(ctx, stmt)
   231  	defer iter.Stop()
   232  
   233  	row, err := iter.Next()
   234  	switch err {
   235  	case iterator.Done:
   236  		return database.NilVersion, false, nil
   237  	case nil:
   238  		var v int64
   239  		if err = row.Columns(&v, &dirty); err != nil {
   240  			return 0, false, &database.Error{OrigErr: err, Query: []byte(stmt.SQL)}
   241  		}
   242  		version = int(v)
   243  	default:
   244  		return 0, false, &database.Error{OrigErr: err, Query: []byte(stmt.SQL)}
   245  	}
   246  
   247  	return version, dirty, nil
   248  }
   249  
   250  var nameMatcher = regexp.MustCompile(`(CREATE TABLE\s(\S+)\s)|(CREATE.+INDEX\s(\S+)\s)`)
   251  
   252  // Drop implements database.Driver. Retrieves the database schema first and
   253  // creates statements to drop the indexes and tables accordingly.
   254  // Note: The drop statements are created in reverse order to how they're
   255  // provided in the schema. Assuming the schema describes how the database can
   256  // be "build up", it seems logical to "unbuild" the database simply by going the
   257  // opposite direction. More testing
   258  func (s *Spanner) Drop() error {
   259  	ctx := context.Background()
   260  	res, err := s.db.admin.GetDatabaseDdl(ctx, &adminpb.GetDatabaseDdlRequest{
   261  		Database: s.config.DatabaseName,
   262  	})
   263  	if err != nil {
   264  		return &database.Error{OrigErr: err, Err: "drop failed"}
   265  	}
   266  	if len(res.Statements) == 0 {
   267  		return nil
   268  	}
   269  
   270  	stmts := make([]string, 0)
   271  	for i := len(res.Statements) - 1; i >= 0; i-- {
   272  		s := res.Statements[i]
   273  		m := nameMatcher.FindSubmatch([]byte(s))
   274  
   275  		if len(m) == 0 {
   276  			continue
   277  		} else if tbl := m[2]; len(tbl) > 0 {
   278  			stmts = append(stmts, fmt.Sprintf(`DROP TABLE %s`, tbl))
   279  		} else if idx := m[4]; len(idx) > 0 {
   280  			stmts = append(stmts, fmt.Sprintf(`DROP INDEX %s`, idx))
   281  		}
   282  	}
   283  
   284  	op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{
   285  		Database:   s.config.DatabaseName,
   286  		Statements: stmts,
   287  	})
   288  	if err != nil {
   289  		return &database.Error{OrigErr: err, Query: []byte(strings.Join(stmts, "; "))}
   290  	}
   291  	if err := op.Wait(ctx); err != nil {
   292  		return &database.Error{OrigErr: err, Query: []byte(strings.Join(stmts, "; "))}
   293  	}
   294  
   295  	return nil
   296  }
   297  
   298  // ensureVersionTable checks if versions table exists and, if not, creates it.
   299  // Note that this function locks the database, which deviates from the usual
   300  // convention of "caller locks" in the Spanner type.
   301  func (s *Spanner) ensureVersionTable() (err error) {
   302  	if err = s.Lock(); err != nil {
   303  		return err
   304  	}
   305  
   306  	defer func() {
   307  		if e := s.Unlock(); e != nil {
   308  			if err == nil {
   309  				err = e
   310  			} else {
   311  				err = multierror.Append(err, e)
   312  			}
   313  		}
   314  	}()
   315  
   316  	ctx := context.Background()
   317  	tbl := s.config.MigrationsTable
   318  	iter := s.db.data.Single().Read(ctx, tbl, spanner.AllKeys(), []string{"Version"})
   319  	if err := iter.Do(func(r *spanner.Row) error { return nil }); err == nil {
   320  		return nil
   321  	}
   322  
   323  	stmt := fmt.Sprintf(`CREATE TABLE %s (
   324      Version INT64 NOT NULL,
   325      Dirty    BOOL NOT NULL
   326  	) PRIMARY KEY(Version)`, tbl)
   327  
   328  	op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{
   329  		Database:   s.config.DatabaseName,
   330  		Statements: []string{stmt},
   331  	})
   332  
   333  	if err != nil {
   334  		return &database.Error{OrigErr: err, Query: []byte(stmt)}
   335  	}
   336  	if err := op.Wait(ctx); err != nil {
   337  		return &database.Error{OrigErr: err, Query: []byte(stmt)}
   338  	}
   339  
   340  	return nil
   341  }
   342  
   343  func cleanStatements(migration []byte) ([]string, error) {
   344  	// The Spanner GCP backend does not yet support comments for the UpdateDatabaseDdl RPC
   345  	// (see https://issuetracker.google.com/issues/159730604) we use
   346  	// spansql to parse the DDL and output valid stamements without comments
   347  	ddl, err := spansql.ParseDDL("", string(migration))
   348  	if err != nil {
   349  		return nil, err
   350  	}
   351  	stmts := make([]string, 0, len(ddl.List))
   352  	for _, stmt := range ddl.List {
   353  		stmts = append(stmts, stmt.SQL())
   354  	}
   355  	return stmts, nil
   356  }