github.com/nagyist/migrate/v4@v4.14.6/database/cassandra/cassandra.go (about) 1 package cassandra 2 3 import ( 4 "errors" 5 "fmt" 6 "io" 7 "io/ioutil" 8 nurl "net/url" 9 "strconv" 10 "strings" 11 "time" 12 13 "github.com/gocql/gocql" 14 "github.com/golang-migrate/migrate/v4/database" 15 "github.com/golang-migrate/migrate/v4/database/multistmt" 16 "github.com/hashicorp/go-multierror" 17 ) 18 19 func init() { 20 db := new(Cassandra) 21 database.Register("cassandra", db) 22 } 23 24 var ( 25 multiStmtDelimiter = []byte(";") 26 27 DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB 28 ) 29 30 var DefaultMigrationsTable = "schema_migrations" 31 32 var ( 33 ErrNilConfig = errors.New("no config") 34 ErrNoKeyspace = errors.New("no keyspace provided") 35 ErrDatabaseDirty = errors.New("database is dirty") 36 ErrClosedSession = errors.New("session is closed") 37 ) 38 39 type Config struct { 40 MigrationsTable string 41 KeyspaceName string 42 MultiStatementEnabled bool 43 MultiStatementMaxSize int 44 } 45 46 type Cassandra struct { 47 session *gocql.Session 48 isLocked bool 49 50 // Open and WithInstance need to guarantee that config is never nil 51 config *Config 52 } 53 54 func WithInstance(session *gocql.Session, config *Config) (database.Driver, error) { 55 if config == nil { 56 return nil, ErrNilConfig 57 } else if len(config.KeyspaceName) == 0 { 58 return nil, ErrNoKeyspace 59 } 60 61 if session.Closed() { 62 return nil, ErrClosedSession 63 } 64 65 if len(config.MigrationsTable) == 0 { 66 config.MigrationsTable = DefaultMigrationsTable 67 } 68 69 if config.MultiStatementMaxSize <= 0 { 70 config.MultiStatementMaxSize = DefaultMultiStatementMaxSize 71 } 72 73 c := &Cassandra{ 74 session: session, 75 config: config, 76 } 77 78 if err := c.ensureVersionTable(); err != nil { 79 return nil, err 80 } 81 82 return c, nil 83 } 84 85 func (c *Cassandra) Open(url string) (database.Driver, error) { 86 u, err := nurl.Parse(url) 87 if err != nil { 88 return nil, err 89 } 90 91 // Check for missing mandatory attributes 92 if len(u.Path) == 0 { 93 return nil, ErrNoKeyspace 94 } 95 96 cluster := gocql.NewCluster(u.Host) 97 cluster.Keyspace = strings.TrimPrefix(u.Path, "/") 98 cluster.Consistency = gocql.All 99 cluster.Timeout = 1 * time.Minute 100 101 if len(u.Query().Get("username")) > 0 && len(u.Query().Get("password")) > 0 { 102 authenticator := gocql.PasswordAuthenticator{ 103 Username: u.Query().Get("username"), 104 Password: u.Query().Get("password"), 105 } 106 cluster.Authenticator = authenticator 107 } 108 109 // Retrieve query string configuration 110 if len(u.Query().Get("consistency")) > 0 { 111 var consistency gocql.Consistency 112 consistency, err = parseConsistency(u.Query().Get("consistency")) 113 if err != nil { 114 return nil, err 115 } 116 117 cluster.Consistency = consistency 118 } 119 if len(u.Query().Get("protocol")) > 0 { 120 var protoversion int 121 protoversion, err = strconv.Atoi(u.Query().Get("protocol")) 122 if err != nil { 123 return nil, err 124 } 125 cluster.ProtoVersion = protoversion 126 } 127 if len(u.Query().Get("timeout")) > 0 { 128 var timeout time.Duration 129 timeout, err = time.ParseDuration(u.Query().Get("timeout")) 130 if err != nil { 131 return nil, err 132 } 133 cluster.Timeout = timeout 134 } 135 136 if len(u.Query().Get("sslmode")) > 0 { 137 if u.Query().Get("sslmode") != "disable" { 138 sslOpts := &gocql.SslOptions{} 139 140 if len(u.Query().Get("sslrootcert")) > 0 { 141 sslOpts.CaPath = u.Query().Get("sslrootcert") 142 } 143 if len(u.Query().Get("sslcert")) > 0 { 144 sslOpts.CertPath = u.Query().Get("sslcert") 145 } 146 if len(u.Query().Get("sslkey")) > 0 { 147 sslOpts.KeyPath = u.Query().Get("sslkey") 148 } 149 150 if u.Query().Get("sslmode") == "verify-full" { 151 sslOpts.EnableHostVerification = true 152 } 153 154 cluster.SslOpts = sslOpts 155 } 156 } 157 158 session, err := cluster.CreateSession() 159 if err != nil { 160 return nil, err 161 } 162 163 multiStatementMaxSize := DefaultMultiStatementMaxSize 164 if s := u.Query().Get("x-multi-statement-max-size"); len(s) > 0 { 165 multiStatementMaxSize, err = strconv.Atoi(s) 166 if err != nil { 167 return nil, err 168 } 169 } 170 171 return WithInstance(session, &Config{ 172 KeyspaceName: strings.TrimPrefix(u.Path, "/"), 173 MigrationsTable: u.Query().Get("x-migrations-table"), 174 MultiStatementEnabled: u.Query().Get("x-multi-statement") == "true", 175 MultiStatementMaxSize: multiStatementMaxSize, 176 }) 177 } 178 179 func (c *Cassandra) Close() error { 180 c.session.Close() 181 return nil 182 } 183 184 func (c *Cassandra) Lock() error { 185 if c.isLocked { 186 return database.ErrLocked 187 } 188 c.isLocked = true 189 return nil 190 } 191 192 func (c *Cassandra) Unlock() error { 193 c.isLocked = false 194 return nil 195 } 196 197 func (c *Cassandra) Run(migration io.Reader) error { 198 if c.config.MultiStatementEnabled { 199 var err error 200 if e := multistmt.Parse(migration, multiStmtDelimiter, c.config.MultiStatementMaxSize, func(m []byte) bool { 201 tq := strings.TrimSpace(string(m)) 202 if tq == "" { 203 return true 204 } 205 if e := c.session.Query(tq).Exec(); e != nil { 206 err = database.Error{OrigErr: e, Err: "migration failed", Query: m} 207 return false 208 } 209 return true 210 }); e != nil { 211 return e 212 } 213 return err 214 } 215 216 migr, err := ioutil.ReadAll(migration) 217 if err != nil { 218 return err 219 } 220 // run migration 221 if err := c.session.Query(string(migr)).Exec(); err != nil { 222 // TODO: cast to Cassandra error and get line number 223 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 224 } 225 return nil 226 } 227 228 func (c *Cassandra) SetVersion(version int, dirty bool) error { 229 query := `TRUNCATE "` + c.config.MigrationsTable + `"` 230 if err := c.session.Query(query).Exec(); err != nil { 231 return &database.Error{OrigErr: err, Query: []byte(query)} 232 } 233 234 // Also re-write the schema version for nil dirty versions to prevent 235 // empty schema version for failed down migration on the first migration 236 // See: https://github.com/golang-migrate/migrate/issues/330 237 if version >= 0 || (version == database.NilVersion && dirty) { 238 query = `INSERT INTO "` + c.config.MigrationsTable + `" (version, dirty) VALUES (?, ?)` 239 if err := c.session.Query(query, version, dirty).Exec(); err != nil { 240 return &database.Error{OrigErr: err, Query: []byte(query)} 241 } 242 } 243 244 return nil 245 } 246 247 // Return current keyspace version 248 func (c *Cassandra) Version() (version int, dirty bool, err error) { 249 query := `SELECT version, dirty FROM "` + c.config.MigrationsTable + `" LIMIT 1` 250 err = c.session.Query(query).Scan(&version, &dirty) 251 switch { 252 case err == gocql.ErrNotFound: 253 return database.NilVersion, false, nil 254 255 case err != nil: 256 if _, ok := err.(*gocql.Error); ok { 257 return database.NilVersion, false, nil 258 } 259 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 260 261 default: 262 return version, dirty, nil 263 } 264 } 265 266 func (c *Cassandra) Drop() error { 267 // select all tables in current schema 268 query := fmt.Sprintf(`SELECT table_name from system_schema.tables WHERE keyspace_name='%s'`, c.config.KeyspaceName) 269 iter := c.session.Query(query).Iter() 270 var tableName string 271 for iter.Scan(&tableName) { 272 err := c.session.Query(fmt.Sprintf(`DROP TABLE %s`, tableName)).Exec() 273 if err != nil { 274 return err 275 } 276 } 277 278 return nil 279 } 280 281 // ensureVersionTable checks if versions table exists and, if not, creates it. 282 // Note that this function locks the database, which deviates from the usual 283 // convention of "caller locks" in the Cassandra type. 284 func (c *Cassandra) ensureVersionTable() (err error) { 285 if err = c.Lock(); err != nil { 286 return err 287 } 288 289 defer func() { 290 if e := c.Unlock(); e != nil { 291 if err == nil { 292 err = e 293 } else { 294 err = multierror.Append(err, e) 295 } 296 } 297 }() 298 299 err = c.session.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version bigint, dirty boolean, PRIMARY KEY(version))", c.config.MigrationsTable)).Exec() 300 if err != nil { 301 return err 302 } 303 if _, _, err = c.Version(); err != nil { 304 return err 305 } 306 return nil 307 } 308 309 // ParseConsistency wraps gocql.ParseConsistency 310 // to return an error instead of a panicking. 311 func parseConsistency(consistencyStr string) (consistency gocql.Consistency, err error) { 312 defer func() { 313 if r := recover(); r != nil { 314 var ok bool 315 err, ok = r.(error) 316 if !ok { 317 err = fmt.Errorf("Failed to parse consistency \"%s\": %v", consistencyStr, r) 318 } 319 } 320 }() 321 consistency = gocql.ParseConsistency(consistencyStr) 322 323 return consistency, nil 324 }