github.com/nokia/migrate/v4@v4.16.0/database/cassandra/cassandra.go (about) 1 package cassandra 2 3 import ( 4 "errors" 5 "fmt" 6 "io" 7 "io/ioutil" 8 nurl "net/url" 9 "strconv" 10 "strings" 11 "time" 12 13 "go.uber.org/atomic" 14 15 "github.com/gocql/gocql" 16 "github.com/hashicorp/go-multierror" 17 "github.com/nokia/migrate/v4/database" 18 "github.com/nokia/migrate/v4/database/multistmt" 19 "github.com/nokia/migrate/v4/source" 20 ) 21 22 func init() { 23 db := new(Cassandra) 24 database.Register("cassandra", db) 25 } 26 27 var ( 28 multiStmtDelimiter = []byte(";") 29 30 DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB 31 ) 32 33 var DefaultMigrationsTable = "schema_migrations" 34 35 var ( 36 ErrNilConfig = errors.New("no config") 37 ErrNoKeyspace = errors.New("no keyspace provided") 38 ErrDatabaseDirty = errors.New("database is dirty") 39 ErrClosedSession = errors.New("session is closed") 40 ) 41 42 type Config struct { 43 MigrationsTable string 44 KeyspaceName string 45 MultiStatementEnabled bool 46 MultiStatementMaxSize int 47 } 48 49 type Cassandra struct { 50 session *gocql.Session 51 isLocked atomic.Bool 52 53 // Open and WithInstance need to guarantee that config is never nil 54 config *Config 55 } 56 57 func WithInstance(session *gocql.Session, config *Config) (database.Driver, error) { 58 if config == nil { 59 return nil, ErrNilConfig 60 } else if len(config.KeyspaceName) == 0 { 61 return nil, ErrNoKeyspace 62 } 63 64 if session.Closed() { 65 return nil, ErrClosedSession 66 } 67 68 if len(config.MigrationsTable) == 0 { 69 config.MigrationsTable = DefaultMigrationsTable 70 } 71 72 if config.MultiStatementMaxSize <= 0 { 73 config.MultiStatementMaxSize = DefaultMultiStatementMaxSize 74 } 75 76 c := &Cassandra{ 77 session: session, 78 config: config, 79 } 80 81 if err := c.ensureVersionTable(); err != nil { 82 return nil, err 83 } 84 85 return c, nil 86 } 87 88 func (c *Cassandra) Open(url string) (database.Driver, error) { 89 u, err := nurl.Parse(url) 90 if err != nil { 91 return nil, err 92 } 93 94 // Check for missing mandatory attributes 95 if len(u.Path) == 0 { 96 return nil, ErrNoKeyspace 97 } 98 99 cluster := gocql.NewCluster(u.Host) 100 cluster.Keyspace = strings.TrimPrefix(u.Path, "/") 101 cluster.Consistency = gocql.All 102 cluster.Timeout = 1 * time.Minute 103 104 if len(u.Query().Get("username")) > 0 && len(u.Query().Get("password")) > 0 { 105 authenticator := gocql.PasswordAuthenticator{ 106 Username: u.Query().Get("username"), 107 Password: u.Query().Get("password"), 108 } 109 cluster.Authenticator = authenticator 110 } 111 112 // Retrieve query string configuration 113 if len(u.Query().Get("consistency")) > 0 { 114 var consistency gocql.Consistency 115 consistency, err = parseConsistency(u.Query().Get("consistency")) 116 if err != nil { 117 return nil, err 118 } 119 120 cluster.Consistency = consistency 121 } 122 if len(u.Query().Get("protocol")) > 0 { 123 var protoversion int 124 protoversion, err = strconv.Atoi(u.Query().Get("protocol")) 125 if err != nil { 126 return nil, err 127 } 128 cluster.ProtoVersion = protoversion 129 } 130 if len(u.Query().Get("timeout")) > 0 { 131 var timeout time.Duration 132 timeout, err = time.ParseDuration(u.Query().Get("timeout")) 133 if err != nil { 134 return nil, err 135 } 136 cluster.Timeout = timeout 137 } 138 139 if len(u.Query().Get("sslmode")) > 0 { 140 if u.Query().Get("sslmode") != "disable" { 141 sslOpts := &gocql.SslOptions{} 142 143 if len(u.Query().Get("sslrootcert")) > 0 { 144 sslOpts.CaPath = u.Query().Get("sslrootcert") 145 } 146 if len(u.Query().Get("sslcert")) > 0 { 147 sslOpts.CertPath = u.Query().Get("sslcert") 148 } 149 if len(u.Query().Get("sslkey")) > 0 { 150 sslOpts.KeyPath = u.Query().Get("sslkey") 151 } 152 153 if u.Query().Get("sslmode") == "verify-full" { 154 sslOpts.EnableHostVerification = true 155 } 156 157 cluster.SslOpts = sslOpts 158 } 159 } 160 161 if len(u.Query().Get("disable-host-lookup")) > 0 { 162 if flag, err := strconv.ParseBool(u.Query().Get("disable-host-lookup")); err != nil && flag { 163 cluster.DisableInitialHostLookup = true 164 } else if err != nil { 165 return nil, err 166 } 167 } 168 169 session, err := cluster.CreateSession() 170 if err != nil { 171 return nil, err 172 } 173 174 multiStatementMaxSize := DefaultMultiStatementMaxSize 175 if s := u.Query().Get("x-multi-statement-max-size"); len(s) > 0 { 176 multiStatementMaxSize, err = strconv.Atoi(s) 177 if err != nil { 178 return nil, err 179 } 180 } 181 182 return WithInstance(session, &Config{ 183 KeyspaceName: strings.TrimPrefix(u.Path, "/"), 184 MigrationsTable: u.Query().Get("x-migrations-table"), 185 MultiStatementEnabled: u.Query().Get("x-multi-statement") == "true", 186 MultiStatementMaxSize: multiStatementMaxSize, 187 }) 188 } 189 190 func (c *Cassandra) Close() error { 191 c.session.Close() 192 return nil 193 } 194 195 func (c *Cassandra) Lock() error { 196 if !c.isLocked.CAS(false, true) { 197 return database.ErrLocked 198 } 199 return nil 200 } 201 202 func (c *Cassandra) Unlock() error { 203 if !c.isLocked.CAS(true, false) { 204 return database.ErrNotLocked 205 } 206 return nil 207 } 208 209 func (c *Cassandra) Run(migration io.Reader) error { 210 if c.config.MultiStatementEnabled { 211 var err error 212 if e := multistmt.Parse(migration, multiStmtDelimiter, c.config.MultiStatementMaxSize, func(m []byte) bool { 213 tq := strings.TrimSpace(string(m)) 214 if tq == "" { 215 return true 216 } 217 if e := c.session.Query(tq).Exec(); e != nil { 218 err = database.Error{OrigErr: e, Err: "migration failed", Query: m} 219 return false 220 } 221 return true 222 }); e != nil { 223 return e 224 } 225 return err 226 } 227 228 migr, err := ioutil.ReadAll(migration) 229 if err != nil { 230 return err 231 } 232 // run migration 233 if err := c.session.Query(string(migr)).Exec(); err != nil { 234 // TODO: cast to Cassandra error and get line number 235 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 236 } 237 return nil 238 } 239 240 func (c *Cassandra) RunFunctionMigration(fn source.MigrationFunc) error { 241 return database.ErrNotImpl 242 } 243 244 func (c *Cassandra) SetVersion(version int, dirty bool) error { 245 // DELETE instead of TRUNCATE because AWS Keyspaces does not support it 246 // see: https://docs.aws.amazon.com/keyspaces/latest/devguide/cassandra-apis.html 247 squery := `SELECT version FROM "` + c.config.MigrationsTable + `"` 248 dquery := `DELETE FROM "` + c.config.MigrationsTable + `" WHERE version = ?` 249 iter := c.session.Query(squery).Iter() 250 var previous int 251 for iter.Scan(&previous) { 252 if err := c.session.Query(dquery, previous).Exec(); err != nil { 253 return &database.Error{OrigErr: err, Query: []byte(dquery)} 254 } 255 } 256 if err := iter.Close(); err != nil { 257 return &database.Error{OrigErr: err, Query: []byte(squery)} 258 } 259 260 // Also re-write the schema version for nil dirty versions to prevent 261 // empty schema version for failed down migration on the first migration 262 // See: https://github.com/nokia/migrate/issues/330 263 if version >= 0 || (version == database.NilVersion && dirty) { 264 query := `INSERT INTO "` + c.config.MigrationsTable + `" (version, dirty) VALUES (?, ?)` 265 if err := c.session.Query(query, version, dirty).Exec(); err != nil { 266 return &database.Error{OrigErr: err, Query: []byte(query)} 267 } 268 } 269 270 return nil 271 } 272 273 // Return current keyspace version 274 func (c *Cassandra) Version() (version int, dirty bool, err error) { 275 query := `SELECT version, dirty FROM "` + c.config.MigrationsTable + `" LIMIT 1` 276 err = c.session.Query(query).Scan(&version, &dirty) 277 switch { 278 case err == gocql.ErrNotFound: 279 return database.NilVersion, false, nil 280 281 case err != nil: 282 if _, ok := err.(*gocql.Error); ok { 283 return database.NilVersion, false, nil 284 } 285 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 286 287 default: 288 return version, dirty, nil 289 } 290 } 291 292 func (c *Cassandra) Drop() error { 293 // select all tables in current schema 294 query := fmt.Sprintf(`SELECT table_name from system_schema.tables WHERE keyspace_name='%s'`, c.config.KeyspaceName) 295 iter := c.session.Query(query).Iter() 296 var tableName string 297 for iter.Scan(&tableName) { 298 err := c.session.Query(fmt.Sprintf(`DROP TABLE %s`, tableName)).Exec() 299 if err != nil { 300 return err 301 } 302 } 303 304 return nil 305 } 306 307 // ensureVersionTable checks if versions table exists and, if not, creates it. 308 // Note that this function locks the database, which deviates from the usual 309 // convention of "caller locks" in the Cassandra type. 310 func (c *Cassandra) ensureVersionTable() (err error) { 311 if err = c.Lock(); err != nil { 312 return err 313 } 314 315 defer func() { 316 if e := c.Unlock(); e != nil { 317 if err == nil { 318 err = e 319 } else { 320 err = multierror.Append(err, e) 321 } 322 } 323 }() 324 325 err = c.session.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version bigint, dirty boolean, PRIMARY KEY(version))", c.config.MigrationsTable)).Exec() 326 if err != nil { 327 return err 328 } 329 if _, _, err = c.Version(); err != nil { 330 return err 331 } 332 return nil 333 } 334 335 // ParseConsistency wraps gocql.ParseConsistency 336 // to return an error instead of a panicking. 337 func parseConsistency(consistencyStr string) (consistency gocql.Consistency, err error) { 338 defer func() { 339 if r := recover(); r != nil { 340 var ok bool 341 err, ok = r.(error) 342 if !ok { 343 err = fmt.Errorf("Failed to parse consistency \"%s\": %v", consistencyStr, r) 344 } 345 } 346 }() 347 consistency = gocql.ParseConsistency(consistencyStr) 348 349 return consistency, nil 350 }