github.com/ldej/migrate@v3.5.4+incompatible/database/cassandra/cassandra.go (about)

     1  package cassandra
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"io"
     7  	"io/ioutil"
     8  	nurl "net/url"
     9  	"strconv"
    10  	"strings"
    11  	"time"
    12  
    13  	"github.com/gocql/gocql"
    14  	"github.com/golang-migrate/migrate/database"
    15  )
    16  
    17  func init() {
    18  	db := new(Cassandra)
    19  	database.Register("cassandra", db)
    20  }
    21  
    22  var DefaultMigrationsTable = "schema_migrations"
    23  
    24  var (
    25  	ErrNilConfig     = errors.New("no config")
    26  	ErrNoKeyspace    = errors.New("no keyspace provided")
    27  	ErrDatabaseDirty = errors.New("database is dirty")
    28  	ErrClosedSession = errors.New("session is closed")
    29  )
    30  
    31  type Config struct {
    32  	MigrationsTable       string
    33  	KeyspaceName          string
    34  	MultiStatementEnabled bool
    35  }
    36  
    37  type Cassandra struct {
    38  	session  *gocql.Session
    39  	isLocked bool
    40  
    41  	// Open and WithInstance need to guarantee that config is never nil
    42  	config *Config
    43  }
    44  
    45  func WithInstance(session *gocql.Session, config *Config) (database.Driver, error) {
    46  	if config == nil {
    47  		return nil, ErrNilConfig
    48  	} else if len(config.KeyspaceName) == 0 {
    49  		return nil, ErrNoKeyspace
    50  	}
    51  
    52  	if session.Closed() {
    53  		return nil, ErrClosedSession
    54  	}
    55  
    56  	if len(config.MigrationsTable) == 0 {
    57  		config.MigrationsTable = DefaultMigrationsTable
    58  	}
    59  
    60  	c := &Cassandra{
    61  		session: session,
    62  		config:  config,
    63  	}
    64  
    65  	if err := c.ensureVersionTable(); err != nil {
    66  		return nil, err
    67  	}
    68  
    69  	return c, nil
    70  }
    71  
    72  func (c *Cassandra) Open(url string) (database.Driver, error) {
    73  	u, err := nurl.Parse(url)
    74  	if err != nil {
    75  		return nil, err
    76  	}
    77  
    78  	// Check for missing mandatory attributes
    79  	if len(u.Path) == 0 {
    80  		return nil, ErrNoKeyspace
    81  	}
    82  
    83  	cluster := gocql.NewCluster(u.Host)
    84  	cluster.Keyspace = strings.TrimPrefix(u.Path, "/")
    85  	cluster.Consistency = gocql.All
    86  	cluster.Timeout = 1 * time.Minute
    87  
    88  	if len(u.Query().Get("username")) > 0 && len(u.Query().Get("password")) > 0 {
    89  		authenticator := gocql.PasswordAuthenticator{
    90  			Username: u.Query().Get("username"),
    91  			Password: u.Query().Get("password"),
    92  		}
    93  		cluster.Authenticator = authenticator
    94  	}
    95  
    96  	// Retrieve query string configuration
    97  	if len(u.Query().Get("consistency")) > 0 {
    98  		var consistency gocql.Consistency
    99  		consistency, err = parseConsistency(u.Query().Get("consistency"))
   100  		if err != nil {
   101  			return nil, err
   102  		}
   103  
   104  		cluster.Consistency = consistency
   105  	}
   106  	if len(u.Query().Get("protocol")) > 0 {
   107  		var protoversion int
   108  		protoversion, err = strconv.Atoi(u.Query().Get("protocol"))
   109  		if err != nil {
   110  			return nil, err
   111  		}
   112  		cluster.ProtoVersion = protoversion
   113  	}
   114  	if len(u.Query().Get("timeout")) > 0 {
   115  		var timeout time.Duration
   116  		timeout, err = time.ParseDuration(u.Query().Get("timeout"))
   117  		if err != nil {
   118  			return nil, err
   119  		}
   120  		cluster.Timeout = timeout
   121  	}
   122  
   123  	session, err := cluster.CreateSession()
   124  	if err != nil {
   125  		return nil, err
   126  	}
   127  
   128  	return WithInstance(session, &Config{
   129  		KeyspaceName:          strings.TrimPrefix(u.Path, "/"),
   130  		MigrationsTable:       u.Query().Get("x-migrations-table"),
   131  		MultiStatementEnabled: u.Query().Get("x-multi-statement") == "true",
   132  	})
   133  }
   134  
   135  func (c *Cassandra) Close() error {
   136  	c.session.Close()
   137  	return nil
   138  }
   139  
   140  func (c *Cassandra) Lock() error {
   141  	if c.isLocked {
   142  		return database.ErrLocked
   143  	}
   144  	c.isLocked = true
   145  	return nil
   146  }
   147  
   148  func (c *Cassandra) Unlock() error {
   149  	c.isLocked = false
   150  	return nil
   151  }
   152  
   153  func (c *Cassandra) Run(migration io.Reader) error {
   154  	migr, err := ioutil.ReadAll(migration)
   155  	if err != nil {
   156  		return err
   157  	}
   158  	// run migration
   159  	query := string(migr[:])
   160  
   161  	if c.config.MultiStatementEnabled {
   162  		// split query by semi-colon
   163  		queries := strings.Split(query, ";")
   164  
   165  		for _, q := range queries {
   166  			tq := strings.TrimSpace(q)
   167  			if tq == "" {
   168  				continue
   169  			}
   170  			if err := c.session.Query(tq).Exec(); err != nil {
   171  				// TODO: cast to Cassandra error and get line number
   172  				return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   173  			}
   174  		}
   175  		return nil
   176  	}
   177  
   178  	if err := c.session.Query(query).Exec(); err != nil {
   179  		// TODO: cast to Cassandra error and get line number
   180  		return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   181  	}
   182  	return nil
   183  }
   184  
   185  func (c *Cassandra) SetVersion(version int, dirty bool) error {
   186  	query := `TRUNCATE "` + c.config.MigrationsTable + `"`
   187  	if err := c.session.Query(query).Exec(); err != nil {
   188  		return &database.Error{OrigErr: err, Query: []byte(query)}
   189  	}
   190  	if version >= 0 {
   191  		query = `INSERT INTO "` + c.config.MigrationsTable + `" (version, dirty) VALUES (?, ?)`
   192  		if err := c.session.Query(query, version, dirty).Exec(); err != nil {
   193  			return &database.Error{OrigErr: err, Query: []byte(query)}
   194  		}
   195  	}
   196  
   197  	return nil
   198  }
   199  
   200  // Return current keyspace version
   201  func (c *Cassandra) Version() (version int, dirty bool, err error) {
   202  	query := `SELECT version, dirty FROM "` + c.config.MigrationsTable + `" LIMIT 1`
   203  	err = c.session.Query(query).Scan(&version, &dirty)
   204  	switch {
   205  	case err == gocql.ErrNotFound:
   206  		return database.NilVersion, false, nil
   207  
   208  	case err != nil:
   209  		if _, ok := err.(*gocql.Error); ok {
   210  			return database.NilVersion, false, nil
   211  		}
   212  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   213  
   214  	default:
   215  		return version, dirty, nil
   216  	}
   217  }
   218  
   219  func (c *Cassandra) Drop() error {
   220  	// select all tables in current schema
   221  	query := fmt.Sprintf(`SELECT table_name from system_schema.tables WHERE keyspace_name='%s'`, c.config.KeyspaceName)
   222  	iter := c.session.Query(query).Iter()
   223  	var tableName string
   224  	for iter.Scan(&tableName) {
   225  		err := c.session.Query(fmt.Sprintf(`DROP TABLE %s`, tableName)).Exec()
   226  		if err != nil {
   227  			return err
   228  		}
   229  	}
   230  	// Re-create the version table
   231  	return c.ensureVersionTable()
   232  }
   233  
   234  // Ensure version table exists
   235  func (c *Cassandra) ensureVersionTable() error {
   236  	err := c.session.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version bigint, dirty boolean, PRIMARY KEY(version))", c.config.MigrationsTable)).Exec()
   237  	if err != nil {
   238  		return err
   239  	}
   240  	if _, _, err = c.Version(); err != nil {
   241  		return err
   242  	}
   243  	return nil
   244  }
   245  
   246  // ParseConsistency wraps gocql.ParseConsistency
   247  // to return an error instead of a panicking.
   248  func parseConsistency(consistencyStr string) (consistency gocql.Consistency, err error) {
   249  	defer func() {
   250  		if r := recover(); r != nil {
   251  			var ok bool
   252  			err, ok = r.(error)
   253  			if !ok {
   254  				err = fmt.Errorf("Failed to parse consistency \"%s\": %v", consistencyStr, r)
   255  			}
   256  		}
   257  	}()
   258  	consistency = gocql.ParseConsistency(consistencyStr)
   259  
   260  	return consistency, nil
   261  }