github.com/scraniel/migrate@v0.0.0-20230320185700-339088f36cee/database/cassandra/cassandra.go (about)

     1  package cassandra
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"io"
     7  	nurl "net/url"
     8  	"strconv"
     9  	"strings"
    10  	"time"
    11  
    12  	"go.uber.org/atomic"
    13  
    14  	"github.com/gocql/gocql"
    15  	"github.com/golang-migrate/migrate/v4/database"
    16  	"github.com/golang-migrate/migrate/v4/database/multistmt"
    17  	"github.com/hashicorp/go-multierror"
    18  )
    19  
    20  func init() {
    21  	db := new(Cassandra)
    22  	database.Register("cassandra", db)
    23  }
    24  
    25  var (
    26  	multiStmtDelimiter = []byte(";")
    27  
    28  	DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB
    29  )
    30  
    31  var DefaultMigrationsTable = "schema_migrations"
    32  
    33  var (
    34  	ErrNilConfig     = errors.New("no config")
    35  	ErrNoKeyspace    = errors.New("no keyspace provided")
    36  	ErrDatabaseDirty = errors.New("database is dirty")
    37  	ErrClosedSession = errors.New("session is closed")
    38  )
    39  
    40  type Config struct {
    41  	MigrationsTable       string
    42  	KeyspaceName          string
    43  	MultiStatementEnabled bool
    44  	MultiStatementMaxSize int
    45  }
    46  
    47  type Cassandra struct {
    48  	session  *gocql.Session
    49  	isLocked atomic.Bool
    50  
    51  	// Open and WithInstance need to guarantee that config is never nil
    52  	config *Config
    53  }
    54  
    55  func WithInstance(session *gocql.Session, config *Config) (database.Driver, error) {
    56  	if config == nil {
    57  		return nil, ErrNilConfig
    58  	} else if len(config.KeyspaceName) == 0 {
    59  		return nil, ErrNoKeyspace
    60  	}
    61  
    62  	if session.Closed() {
    63  		return nil, ErrClosedSession
    64  	}
    65  
    66  	if len(config.MigrationsTable) == 0 {
    67  		config.MigrationsTable = DefaultMigrationsTable
    68  	}
    69  
    70  	if config.MultiStatementMaxSize <= 0 {
    71  		config.MultiStatementMaxSize = DefaultMultiStatementMaxSize
    72  	}
    73  
    74  	c := &Cassandra{
    75  		session: session,
    76  		config:  config,
    77  	}
    78  
    79  	if err := c.ensureVersionTable(); err != nil {
    80  		return nil, err
    81  	}
    82  
    83  	return c, nil
    84  }
    85  
    86  func (c *Cassandra) Open(url string) (database.Driver, error) {
    87  	u, err := nurl.Parse(url)
    88  	if err != nil {
    89  		return nil, err
    90  	}
    91  
    92  	// Check for missing mandatory attributes
    93  	if len(u.Path) == 0 {
    94  		return nil, ErrNoKeyspace
    95  	}
    96  
    97  	cluster := gocql.NewCluster(u.Host)
    98  	cluster.Keyspace = strings.TrimPrefix(u.Path, "/")
    99  	cluster.Consistency = gocql.All
   100  	cluster.Timeout = 1 * time.Minute
   101  
   102  	if len(u.Query().Get("username")) > 0 && len(u.Query().Get("password")) > 0 {
   103  		authenticator := gocql.PasswordAuthenticator{
   104  			Username: u.Query().Get("username"),
   105  			Password: u.Query().Get("password"),
   106  		}
   107  		cluster.Authenticator = authenticator
   108  	}
   109  
   110  	// Retrieve query string configuration
   111  	if len(u.Query().Get("consistency")) > 0 {
   112  		var consistency gocql.Consistency
   113  		consistency, err = parseConsistency(u.Query().Get("consistency"))
   114  		if err != nil {
   115  			return nil, err
   116  		}
   117  
   118  		cluster.Consistency = consistency
   119  	}
   120  	if len(u.Query().Get("protocol")) > 0 {
   121  		var protoversion int
   122  		protoversion, err = strconv.Atoi(u.Query().Get("protocol"))
   123  		if err != nil {
   124  			return nil, err
   125  		}
   126  		cluster.ProtoVersion = protoversion
   127  	}
   128  	if len(u.Query().Get("timeout")) > 0 {
   129  		var timeout time.Duration
   130  		timeout, err = time.ParseDuration(u.Query().Get("timeout"))
   131  		if err != nil {
   132  			return nil, err
   133  		}
   134  		cluster.Timeout = timeout
   135  	}
   136  	if len(u.Query().Get("connect-timeout")) > 0 {
   137  		var connectTimeout time.Duration
   138  		connectTimeout, err = time.ParseDuration(u.Query().Get("connect-timeout"))
   139  		if err != nil {
   140  			return nil, err
   141  		}
   142  		cluster.ConnectTimeout = connectTimeout
   143  	}
   144  
   145  	if len(u.Query().Get("sslmode")) > 0 {
   146  		if u.Query().Get("sslmode") != "disable" {
   147  			sslOpts := &gocql.SslOptions{}
   148  
   149  			if len(u.Query().Get("sslrootcert")) > 0 {
   150  				sslOpts.CaPath = u.Query().Get("sslrootcert")
   151  			}
   152  			if len(u.Query().Get("sslcert")) > 0 {
   153  				sslOpts.CertPath = u.Query().Get("sslcert")
   154  			}
   155  			if len(u.Query().Get("sslkey")) > 0 {
   156  				sslOpts.KeyPath = u.Query().Get("sslkey")
   157  			}
   158  
   159  			if u.Query().Get("sslmode") == "verify-full" {
   160  				sslOpts.EnableHostVerification = true
   161  			}
   162  
   163  			cluster.SslOpts = sslOpts
   164  		}
   165  	}
   166  
   167  	if len(u.Query().Get("disable-host-lookup")) > 0 {
   168  		if flag, err := strconv.ParseBool(u.Query().Get("disable-host-lookup")); err != nil && flag {
   169  			cluster.DisableInitialHostLookup = true
   170  		} else if err != nil {
   171  			return nil, err
   172  		}
   173  	}
   174  
   175  	session, err := cluster.CreateSession()
   176  	if err != nil {
   177  		return nil, err
   178  	}
   179  
   180  	multiStatementMaxSize := DefaultMultiStatementMaxSize
   181  	if s := u.Query().Get("x-multi-statement-max-size"); len(s) > 0 {
   182  		multiStatementMaxSize, err = strconv.Atoi(s)
   183  		if err != nil {
   184  			return nil, err
   185  		}
   186  	}
   187  
   188  	return WithInstance(session, &Config{
   189  		KeyspaceName:          strings.TrimPrefix(u.Path, "/"),
   190  		MigrationsTable:       u.Query().Get("x-migrations-table"),
   191  		MultiStatementEnabled: u.Query().Get("x-multi-statement") == "true",
   192  		MultiStatementMaxSize: multiStatementMaxSize,
   193  	})
   194  }
   195  
   196  func (c *Cassandra) Close() error {
   197  	c.session.Close()
   198  	return nil
   199  }
   200  
   201  func (c *Cassandra) Lock() error {
   202  	if !c.isLocked.CAS(false, true) {
   203  		return database.ErrLocked
   204  	}
   205  	return nil
   206  }
   207  
   208  func (c *Cassandra) Unlock() error {
   209  	if !c.isLocked.CAS(true, false) {
   210  		return database.ErrNotLocked
   211  	}
   212  	return nil
   213  }
   214  
   215  func (c *Cassandra) Run(migration io.Reader) error {
   216  	if c.config.MultiStatementEnabled {
   217  		var err error
   218  		if e := multistmt.Parse(migration, multiStmtDelimiter, c.config.MultiStatementMaxSize, func(m []byte) bool {
   219  			tq := strings.TrimSpace(string(m))
   220  			if tq == "" {
   221  				return true
   222  			}
   223  			if e := c.session.Query(tq).Exec(); e != nil {
   224  				err = database.Error{OrigErr: e, Err: "migration failed", Query: m}
   225  				return false
   226  			}
   227  			return true
   228  		}); e != nil {
   229  			return e
   230  		}
   231  		return err
   232  	}
   233  
   234  	migr, err := io.ReadAll(migration)
   235  	if err != nil {
   236  		return err
   237  	}
   238  	// run migration
   239  	if err := c.session.Query(string(migr)).Exec(); err != nil {
   240  		// TODO: cast to Cassandra error and get line number
   241  		return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   242  	}
   243  	return nil
   244  }
   245  
   246  func (c *Cassandra) SetVersion(version int, dirty bool) error {
   247  	// DELETE instead of TRUNCATE because AWS Keyspaces does not support it
   248  	// see: https://docs.aws.amazon.com/keyspaces/latest/devguide/cassandra-apis.html
   249  	squery := `SELECT version FROM "` + c.config.MigrationsTable + `"`
   250  	dquery := `DELETE FROM "` + c.config.MigrationsTable + `" WHERE version = ?`
   251  	iter := c.session.Query(squery).Iter()
   252  	var previous int
   253  	for iter.Scan(&previous) {
   254  		if err := c.session.Query(dquery, previous).Exec(); err != nil {
   255  			return &database.Error{OrigErr: err, Query: []byte(dquery)}
   256  		}
   257  	}
   258  	if err := iter.Close(); err != nil {
   259  		return &database.Error{OrigErr: err, Query: []byte(squery)}
   260  	}
   261  
   262  	// Also re-write the schema version for nil dirty versions to prevent
   263  	// empty schema version for failed down migration on the first migration
   264  	// See: https://github.com/golang-migrate/migrate/issues/330
   265  	if version >= 0 || (version == database.NilVersion && dirty) {
   266  		query := `INSERT INTO "` + c.config.MigrationsTable + `" (version, dirty) VALUES (?, ?)`
   267  		if err := c.session.Query(query, version, dirty).Exec(); err != nil {
   268  			return &database.Error{OrigErr: err, Query: []byte(query)}
   269  		}
   270  	}
   271  
   272  	return nil
   273  }
   274  
   275  // Return current keyspace version
   276  func (c *Cassandra) Version() (version int, dirty bool, err error) {
   277  	query := `SELECT version, dirty FROM "` + c.config.MigrationsTable + `" LIMIT 1`
   278  	err = c.session.Query(query).Scan(&version, &dirty)
   279  	switch {
   280  	case err == gocql.ErrNotFound:
   281  		return database.NilVersion, false, nil
   282  
   283  	case err != nil:
   284  		if _, ok := err.(*gocql.Error); ok {
   285  			return database.NilVersion, false, nil
   286  		}
   287  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   288  
   289  	default:
   290  		return version, dirty, nil
   291  	}
   292  }
   293  
   294  func (c *Cassandra) Drop() error {
   295  	// select all tables in current schema
   296  	query := fmt.Sprintf(`SELECT table_name from system_schema.tables WHERE keyspace_name='%s'`, c.config.KeyspaceName)
   297  	iter := c.session.Query(query).Iter()
   298  	var tableName string
   299  	for iter.Scan(&tableName) {
   300  		err := c.session.Query(fmt.Sprintf(`DROP TABLE %s`, tableName)).Exec()
   301  		if err != nil {
   302  			return err
   303  		}
   304  	}
   305  
   306  	return nil
   307  }
   308  
   309  // ensureVersionTable checks if versions table exists and, if not, creates it.
   310  // Note that this function locks the database, which deviates from the usual
   311  // convention of "caller locks" in the Cassandra type.
   312  func (c *Cassandra) ensureVersionTable() (err error) {
   313  	if err = c.Lock(); err != nil {
   314  		return err
   315  	}
   316  
   317  	defer func() {
   318  		if e := c.Unlock(); e != nil {
   319  			if err == nil {
   320  				err = e
   321  			} else {
   322  				err = multierror.Append(err, e)
   323  			}
   324  		}
   325  	}()
   326  
   327  	err = c.session.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version bigint, dirty boolean, PRIMARY KEY(version))", c.config.MigrationsTable)).Exec()
   328  	if err != nil {
   329  		return err
   330  	}
   331  	if _, _, err = c.Version(); err != nil {
   332  		return err
   333  	}
   334  	return nil
   335  }
   336  
   337  // ParseConsistency wraps gocql.ParseConsistency
   338  // to return an error instead of a panicking.
   339  func parseConsistency(consistencyStr string) (consistency gocql.Consistency, err error) {
   340  	defer func() {
   341  		if r := recover(); r != nil {
   342  			var ok bool
   343  			err, ok = r.(error)
   344  			if !ok {
   345  				err = fmt.Errorf("Failed to parse consistency \"%s\": %v", consistencyStr, r)
   346  			}
   347  		}
   348  	}()
   349  	consistency = gocql.ParseConsistency(consistencyStr)
   350  
   351  	return consistency, nil
   352  }