github.com/ldej/migrate@v3.5.4+incompatible/database/cassandra/cassandra.go (about) 1 package cassandra 2 3 import ( 4 "errors" 5 "fmt" 6 "io" 7 "io/ioutil" 8 nurl "net/url" 9 "strconv" 10 "strings" 11 "time" 12 13 "github.com/gocql/gocql" 14 "github.com/golang-migrate/migrate/database" 15 ) 16 17 func init() { 18 db := new(Cassandra) 19 database.Register("cassandra", db) 20 } 21 22 var DefaultMigrationsTable = "schema_migrations" 23 24 var ( 25 ErrNilConfig = errors.New("no config") 26 ErrNoKeyspace = errors.New("no keyspace provided") 27 ErrDatabaseDirty = errors.New("database is dirty") 28 ErrClosedSession = errors.New("session is closed") 29 ) 30 31 type Config struct { 32 MigrationsTable string 33 KeyspaceName string 34 MultiStatementEnabled bool 35 } 36 37 type Cassandra struct { 38 session *gocql.Session 39 isLocked bool 40 41 // Open and WithInstance need to guarantee that config is never nil 42 config *Config 43 } 44 45 func WithInstance(session *gocql.Session, config *Config) (database.Driver, error) { 46 if config == nil { 47 return nil, ErrNilConfig 48 } else if len(config.KeyspaceName) == 0 { 49 return nil, ErrNoKeyspace 50 } 51 52 if session.Closed() { 53 return nil, ErrClosedSession 54 } 55 56 if len(config.MigrationsTable) == 0 { 57 config.MigrationsTable = DefaultMigrationsTable 58 } 59 60 c := &Cassandra{ 61 session: session, 62 config: config, 63 } 64 65 if err := c.ensureVersionTable(); err != nil { 66 return nil, err 67 } 68 69 return c, nil 70 } 71 72 func (c *Cassandra) Open(url string) (database.Driver, error) { 73 u, err := nurl.Parse(url) 74 if err != nil { 75 return nil, err 76 } 77 78 // Check for missing mandatory attributes 79 if len(u.Path) == 0 { 80 return nil, ErrNoKeyspace 81 } 82 83 cluster := gocql.NewCluster(u.Host) 84 cluster.Keyspace = strings.TrimPrefix(u.Path, "/") 85 cluster.Consistency = gocql.All 86 cluster.Timeout = 1 * time.Minute 87 88 if len(u.Query().Get("username")) > 0 && len(u.Query().Get("password")) > 0 { 89 authenticator := gocql.PasswordAuthenticator{ 90 Username: u.Query().Get("username"), 91 Password: u.Query().Get("password"), 92 } 93 cluster.Authenticator = authenticator 94 } 95 96 // Retrieve query string configuration 97 if len(u.Query().Get("consistency")) > 0 { 98 var consistency gocql.Consistency 99 consistency, err = parseConsistency(u.Query().Get("consistency")) 100 if err != nil { 101 return nil, err 102 } 103 104 cluster.Consistency = consistency 105 } 106 if len(u.Query().Get("protocol")) > 0 { 107 var protoversion int 108 protoversion, err = strconv.Atoi(u.Query().Get("protocol")) 109 if err != nil { 110 return nil, err 111 } 112 cluster.ProtoVersion = protoversion 113 } 114 if len(u.Query().Get("timeout")) > 0 { 115 var timeout time.Duration 116 timeout, err = time.ParseDuration(u.Query().Get("timeout")) 117 if err != nil { 118 return nil, err 119 } 120 cluster.Timeout = timeout 121 } 122 123 session, err := cluster.CreateSession() 124 if err != nil { 125 return nil, err 126 } 127 128 return WithInstance(session, &Config{ 129 KeyspaceName: strings.TrimPrefix(u.Path, "/"), 130 MigrationsTable: u.Query().Get("x-migrations-table"), 131 MultiStatementEnabled: u.Query().Get("x-multi-statement") == "true", 132 }) 133 } 134 135 func (c *Cassandra) Close() error { 136 c.session.Close() 137 return nil 138 } 139 140 func (c *Cassandra) Lock() error { 141 if c.isLocked { 142 return database.ErrLocked 143 } 144 c.isLocked = true 145 return nil 146 } 147 148 func (c *Cassandra) Unlock() error { 149 c.isLocked = false 150 return nil 151 } 152 153 func (c *Cassandra) Run(migration io.Reader) error { 154 migr, err := ioutil.ReadAll(migration) 155 if err != nil { 156 return err 157 } 158 // run migration 159 query := string(migr[:]) 160 161 if c.config.MultiStatementEnabled { 162 // split query by semi-colon 163 queries := strings.Split(query, ";") 164 165 for _, q := range queries { 166 tq := strings.TrimSpace(q) 167 if tq == "" { 168 continue 169 } 170 if err := c.session.Query(tq).Exec(); err != nil { 171 // TODO: cast to Cassandra error and get line number 172 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 173 } 174 } 175 return nil 176 } 177 178 if err := c.session.Query(query).Exec(); err != nil { 179 // TODO: cast to Cassandra error and get line number 180 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 181 } 182 return nil 183 } 184 185 func (c *Cassandra) SetVersion(version int, dirty bool) error { 186 query := `TRUNCATE "` + c.config.MigrationsTable + `"` 187 if err := c.session.Query(query).Exec(); err != nil { 188 return &database.Error{OrigErr: err, Query: []byte(query)} 189 } 190 if version >= 0 { 191 query = `INSERT INTO "` + c.config.MigrationsTable + `" (version, dirty) VALUES (?, ?)` 192 if err := c.session.Query(query, version, dirty).Exec(); err != nil { 193 return &database.Error{OrigErr: err, Query: []byte(query)} 194 } 195 } 196 197 return nil 198 } 199 200 // Return current keyspace version 201 func (c *Cassandra) Version() (version int, dirty bool, err error) { 202 query := `SELECT version, dirty FROM "` + c.config.MigrationsTable + `" LIMIT 1` 203 err = c.session.Query(query).Scan(&version, &dirty) 204 switch { 205 case err == gocql.ErrNotFound: 206 return database.NilVersion, false, nil 207 208 case err != nil: 209 if _, ok := err.(*gocql.Error); ok { 210 return database.NilVersion, false, nil 211 } 212 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 213 214 default: 215 return version, dirty, nil 216 } 217 } 218 219 func (c *Cassandra) Drop() error { 220 // select all tables in current schema 221 query := fmt.Sprintf(`SELECT table_name from system_schema.tables WHERE keyspace_name='%s'`, c.config.KeyspaceName) 222 iter := c.session.Query(query).Iter() 223 var tableName string 224 for iter.Scan(&tableName) { 225 err := c.session.Query(fmt.Sprintf(`DROP TABLE %s`, tableName)).Exec() 226 if err != nil { 227 return err 228 } 229 } 230 // Re-create the version table 231 return c.ensureVersionTable() 232 } 233 234 // Ensure version table exists 235 func (c *Cassandra) ensureVersionTable() error { 236 err := c.session.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version bigint, dirty boolean, PRIMARY KEY(version))", c.config.MigrationsTable)).Exec() 237 if err != nil { 238 return err 239 } 240 if _, _, err = c.Version(); err != nil { 241 return err 242 } 243 return nil 244 } 245 246 // ParseConsistency wraps gocql.ParseConsistency 247 // to return an error instead of a panicking. 248 func parseConsistency(consistencyStr string) (consistency gocql.Consistency, err error) { 249 defer func() { 250 if r := recover(); r != nil { 251 var ok bool 252 err, ok = r.(error) 253 if !ok { 254 err = fmt.Errorf("Failed to parse consistency \"%s\": %v", consistencyStr, r) 255 } 256 } 257 }() 258 consistency = gocql.ParseConsistency(consistencyStr) 259 260 return consistency, nil 261 }