github.com/dhui/migrate@v3.4.0+incompatible/database/cassandra/cassandra.go (about) 1 package cassandra 2 3 import ( 4 "errors" 5 "fmt" 6 "io" 7 "io/ioutil" 8 nurl "net/url" 9 "strconv" 10 "strings" 11 "time" 12 13 "github.com/gocql/gocql" 14 "github.com/golang-migrate/migrate/database" 15 ) 16 17 func init() { 18 db := new(Cassandra) 19 database.Register("cassandra", db) 20 } 21 22 var DefaultMigrationsTable = "schema_migrations" 23 24 var ( 25 ErrNilConfig = errors.New("no config") 26 ErrNoKeyspace = errors.New("no keyspace provided") 27 ErrDatabaseDirty = errors.New("database is dirty") 28 ErrClosedSession = errors.New("session is closed") 29 ) 30 31 type Config struct { 32 MigrationsTable string 33 KeyspaceName string 34 MultiStatementEnabled bool 35 } 36 37 type Cassandra struct { 38 session *gocql.Session 39 isLocked bool 40 41 // Open and WithInstance need to guarantee that config is never nil 42 config *Config 43 } 44 45 func WithInstance(session *gocql.Session, config *Config) (database.Driver, error) { 46 if config == nil { 47 return nil, ErrNilConfig 48 } else if len(config.KeyspaceName) == 0 { 49 return nil, ErrNoKeyspace 50 } 51 52 if session.Closed() { 53 return nil, ErrClosedSession 54 } 55 56 if len(config.MigrationsTable) == 0 { 57 config.MigrationsTable = DefaultMigrationsTable 58 } 59 60 c := &Cassandra{ 61 session: session, 62 config: config, 63 } 64 65 if err := c.ensureVersionTable(); err != nil { 66 return nil, err 67 } 68 69 return c, nil 70 } 71 72 func (c *Cassandra) Open(url string) (database.Driver, error) { 73 u, err := nurl.Parse(url) 74 if err != nil { 75 return nil, err 76 } 77 78 // Check for missing mandatory attributes 79 if len(u.Path) == 0 { 80 return nil, ErrNoKeyspace 81 } 82 83 cluster := gocql.NewCluster(u.Host) 84 cluster.Keyspace = strings.TrimPrefix(u.Path, "/") 85 cluster.Consistency = gocql.All 86 cluster.Timeout = 1 * time.Minute 87 88 if len(u.Query().Get("username")) > 0 && len(u.Query().Get("password")) > 0 { 89 authenticator := gocql.PasswordAuthenticator{ 90 Username: u.Query().Get("username"), 91 Password: u.Query().Get("password"), 92 } 93 cluster.Authenticator = authenticator 94 } 95 96 // Retrieve query string configuration 97 if len(u.Query().Get("consistency")) > 0 { 98 var consistency gocql.Consistency 99 consistency, err = parseConsistency(u.Query().Get("consistency")) 100 if err != nil { 101 return nil, err 102 } 103 104 cluster.Consistency = consistency 105 } 106 if len(u.Query().Get("protocol")) > 0 { 107 var protoversion int 108 protoversion, err = strconv.Atoi(u.Query().Get("protocol")) 109 if err != nil { 110 return nil, err 111 } 112 cluster.ProtoVersion = protoversion 113 } 114 if len(u.Query().Get("timeout")) > 0 { 115 var timeout time.Duration 116 timeout, err = time.ParseDuration(u.Query().Get("timeout")) 117 if err != nil { 118 return nil, err 119 } 120 cluster.Timeout = timeout 121 } 122 123 session, err := cluster.CreateSession() 124 if err != nil { 125 return nil, err 126 } 127 128 return WithInstance(session, &Config{ 129 KeyspaceName: strings.TrimPrefix(u.Path, "/"), 130 MigrationsTable: u.Query().Get("x-migrations-table"), 131 MultiStatementEnabled: u.Query().Get("x-multi-statement") == "true", 132 }) 133 } 134 135 func (c *Cassandra) Close() error { 136 c.session.Close() 137 return nil 138 } 139 140 func (c *Cassandra) Lock() error { 141 if c.isLocked { 142 return database.ErrLocked 143 } 144 c.isLocked = true 145 return nil 146 } 147 148 func (c *Cassandra) Unlock() error { 149 c.isLocked = false 150 return nil 151 } 152 153 func (c *Cassandra) Run(migration io.Reader) error { 154 migr, err := ioutil.ReadAll(migration) 155 if err != nil { 156 return err 157 } 158 // run migration 159 query := string(migr[:]) 160 161 if c.config.MultiStatementEnabled { 162 // split query by semi-colon 163 queries := strings.Split(query, ";") 164 165 for _, q := range(queries) { 166 tq := strings.TrimSpace(q) 167 if (tq == "") { continue } 168 if err := c.session.Query(tq).Exec(); err != nil { 169 // TODO: cast to Cassandra error and get line number 170 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 171 } 172 } 173 return nil 174 } 175 176 if err := c.session.Query(query).Exec(); err != nil { 177 // TODO: cast to Cassandra error and get line number 178 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 179 } 180 return nil 181 } 182 183 func (c *Cassandra) SetVersion(version int, dirty bool) error { 184 query := `TRUNCATE "` + c.config.MigrationsTable + `"` 185 if err := c.session.Query(query).Exec(); err != nil { 186 return &database.Error{OrigErr: err, Query: []byte(query)} 187 } 188 if version >= 0 { 189 query = `INSERT INTO "` + c.config.MigrationsTable + `" (version, dirty) VALUES (?, ?)` 190 if err := c.session.Query(query, version, dirty).Exec(); err != nil { 191 return &database.Error{OrigErr: err, Query: []byte(query)} 192 } 193 } 194 195 return nil 196 } 197 198 // Return current keyspace version 199 func (c *Cassandra) Version() (version int, dirty bool, err error) { 200 query := `SELECT version, dirty FROM "` + c.config.MigrationsTable + `" LIMIT 1` 201 err = c.session.Query(query).Scan(&version, &dirty) 202 switch { 203 case err == gocql.ErrNotFound: 204 return database.NilVersion, false, nil 205 206 case err != nil: 207 if _, ok := err.(*gocql.Error); ok { 208 return database.NilVersion, false, nil 209 } 210 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 211 212 default: 213 return version, dirty, nil 214 } 215 } 216 217 func (c *Cassandra) Drop() error { 218 // select all tables in current schema 219 query := fmt.Sprintf(`SELECT table_name from system_schema.tables WHERE keyspace_name='%s'`, c.config.KeyspaceName) 220 iter := c.session.Query(query).Iter() 221 var tableName string 222 for iter.Scan(&tableName) { 223 err := c.session.Query(fmt.Sprintf(`DROP TABLE %s`, tableName)).Exec() 224 if err != nil { 225 return err 226 } 227 } 228 // Re-create the version table 229 return c.ensureVersionTable() 230 } 231 232 // Ensure version table exists 233 func (c *Cassandra) ensureVersionTable() error { 234 err := c.session.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version bigint, dirty boolean, PRIMARY KEY(version))", c.config.MigrationsTable)).Exec() 235 if err != nil { 236 return err 237 } 238 if _, _, err = c.Version(); err != nil { 239 return err 240 } 241 return nil 242 } 243 244 // ParseConsistency wraps gocql.ParseConsistency 245 // to return an error instead of a panicking. 246 func parseConsistency(consistencyStr string) (consistency gocql.Consistency, err error) { 247 defer func() { 248 if r := recover(); r != nil { 249 var ok bool 250 err, ok = r.(error) 251 if !ok { 252 err = fmt.Errorf("Failed to parse consistency \"%s\": %v", consistencyStr, r) 253 } 254 } 255 }() 256 consistency = gocql.ParseConsistency(consistencyStr) 257 258 return consistency, nil 259 }