github.com/postmates/migrate@v3.0.2-0.20200730201548-1a6ead3e680d+incompatible/database/cassandra/cassandra.go (about)

     1  package cassandra
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"io/ioutil"
     7  	nurl "net/url"
     8  	"github.com/gocql/gocql"
     9  	"time"
    10  	"github.com/mattes/migrate/database"
    11  	"strconv"
    12  )
    13  
    14  func init() {
    15  	db := new(Cassandra)
    16  	database.Register("cassandra", db)
    17  }
    18  
    19  var DefaultMigrationsTable = "schema_migrations"
    20  var dbLocked = false
    21  
    22  var (
    23  	ErrNilConfig = fmt.Errorf("no config")
    24  	ErrNoKeyspace = fmt.Errorf("no keyspace provided")
    25  	ErrDatabaseDirty = fmt.Errorf("database is dirty")
    26  )
    27  
    28  type Config struct {
    29  	MigrationsTable string
    30  	KeyspaceName    string
    31  }
    32  
    33  type Cassandra struct {
    34  	session  *gocql.Session
    35  	isLocked bool
    36  
    37  	// Open and WithInstance need to guarantee that config is never nil
    38  	config   *Config
    39  }
    40  
    41  func (p *Cassandra) Open(url string) (database.Driver, error) {
    42  	u, err := nurl.Parse(url)
    43  	if err != nil {
    44  		return nil, err
    45  	}
    46  
    47  	// Check for missing mandatory attributes
    48  	if len(u.Path) == 0 {
    49  		return nil, ErrNoKeyspace
    50  	}
    51  
    52  	migrationsTable := u.Query().Get("x-migrations-table")
    53  	if len(migrationsTable) == 0 {
    54  		migrationsTable = DefaultMigrationsTable
    55  	}
    56  
    57  	p.config = &Config{
    58  		KeyspaceName:    u.Path,
    59  		MigrationsTable: migrationsTable,
    60  	}
    61  
    62  	cluster := gocql.NewCluster(u.Host)
    63  	cluster.Keyspace = u.Path[1:len(u.Path)]
    64  	cluster.Consistency = gocql.All
    65  	cluster.Timeout = 1 * time.Minute
    66  
    67  
    68  	// Retrieve query string configuration
    69  	if len(u.Query().Get("consistency")) > 0 {
    70  		var consistency gocql.Consistency
    71  		consistency, err = parseConsistency(u.Query().Get("consistency"))
    72  		if err != nil {
    73  			return nil, err
    74  		}
    75  
    76  		cluster.Consistency = consistency
    77  	}
    78  	if len(u.Query().Get("protocol")) > 0 {
    79  		var protoversion int
    80  		protoversion, err = strconv.Atoi(u.Query().Get("protocol"))
    81  		if err != nil {
    82  			return nil, err
    83  		}
    84  		cluster.ProtoVersion = protoversion
    85  	}
    86  	if len(u.Query().Get("timeout")) > 0 {
    87  		var timeout time.Duration
    88  		timeout, err = time.ParseDuration(u.Query().Get("timeout"))
    89  		if err != nil {
    90  			return nil, err
    91  		}
    92  		cluster.Timeout = timeout
    93  	}
    94  
    95  	p.session, err = cluster.CreateSession()
    96  
    97  	if err != nil {
    98  		return nil, err
    99  	}
   100  
   101  	if err := p.ensureVersionTable(); err != nil {
   102  		return nil, err
   103  	}
   104  
   105  	return p, nil
   106  }
   107  
   108  func (p *Cassandra) Close() error {
   109  	p.session.Close()
   110  	return nil
   111  }
   112  
   113  func (p *Cassandra) Lock() error {
   114  	if (dbLocked) {
   115  		return database.ErrLocked
   116  	}
   117  	dbLocked = true
   118  	return nil
   119  }
   120  
   121  func (p *Cassandra) Unlock() error {
   122  	dbLocked = false
   123  	return nil
   124  }
   125  
   126  func (p *Cassandra) Run(migration io.Reader) error {
   127  	migr, err := ioutil.ReadAll(migration)
   128  	if err != nil {
   129  		return err
   130  	}
   131  	// run migration
   132  	query := string(migr[:])
   133  	if err := p.session.Query(query).Exec(); err != nil {
   134  		// TODO: cast to Cassandra error and get line number
   135  		return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   136  	}
   137  
   138  	return nil
   139  }
   140  
   141  func (p *Cassandra) SetVersion(version int, dirty bool) error {
   142  	query := `TRUNCATE "` + p.config.MigrationsTable + `"`
   143  	if err := p.session.Query(query).Exec(); err != nil {
   144  		return &database.Error{OrigErr: err, Query: []byte(query)}
   145  	}
   146  	if version >= 0 {
   147  		query = `INSERT INTO "` + p.config.MigrationsTable + `" (version, dirty) VALUES (?, ?)`
   148  		if err := p.session.Query(query, version, dirty).Exec(); err != nil {
   149  			return &database.Error{OrigErr: err, Query: []byte(query)}
   150  		}
   151  	}
   152  
   153  	return nil
   154  }
   155  
   156  
   157  // Return current keyspace version
   158  func (p *Cassandra) Version() (version int, dirty bool, err error) {
   159  	query := `SELECT version, dirty FROM "` + p.config.MigrationsTable + `" LIMIT 1`
   160  	err = p.session.Query(query).Scan(&version, &dirty)
   161  	switch {
   162  	case err == gocql.ErrNotFound:
   163  		return database.NilVersion, false, nil
   164  
   165  	case err != nil:
   166  		if _, ok := err.(*gocql.Error); ok {
   167  			return database.NilVersion, false, nil
   168  		}
   169  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   170  
   171  	default:
   172  		return version, dirty, nil
   173  	}
   174  }
   175  
   176  func (p *Cassandra) Drop() error {
   177  	// select all tables in current schema
   178  	query := fmt.Sprintf(`SELECT table_name from system_schema.tables WHERE keyspace_name='%s'`, p.config.KeyspaceName[1:]) // Skip '/' character
   179  	iter := p.session.Query(query).Iter()
   180  	var tableName string
   181  	for iter.Scan(&tableName) {
   182  		err := p.session.Query(fmt.Sprintf(`DROP TABLE %s`, tableName)).Exec()
   183  		if err != nil {
   184  			return err
   185  		}
   186  	}
   187  	// Re-create the version table
   188  	if err := p.ensureVersionTable(); err != nil {
   189  		return err
   190  	}
   191  	return nil
   192  }
   193  
   194  
   195  // Ensure version table exists
   196  func (p *Cassandra) ensureVersionTable() error {
   197  	err := p.session.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version bigint, dirty boolean, PRIMARY KEY(version))", p.config.MigrationsTable)).Exec()
   198  	if err != nil {
   199  		return err
   200  	}
   201  	if _, _, err = p.Version(); err != nil {
   202  		return err
   203  	}
   204  	return nil
   205  }
   206  
   207  
   208  // ParseConsistency wraps gocql.ParseConsistency
   209  // to return an error instead of a panicking.
   210  func parseConsistency(consistencyStr string) (consistency gocql.Consistency, err error) {
   211  	defer func() {
   212  		if r := recover(); r != nil {
   213  			var ok bool
   214  			err, ok = r.(error)
   215  			if !ok {
   216  				err = fmt.Errorf("Failed to parse consistency \"%s\": %v", consistencyStr, r)
   217  			}
   218  		}
   219  	}()
   220  	consistency = gocql.ParseConsistency(consistencyStr)
   221  
   222  	return consistency, nil
   223  }