github.com/fr-nvriep/migrate/v4@v4.3.2/database/cassandra/cassandra.go (about) 1 package cassandra 2 3 import ( 4 "errors" 5 "fmt" 6 "io" 7 "io/ioutil" 8 nurl "net/url" 9 "strconv" 10 "strings" 11 "time" 12 13 "github.com/gocql/gocql" 14 "github.com/fr-nvriep/migrate/v4/database" 15 "github.com/hashicorp/go-multierror" 16 ) 17 18 func init() { 19 db := new(Cassandra) 20 database.Register("cassandra", db) 21 } 22 23 var DefaultMigrationsTable = "schema_migrations" 24 25 var ( 26 ErrNilConfig = errors.New("no config") 27 ErrNoKeyspace = errors.New("no keyspace provided") 28 ErrDatabaseDirty = errors.New("database is dirty") 29 ErrClosedSession = errors.New("session is closed") 30 ) 31 32 type Config struct { 33 MigrationsTable string 34 KeyspaceName string 35 MultiStatementEnabled bool 36 } 37 38 type Cassandra struct { 39 session *gocql.Session 40 isLocked bool 41 42 // Open and WithInstance need to guarantee that config is never nil 43 config *Config 44 } 45 46 func WithInstance(session *gocql.Session, config *Config) (database.Driver, error) { 47 if config == nil { 48 return nil, ErrNilConfig 49 } else if len(config.KeyspaceName) == 0 { 50 return nil, ErrNoKeyspace 51 } 52 53 if session.Closed() { 54 return nil, ErrClosedSession 55 } 56 57 if len(config.MigrationsTable) == 0 { 58 config.MigrationsTable = DefaultMigrationsTable 59 } 60 61 c := &Cassandra{ 62 session: session, 63 config: config, 64 } 65 66 if err := c.ensureVersionTable(); err != nil { 67 return nil, err 68 } 69 70 return c, nil 71 } 72 73 func (c *Cassandra) Open(url string) (database.Driver, error) { 74 u, err := nurl.Parse(url) 75 if err != nil { 76 return nil, err 77 } 78 79 // Check for missing mandatory attributes 80 if len(u.Path) == 0 { 81 return nil, ErrNoKeyspace 82 } 83 84 cluster := gocql.NewCluster(u.Host) 85 cluster.Keyspace = strings.TrimPrefix(u.Path, "/") 86 cluster.Consistency = gocql.All 87 cluster.Timeout = 1 * time.Minute 88 89 if len(u.Query().Get("username")) > 0 && len(u.Query().Get("password")) > 0 { 90 authenticator := gocql.PasswordAuthenticator{ 91 Username: u.Query().Get("username"), 92 Password: u.Query().Get("password"), 93 } 94 cluster.Authenticator = authenticator 95 } 96 97 // Retrieve query string configuration 98 if len(u.Query().Get("consistency")) > 0 { 99 var consistency gocql.Consistency 100 consistency, err = parseConsistency(u.Query().Get("consistency")) 101 if err != nil { 102 return nil, err 103 } 104 105 cluster.Consistency = consistency 106 } 107 if len(u.Query().Get("protocol")) > 0 { 108 var protoversion int 109 protoversion, err = strconv.Atoi(u.Query().Get("protocol")) 110 if err != nil { 111 return nil, err 112 } 113 cluster.ProtoVersion = protoversion 114 } 115 if len(u.Query().Get("timeout")) > 0 { 116 var timeout time.Duration 117 timeout, err = time.ParseDuration(u.Query().Get("timeout")) 118 if err != nil { 119 return nil, err 120 } 121 cluster.Timeout = timeout 122 } 123 124 if len(u.Query().Get("sslmode")) > 0 && len(u.Query().Get("sslrootcert")) > 0 && len(u.Query().Get("sslcert")) > 0 && len(u.Query().Get("sslkey")) > 0 { 125 if u.Query().Get("sslmode") != "disable" { 126 cluster.SslOpts = &gocql.SslOptions{ 127 CaPath: u.Query().Get("sslrootcert"), 128 CertPath: u.Query().Get("sslcert"), 129 KeyPath: u.Query().Get("sslkey"), 130 } 131 if u.Query().Get("sslmode") == "verify-full" { 132 cluster.SslOpts.EnableHostVerification = true 133 } 134 } 135 } 136 137 session, err := cluster.CreateSession() 138 if err != nil { 139 return nil, err 140 } 141 142 return WithInstance(session, &Config{ 143 KeyspaceName: strings.TrimPrefix(u.Path, "/"), 144 MigrationsTable: u.Query().Get("x-migrations-table"), 145 MultiStatementEnabled: u.Query().Get("x-multi-statement") == "true", 146 }) 147 } 148 149 func (c *Cassandra) Close() error { 150 c.session.Close() 151 return nil 152 } 153 154 func (c *Cassandra) Lock() error { 155 if c.isLocked { 156 return database.ErrLocked 157 } 158 c.isLocked = true 159 return nil 160 } 161 162 func (c *Cassandra) Unlock() error { 163 c.isLocked = false 164 return nil 165 } 166 167 func (c *Cassandra) Run(migration io.Reader) error { 168 migr, err := ioutil.ReadAll(migration) 169 if err != nil { 170 return err 171 } 172 // run migration 173 query := string(migr[:]) 174 175 if c.config.MultiStatementEnabled { 176 // split query by semi-colon 177 queries := strings.Split(query, ";") 178 179 for _, q := range queries { 180 tq := strings.TrimSpace(q) 181 if tq == "" { 182 continue 183 } 184 if err := c.session.Query(tq).Exec(); err != nil { 185 // TODO: cast to Cassandra error and get line number 186 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 187 } 188 } 189 return nil 190 } 191 192 if err := c.session.Query(query).Exec(); err != nil { 193 // TODO: cast to Cassandra error and get line number 194 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 195 } 196 return nil 197 } 198 199 func (c *Cassandra) SetVersion(version int, dirty bool) error { 200 query := `TRUNCATE "` + c.config.MigrationsTable + `"` 201 if err := c.session.Query(query).Exec(); err != nil { 202 return &database.Error{OrigErr: err, Query: []byte(query)} 203 } 204 if version >= 0 { 205 query = `INSERT INTO "` + c.config.MigrationsTable + `" (version, dirty) VALUES (?, ?)` 206 if err := c.session.Query(query, version, dirty).Exec(); err != nil { 207 return &database.Error{OrigErr: err, Query: []byte(query)} 208 } 209 } 210 211 return nil 212 } 213 214 // Return current keyspace version 215 func (c *Cassandra) Version() (version int, dirty bool, err error) { 216 query := `SELECT version, dirty FROM "` + c.config.MigrationsTable + `" LIMIT 1` 217 err = c.session.Query(query).Scan(&version, &dirty) 218 switch { 219 case err == gocql.ErrNotFound: 220 return database.NilVersion, false, nil 221 222 case err != nil: 223 if _, ok := err.(*gocql.Error); ok { 224 return database.NilVersion, false, nil 225 } 226 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 227 228 default: 229 return version, dirty, nil 230 } 231 } 232 233 func (c *Cassandra) Drop() error { 234 // select all tables in current schema 235 query := fmt.Sprintf(`SELECT table_name from system_schema.tables WHERE keyspace_name='%s'`, c.config.KeyspaceName) 236 iter := c.session.Query(query).Iter() 237 var tableName string 238 for iter.Scan(&tableName) { 239 err := c.session.Query(fmt.Sprintf(`DROP TABLE %s`, tableName)).Exec() 240 if err != nil { 241 return err 242 } 243 } 244 245 return nil 246 } 247 248 // ensureVersionTable checks if versions table exists and, if not, creates it. 249 // Note that this function locks the database, which deviates from the usual 250 // convention of "caller locks" in the Cassandra type. 251 func (c *Cassandra) ensureVersionTable() (err error) { 252 if err = c.Lock(); err != nil { 253 return err 254 } 255 256 defer func() { 257 if e := c.Unlock(); e != nil { 258 if err == nil { 259 err = e 260 } else { 261 err = multierror.Append(err, e) 262 } 263 } 264 }() 265 266 err = c.session.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version bigint, dirty boolean, PRIMARY KEY(version))", c.config.MigrationsTable)).Exec() 267 if err != nil { 268 return err 269 } 270 if _, _, err = c.Version(); err != nil { 271 return err 272 } 273 return nil 274 } 275 276 // ParseConsistency wraps gocql.ParseConsistency 277 // to return an error instead of a panicking. 278 func parseConsistency(consistencyStr string) (consistency gocql.Consistency, err error) { 279 defer func() { 280 if r := recover(); r != nil { 281 var ok bool 282 err, ok = r.(error) 283 if !ok { 284 err = fmt.Errorf("Failed to parse consistency \"%s\": %v", consistencyStr, r) 285 } 286 } 287 }() 288 consistency = gocql.ParseConsistency(consistencyStr) 289 290 return consistency, nil 291 }