github.com/matcornic/migrate@v3.3.2-0.20180717234201-feea45c20506+incompatible/database/cassandra/cassandra.go (about) 1 package cassandra 2 3 import ( 4 "errors" 5 "fmt" 6 "io" 7 "io/ioutil" 8 nurl "net/url" 9 "strconv" 10 "strings" 11 "time" 12 13 "github.com/gocql/gocql" 14 "github.com/golang-migrate/migrate/database" 15 ) 16 17 func init() { 18 db := new(Cassandra) 19 database.Register("cassandra", db) 20 } 21 22 var DefaultMigrationsTable = "schema_migrations" 23 24 var ( 25 ErrNilConfig = errors.New("no config") 26 ErrNoKeyspace = errors.New("no keyspace provided") 27 ErrDatabaseDirty = errors.New("database is dirty") 28 ErrClosedSession = errors.New("session is closed") 29 ) 30 31 type Config struct { 32 MigrationsTable string 33 KeyspaceName string 34 } 35 36 type Cassandra struct { 37 session *gocql.Session 38 isLocked bool 39 40 // Open and WithInstance need to guarantee that config is never nil 41 config *Config 42 } 43 44 func WithInstance(session *gocql.Session, config *Config) (database.Driver, error) { 45 if config == nil { 46 return nil, ErrNilConfig 47 } else if len(config.KeyspaceName) == 0 { 48 return nil, ErrNoKeyspace 49 } 50 51 if session.Closed() { 52 return nil, ErrClosedSession 53 } 54 55 if len(config.MigrationsTable) == 0 { 56 config.MigrationsTable = DefaultMigrationsTable 57 } 58 59 c := &Cassandra{ 60 session: session, 61 config: config, 62 } 63 64 if err := c.ensureVersionTable(); err != nil { 65 return nil, err 66 } 67 68 return c, nil 69 } 70 71 func (c *Cassandra) Open(url string) (database.Driver, error) { 72 u, err := nurl.Parse(url) 73 if err != nil { 74 return nil, err 75 } 76 77 // Check for missing mandatory attributes 78 if len(u.Path) == 0 { 79 return nil, ErrNoKeyspace 80 } 81 82 cluster := gocql.NewCluster(u.Host) 83 cluster.Keyspace = strings.TrimPrefix(u.Path, "/") 84 cluster.Consistency = gocql.All 85 cluster.Timeout = 1 * time.Minute 86 87 if len(u.Query().Get("username")) > 0 && len(u.Query().Get("password")) > 0 { 88 authenticator := gocql.PasswordAuthenticator{ 89 Username: u.Query().Get("username"), 90 Password: u.Query().Get("password"), 91 } 92 cluster.Authenticator = authenticator 93 } 94 95 // Retrieve query string configuration 96 if len(u.Query().Get("consistency")) > 0 { 97 var consistency gocql.Consistency 98 consistency, err = parseConsistency(u.Query().Get("consistency")) 99 if err != nil { 100 return nil, err 101 } 102 103 cluster.Consistency = consistency 104 } 105 if len(u.Query().Get("protocol")) > 0 { 106 var protoversion int 107 protoversion, err = strconv.Atoi(u.Query().Get("protocol")) 108 if err != nil { 109 return nil, err 110 } 111 cluster.ProtoVersion = protoversion 112 } 113 if len(u.Query().Get("timeout")) > 0 { 114 var timeout time.Duration 115 timeout, err = time.ParseDuration(u.Query().Get("timeout")) 116 if err != nil { 117 return nil, err 118 } 119 cluster.Timeout = timeout 120 } 121 122 session, err := cluster.CreateSession() 123 if err != nil { 124 return nil, err 125 } 126 127 return WithInstance(session, &Config{ 128 KeyspaceName: strings.TrimPrefix(u.Path, "/"), 129 MigrationsTable: u.Query().Get("x-migrations-table"), 130 }) 131 } 132 133 func (c *Cassandra) Close() error { 134 c.session.Close() 135 return nil 136 } 137 138 func (c *Cassandra) Lock() error { 139 if c.isLocked { 140 return database.ErrLocked 141 } 142 c.isLocked = true 143 return nil 144 } 145 146 func (c *Cassandra) Unlock() error { 147 c.isLocked = false 148 return nil 149 } 150 151 func (c *Cassandra) Run(migration io.Reader) error { 152 migr, err := ioutil.ReadAll(migration) 153 if err != nil { 154 return err 155 } 156 // run migration 157 query := string(migr[:]) 158 if err := c.session.Query(query).Exec(); err != nil { 159 // TODO: cast to Cassandra error and get line number 160 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 161 } 162 163 return nil 164 } 165 166 func (c *Cassandra) SetVersion(version int, dirty bool) error { 167 query := `TRUNCATE "` + c.config.MigrationsTable + `"` 168 if err := c.session.Query(query).Exec(); err != nil { 169 return &database.Error{OrigErr: err, Query: []byte(query)} 170 } 171 if version >= 0 { 172 query = `INSERT INTO "` + c.config.MigrationsTable + `" (version, dirty) VALUES (?, ?)` 173 if err := c.session.Query(query, version, dirty).Exec(); err != nil { 174 return &database.Error{OrigErr: err, Query: []byte(query)} 175 } 176 } 177 178 return nil 179 } 180 181 // Return current keyspace version 182 func (c *Cassandra) Version() (version int, dirty bool, err error) { 183 query := `SELECT version, dirty FROM "` + c.config.MigrationsTable + `" LIMIT 1` 184 err = c.session.Query(query).Scan(&version, &dirty) 185 switch { 186 case err == gocql.ErrNotFound: 187 return database.NilVersion, false, nil 188 189 case err != nil: 190 if _, ok := err.(*gocql.Error); ok { 191 return database.NilVersion, false, nil 192 } 193 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 194 195 default: 196 return version, dirty, nil 197 } 198 } 199 200 func (c *Cassandra) Drop() error { 201 // select all tables in current schema 202 query := fmt.Sprintf(`SELECT table_name from system_schema.tables WHERE keyspace_name='%s'`, c.config.KeyspaceName) 203 iter := c.session.Query(query).Iter() 204 var tableName string 205 for iter.Scan(&tableName) { 206 err := c.session.Query(fmt.Sprintf(`DROP TABLE %s`, tableName)).Exec() 207 if err != nil { 208 return err 209 } 210 } 211 // Re-create the version table 212 return c.ensureVersionTable() 213 } 214 215 // Ensure version table exists 216 func (c *Cassandra) ensureVersionTable() error { 217 err := c.session.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version bigint, dirty boolean, PRIMARY KEY(version))", c.config.MigrationsTable)).Exec() 218 if err != nil { 219 return err 220 } 221 if _, _, err = c.Version(); err != nil { 222 return err 223 } 224 return nil 225 } 226 227 // ParseConsistency wraps gocql.ParseConsistency 228 // to return an error instead of a panicking. 229 func parseConsistency(consistencyStr string) (consistency gocql.Consistency, err error) { 230 defer func() { 231 if r := recover(); r != nil { 232 var ok bool 233 err, ok = r.(error) 234 if !ok { 235 err = fmt.Errorf("Failed to parse consistency \"%s\": %v", consistencyStr, r) 236 } 237 } 238 }() 239 consistency = gocql.ParseConsistency(consistencyStr) 240 241 return consistency, nil 242 }