github.com/matcornic/migrate@v3.3.2-0.20180717234201-feea45c20506+incompatible/database/cassandra/cassandra.go (about)

     1  package cassandra
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"io"
     7  	"io/ioutil"
     8  	nurl "net/url"
     9  	"strconv"
    10  	"strings"
    11  	"time"
    12  
    13  	"github.com/gocql/gocql"
    14  	"github.com/golang-migrate/migrate/database"
    15  )
    16  
    17  func init() {
    18  	db := new(Cassandra)
    19  	database.Register("cassandra", db)
    20  }
    21  
    22  var DefaultMigrationsTable = "schema_migrations"
    23  
    24  var (
    25  	ErrNilConfig     = errors.New("no config")
    26  	ErrNoKeyspace    = errors.New("no keyspace provided")
    27  	ErrDatabaseDirty = errors.New("database is dirty")
    28  	ErrClosedSession = errors.New("session is closed")
    29  )
    30  
    31  type Config struct {
    32  	MigrationsTable string
    33  	KeyspaceName    string
    34  }
    35  
    36  type Cassandra struct {
    37  	session  *gocql.Session
    38  	isLocked bool
    39  
    40  	// Open and WithInstance need to guarantee that config is never nil
    41  	config *Config
    42  }
    43  
    44  func WithInstance(session *gocql.Session, config *Config) (database.Driver, error) {
    45  	if config == nil {
    46  		return nil, ErrNilConfig
    47  	} else if len(config.KeyspaceName) == 0 {
    48  		return nil, ErrNoKeyspace
    49  	}
    50  
    51  	if session.Closed() {
    52  		return nil, ErrClosedSession
    53  	}
    54  
    55  	if len(config.MigrationsTable) == 0 {
    56  		config.MigrationsTable = DefaultMigrationsTable
    57  	}
    58  
    59  	c := &Cassandra{
    60  		session: session,
    61  		config:  config,
    62  	}
    63  
    64  	if err := c.ensureVersionTable(); err != nil {
    65  		return nil, err
    66  	}
    67  
    68  	return c, nil
    69  }
    70  
    71  func (c *Cassandra) Open(url string) (database.Driver, error) {
    72  	u, err := nurl.Parse(url)
    73  	if err != nil {
    74  		return nil, err
    75  	}
    76  
    77  	// Check for missing mandatory attributes
    78  	if len(u.Path) == 0 {
    79  		return nil, ErrNoKeyspace
    80  	}
    81  
    82  	cluster := gocql.NewCluster(u.Host)
    83  	cluster.Keyspace = strings.TrimPrefix(u.Path, "/")
    84  	cluster.Consistency = gocql.All
    85  	cluster.Timeout = 1 * time.Minute
    86  
    87  	if len(u.Query().Get("username")) > 0 && len(u.Query().Get("password")) > 0 {
    88  		authenticator := gocql.PasswordAuthenticator{
    89  			Username: u.Query().Get("username"),
    90  			Password: u.Query().Get("password"),
    91  		}
    92  		cluster.Authenticator = authenticator
    93  	}
    94  
    95  	// Retrieve query string configuration
    96  	if len(u.Query().Get("consistency")) > 0 {
    97  		var consistency gocql.Consistency
    98  		consistency, err = parseConsistency(u.Query().Get("consistency"))
    99  		if err != nil {
   100  			return nil, err
   101  		}
   102  
   103  		cluster.Consistency = consistency
   104  	}
   105  	if len(u.Query().Get("protocol")) > 0 {
   106  		var protoversion int
   107  		protoversion, err = strconv.Atoi(u.Query().Get("protocol"))
   108  		if err != nil {
   109  			return nil, err
   110  		}
   111  		cluster.ProtoVersion = protoversion
   112  	}
   113  	if len(u.Query().Get("timeout")) > 0 {
   114  		var timeout time.Duration
   115  		timeout, err = time.ParseDuration(u.Query().Get("timeout"))
   116  		if err != nil {
   117  			return nil, err
   118  		}
   119  		cluster.Timeout = timeout
   120  	}
   121  
   122  	session, err := cluster.CreateSession()
   123  	if err != nil {
   124  		return nil, err
   125  	}
   126  
   127  	return WithInstance(session, &Config{
   128  		KeyspaceName:    strings.TrimPrefix(u.Path, "/"),
   129  		MigrationsTable: u.Query().Get("x-migrations-table"),
   130  	})
   131  }
   132  
   133  func (c *Cassandra) Close() error {
   134  	c.session.Close()
   135  	return nil
   136  }
   137  
   138  func (c *Cassandra) Lock() error {
   139  	if c.isLocked {
   140  		return database.ErrLocked
   141  	}
   142  	c.isLocked = true
   143  	return nil
   144  }
   145  
   146  func (c *Cassandra) Unlock() error {
   147  	c.isLocked = false
   148  	return nil
   149  }
   150  
   151  func (c *Cassandra) Run(migration io.Reader) error {
   152  	migr, err := ioutil.ReadAll(migration)
   153  	if err != nil {
   154  		return err
   155  	}
   156  	// run migration
   157  	query := string(migr[:])
   158  	if err := c.session.Query(query).Exec(); err != nil {
   159  		// TODO: cast to Cassandra error and get line number
   160  		return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   161  	}
   162  
   163  	return nil
   164  }
   165  
   166  func (c *Cassandra) SetVersion(version int, dirty bool) error {
   167  	query := `TRUNCATE "` + c.config.MigrationsTable + `"`
   168  	if err := c.session.Query(query).Exec(); err != nil {
   169  		return &database.Error{OrigErr: err, Query: []byte(query)}
   170  	}
   171  	if version >= 0 {
   172  		query = `INSERT INTO "` + c.config.MigrationsTable + `" (version, dirty) VALUES (?, ?)`
   173  		if err := c.session.Query(query, version, dirty).Exec(); err != nil {
   174  			return &database.Error{OrigErr: err, Query: []byte(query)}
   175  		}
   176  	}
   177  
   178  	return nil
   179  }
   180  
   181  // Return current keyspace version
   182  func (c *Cassandra) Version() (version int, dirty bool, err error) {
   183  	query := `SELECT version, dirty FROM "` + c.config.MigrationsTable + `" LIMIT 1`
   184  	err = c.session.Query(query).Scan(&version, &dirty)
   185  	switch {
   186  	case err == gocql.ErrNotFound:
   187  		return database.NilVersion, false, nil
   188  
   189  	case err != nil:
   190  		if _, ok := err.(*gocql.Error); ok {
   191  			return database.NilVersion, false, nil
   192  		}
   193  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   194  
   195  	default:
   196  		return version, dirty, nil
   197  	}
   198  }
   199  
   200  func (c *Cassandra) Drop() error {
   201  	// select all tables in current schema
   202  	query := fmt.Sprintf(`SELECT table_name from system_schema.tables WHERE keyspace_name='%s'`, c.config.KeyspaceName)
   203  	iter := c.session.Query(query).Iter()
   204  	var tableName string
   205  	for iter.Scan(&tableName) {
   206  		err := c.session.Query(fmt.Sprintf(`DROP TABLE %s`, tableName)).Exec()
   207  		if err != nil {
   208  			return err
   209  		}
   210  	}
   211  	// Re-create the version table
   212  	return c.ensureVersionTable()
   213  }
   214  
   215  // Ensure version table exists
   216  func (c *Cassandra) ensureVersionTable() error {
   217  	err := c.session.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version bigint, dirty boolean, PRIMARY KEY(version))", c.config.MigrationsTable)).Exec()
   218  	if err != nil {
   219  		return err
   220  	}
   221  	if _, _, err = c.Version(); err != nil {
   222  		return err
   223  	}
   224  	return nil
   225  }
   226  
   227  // ParseConsistency wraps gocql.ParseConsistency
   228  // to return an error instead of a panicking.
   229  func parseConsistency(consistencyStr string) (consistency gocql.Consistency, err error) {
   230  	defer func() {
   231  		if r := recover(); r != nil {
   232  			var ok bool
   233  			err, ok = r.(error)
   234  			if !ok {
   235  				err = fmt.Errorf("Failed to parse consistency \"%s\": %v", consistencyStr, r)
   236  			}
   237  		}
   238  	}()
   239  	consistency = gocql.ParseConsistency(consistencyStr)
   240  
   241  	return consistency, nil
   242  }