github.com/nokia/migrate/v4@v4.16.0/database/cassandra/cassandra.go (about)

     1  package cassandra
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"io"
     7  	"io/ioutil"
     8  	nurl "net/url"
     9  	"strconv"
    10  	"strings"
    11  	"time"
    12  
    13  	"go.uber.org/atomic"
    14  
    15  	"github.com/gocql/gocql"
    16  	"github.com/hashicorp/go-multierror"
    17  	"github.com/nokia/migrate/v4/database"
    18  	"github.com/nokia/migrate/v4/database/multistmt"
    19  	"github.com/nokia/migrate/v4/source"
    20  )
    21  
    22  func init() {
    23  	db := new(Cassandra)
    24  	database.Register("cassandra", db)
    25  }
    26  
    27  var (
    28  	multiStmtDelimiter = []byte(";")
    29  
    30  	DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB
    31  )
    32  
    33  var DefaultMigrationsTable = "schema_migrations"
    34  
    35  var (
    36  	ErrNilConfig     = errors.New("no config")
    37  	ErrNoKeyspace    = errors.New("no keyspace provided")
    38  	ErrDatabaseDirty = errors.New("database is dirty")
    39  	ErrClosedSession = errors.New("session is closed")
    40  )
    41  
    42  type Config struct {
    43  	MigrationsTable       string
    44  	KeyspaceName          string
    45  	MultiStatementEnabled bool
    46  	MultiStatementMaxSize int
    47  }
    48  
    49  type Cassandra struct {
    50  	session  *gocql.Session
    51  	isLocked atomic.Bool
    52  
    53  	// Open and WithInstance need to guarantee that config is never nil
    54  	config *Config
    55  }
    56  
    57  func WithInstance(session *gocql.Session, config *Config) (database.Driver, error) {
    58  	if config == nil {
    59  		return nil, ErrNilConfig
    60  	} else if len(config.KeyspaceName) == 0 {
    61  		return nil, ErrNoKeyspace
    62  	}
    63  
    64  	if session.Closed() {
    65  		return nil, ErrClosedSession
    66  	}
    67  
    68  	if len(config.MigrationsTable) == 0 {
    69  		config.MigrationsTable = DefaultMigrationsTable
    70  	}
    71  
    72  	if config.MultiStatementMaxSize <= 0 {
    73  		config.MultiStatementMaxSize = DefaultMultiStatementMaxSize
    74  	}
    75  
    76  	c := &Cassandra{
    77  		session: session,
    78  		config:  config,
    79  	}
    80  
    81  	if err := c.ensureVersionTable(); err != nil {
    82  		return nil, err
    83  	}
    84  
    85  	return c, nil
    86  }
    87  
    88  func (c *Cassandra) Open(url string) (database.Driver, error) {
    89  	u, err := nurl.Parse(url)
    90  	if err != nil {
    91  		return nil, err
    92  	}
    93  
    94  	// Check for missing mandatory attributes
    95  	if len(u.Path) == 0 {
    96  		return nil, ErrNoKeyspace
    97  	}
    98  
    99  	cluster := gocql.NewCluster(u.Host)
   100  	cluster.Keyspace = strings.TrimPrefix(u.Path, "/")
   101  	cluster.Consistency = gocql.All
   102  	cluster.Timeout = 1 * time.Minute
   103  
   104  	if len(u.Query().Get("username")) > 0 && len(u.Query().Get("password")) > 0 {
   105  		authenticator := gocql.PasswordAuthenticator{
   106  			Username: u.Query().Get("username"),
   107  			Password: u.Query().Get("password"),
   108  		}
   109  		cluster.Authenticator = authenticator
   110  	}
   111  
   112  	// Retrieve query string configuration
   113  	if len(u.Query().Get("consistency")) > 0 {
   114  		var consistency gocql.Consistency
   115  		consistency, err = parseConsistency(u.Query().Get("consistency"))
   116  		if err != nil {
   117  			return nil, err
   118  		}
   119  
   120  		cluster.Consistency = consistency
   121  	}
   122  	if len(u.Query().Get("protocol")) > 0 {
   123  		var protoversion int
   124  		protoversion, err = strconv.Atoi(u.Query().Get("protocol"))
   125  		if err != nil {
   126  			return nil, err
   127  		}
   128  		cluster.ProtoVersion = protoversion
   129  	}
   130  	if len(u.Query().Get("timeout")) > 0 {
   131  		var timeout time.Duration
   132  		timeout, err = time.ParseDuration(u.Query().Get("timeout"))
   133  		if err != nil {
   134  			return nil, err
   135  		}
   136  		cluster.Timeout = timeout
   137  	}
   138  
   139  	if len(u.Query().Get("sslmode")) > 0 {
   140  		if u.Query().Get("sslmode") != "disable" {
   141  			sslOpts := &gocql.SslOptions{}
   142  
   143  			if len(u.Query().Get("sslrootcert")) > 0 {
   144  				sslOpts.CaPath = u.Query().Get("sslrootcert")
   145  			}
   146  			if len(u.Query().Get("sslcert")) > 0 {
   147  				sslOpts.CertPath = u.Query().Get("sslcert")
   148  			}
   149  			if len(u.Query().Get("sslkey")) > 0 {
   150  				sslOpts.KeyPath = u.Query().Get("sslkey")
   151  			}
   152  
   153  			if u.Query().Get("sslmode") == "verify-full" {
   154  				sslOpts.EnableHostVerification = true
   155  			}
   156  
   157  			cluster.SslOpts = sslOpts
   158  		}
   159  	}
   160  
   161  	if len(u.Query().Get("disable-host-lookup")) > 0 {
   162  		if flag, err := strconv.ParseBool(u.Query().Get("disable-host-lookup")); err != nil && flag {
   163  			cluster.DisableInitialHostLookup = true
   164  		} else if err != nil {
   165  			return nil, err
   166  		}
   167  	}
   168  
   169  	session, err := cluster.CreateSession()
   170  	if err != nil {
   171  		return nil, err
   172  	}
   173  
   174  	multiStatementMaxSize := DefaultMultiStatementMaxSize
   175  	if s := u.Query().Get("x-multi-statement-max-size"); len(s) > 0 {
   176  		multiStatementMaxSize, err = strconv.Atoi(s)
   177  		if err != nil {
   178  			return nil, err
   179  		}
   180  	}
   181  
   182  	return WithInstance(session, &Config{
   183  		KeyspaceName:          strings.TrimPrefix(u.Path, "/"),
   184  		MigrationsTable:       u.Query().Get("x-migrations-table"),
   185  		MultiStatementEnabled: u.Query().Get("x-multi-statement") == "true",
   186  		MultiStatementMaxSize: multiStatementMaxSize,
   187  	})
   188  }
   189  
   190  func (c *Cassandra) Close() error {
   191  	c.session.Close()
   192  	return nil
   193  }
   194  
   195  func (c *Cassandra) Lock() error {
   196  	if !c.isLocked.CAS(false, true) {
   197  		return database.ErrLocked
   198  	}
   199  	return nil
   200  }
   201  
   202  func (c *Cassandra) Unlock() error {
   203  	if !c.isLocked.CAS(true, false) {
   204  		return database.ErrNotLocked
   205  	}
   206  	return nil
   207  }
   208  
   209  func (c *Cassandra) Run(migration io.Reader) error {
   210  	if c.config.MultiStatementEnabled {
   211  		var err error
   212  		if e := multistmt.Parse(migration, multiStmtDelimiter, c.config.MultiStatementMaxSize, func(m []byte) bool {
   213  			tq := strings.TrimSpace(string(m))
   214  			if tq == "" {
   215  				return true
   216  			}
   217  			if e := c.session.Query(tq).Exec(); e != nil {
   218  				err = database.Error{OrigErr: e, Err: "migration failed", Query: m}
   219  				return false
   220  			}
   221  			return true
   222  		}); e != nil {
   223  			return e
   224  		}
   225  		return err
   226  	}
   227  
   228  	migr, err := ioutil.ReadAll(migration)
   229  	if err != nil {
   230  		return err
   231  	}
   232  	// run migration
   233  	if err := c.session.Query(string(migr)).Exec(); err != nil {
   234  		// TODO: cast to Cassandra error and get line number
   235  		return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   236  	}
   237  	return nil
   238  }
   239  
   240  func (c *Cassandra) RunFunctionMigration(fn source.MigrationFunc) error {
   241  	return database.ErrNotImpl
   242  }
   243  
   244  func (c *Cassandra) SetVersion(version int, dirty bool) error {
   245  	// DELETE instead of TRUNCATE because AWS Keyspaces does not support it
   246  	// see: https://docs.aws.amazon.com/keyspaces/latest/devguide/cassandra-apis.html
   247  	squery := `SELECT version FROM "` + c.config.MigrationsTable + `"`
   248  	dquery := `DELETE FROM "` + c.config.MigrationsTable + `" WHERE version = ?`
   249  	iter := c.session.Query(squery).Iter()
   250  	var previous int
   251  	for iter.Scan(&previous) {
   252  		if err := c.session.Query(dquery, previous).Exec(); err != nil {
   253  			return &database.Error{OrigErr: err, Query: []byte(dquery)}
   254  		}
   255  	}
   256  	if err := iter.Close(); err != nil {
   257  		return &database.Error{OrigErr: err, Query: []byte(squery)}
   258  	}
   259  
   260  	// Also re-write the schema version for nil dirty versions to prevent
   261  	// empty schema version for failed down migration on the first migration
   262  	// See: https://github.com/nokia/migrate/issues/330
   263  	if version >= 0 || (version == database.NilVersion && dirty) {
   264  		query := `INSERT INTO "` + c.config.MigrationsTable + `" (version, dirty) VALUES (?, ?)`
   265  		if err := c.session.Query(query, version, dirty).Exec(); err != nil {
   266  			return &database.Error{OrigErr: err, Query: []byte(query)}
   267  		}
   268  	}
   269  
   270  	return nil
   271  }
   272  
   273  // Return current keyspace version
   274  func (c *Cassandra) Version() (version int, dirty bool, err error) {
   275  	query := `SELECT version, dirty FROM "` + c.config.MigrationsTable + `" LIMIT 1`
   276  	err = c.session.Query(query).Scan(&version, &dirty)
   277  	switch {
   278  	case err == gocql.ErrNotFound:
   279  		return database.NilVersion, false, nil
   280  
   281  	case err != nil:
   282  		if _, ok := err.(*gocql.Error); ok {
   283  			return database.NilVersion, false, nil
   284  		}
   285  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   286  
   287  	default:
   288  		return version, dirty, nil
   289  	}
   290  }
   291  
   292  func (c *Cassandra) Drop() error {
   293  	// select all tables in current schema
   294  	query := fmt.Sprintf(`SELECT table_name from system_schema.tables WHERE keyspace_name='%s'`, c.config.KeyspaceName)
   295  	iter := c.session.Query(query).Iter()
   296  	var tableName string
   297  	for iter.Scan(&tableName) {
   298  		err := c.session.Query(fmt.Sprintf(`DROP TABLE %s`, tableName)).Exec()
   299  		if err != nil {
   300  			return err
   301  		}
   302  	}
   303  
   304  	return nil
   305  }
   306  
   307  // ensureVersionTable checks if versions table exists and, if not, creates it.
   308  // Note that this function locks the database, which deviates from the usual
   309  // convention of "caller locks" in the Cassandra type.
   310  func (c *Cassandra) ensureVersionTable() (err error) {
   311  	if err = c.Lock(); err != nil {
   312  		return err
   313  	}
   314  
   315  	defer func() {
   316  		if e := c.Unlock(); e != nil {
   317  			if err == nil {
   318  				err = e
   319  			} else {
   320  				err = multierror.Append(err, e)
   321  			}
   322  		}
   323  	}()
   324  
   325  	err = c.session.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version bigint, dirty boolean, PRIMARY KEY(version))", c.config.MigrationsTable)).Exec()
   326  	if err != nil {
   327  		return err
   328  	}
   329  	if _, _, err = c.Version(); err != nil {
   330  		return err
   331  	}
   332  	return nil
   333  }
   334  
   335  // ParseConsistency wraps gocql.ParseConsistency
   336  // to return an error instead of a panicking.
   337  func parseConsistency(consistencyStr string) (consistency gocql.Consistency, err error) {
   338  	defer func() {
   339  		if r := recover(); r != nil {
   340  			var ok bool
   341  			err, ok = r.(error)
   342  			if !ok {
   343  				err = fmt.Errorf("Failed to parse consistency \"%s\": %v", consistencyStr, r)
   344  			}
   345  		}
   346  	}()
   347  	consistency = gocql.ParseConsistency(consistencyStr)
   348  
   349  	return consistency, nil
   350  }