github.com/fr-nvriep/migrate/v4@v4.3.2/database/cassandra/cassandra.go (about)

     1  package cassandra
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"io"
     7  	"io/ioutil"
     8  	nurl "net/url"
     9  	"strconv"
    10  	"strings"
    11  	"time"
    12  
    13  	"github.com/gocql/gocql"
    14  	"github.com/fr-nvriep/migrate/v4/database"
    15  	"github.com/hashicorp/go-multierror"
    16  )
    17  
    18  func init() {
    19  	db := new(Cassandra)
    20  	database.Register("cassandra", db)
    21  }
    22  
    23  var DefaultMigrationsTable = "schema_migrations"
    24  
    25  var (
    26  	ErrNilConfig     = errors.New("no config")
    27  	ErrNoKeyspace    = errors.New("no keyspace provided")
    28  	ErrDatabaseDirty = errors.New("database is dirty")
    29  	ErrClosedSession = errors.New("session is closed")
    30  )
    31  
    32  type Config struct {
    33  	MigrationsTable       string
    34  	KeyspaceName          string
    35  	MultiStatementEnabled bool
    36  }
    37  
    38  type Cassandra struct {
    39  	session  *gocql.Session
    40  	isLocked bool
    41  
    42  	// Open and WithInstance need to guarantee that config is never nil
    43  	config *Config
    44  }
    45  
    46  func WithInstance(session *gocql.Session, config *Config) (database.Driver, error) {
    47  	if config == nil {
    48  		return nil, ErrNilConfig
    49  	} else if len(config.KeyspaceName) == 0 {
    50  		return nil, ErrNoKeyspace
    51  	}
    52  
    53  	if session.Closed() {
    54  		return nil, ErrClosedSession
    55  	}
    56  
    57  	if len(config.MigrationsTable) == 0 {
    58  		config.MigrationsTable = DefaultMigrationsTable
    59  	}
    60  
    61  	c := &Cassandra{
    62  		session: session,
    63  		config:  config,
    64  	}
    65  
    66  	if err := c.ensureVersionTable(); err != nil {
    67  		return nil, err
    68  	}
    69  
    70  	return c, nil
    71  }
    72  
    73  func (c *Cassandra) Open(url string) (database.Driver, error) {
    74  	u, err := nurl.Parse(url)
    75  	if err != nil {
    76  		return nil, err
    77  	}
    78  
    79  	// Check for missing mandatory attributes
    80  	if len(u.Path) == 0 {
    81  		return nil, ErrNoKeyspace
    82  	}
    83  
    84  	cluster := gocql.NewCluster(u.Host)
    85  	cluster.Keyspace = strings.TrimPrefix(u.Path, "/")
    86  	cluster.Consistency = gocql.All
    87  	cluster.Timeout = 1 * time.Minute
    88  
    89  	if len(u.Query().Get("username")) > 0 && len(u.Query().Get("password")) > 0 {
    90  		authenticator := gocql.PasswordAuthenticator{
    91  			Username: u.Query().Get("username"),
    92  			Password: u.Query().Get("password"),
    93  		}
    94  		cluster.Authenticator = authenticator
    95  	}
    96  
    97  	// Retrieve query string configuration
    98  	if len(u.Query().Get("consistency")) > 0 {
    99  		var consistency gocql.Consistency
   100  		consistency, err = parseConsistency(u.Query().Get("consistency"))
   101  		if err != nil {
   102  			return nil, err
   103  		}
   104  
   105  		cluster.Consistency = consistency
   106  	}
   107  	if len(u.Query().Get("protocol")) > 0 {
   108  		var protoversion int
   109  		protoversion, err = strconv.Atoi(u.Query().Get("protocol"))
   110  		if err != nil {
   111  			return nil, err
   112  		}
   113  		cluster.ProtoVersion = protoversion
   114  	}
   115  	if len(u.Query().Get("timeout")) > 0 {
   116  		var timeout time.Duration
   117  		timeout, err = time.ParseDuration(u.Query().Get("timeout"))
   118  		if err != nil {
   119  			return nil, err
   120  		}
   121  		cluster.Timeout = timeout
   122  	}
   123  
   124  	if len(u.Query().Get("sslmode")) > 0 && len(u.Query().Get("sslrootcert")) > 0 && len(u.Query().Get("sslcert")) > 0 && len(u.Query().Get("sslkey")) > 0 {
   125  		if u.Query().Get("sslmode") != "disable" {
   126  			cluster.SslOpts = &gocql.SslOptions{
   127  				CaPath:   u.Query().Get("sslrootcert"),
   128  				CertPath: u.Query().Get("sslcert"),
   129  				KeyPath:  u.Query().Get("sslkey"),
   130  			}
   131  			if u.Query().Get("sslmode") == "verify-full" {
   132  				cluster.SslOpts.EnableHostVerification = true
   133  			}
   134  		}
   135  	}
   136  
   137  	session, err := cluster.CreateSession()
   138  	if err != nil {
   139  		return nil, err
   140  	}
   141  
   142  	return WithInstance(session, &Config{
   143  		KeyspaceName:          strings.TrimPrefix(u.Path, "/"),
   144  		MigrationsTable:       u.Query().Get("x-migrations-table"),
   145  		MultiStatementEnabled: u.Query().Get("x-multi-statement") == "true",
   146  	})
   147  }
   148  
   149  func (c *Cassandra) Close() error {
   150  	c.session.Close()
   151  	return nil
   152  }
   153  
   154  func (c *Cassandra) Lock() error {
   155  	if c.isLocked {
   156  		return database.ErrLocked
   157  	}
   158  	c.isLocked = true
   159  	return nil
   160  }
   161  
   162  func (c *Cassandra) Unlock() error {
   163  	c.isLocked = false
   164  	return nil
   165  }
   166  
   167  func (c *Cassandra) Run(migration io.Reader) error {
   168  	migr, err := ioutil.ReadAll(migration)
   169  	if err != nil {
   170  		return err
   171  	}
   172  	// run migration
   173  	query := string(migr[:])
   174  
   175  	if c.config.MultiStatementEnabled {
   176  		// split query by semi-colon
   177  		queries := strings.Split(query, ";")
   178  
   179  		for _, q := range queries {
   180  			tq := strings.TrimSpace(q)
   181  			if tq == "" {
   182  				continue
   183  			}
   184  			if err := c.session.Query(tq).Exec(); err != nil {
   185  				// TODO: cast to Cassandra error and get line number
   186  				return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   187  			}
   188  		}
   189  		return nil
   190  	}
   191  
   192  	if err := c.session.Query(query).Exec(); err != nil {
   193  		// TODO: cast to Cassandra error and get line number
   194  		return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   195  	}
   196  	return nil
   197  }
   198  
   199  func (c *Cassandra) SetVersion(version int, dirty bool) error {
   200  	query := `TRUNCATE "` + c.config.MigrationsTable + `"`
   201  	if err := c.session.Query(query).Exec(); err != nil {
   202  		return &database.Error{OrigErr: err, Query: []byte(query)}
   203  	}
   204  	if version >= 0 {
   205  		query = `INSERT INTO "` + c.config.MigrationsTable + `" (version, dirty) VALUES (?, ?)`
   206  		if err := c.session.Query(query, version, dirty).Exec(); err != nil {
   207  			return &database.Error{OrigErr: err, Query: []byte(query)}
   208  		}
   209  	}
   210  
   211  	return nil
   212  }
   213  
   214  // Return current keyspace version
   215  func (c *Cassandra) Version() (version int, dirty bool, err error) {
   216  	query := `SELECT version, dirty FROM "` + c.config.MigrationsTable + `" LIMIT 1`
   217  	err = c.session.Query(query).Scan(&version, &dirty)
   218  	switch {
   219  	case err == gocql.ErrNotFound:
   220  		return database.NilVersion, false, nil
   221  
   222  	case err != nil:
   223  		if _, ok := err.(*gocql.Error); ok {
   224  			return database.NilVersion, false, nil
   225  		}
   226  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   227  
   228  	default:
   229  		return version, dirty, nil
   230  	}
   231  }
   232  
   233  func (c *Cassandra) Drop() error {
   234  	// select all tables in current schema
   235  	query := fmt.Sprintf(`SELECT table_name from system_schema.tables WHERE keyspace_name='%s'`, c.config.KeyspaceName)
   236  	iter := c.session.Query(query).Iter()
   237  	var tableName string
   238  	for iter.Scan(&tableName) {
   239  		err := c.session.Query(fmt.Sprintf(`DROP TABLE %s`, tableName)).Exec()
   240  		if err != nil {
   241  			return err
   242  		}
   243  	}
   244  
   245  	return nil
   246  }
   247  
   248  // ensureVersionTable checks if versions table exists and, if not, creates it.
   249  // Note that this function locks the database, which deviates from the usual
   250  // convention of "caller locks" in the Cassandra type.
   251  func (c *Cassandra) ensureVersionTable() (err error) {
   252  	if err = c.Lock(); err != nil {
   253  		return err
   254  	}
   255  
   256  	defer func() {
   257  		if e := c.Unlock(); e != nil {
   258  			if err == nil {
   259  				err = e
   260  			} else {
   261  				err = multierror.Append(err, e)
   262  			}
   263  		}
   264  	}()
   265  
   266  	err = c.session.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version bigint, dirty boolean, PRIMARY KEY(version))", c.config.MigrationsTable)).Exec()
   267  	if err != nil {
   268  		return err
   269  	}
   270  	if _, _, err = c.Version(); err != nil {
   271  		return err
   272  	}
   273  	return nil
   274  }
   275  
   276  // ParseConsistency wraps gocql.ParseConsistency
   277  // to return an error instead of a panicking.
   278  func parseConsistency(consistencyStr string) (consistency gocql.Consistency, err error) {
   279  	defer func() {
   280  		if r := recover(); r != nil {
   281  			var ok bool
   282  			err, ok = r.(error)
   283  			if !ok {
   284  				err = fmt.Errorf("Failed to parse consistency \"%s\": %v", consistencyStr, r)
   285  			}
   286  		}
   287  	}()
   288  	consistency = gocql.ParseConsistency(consistencyStr)
   289  
   290  	return consistency, nil
   291  }