github.com/nagyist/migrate/v4@v4.14.6/database/cassandra/cassandra.go (about)

     1  package cassandra
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"io"
     7  	"io/ioutil"
     8  	nurl "net/url"
     9  	"strconv"
    10  	"strings"
    11  	"time"
    12  
    13  	"github.com/gocql/gocql"
    14  	"github.com/golang-migrate/migrate/v4/database"
    15  	"github.com/golang-migrate/migrate/v4/database/multistmt"
    16  	"github.com/hashicorp/go-multierror"
    17  )
    18  
    19  func init() {
    20  	db := new(Cassandra)
    21  	database.Register("cassandra", db)
    22  }
    23  
    24  var (
    25  	multiStmtDelimiter = []byte(";")
    26  
    27  	DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB
    28  )
    29  
    30  var DefaultMigrationsTable = "schema_migrations"
    31  
    32  var (
    33  	ErrNilConfig     = errors.New("no config")
    34  	ErrNoKeyspace    = errors.New("no keyspace provided")
    35  	ErrDatabaseDirty = errors.New("database is dirty")
    36  	ErrClosedSession = errors.New("session is closed")
    37  )
    38  
    39  type Config struct {
    40  	MigrationsTable       string
    41  	KeyspaceName          string
    42  	MultiStatementEnabled bool
    43  	MultiStatementMaxSize int
    44  }
    45  
    46  type Cassandra struct {
    47  	session  *gocql.Session
    48  	isLocked bool
    49  
    50  	// Open and WithInstance need to guarantee that config is never nil
    51  	config *Config
    52  }
    53  
    54  func WithInstance(session *gocql.Session, config *Config) (database.Driver, error) {
    55  	if config == nil {
    56  		return nil, ErrNilConfig
    57  	} else if len(config.KeyspaceName) == 0 {
    58  		return nil, ErrNoKeyspace
    59  	}
    60  
    61  	if session.Closed() {
    62  		return nil, ErrClosedSession
    63  	}
    64  
    65  	if len(config.MigrationsTable) == 0 {
    66  		config.MigrationsTable = DefaultMigrationsTable
    67  	}
    68  
    69  	if config.MultiStatementMaxSize <= 0 {
    70  		config.MultiStatementMaxSize = DefaultMultiStatementMaxSize
    71  	}
    72  
    73  	c := &Cassandra{
    74  		session: session,
    75  		config:  config,
    76  	}
    77  
    78  	if err := c.ensureVersionTable(); err != nil {
    79  		return nil, err
    80  	}
    81  
    82  	return c, nil
    83  }
    84  
    85  func (c *Cassandra) Open(url string) (database.Driver, error) {
    86  	u, err := nurl.Parse(url)
    87  	if err != nil {
    88  		return nil, err
    89  	}
    90  
    91  	// Check for missing mandatory attributes
    92  	if len(u.Path) == 0 {
    93  		return nil, ErrNoKeyspace
    94  	}
    95  
    96  	cluster := gocql.NewCluster(u.Host)
    97  	cluster.Keyspace = strings.TrimPrefix(u.Path, "/")
    98  	cluster.Consistency = gocql.All
    99  	cluster.Timeout = 1 * time.Minute
   100  
   101  	if len(u.Query().Get("username")) > 0 && len(u.Query().Get("password")) > 0 {
   102  		authenticator := gocql.PasswordAuthenticator{
   103  			Username: u.Query().Get("username"),
   104  			Password: u.Query().Get("password"),
   105  		}
   106  		cluster.Authenticator = authenticator
   107  	}
   108  
   109  	// Retrieve query string configuration
   110  	if len(u.Query().Get("consistency")) > 0 {
   111  		var consistency gocql.Consistency
   112  		consistency, err = parseConsistency(u.Query().Get("consistency"))
   113  		if err != nil {
   114  			return nil, err
   115  		}
   116  
   117  		cluster.Consistency = consistency
   118  	}
   119  	if len(u.Query().Get("protocol")) > 0 {
   120  		var protoversion int
   121  		protoversion, err = strconv.Atoi(u.Query().Get("protocol"))
   122  		if err != nil {
   123  			return nil, err
   124  		}
   125  		cluster.ProtoVersion = protoversion
   126  	}
   127  	if len(u.Query().Get("timeout")) > 0 {
   128  		var timeout time.Duration
   129  		timeout, err = time.ParseDuration(u.Query().Get("timeout"))
   130  		if err != nil {
   131  			return nil, err
   132  		}
   133  		cluster.Timeout = timeout
   134  	}
   135  
   136  	if len(u.Query().Get("sslmode")) > 0 {
   137  		if u.Query().Get("sslmode") != "disable" {
   138  			sslOpts := &gocql.SslOptions{}
   139  
   140  			if len(u.Query().Get("sslrootcert")) > 0 {
   141  				sslOpts.CaPath = u.Query().Get("sslrootcert")
   142  			}
   143  			if len(u.Query().Get("sslcert")) > 0 {
   144  				sslOpts.CertPath = u.Query().Get("sslcert")
   145  			}
   146  			if len(u.Query().Get("sslkey")) > 0 {
   147  				sslOpts.KeyPath = u.Query().Get("sslkey")
   148  			}
   149  
   150  			if u.Query().Get("sslmode") == "verify-full" {
   151  				sslOpts.EnableHostVerification = true
   152  			}
   153  
   154  			cluster.SslOpts = sslOpts
   155  		}
   156  	}
   157  
   158  	session, err := cluster.CreateSession()
   159  	if err != nil {
   160  		return nil, err
   161  	}
   162  
   163  	multiStatementMaxSize := DefaultMultiStatementMaxSize
   164  	if s := u.Query().Get("x-multi-statement-max-size"); len(s) > 0 {
   165  		multiStatementMaxSize, err = strconv.Atoi(s)
   166  		if err != nil {
   167  			return nil, err
   168  		}
   169  	}
   170  
   171  	return WithInstance(session, &Config{
   172  		KeyspaceName:          strings.TrimPrefix(u.Path, "/"),
   173  		MigrationsTable:       u.Query().Get("x-migrations-table"),
   174  		MultiStatementEnabled: u.Query().Get("x-multi-statement") == "true",
   175  		MultiStatementMaxSize: multiStatementMaxSize,
   176  	})
   177  }
   178  
   179  func (c *Cassandra) Close() error {
   180  	c.session.Close()
   181  	return nil
   182  }
   183  
   184  func (c *Cassandra) Lock() error {
   185  	if c.isLocked {
   186  		return database.ErrLocked
   187  	}
   188  	c.isLocked = true
   189  	return nil
   190  }
   191  
   192  func (c *Cassandra) Unlock() error {
   193  	c.isLocked = false
   194  	return nil
   195  }
   196  
   197  func (c *Cassandra) Run(migration io.Reader) error {
   198  	if c.config.MultiStatementEnabled {
   199  		var err error
   200  		if e := multistmt.Parse(migration, multiStmtDelimiter, c.config.MultiStatementMaxSize, func(m []byte) bool {
   201  			tq := strings.TrimSpace(string(m))
   202  			if tq == "" {
   203  				return true
   204  			}
   205  			if e := c.session.Query(tq).Exec(); e != nil {
   206  				err = database.Error{OrigErr: e, Err: "migration failed", Query: m}
   207  				return false
   208  			}
   209  			return true
   210  		}); e != nil {
   211  			return e
   212  		}
   213  		return err
   214  	}
   215  
   216  	migr, err := ioutil.ReadAll(migration)
   217  	if err != nil {
   218  		return err
   219  	}
   220  	// run migration
   221  	if err := c.session.Query(string(migr)).Exec(); err != nil {
   222  		// TODO: cast to Cassandra error and get line number
   223  		return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   224  	}
   225  	return nil
   226  }
   227  
   228  func (c *Cassandra) SetVersion(version int, dirty bool) error {
   229  	query := `TRUNCATE "` + c.config.MigrationsTable + `"`
   230  	if err := c.session.Query(query).Exec(); err != nil {
   231  		return &database.Error{OrigErr: err, Query: []byte(query)}
   232  	}
   233  
   234  	// Also re-write the schema version for nil dirty versions to prevent
   235  	// empty schema version for failed down migration on the first migration
   236  	// See: https://github.com/golang-migrate/migrate/issues/330
   237  	if version >= 0 || (version == database.NilVersion && dirty) {
   238  		query = `INSERT INTO "` + c.config.MigrationsTable + `" (version, dirty) VALUES (?, ?)`
   239  		if err := c.session.Query(query, version, dirty).Exec(); err != nil {
   240  			return &database.Error{OrigErr: err, Query: []byte(query)}
   241  		}
   242  	}
   243  
   244  	return nil
   245  }
   246  
   247  // Return current keyspace version
   248  func (c *Cassandra) Version() (version int, dirty bool, err error) {
   249  	query := `SELECT version, dirty FROM "` + c.config.MigrationsTable + `" LIMIT 1`
   250  	err = c.session.Query(query).Scan(&version, &dirty)
   251  	switch {
   252  	case err == gocql.ErrNotFound:
   253  		return database.NilVersion, false, nil
   254  
   255  	case err != nil:
   256  		if _, ok := err.(*gocql.Error); ok {
   257  			return database.NilVersion, false, nil
   258  		}
   259  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   260  
   261  	default:
   262  		return version, dirty, nil
   263  	}
   264  }
   265  
   266  func (c *Cassandra) Drop() error {
   267  	// select all tables in current schema
   268  	query := fmt.Sprintf(`SELECT table_name from system_schema.tables WHERE keyspace_name='%s'`, c.config.KeyspaceName)
   269  	iter := c.session.Query(query).Iter()
   270  	var tableName string
   271  	for iter.Scan(&tableName) {
   272  		err := c.session.Query(fmt.Sprintf(`DROP TABLE %s`, tableName)).Exec()
   273  		if err != nil {
   274  			return err
   275  		}
   276  	}
   277  
   278  	return nil
   279  }
   280  
   281  // ensureVersionTable checks if versions table exists and, if not, creates it.
   282  // Note that this function locks the database, which deviates from the usual
   283  // convention of "caller locks" in the Cassandra type.
   284  func (c *Cassandra) ensureVersionTable() (err error) {
   285  	if err = c.Lock(); err != nil {
   286  		return err
   287  	}
   288  
   289  	defer func() {
   290  		if e := c.Unlock(); e != nil {
   291  			if err == nil {
   292  				err = e
   293  			} else {
   294  				err = multierror.Append(err, e)
   295  			}
   296  		}
   297  	}()
   298  
   299  	err = c.session.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version bigint, dirty boolean, PRIMARY KEY(version))", c.config.MigrationsTable)).Exec()
   300  	if err != nil {
   301  		return err
   302  	}
   303  	if _, _, err = c.Version(); err != nil {
   304  		return err
   305  	}
   306  	return nil
   307  }
   308  
   309  // ParseConsistency wraps gocql.ParseConsistency
   310  // to return an error instead of a panicking.
   311  func parseConsistency(consistencyStr string) (consistency gocql.Consistency, err error) {
   312  	defer func() {
   313  		if r := recover(); r != nil {
   314  			var ok bool
   315  			err, ok = r.(error)
   316  			if !ok {
   317  				err = fmt.Errorf("Failed to parse consistency \"%s\": %v", consistencyStr, r)
   318  			}
   319  		}
   320  	}()
   321  	consistency = gocql.ParseConsistency(consistencyStr)
   322  
   323  	return consistency, nil
   324  }