github.com/seashell-org/golang-migrate/v4@v4.15.3-0.20220722221203-6ab6c6c062d1/database/cassandra/cassandra.go (about)

     1  package cassandra
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"io"
     7  	"io/ioutil"
     8  	nurl "net/url"
     9  	"strconv"
    10  	"strings"
    11  	"time"
    12  
    13  	"go.uber.org/atomic"
    14  
    15  	"github.com/gocql/gocql"
    16  	"github.com/hashicorp/go-multierror"
    17  	"github.com/seashell-org/golang-migrate/v4/database"
    18  	"github.com/seashell-org/golang-migrate/v4/database/multistmt"
    19  )
    20  
    21  func init() {
    22  	db := new(Cassandra)
    23  	database.Register("cassandra", db)
    24  }
    25  
    26  var (
    27  	multiStmtDelimiter = []byte(";")
    28  
    29  	DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB
    30  )
    31  
    32  var DefaultMigrationsTable = "schema_migrations"
    33  
    34  var (
    35  	ErrNilConfig     = errors.New("no config")
    36  	ErrNoKeyspace    = errors.New("no keyspace provided")
    37  	ErrDatabaseDirty = errors.New("database is dirty")
    38  	ErrClosedSession = errors.New("session is closed")
    39  )
    40  
    41  type Config struct {
    42  	MigrationsTable       string
    43  	KeyspaceName          string
    44  	MultiStatementEnabled bool
    45  	MultiStatementMaxSize int
    46  }
    47  
    48  type Cassandra struct {
    49  	session  *gocql.Session
    50  	isLocked atomic.Bool
    51  
    52  	// Open and WithInstance need to guarantee that config is never nil
    53  	config *Config
    54  }
    55  
    56  func WithInstance(session *gocql.Session, config *Config) (database.Driver, error) {
    57  	if config == nil {
    58  		return nil, ErrNilConfig
    59  	} else if len(config.KeyspaceName) == 0 {
    60  		return nil, ErrNoKeyspace
    61  	}
    62  
    63  	if session.Closed() {
    64  		return nil, ErrClosedSession
    65  	}
    66  
    67  	if len(config.MigrationsTable) == 0 {
    68  		config.MigrationsTable = DefaultMigrationsTable
    69  	}
    70  
    71  	if config.MultiStatementMaxSize <= 0 {
    72  		config.MultiStatementMaxSize = DefaultMultiStatementMaxSize
    73  	}
    74  
    75  	c := &Cassandra{
    76  		session: session,
    77  		config:  config,
    78  	}
    79  
    80  	if err := c.ensureVersionTable(); err != nil {
    81  		return nil, err
    82  	}
    83  
    84  	return c, nil
    85  }
    86  
    87  func (c *Cassandra) Open(url string) (database.Driver, error) {
    88  	u, err := nurl.Parse(url)
    89  	if err != nil {
    90  		return nil, err
    91  	}
    92  
    93  	// Check for missing mandatory attributes
    94  	if len(u.Path) == 0 {
    95  		return nil, ErrNoKeyspace
    96  	}
    97  
    98  	cluster := gocql.NewCluster(u.Host)
    99  	cluster.Keyspace = strings.TrimPrefix(u.Path, "/")
   100  	cluster.Consistency = gocql.All
   101  	cluster.Timeout = 1 * time.Minute
   102  
   103  	if len(u.Query().Get("username")) > 0 && len(u.Query().Get("password")) > 0 {
   104  		authenticator := gocql.PasswordAuthenticator{
   105  			Username: u.Query().Get("username"),
   106  			Password: u.Query().Get("password"),
   107  		}
   108  		cluster.Authenticator = authenticator
   109  	}
   110  
   111  	// Retrieve query string configuration
   112  	if len(u.Query().Get("consistency")) > 0 {
   113  		var consistency gocql.Consistency
   114  		consistency, err = parseConsistency(u.Query().Get("consistency"))
   115  		if err != nil {
   116  			return nil, err
   117  		}
   118  
   119  		cluster.Consistency = consistency
   120  	}
   121  	if len(u.Query().Get("protocol")) > 0 {
   122  		var protoversion int
   123  		protoversion, err = strconv.Atoi(u.Query().Get("protocol"))
   124  		if err != nil {
   125  			return nil, err
   126  		}
   127  		cluster.ProtoVersion = protoversion
   128  	}
   129  	if len(u.Query().Get("timeout")) > 0 {
   130  		var timeout time.Duration
   131  		timeout, err = time.ParseDuration(u.Query().Get("timeout"))
   132  		if err != nil {
   133  			return nil, err
   134  		}
   135  		cluster.Timeout = timeout
   136  	}
   137  	if len(u.Query().Get("connect-timeout")) > 0 {
   138  		var connectTimeout time.Duration
   139  		connectTimeout, err = time.ParseDuration(u.Query().Get("connect-timeout"))
   140  		if err != nil {
   141  			return nil, err
   142  		}
   143  		cluster.ConnectTimeout = connectTimeout
   144  	}
   145  
   146  	if len(u.Query().Get("sslmode")) > 0 {
   147  		if u.Query().Get("sslmode") != "disable" {
   148  			sslOpts := &gocql.SslOptions{}
   149  
   150  			if len(u.Query().Get("sslrootcert")) > 0 {
   151  				sslOpts.CaPath = u.Query().Get("sslrootcert")
   152  			}
   153  			if len(u.Query().Get("sslcert")) > 0 {
   154  				sslOpts.CertPath = u.Query().Get("sslcert")
   155  			}
   156  			if len(u.Query().Get("sslkey")) > 0 {
   157  				sslOpts.KeyPath = u.Query().Get("sslkey")
   158  			}
   159  
   160  			if u.Query().Get("sslmode") == "verify-full" {
   161  				sslOpts.EnableHostVerification = true
   162  			}
   163  
   164  			cluster.SslOpts = sslOpts
   165  		}
   166  	}
   167  
   168  	if len(u.Query().Get("disable-host-lookup")) > 0 {
   169  		if flag, err := strconv.ParseBool(u.Query().Get("disable-host-lookup")); err != nil && flag {
   170  			cluster.DisableInitialHostLookup = true
   171  		} else if err != nil {
   172  			return nil, err
   173  		}
   174  	}
   175  
   176  	session, err := cluster.CreateSession()
   177  	if err != nil {
   178  		return nil, err
   179  	}
   180  
   181  	multiStatementMaxSize := DefaultMultiStatementMaxSize
   182  	if s := u.Query().Get("x-multi-statement-max-size"); len(s) > 0 {
   183  		multiStatementMaxSize, err = strconv.Atoi(s)
   184  		if err != nil {
   185  			return nil, err
   186  		}
   187  	}
   188  
   189  	return WithInstance(session, &Config{
   190  		KeyspaceName:          strings.TrimPrefix(u.Path, "/"),
   191  		MigrationsTable:       u.Query().Get("x-migrations-table"),
   192  		MultiStatementEnabled: u.Query().Get("x-multi-statement") == "true",
   193  		MultiStatementMaxSize: multiStatementMaxSize,
   194  	})
   195  }
   196  
   197  func (c *Cassandra) Close() error {
   198  	c.session.Close()
   199  	return nil
   200  }
   201  
   202  func (c *Cassandra) Lock() error {
   203  	if !c.isLocked.CAS(false, true) {
   204  		return database.ErrLocked
   205  	}
   206  	return nil
   207  }
   208  
   209  func (c *Cassandra) Unlock() error {
   210  	if !c.isLocked.CAS(true, false) {
   211  		return database.ErrNotLocked
   212  	}
   213  	return nil
   214  }
   215  
   216  func (c *Cassandra) Run(migration io.Reader) error {
   217  	if c.config.MultiStatementEnabled {
   218  		var err error
   219  		if e := multistmt.Parse(migration, multiStmtDelimiter, c.config.MultiStatementMaxSize, func(m []byte) bool {
   220  			tq := strings.TrimSpace(string(m))
   221  			if tq == "" {
   222  				return true
   223  			}
   224  			if e := c.session.Query(tq).Exec(); e != nil {
   225  				err = database.Error{OrigErr: e, Err: "migration failed", Query: m}
   226  				return false
   227  			}
   228  			return true
   229  		}); e != nil {
   230  			return e
   231  		}
   232  		return err
   233  	}
   234  
   235  	migr, err := ioutil.ReadAll(migration)
   236  	if err != nil {
   237  		return err
   238  	}
   239  	// run migration
   240  	if err := c.session.Query(string(migr)).Exec(); err != nil {
   241  		// TODO: cast to Cassandra error and get line number
   242  		return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   243  	}
   244  	return nil
   245  }
   246  
   247  func (c *Cassandra) SetVersion(version int, dirty bool) error {
   248  	// DELETE instead of TRUNCATE because AWS Keyspaces does not support it
   249  	// see: https://docs.aws.amazon.com/keyspaces/latest/devguide/cassandra-apis.html
   250  	squery := `SELECT version FROM "` + c.config.MigrationsTable + `"`
   251  	dquery := `DELETE FROM "` + c.config.MigrationsTable + `" WHERE version = ?`
   252  	iter := c.session.Query(squery).Iter()
   253  	var previous int
   254  	for iter.Scan(&previous) {
   255  		if err := c.session.Query(dquery, previous).Exec(); err != nil {
   256  			return &database.Error{OrigErr: err, Query: []byte(dquery)}
   257  		}
   258  	}
   259  	if err := iter.Close(); err != nil {
   260  		return &database.Error{OrigErr: err, Query: []byte(squery)}
   261  	}
   262  
   263  	// Also re-write the schema version for nil dirty versions to prevent
   264  	// empty schema version for failed down migration on the first migration
   265  	// See: https://github.com/seashell-org/golang-migrate/issues/330
   266  	if version >= 0 || (version == database.NilVersion && dirty) {
   267  		query := `INSERT INTO "` + c.config.MigrationsTable + `" (version, dirty) VALUES (?, ?)`
   268  		if err := c.session.Query(query, version, dirty).Exec(); err != nil {
   269  			return &database.Error{OrigErr: err, Query: []byte(query)}
   270  		}
   271  	}
   272  
   273  	return nil
   274  }
   275  
   276  // Return current keyspace version
   277  func (c *Cassandra) Version() (version int, dirty bool, err error) {
   278  	query := `SELECT version, dirty FROM "` + c.config.MigrationsTable + `" LIMIT 1`
   279  	err = c.session.Query(query).Scan(&version, &dirty)
   280  	switch {
   281  	case err == gocql.ErrNotFound:
   282  		return database.NilVersion, false, nil
   283  
   284  	case err != nil:
   285  		if _, ok := err.(*gocql.Error); ok {
   286  			return database.NilVersion, false, nil
   287  		}
   288  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   289  
   290  	default:
   291  		return version, dirty, nil
   292  	}
   293  }
   294  
   295  func (c *Cassandra) Drop() error {
   296  	// select all tables in current schema
   297  	query := fmt.Sprintf(`SELECT table_name from system_schema.tables WHERE keyspace_name='%s'`, c.config.KeyspaceName)
   298  	iter := c.session.Query(query).Iter()
   299  	var tableName string
   300  	for iter.Scan(&tableName) {
   301  		err := c.session.Query(fmt.Sprintf(`DROP TABLE %s`, tableName)).Exec()
   302  		if err != nil {
   303  			return err
   304  		}
   305  	}
   306  
   307  	return nil
   308  }
   309  
   310  // ensureVersionTable checks if versions table exists and, if not, creates it.
   311  // Note that this function locks the database, which deviates from the usual
   312  // convention of "caller locks" in the Cassandra type.
   313  func (c *Cassandra) ensureVersionTable() (err error) {
   314  	if err = c.Lock(); err != nil {
   315  		return err
   316  	}
   317  
   318  	defer func() {
   319  		if e := c.Unlock(); e != nil {
   320  			if err == nil {
   321  				err = e
   322  			} else {
   323  				err = multierror.Append(err, e)
   324  			}
   325  		}
   326  	}()
   327  
   328  	err = c.session.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version bigint, dirty boolean, PRIMARY KEY(version))", c.config.MigrationsTable)).Exec()
   329  	if err != nil {
   330  		return err
   331  	}
   332  	if _, _, err = c.Version(); err != nil {
   333  		return err
   334  	}
   335  	return nil
   336  }
   337  
   338  // ParseConsistency wraps gocql.ParseConsistency
   339  // to return an error instead of a panicking.
   340  func parseConsistency(consistencyStr string) (consistency gocql.Consistency, err error) {
   341  	defer func() {
   342  		if r := recover(); r != nil {
   343  			var ok bool
   344  			err, ok = r.(error)
   345  			if !ok {
   346  				err = fmt.Errorf("Failed to parse consistency \"%s\": %v", consistencyStr, r)
   347  			}
   348  		}
   349  	}()
   350  	consistency = gocql.ParseConsistency(consistencyStr)
   351  
   352  	return consistency, nil
   353  }