github.com/eatigo/migrate@v3.0.2-0.20210729130915-7610befb1b6b+incompatible/database/cassandra/cassandra.go (about)

     1  package cassandra
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"io/ioutil"
     7  	nurl "net/url"
     8  	"strconv"
     9  	"time"
    10  
    11  	"github.com/gocql/gocql"
    12  	"github.com/eatigo/migrate/database"
    13  )
    14  
    15  func init() {
    16  	db := new(Cassandra)
    17  	database.Register("cassandra", db)
    18  }
    19  
    20  var DefaultMigrationsTable = "schema_migrations"
    21  var dbLocked = false
    22  
    23  var (
    24  	ErrNilConfig     = fmt.Errorf("no config")
    25  	ErrNoKeyspace    = fmt.Errorf("no keyspace provided")
    26  	ErrDatabaseDirty = fmt.Errorf("database is dirty")
    27  )
    28  
    29  type Config struct {
    30  	MigrationsTable string
    31  	KeyspaceName    string
    32  }
    33  
    34  type Cassandra struct {
    35  	session  *gocql.Session
    36  	isLocked bool
    37  
    38  	// Open and WithInstance need to guarantee that config is never nil
    39  	config *Config
    40  }
    41  
    42  func (p *Cassandra) Open(url string) (database.Driver, error) {
    43  	u, err := nurl.Parse(url)
    44  	if err != nil {
    45  		return nil, err
    46  	}
    47  
    48  	// Check for missing mandatory attributes
    49  	if len(u.Path) == 0 {
    50  		return nil, ErrNoKeyspace
    51  	}
    52  
    53  	migrationsTable := u.Query().Get("x-migrations-table")
    54  	if len(migrationsTable) == 0 {
    55  		migrationsTable = DefaultMigrationsTable
    56  	}
    57  
    58  	p.config = &Config{
    59  		KeyspaceName:    u.Path,
    60  		MigrationsTable: migrationsTable,
    61  	}
    62  
    63  	cluster := gocql.NewCluster(u.Host)
    64  	cluster.Keyspace = u.Path[1:len(u.Path)]
    65  	cluster.Consistency = gocql.All
    66  	cluster.Timeout = 1 * time.Minute
    67  
    68  	if len(u.Query().Get("username")) > 0 && len(u.Query().Get("password")) > 0 {
    69  		authenticator := gocql.PasswordAuthenticator{
    70  			Username: u.Query().Get("username"),
    71  			Password: u.Query().Get("password"),
    72  		}
    73  		cluster.Authenticator = authenticator
    74  	}
    75  
    76  	// Retrieve query string configuration
    77  	if len(u.Query().Get("consistency")) > 0 {
    78  		var consistency gocql.Consistency
    79  		consistency, err = parseConsistency(u.Query().Get("consistency"))
    80  		if err != nil {
    81  			return nil, err
    82  		}
    83  
    84  		cluster.Consistency = consistency
    85  	}
    86  	if len(u.Query().Get("protocol")) > 0 {
    87  		var protoversion int
    88  		protoversion, err = strconv.Atoi(u.Query().Get("protocol"))
    89  		if err != nil {
    90  			return nil, err
    91  		}
    92  		cluster.ProtoVersion = protoversion
    93  	}
    94  	if len(u.Query().Get("timeout")) > 0 {
    95  		var timeout time.Duration
    96  		timeout, err = time.ParseDuration(u.Query().Get("timeout"))
    97  		if err != nil {
    98  			return nil, err
    99  		}
   100  		cluster.Timeout = timeout
   101  	}
   102  
   103  	p.session, err = cluster.CreateSession()
   104  
   105  	if err != nil {
   106  		return nil, err
   107  	}
   108  
   109  	if err := p.ensureVersionTable(); err != nil {
   110  		return nil, err
   111  	}
   112  
   113  	return p, nil
   114  }
   115  
   116  func (p *Cassandra) Close() error {
   117  	p.session.Close()
   118  	return nil
   119  }
   120  
   121  func (p *Cassandra) Lock() error {
   122  	if dbLocked {
   123  		return database.ErrLocked
   124  	}
   125  	dbLocked = true
   126  	return nil
   127  }
   128  
   129  func (p *Cassandra) Unlock() error {
   130  	dbLocked = false
   131  	return nil
   132  }
   133  
   134  func (p *Cassandra) Run(migration io.Reader) error {
   135  	migr, err := ioutil.ReadAll(migration)
   136  	if err != nil {
   137  		return err
   138  	}
   139  	// run migration
   140  	query := string(migr[:])
   141  	if err := p.session.Query(query).Exec(); err != nil {
   142  		// TODO: cast to Cassandra error and get line number
   143  		return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   144  	}
   145  
   146  	return nil
   147  }
   148  
   149  func (p *Cassandra) SetVersion(version int, dirty bool) error {
   150  	query := `TRUNCATE "` + p.config.MigrationsTable + `"`
   151  	if err := p.session.Query(query).Exec(); err != nil {
   152  		return &database.Error{OrigErr: err, Query: []byte(query)}
   153  	}
   154  	if version >= 0 {
   155  		query = `INSERT INTO "` + p.config.MigrationsTable + `" (version, dirty) VALUES (?, ?)`
   156  		if err := p.session.Query(query, version, dirty).Exec(); err != nil {
   157  			return &database.Error{OrigErr: err, Query: []byte(query)}
   158  		}
   159  	}
   160  
   161  	return nil
   162  }
   163  
   164  // Return current keyspace version
   165  func (p *Cassandra) Version() (version int, dirty bool, err error) {
   166  	query := `SELECT version, dirty FROM "` + p.config.MigrationsTable + `" LIMIT 1`
   167  	err = p.session.Query(query).Scan(&version, &dirty)
   168  	switch {
   169  	case err == gocql.ErrNotFound:
   170  		return database.NilVersion, false, nil
   171  
   172  	case err != nil:
   173  		if _, ok := err.(*gocql.Error); ok {
   174  			return database.NilVersion, false, nil
   175  		}
   176  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   177  
   178  	default:
   179  		return version, dirty, nil
   180  	}
   181  }
   182  
   183  func (p *Cassandra) Drop() error {
   184  	// select all tables in current schema
   185  	query := fmt.Sprintf(`SELECT table_name from system_schema.tables WHERE keyspace_name='%s'`, p.config.KeyspaceName[1:]) // Skip '/' character
   186  	iter := p.session.Query(query).Iter()
   187  	var tableName string
   188  	for iter.Scan(&tableName) {
   189  		err := p.session.Query(fmt.Sprintf(`DROP TABLE %s`, tableName)).Exec()
   190  		if err != nil {
   191  			return err
   192  		}
   193  	}
   194  	// Re-create the version table
   195  	if err := p.ensureVersionTable(); err != nil {
   196  		return err
   197  	}
   198  	return nil
   199  }
   200  
   201  // Ensure version table exists
   202  func (p *Cassandra) ensureVersionTable() error {
   203  	err := p.session.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version bigint, dirty boolean, PRIMARY KEY(version))", p.config.MigrationsTable)).Exec()
   204  	if err != nil {
   205  		return err
   206  	}
   207  	if _, _, err = p.Version(); err != nil {
   208  		return err
   209  	}
   210  	return nil
   211  }
   212  
   213  // ParseConsistency wraps gocql.ParseConsistency
   214  // to return an error instead of a panicking.
   215  func parseConsistency(consistencyStr string) (consistency gocql.Consistency, err error) {
   216  	defer func() {
   217  		if r := recover(); r != nil {
   218  			var ok bool
   219  			err, ok = r.(error)
   220  			if !ok {
   221  				err = fmt.Errorf("Failed to parse consistency \"%s\": %v", consistencyStr, r)
   222  			}
   223  		}
   224  	}()
   225  	consistency = gocql.ParseConsistency(consistencyStr)
   226  
   227  	return consistency, nil
   228  }