github.com/scraniel/migrate@v0.0.0-20230320185700-339088f36cee/database/cassandra/cassandra.go (about) 1 package cassandra 2 3 import ( 4 "errors" 5 "fmt" 6 "io" 7 nurl "net/url" 8 "strconv" 9 "strings" 10 "time" 11 12 "go.uber.org/atomic" 13 14 "github.com/gocql/gocql" 15 "github.com/golang-migrate/migrate/v4/database" 16 "github.com/golang-migrate/migrate/v4/database/multistmt" 17 "github.com/hashicorp/go-multierror" 18 ) 19 20 func init() { 21 db := new(Cassandra) 22 database.Register("cassandra", db) 23 } 24 25 var ( 26 multiStmtDelimiter = []byte(";") 27 28 DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB 29 ) 30 31 var DefaultMigrationsTable = "schema_migrations" 32 33 var ( 34 ErrNilConfig = errors.New("no config") 35 ErrNoKeyspace = errors.New("no keyspace provided") 36 ErrDatabaseDirty = errors.New("database is dirty") 37 ErrClosedSession = errors.New("session is closed") 38 ) 39 40 type Config struct { 41 MigrationsTable string 42 KeyspaceName string 43 MultiStatementEnabled bool 44 MultiStatementMaxSize int 45 } 46 47 type Cassandra struct { 48 session *gocql.Session 49 isLocked atomic.Bool 50 51 // Open and WithInstance need to guarantee that config is never nil 52 config *Config 53 } 54 55 func WithInstance(session *gocql.Session, config *Config) (database.Driver, error) { 56 if config == nil { 57 return nil, ErrNilConfig 58 } else if len(config.KeyspaceName) == 0 { 59 return nil, ErrNoKeyspace 60 } 61 62 if session.Closed() { 63 return nil, ErrClosedSession 64 } 65 66 if len(config.MigrationsTable) == 0 { 67 config.MigrationsTable = DefaultMigrationsTable 68 } 69 70 if config.MultiStatementMaxSize <= 0 { 71 config.MultiStatementMaxSize = DefaultMultiStatementMaxSize 72 } 73 74 c := &Cassandra{ 75 session: session, 76 config: config, 77 } 78 79 if err := c.ensureVersionTable(); err != nil { 80 return nil, err 81 } 82 83 return c, nil 84 } 85 86 func (c *Cassandra) Open(url string) (database.Driver, error) { 87 u, err := nurl.Parse(url) 88 if err != nil { 89 return nil, err 90 } 91 92 // Check for missing mandatory attributes 93 if len(u.Path) == 0 { 94 return nil, ErrNoKeyspace 95 } 96 97 cluster := gocql.NewCluster(u.Host) 98 cluster.Keyspace = strings.TrimPrefix(u.Path, "/") 99 cluster.Consistency = gocql.All 100 cluster.Timeout = 1 * time.Minute 101 102 if len(u.Query().Get("username")) > 0 && len(u.Query().Get("password")) > 0 { 103 authenticator := gocql.PasswordAuthenticator{ 104 Username: u.Query().Get("username"), 105 Password: u.Query().Get("password"), 106 } 107 cluster.Authenticator = authenticator 108 } 109 110 // Retrieve query string configuration 111 if len(u.Query().Get("consistency")) > 0 { 112 var consistency gocql.Consistency 113 consistency, err = parseConsistency(u.Query().Get("consistency")) 114 if err != nil { 115 return nil, err 116 } 117 118 cluster.Consistency = consistency 119 } 120 if len(u.Query().Get("protocol")) > 0 { 121 var protoversion int 122 protoversion, err = strconv.Atoi(u.Query().Get("protocol")) 123 if err != nil { 124 return nil, err 125 } 126 cluster.ProtoVersion = protoversion 127 } 128 if len(u.Query().Get("timeout")) > 0 { 129 var timeout time.Duration 130 timeout, err = time.ParseDuration(u.Query().Get("timeout")) 131 if err != nil { 132 return nil, err 133 } 134 cluster.Timeout = timeout 135 } 136 if len(u.Query().Get("connect-timeout")) > 0 { 137 var connectTimeout time.Duration 138 connectTimeout, err = time.ParseDuration(u.Query().Get("connect-timeout")) 139 if err != nil { 140 return nil, err 141 } 142 cluster.ConnectTimeout = connectTimeout 143 } 144 145 if len(u.Query().Get("sslmode")) > 0 { 146 if u.Query().Get("sslmode") != "disable" { 147 sslOpts := &gocql.SslOptions{} 148 149 if len(u.Query().Get("sslrootcert")) > 0 { 150 sslOpts.CaPath = u.Query().Get("sslrootcert") 151 } 152 if len(u.Query().Get("sslcert")) > 0 { 153 sslOpts.CertPath = u.Query().Get("sslcert") 154 } 155 if len(u.Query().Get("sslkey")) > 0 { 156 sslOpts.KeyPath = u.Query().Get("sslkey") 157 } 158 159 if u.Query().Get("sslmode") == "verify-full" { 160 sslOpts.EnableHostVerification = true 161 } 162 163 cluster.SslOpts = sslOpts 164 } 165 } 166 167 if len(u.Query().Get("disable-host-lookup")) > 0 { 168 if flag, err := strconv.ParseBool(u.Query().Get("disable-host-lookup")); err != nil && flag { 169 cluster.DisableInitialHostLookup = true 170 } else if err != nil { 171 return nil, err 172 } 173 } 174 175 session, err := cluster.CreateSession() 176 if err != nil { 177 return nil, err 178 } 179 180 multiStatementMaxSize := DefaultMultiStatementMaxSize 181 if s := u.Query().Get("x-multi-statement-max-size"); len(s) > 0 { 182 multiStatementMaxSize, err = strconv.Atoi(s) 183 if err != nil { 184 return nil, err 185 } 186 } 187 188 return WithInstance(session, &Config{ 189 KeyspaceName: strings.TrimPrefix(u.Path, "/"), 190 MigrationsTable: u.Query().Get("x-migrations-table"), 191 MultiStatementEnabled: u.Query().Get("x-multi-statement") == "true", 192 MultiStatementMaxSize: multiStatementMaxSize, 193 }) 194 } 195 196 func (c *Cassandra) Close() error { 197 c.session.Close() 198 return nil 199 } 200 201 func (c *Cassandra) Lock() error { 202 if !c.isLocked.CAS(false, true) { 203 return database.ErrLocked 204 } 205 return nil 206 } 207 208 func (c *Cassandra) Unlock() error { 209 if !c.isLocked.CAS(true, false) { 210 return database.ErrNotLocked 211 } 212 return nil 213 } 214 215 func (c *Cassandra) Run(migration io.Reader) error { 216 if c.config.MultiStatementEnabled { 217 var err error 218 if e := multistmt.Parse(migration, multiStmtDelimiter, c.config.MultiStatementMaxSize, func(m []byte) bool { 219 tq := strings.TrimSpace(string(m)) 220 if tq == "" { 221 return true 222 } 223 if e := c.session.Query(tq).Exec(); e != nil { 224 err = database.Error{OrigErr: e, Err: "migration failed", Query: m} 225 return false 226 } 227 return true 228 }); e != nil { 229 return e 230 } 231 return err 232 } 233 234 migr, err := io.ReadAll(migration) 235 if err != nil { 236 return err 237 } 238 // run migration 239 if err := c.session.Query(string(migr)).Exec(); err != nil { 240 // TODO: cast to Cassandra error and get line number 241 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 242 } 243 return nil 244 } 245 246 func (c *Cassandra) SetVersion(version int, dirty bool) error { 247 // DELETE instead of TRUNCATE because AWS Keyspaces does not support it 248 // see: https://docs.aws.amazon.com/keyspaces/latest/devguide/cassandra-apis.html 249 squery := `SELECT version FROM "` + c.config.MigrationsTable + `"` 250 dquery := `DELETE FROM "` + c.config.MigrationsTable + `" WHERE version = ?` 251 iter := c.session.Query(squery).Iter() 252 var previous int 253 for iter.Scan(&previous) { 254 if err := c.session.Query(dquery, previous).Exec(); err != nil { 255 return &database.Error{OrigErr: err, Query: []byte(dquery)} 256 } 257 } 258 if err := iter.Close(); err != nil { 259 return &database.Error{OrigErr: err, Query: []byte(squery)} 260 } 261 262 // Also re-write the schema version for nil dirty versions to prevent 263 // empty schema version for failed down migration on the first migration 264 // See: https://github.com/golang-migrate/migrate/issues/330 265 if version >= 0 || (version == database.NilVersion && dirty) { 266 query := `INSERT INTO "` + c.config.MigrationsTable + `" (version, dirty) VALUES (?, ?)` 267 if err := c.session.Query(query, version, dirty).Exec(); err != nil { 268 return &database.Error{OrigErr: err, Query: []byte(query)} 269 } 270 } 271 272 return nil 273 } 274 275 // Return current keyspace version 276 func (c *Cassandra) Version() (version int, dirty bool, err error) { 277 query := `SELECT version, dirty FROM "` + c.config.MigrationsTable + `" LIMIT 1` 278 err = c.session.Query(query).Scan(&version, &dirty) 279 switch { 280 case err == gocql.ErrNotFound: 281 return database.NilVersion, false, nil 282 283 case err != nil: 284 if _, ok := err.(*gocql.Error); ok { 285 return database.NilVersion, false, nil 286 } 287 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 288 289 default: 290 return version, dirty, nil 291 } 292 } 293 294 func (c *Cassandra) Drop() error { 295 // select all tables in current schema 296 query := fmt.Sprintf(`SELECT table_name from system_schema.tables WHERE keyspace_name='%s'`, c.config.KeyspaceName) 297 iter := c.session.Query(query).Iter() 298 var tableName string 299 for iter.Scan(&tableName) { 300 err := c.session.Query(fmt.Sprintf(`DROP TABLE %s`, tableName)).Exec() 301 if err != nil { 302 return err 303 } 304 } 305 306 return nil 307 } 308 309 // ensureVersionTable checks if versions table exists and, if not, creates it. 310 // Note that this function locks the database, which deviates from the usual 311 // convention of "caller locks" in the Cassandra type. 312 func (c *Cassandra) ensureVersionTable() (err error) { 313 if err = c.Lock(); err != nil { 314 return err 315 } 316 317 defer func() { 318 if e := c.Unlock(); e != nil { 319 if err == nil { 320 err = e 321 } else { 322 err = multierror.Append(err, e) 323 } 324 } 325 }() 326 327 err = c.session.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version bigint, dirty boolean, PRIMARY KEY(version))", c.config.MigrationsTable)).Exec() 328 if err != nil { 329 return err 330 } 331 if _, _, err = c.Version(); err != nil { 332 return err 333 } 334 return nil 335 } 336 337 // ParseConsistency wraps gocql.ParseConsistency 338 // to return an error instead of a panicking. 339 func parseConsistency(consistencyStr string) (consistency gocql.Consistency, err error) { 340 defer func() { 341 if r := recover(); r != nil { 342 var ok bool 343 err, ok = r.(error) 344 if !ok { 345 err = fmt.Errorf("Failed to parse consistency \"%s\": %v", consistencyStr, r) 346 } 347 } 348 }() 349 consistency = gocql.ParseConsistency(consistencyStr) 350 351 return consistency, nil 352 }