github.com/postmates/migrate@v3.0.2-0.20200730201548-1a6ead3e680d+incompatible/database/cassandra/cassandra.go (about) 1 package cassandra 2 3 import ( 4 "fmt" 5 "io" 6 "io/ioutil" 7 nurl "net/url" 8 "github.com/gocql/gocql" 9 "time" 10 "github.com/mattes/migrate/database" 11 "strconv" 12 ) 13 14 func init() { 15 db := new(Cassandra) 16 database.Register("cassandra", db) 17 } 18 19 var DefaultMigrationsTable = "schema_migrations" 20 var dbLocked = false 21 22 var ( 23 ErrNilConfig = fmt.Errorf("no config") 24 ErrNoKeyspace = fmt.Errorf("no keyspace provided") 25 ErrDatabaseDirty = fmt.Errorf("database is dirty") 26 ) 27 28 type Config struct { 29 MigrationsTable string 30 KeyspaceName string 31 } 32 33 type Cassandra struct { 34 session *gocql.Session 35 isLocked bool 36 37 // Open and WithInstance need to guarantee that config is never nil 38 config *Config 39 } 40 41 func (p *Cassandra) Open(url string) (database.Driver, error) { 42 u, err := nurl.Parse(url) 43 if err != nil { 44 return nil, err 45 } 46 47 // Check for missing mandatory attributes 48 if len(u.Path) == 0 { 49 return nil, ErrNoKeyspace 50 } 51 52 migrationsTable := u.Query().Get("x-migrations-table") 53 if len(migrationsTable) == 0 { 54 migrationsTable = DefaultMigrationsTable 55 } 56 57 p.config = &Config{ 58 KeyspaceName: u.Path, 59 MigrationsTable: migrationsTable, 60 } 61 62 cluster := gocql.NewCluster(u.Host) 63 cluster.Keyspace = u.Path[1:len(u.Path)] 64 cluster.Consistency = gocql.All 65 cluster.Timeout = 1 * time.Minute 66 67 68 // Retrieve query string configuration 69 if len(u.Query().Get("consistency")) > 0 { 70 var consistency gocql.Consistency 71 consistency, err = parseConsistency(u.Query().Get("consistency")) 72 if err != nil { 73 return nil, err 74 } 75 76 cluster.Consistency = consistency 77 } 78 if len(u.Query().Get("protocol")) > 0 { 79 var protoversion int 80 protoversion, err = strconv.Atoi(u.Query().Get("protocol")) 81 if err != nil { 82 return nil, err 83 } 84 cluster.ProtoVersion = protoversion 85 } 86 if len(u.Query().Get("timeout")) > 0 { 87 var timeout time.Duration 88 timeout, err = time.ParseDuration(u.Query().Get("timeout")) 89 if err != nil { 90 return nil, err 91 } 92 cluster.Timeout = timeout 93 } 94 95 p.session, err = cluster.CreateSession() 96 97 if err != nil { 98 return nil, err 99 } 100 101 if err := p.ensureVersionTable(); err != nil { 102 return nil, err 103 } 104 105 return p, nil 106 } 107 108 func (p *Cassandra) Close() error { 109 p.session.Close() 110 return nil 111 } 112 113 func (p *Cassandra) Lock() error { 114 if (dbLocked) { 115 return database.ErrLocked 116 } 117 dbLocked = true 118 return nil 119 } 120 121 func (p *Cassandra) Unlock() error { 122 dbLocked = false 123 return nil 124 } 125 126 func (p *Cassandra) Run(migration io.Reader) error { 127 migr, err := ioutil.ReadAll(migration) 128 if err != nil { 129 return err 130 } 131 // run migration 132 query := string(migr[:]) 133 if err := p.session.Query(query).Exec(); err != nil { 134 // TODO: cast to Cassandra error and get line number 135 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 136 } 137 138 return nil 139 } 140 141 func (p *Cassandra) SetVersion(version int, dirty bool) error { 142 query := `TRUNCATE "` + p.config.MigrationsTable + `"` 143 if err := p.session.Query(query).Exec(); err != nil { 144 return &database.Error{OrigErr: err, Query: []byte(query)} 145 } 146 if version >= 0 { 147 query = `INSERT INTO "` + p.config.MigrationsTable + `" (version, dirty) VALUES (?, ?)` 148 if err := p.session.Query(query, version, dirty).Exec(); err != nil { 149 return &database.Error{OrigErr: err, Query: []byte(query)} 150 } 151 } 152 153 return nil 154 } 155 156 157 // Return current keyspace version 158 func (p *Cassandra) Version() (version int, dirty bool, err error) { 159 query := `SELECT version, dirty FROM "` + p.config.MigrationsTable + `" LIMIT 1` 160 err = p.session.Query(query).Scan(&version, &dirty) 161 switch { 162 case err == gocql.ErrNotFound: 163 return database.NilVersion, false, nil 164 165 case err != nil: 166 if _, ok := err.(*gocql.Error); ok { 167 return database.NilVersion, false, nil 168 } 169 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 170 171 default: 172 return version, dirty, nil 173 } 174 } 175 176 func (p *Cassandra) Drop() error { 177 // select all tables in current schema 178 query := fmt.Sprintf(`SELECT table_name from system_schema.tables WHERE keyspace_name='%s'`, p.config.KeyspaceName[1:]) // Skip '/' character 179 iter := p.session.Query(query).Iter() 180 var tableName string 181 for iter.Scan(&tableName) { 182 err := p.session.Query(fmt.Sprintf(`DROP TABLE %s`, tableName)).Exec() 183 if err != nil { 184 return err 185 } 186 } 187 // Re-create the version table 188 if err := p.ensureVersionTable(); err != nil { 189 return err 190 } 191 return nil 192 } 193 194 195 // Ensure version table exists 196 func (p *Cassandra) ensureVersionTable() error { 197 err := p.session.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version bigint, dirty boolean, PRIMARY KEY(version))", p.config.MigrationsTable)).Exec() 198 if err != nil { 199 return err 200 } 201 if _, _, err = p.Version(); err != nil { 202 return err 203 } 204 return nil 205 } 206 207 208 // ParseConsistency wraps gocql.ParseConsistency 209 // to return an error instead of a panicking. 210 func parseConsistency(consistencyStr string) (consistency gocql.Consistency, err error) { 211 defer func() { 212 if r := recover(); r != nil { 213 var ok bool 214 err, ok = r.(error) 215 if !ok { 216 err = fmt.Errorf("Failed to parse consistency \"%s\": %v", consistencyStr, r) 217 } 218 } 219 }() 220 consistency = gocql.ParseConsistency(consistencyStr) 221 222 return consistency, nil 223 }