github.com/seashell-org/golang-migrate/v4@v4.15.3-0.20220722221203-6ab6c6c062d1/database/cassandra/cassandra.go (about) 1 package cassandra 2 3 import ( 4 "errors" 5 "fmt" 6 "io" 7 "io/ioutil" 8 nurl "net/url" 9 "strconv" 10 "strings" 11 "time" 12 13 "go.uber.org/atomic" 14 15 "github.com/gocql/gocql" 16 "github.com/hashicorp/go-multierror" 17 "github.com/seashell-org/golang-migrate/v4/database" 18 "github.com/seashell-org/golang-migrate/v4/database/multistmt" 19 ) 20 21 func init() { 22 db := new(Cassandra) 23 database.Register("cassandra", db) 24 } 25 26 var ( 27 multiStmtDelimiter = []byte(";") 28 29 DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB 30 ) 31 32 var DefaultMigrationsTable = "schema_migrations" 33 34 var ( 35 ErrNilConfig = errors.New("no config") 36 ErrNoKeyspace = errors.New("no keyspace provided") 37 ErrDatabaseDirty = errors.New("database is dirty") 38 ErrClosedSession = errors.New("session is closed") 39 ) 40 41 type Config struct { 42 MigrationsTable string 43 KeyspaceName string 44 MultiStatementEnabled bool 45 MultiStatementMaxSize int 46 } 47 48 type Cassandra struct { 49 session *gocql.Session 50 isLocked atomic.Bool 51 52 // Open and WithInstance need to guarantee that config is never nil 53 config *Config 54 } 55 56 func WithInstance(session *gocql.Session, config *Config) (database.Driver, error) { 57 if config == nil { 58 return nil, ErrNilConfig 59 } else if len(config.KeyspaceName) == 0 { 60 return nil, ErrNoKeyspace 61 } 62 63 if session.Closed() { 64 return nil, ErrClosedSession 65 } 66 67 if len(config.MigrationsTable) == 0 { 68 config.MigrationsTable = DefaultMigrationsTable 69 } 70 71 if config.MultiStatementMaxSize <= 0 { 72 config.MultiStatementMaxSize = DefaultMultiStatementMaxSize 73 } 74 75 c := &Cassandra{ 76 session: session, 77 config: config, 78 } 79 80 if err := c.ensureVersionTable(); err != nil { 81 return nil, err 82 } 83 84 return c, nil 85 } 86 87 func (c *Cassandra) Open(url string) (database.Driver, error) { 88 u, err := nurl.Parse(url) 89 if err != nil { 90 return nil, err 91 } 92 93 // Check for missing mandatory attributes 94 if len(u.Path) == 0 { 95 return nil, ErrNoKeyspace 96 } 97 98 cluster := gocql.NewCluster(u.Host) 99 cluster.Keyspace = strings.TrimPrefix(u.Path, "/") 100 cluster.Consistency = gocql.All 101 cluster.Timeout = 1 * time.Minute 102 103 if len(u.Query().Get("username")) > 0 && len(u.Query().Get("password")) > 0 { 104 authenticator := gocql.PasswordAuthenticator{ 105 Username: u.Query().Get("username"), 106 Password: u.Query().Get("password"), 107 } 108 cluster.Authenticator = authenticator 109 } 110 111 // Retrieve query string configuration 112 if len(u.Query().Get("consistency")) > 0 { 113 var consistency gocql.Consistency 114 consistency, err = parseConsistency(u.Query().Get("consistency")) 115 if err != nil { 116 return nil, err 117 } 118 119 cluster.Consistency = consistency 120 } 121 if len(u.Query().Get("protocol")) > 0 { 122 var protoversion int 123 protoversion, err = strconv.Atoi(u.Query().Get("protocol")) 124 if err != nil { 125 return nil, err 126 } 127 cluster.ProtoVersion = protoversion 128 } 129 if len(u.Query().Get("timeout")) > 0 { 130 var timeout time.Duration 131 timeout, err = time.ParseDuration(u.Query().Get("timeout")) 132 if err != nil { 133 return nil, err 134 } 135 cluster.Timeout = timeout 136 } 137 if len(u.Query().Get("connect-timeout")) > 0 { 138 var connectTimeout time.Duration 139 connectTimeout, err = time.ParseDuration(u.Query().Get("connect-timeout")) 140 if err != nil { 141 return nil, err 142 } 143 cluster.ConnectTimeout = connectTimeout 144 } 145 146 if len(u.Query().Get("sslmode")) > 0 { 147 if u.Query().Get("sslmode") != "disable" { 148 sslOpts := &gocql.SslOptions{} 149 150 if len(u.Query().Get("sslrootcert")) > 0 { 151 sslOpts.CaPath = u.Query().Get("sslrootcert") 152 } 153 if len(u.Query().Get("sslcert")) > 0 { 154 sslOpts.CertPath = u.Query().Get("sslcert") 155 } 156 if len(u.Query().Get("sslkey")) > 0 { 157 sslOpts.KeyPath = u.Query().Get("sslkey") 158 } 159 160 if u.Query().Get("sslmode") == "verify-full" { 161 sslOpts.EnableHostVerification = true 162 } 163 164 cluster.SslOpts = sslOpts 165 } 166 } 167 168 if len(u.Query().Get("disable-host-lookup")) > 0 { 169 if flag, err := strconv.ParseBool(u.Query().Get("disable-host-lookup")); err != nil && flag { 170 cluster.DisableInitialHostLookup = true 171 } else if err != nil { 172 return nil, err 173 } 174 } 175 176 session, err := cluster.CreateSession() 177 if err != nil { 178 return nil, err 179 } 180 181 multiStatementMaxSize := DefaultMultiStatementMaxSize 182 if s := u.Query().Get("x-multi-statement-max-size"); len(s) > 0 { 183 multiStatementMaxSize, err = strconv.Atoi(s) 184 if err != nil { 185 return nil, err 186 } 187 } 188 189 return WithInstance(session, &Config{ 190 KeyspaceName: strings.TrimPrefix(u.Path, "/"), 191 MigrationsTable: u.Query().Get("x-migrations-table"), 192 MultiStatementEnabled: u.Query().Get("x-multi-statement") == "true", 193 MultiStatementMaxSize: multiStatementMaxSize, 194 }) 195 } 196 197 func (c *Cassandra) Close() error { 198 c.session.Close() 199 return nil 200 } 201 202 func (c *Cassandra) Lock() error { 203 if !c.isLocked.CAS(false, true) { 204 return database.ErrLocked 205 } 206 return nil 207 } 208 209 func (c *Cassandra) Unlock() error { 210 if !c.isLocked.CAS(true, false) { 211 return database.ErrNotLocked 212 } 213 return nil 214 } 215 216 func (c *Cassandra) Run(migration io.Reader) error { 217 if c.config.MultiStatementEnabled { 218 var err error 219 if e := multistmt.Parse(migration, multiStmtDelimiter, c.config.MultiStatementMaxSize, func(m []byte) bool { 220 tq := strings.TrimSpace(string(m)) 221 if tq == "" { 222 return true 223 } 224 if e := c.session.Query(tq).Exec(); e != nil { 225 err = database.Error{OrigErr: e, Err: "migration failed", Query: m} 226 return false 227 } 228 return true 229 }); e != nil { 230 return e 231 } 232 return err 233 } 234 235 migr, err := ioutil.ReadAll(migration) 236 if err != nil { 237 return err 238 } 239 // run migration 240 if err := c.session.Query(string(migr)).Exec(); err != nil { 241 // TODO: cast to Cassandra error and get line number 242 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 243 } 244 return nil 245 } 246 247 func (c *Cassandra) SetVersion(version int, dirty bool) error { 248 // DELETE instead of TRUNCATE because AWS Keyspaces does not support it 249 // see: https://docs.aws.amazon.com/keyspaces/latest/devguide/cassandra-apis.html 250 squery := `SELECT version FROM "` + c.config.MigrationsTable + `"` 251 dquery := `DELETE FROM "` + c.config.MigrationsTable + `" WHERE version = ?` 252 iter := c.session.Query(squery).Iter() 253 var previous int 254 for iter.Scan(&previous) { 255 if err := c.session.Query(dquery, previous).Exec(); err != nil { 256 return &database.Error{OrigErr: err, Query: []byte(dquery)} 257 } 258 } 259 if err := iter.Close(); err != nil { 260 return &database.Error{OrigErr: err, Query: []byte(squery)} 261 } 262 263 // Also re-write the schema version for nil dirty versions to prevent 264 // empty schema version for failed down migration on the first migration 265 // See: https://github.com/seashell-org/golang-migrate/issues/330 266 if version >= 0 || (version == database.NilVersion && dirty) { 267 query := `INSERT INTO "` + c.config.MigrationsTable + `" (version, dirty) VALUES (?, ?)` 268 if err := c.session.Query(query, version, dirty).Exec(); err != nil { 269 return &database.Error{OrigErr: err, Query: []byte(query)} 270 } 271 } 272 273 return nil 274 } 275 276 // Return current keyspace version 277 func (c *Cassandra) Version() (version int, dirty bool, err error) { 278 query := `SELECT version, dirty FROM "` + c.config.MigrationsTable + `" LIMIT 1` 279 err = c.session.Query(query).Scan(&version, &dirty) 280 switch { 281 case err == gocql.ErrNotFound: 282 return database.NilVersion, false, nil 283 284 case err != nil: 285 if _, ok := err.(*gocql.Error); ok { 286 return database.NilVersion, false, nil 287 } 288 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 289 290 default: 291 return version, dirty, nil 292 } 293 } 294 295 func (c *Cassandra) Drop() error { 296 // select all tables in current schema 297 query := fmt.Sprintf(`SELECT table_name from system_schema.tables WHERE keyspace_name='%s'`, c.config.KeyspaceName) 298 iter := c.session.Query(query).Iter() 299 var tableName string 300 for iter.Scan(&tableName) { 301 err := c.session.Query(fmt.Sprintf(`DROP TABLE %s`, tableName)).Exec() 302 if err != nil { 303 return err 304 } 305 } 306 307 return nil 308 } 309 310 // ensureVersionTable checks if versions table exists and, if not, creates it. 311 // Note that this function locks the database, which deviates from the usual 312 // convention of "caller locks" in the Cassandra type. 313 func (c *Cassandra) ensureVersionTable() (err error) { 314 if err = c.Lock(); err != nil { 315 return err 316 } 317 318 defer func() { 319 if e := c.Unlock(); e != nil { 320 if err == nil { 321 err = e 322 } else { 323 err = multierror.Append(err, e) 324 } 325 } 326 }() 327 328 err = c.session.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version bigint, dirty boolean, PRIMARY KEY(version))", c.config.MigrationsTable)).Exec() 329 if err != nil { 330 return err 331 } 332 if _, _, err = c.Version(); err != nil { 333 return err 334 } 335 return nil 336 } 337 338 // ParseConsistency wraps gocql.ParseConsistency 339 // to return an error instead of a panicking. 340 func parseConsistency(consistencyStr string) (consistency gocql.Consistency, err error) { 341 defer func() { 342 if r := recover(); r != nil { 343 var ok bool 344 err, ok = r.(error) 345 if !ok { 346 err = fmt.Errorf("Failed to parse consistency \"%s\": %v", consistencyStr, r) 347 } 348 } 349 }() 350 consistency = gocql.ParseConsistency(consistencyStr) 351 352 return consistency, nil 353 }