github.com/eatigo/migrate@v3.0.2-0.20210729130915-7610befb1b6b+incompatible/database/cassandra/cassandra.go (about) 1 package cassandra 2 3 import ( 4 "fmt" 5 "io" 6 "io/ioutil" 7 nurl "net/url" 8 "strconv" 9 "time" 10 11 "github.com/gocql/gocql" 12 "github.com/eatigo/migrate/database" 13 ) 14 15 func init() { 16 db := new(Cassandra) 17 database.Register("cassandra", db) 18 } 19 20 var DefaultMigrationsTable = "schema_migrations" 21 var dbLocked = false 22 23 var ( 24 ErrNilConfig = fmt.Errorf("no config") 25 ErrNoKeyspace = fmt.Errorf("no keyspace provided") 26 ErrDatabaseDirty = fmt.Errorf("database is dirty") 27 ) 28 29 type Config struct { 30 MigrationsTable string 31 KeyspaceName string 32 } 33 34 type Cassandra struct { 35 session *gocql.Session 36 isLocked bool 37 38 // Open and WithInstance need to guarantee that config is never nil 39 config *Config 40 } 41 42 func (p *Cassandra) Open(url string) (database.Driver, error) { 43 u, err := nurl.Parse(url) 44 if err != nil { 45 return nil, err 46 } 47 48 // Check for missing mandatory attributes 49 if len(u.Path) == 0 { 50 return nil, ErrNoKeyspace 51 } 52 53 migrationsTable := u.Query().Get("x-migrations-table") 54 if len(migrationsTable) == 0 { 55 migrationsTable = DefaultMigrationsTable 56 } 57 58 p.config = &Config{ 59 KeyspaceName: u.Path, 60 MigrationsTable: migrationsTable, 61 } 62 63 cluster := gocql.NewCluster(u.Host) 64 cluster.Keyspace = u.Path[1:len(u.Path)] 65 cluster.Consistency = gocql.All 66 cluster.Timeout = 1 * time.Minute 67 68 if len(u.Query().Get("username")) > 0 && len(u.Query().Get("password")) > 0 { 69 authenticator := gocql.PasswordAuthenticator{ 70 Username: u.Query().Get("username"), 71 Password: u.Query().Get("password"), 72 } 73 cluster.Authenticator = authenticator 74 } 75 76 // Retrieve query string configuration 77 if len(u.Query().Get("consistency")) > 0 { 78 var consistency gocql.Consistency 79 consistency, err = parseConsistency(u.Query().Get("consistency")) 80 if err != nil { 81 return nil, err 82 } 83 84 cluster.Consistency = consistency 85 } 86 if len(u.Query().Get("protocol")) > 0 { 87 var protoversion int 88 protoversion, err = strconv.Atoi(u.Query().Get("protocol")) 89 if err != nil { 90 return nil, err 91 } 92 cluster.ProtoVersion = protoversion 93 } 94 if len(u.Query().Get("timeout")) > 0 { 95 var timeout time.Duration 96 timeout, err = time.ParseDuration(u.Query().Get("timeout")) 97 if err != nil { 98 return nil, err 99 } 100 cluster.Timeout = timeout 101 } 102 103 p.session, err = cluster.CreateSession() 104 105 if err != nil { 106 return nil, err 107 } 108 109 if err := p.ensureVersionTable(); err != nil { 110 return nil, err 111 } 112 113 return p, nil 114 } 115 116 func (p *Cassandra) Close() error { 117 p.session.Close() 118 return nil 119 } 120 121 func (p *Cassandra) Lock() error { 122 if dbLocked { 123 return database.ErrLocked 124 } 125 dbLocked = true 126 return nil 127 } 128 129 func (p *Cassandra) Unlock() error { 130 dbLocked = false 131 return nil 132 } 133 134 func (p *Cassandra) Run(migration io.Reader) error { 135 migr, err := ioutil.ReadAll(migration) 136 if err != nil { 137 return err 138 } 139 // run migration 140 query := string(migr[:]) 141 if err := p.session.Query(query).Exec(); err != nil { 142 // TODO: cast to Cassandra error and get line number 143 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 144 } 145 146 return nil 147 } 148 149 func (p *Cassandra) SetVersion(version int, dirty bool) error { 150 query := `TRUNCATE "` + p.config.MigrationsTable + `"` 151 if err := p.session.Query(query).Exec(); err != nil { 152 return &database.Error{OrigErr: err, Query: []byte(query)} 153 } 154 if version >= 0 { 155 query = `INSERT INTO "` + p.config.MigrationsTable + `" (version, dirty) VALUES (?, ?)` 156 if err := p.session.Query(query, version, dirty).Exec(); err != nil { 157 return &database.Error{OrigErr: err, Query: []byte(query)} 158 } 159 } 160 161 return nil 162 } 163 164 // Return current keyspace version 165 func (p *Cassandra) Version() (version int, dirty bool, err error) { 166 query := `SELECT version, dirty FROM "` + p.config.MigrationsTable + `" LIMIT 1` 167 err = p.session.Query(query).Scan(&version, &dirty) 168 switch { 169 case err == gocql.ErrNotFound: 170 return database.NilVersion, false, nil 171 172 case err != nil: 173 if _, ok := err.(*gocql.Error); ok { 174 return database.NilVersion, false, nil 175 } 176 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 177 178 default: 179 return version, dirty, nil 180 } 181 } 182 183 func (p *Cassandra) Drop() error { 184 // select all tables in current schema 185 query := fmt.Sprintf(`SELECT table_name from system_schema.tables WHERE keyspace_name='%s'`, p.config.KeyspaceName[1:]) // Skip '/' character 186 iter := p.session.Query(query).Iter() 187 var tableName string 188 for iter.Scan(&tableName) { 189 err := p.session.Query(fmt.Sprintf(`DROP TABLE %s`, tableName)).Exec() 190 if err != nil { 191 return err 192 } 193 } 194 // Re-create the version table 195 if err := p.ensureVersionTable(); err != nil { 196 return err 197 } 198 return nil 199 } 200 201 // Ensure version table exists 202 func (p *Cassandra) ensureVersionTable() error { 203 err := p.session.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version bigint, dirty boolean, PRIMARY KEY(version))", p.config.MigrationsTable)).Exec() 204 if err != nil { 205 return err 206 } 207 if _, _, err = p.Version(); err != nil { 208 return err 209 } 210 return nil 211 } 212 213 // ParseConsistency wraps gocql.ParseConsistency 214 // to return an error instead of a panicking. 215 func parseConsistency(consistencyStr string) (consistency gocql.Consistency, err error) { 216 defer func() { 217 if r := recover(); r != nil { 218 var ok bool 219 err, ok = r.(error) 220 if !ok { 221 err = fmt.Errorf("Failed to parse consistency \"%s\": %v", consistencyStr, r) 222 } 223 } 224 }() 225 consistency = gocql.ParseConsistency(consistencyStr) 226 227 return consistency, nil 228 }