github.com/hellofresh/janus@v0.0.0-20230925145208-ce8de8183c67/cassandra/wrapper/init.go (about) 1 package wrapper 2 3 import ( 4 "bytes" 5 "fmt" 6 "io/ioutil" 7 "os" 8 "path" 9 "regexp" 10 "strings" 11 "time" 12 13 "github.com/gocql/gocql" 14 log "github.com/sirupsen/logrus" 15 ) 16 17 // Schema file to create keyspace if required 18 const ( 19 schemaDefaultPath = "/usr/local/bin" 20 schemaDefaultFileName = "schema.sql" 21 defaultCassandraClusterConsistency = gocql.Quorum 22 defaultUsername = "" 23 defaultPassword = "" 24 defaultSSLCert = "" 25 26 envCassandraClusterConsistency = "CLUSTER_CONSISTENCY" 27 envCassandraSchemaPath = "CASSANDRA_SCHEMA_PATH" 28 envCassandraSchemaFileName = "CASSANDRA_SCHEMA_FILE_NAME" 29 envUsername = "CASSANDRA_USERNAME" 30 envPassword = "CASSANDRA_PASSWORD" 31 envSSLCert = "CASSANDRA_SSL_CERT" 32 ) 33 34 var schemaPath = "/usr/local/bin" 35 var schemaFileName = "schema.sql" 36 var clusterConsistency = gocql.Quorum 37 var username = "" 38 var password = "" 39 var sslCert = "" 40 41 // Package level initialization. 42 // 43 // init functions are automatically executed when the programs starts 44 func init() { 45 46 // reading and setting up environment variables 47 schemaPath = getenv(envCassandraSchemaPath, schemaDefaultPath) 48 schemaFileName = getenv(envCassandraSchemaFileName, schemaDefaultFileName) 49 clusterConsistency = checkConsistency(getenv(envCassandraClusterConsistency, defaultCassandraClusterConsistency.String())) 50 username = getenvnolog(envUsername, defaultUsername) 51 password = getenvnolog(envPassword, defaultPassword) 52 sslCert = getenvnolog(envSSLCert, defaultSSLCert) 53 54 log.Debugf("Got schema path: %s", schemaPath) 55 log.Debugf("Got schema file name: %s", schemaFileName) 56 log.Debugf("Got cluster consistency: %s", clusterConsistency) 57 log.Debugf("Got username: %s", username) 58 } 59 60 // sessionInitializer is an initializer for a cassandra session 61 type sessionInitializer struct { 62 clusterHostName string 63 clusterHostUsername string 64 clusterHostPassword string 65 clusterHostSSLCert string 66 keyspace string 67 consistency gocql.Consistency 68 } 69 70 // sessionHolder stores a cassandra session 71 type sessionHolder struct { 72 session SessionInterface 73 } 74 75 // New return a cassandra session Initializer 76 func New(clusterHostName, keyspace string) Initializer { 77 log.Debugf("in new") 78 79 return sessionInitializer{ 80 clusterHostName: clusterHostName, 81 clusterHostUsername: username, 82 clusterHostPassword: password, 83 clusterHostSSLCert: sslCert, 84 keyspace: keyspace, 85 consistency: clusterConsistency, 86 } 87 } 88 89 // Initialize waits for a Cassandra session, initializes Cassandra keyspace and creates tables if required. 90 // NOTE: Needs to be called only once on the app startup, won't fail if it is called multiple times but is not necessary. 91 // 92 // Params: 93 // clusterHostName: Cassandra cluster host 94 // systemKeyspace: System keyspace 95 // appKeyspace: Application keyspace 96 // connectionTimeout: timeout to get the connection 97 func Initialize(clusterHostName, systemKeyspace, appKeyspace string, connectionTimeout time.Duration) { 98 log.Debug("Setting up cassandra db") 99 connectionHolder, err := loop(connectionTimeout, New(clusterHostName, systemKeyspace), "cassandra-db") 100 if err != nil { 101 log.Fatalf("error connecting to Cassandra db: %v", err) 102 panic(err) 103 } 104 defer connectionHolder.CloseSession() 105 106 log.Debug("Setting up cassandra keyspace") 107 err = createAppKeyspaceIfRequired(clusterHostName, systemKeyspace, appKeyspace) 108 if err != nil { 109 log.Fatalf("error creating keyspace for Cassandra db: %v", err) 110 panic(err) 111 } 112 113 log.Info("Cassandra keyspace has been set up") 114 } 115 116 // NewSession starts a new cassandra session for the given keyspace 117 // NOTE: It is responsibility of the caller to close this new session. 118 // 119 // Returns a session Holder for the session, or an error if can't start the session 120 func (i sessionInitializer) NewSession() (Holder, error) { 121 session, err := newKeyspaceSession(i.clusterHostName, i.keyspace, 122 i.clusterHostUsername, i.clusterHostPassword, i.clusterHostSSLCert, i.consistency, 123 600*time.Millisecond) 124 if err != nil { 125 log.Errorf("error starting Cassandra session for the cluster hostname: %s and keyspace: %s - %v", 126 i.clusterHostName, i.keyspace, err) 127 return nil, err 128 } 129 sessionRetry := sessionRetry{session} 130 connectionHolder := sessionHolder{sessionRetry} 131 return connectionHolder, nil 132 } 133 134 // GetSession returns the stored cassandra session 135 func (holder sessionHolder) GetSession() SessionInterface { 136 return holder.session 137 } 138 139 // CloseSession closes the cassandra session 140 func (holder sessionHolder) CloseSession() { 141 holder.session.Close() 142 } 143 144 // newKeyspaceSession returns a new session for the given keyspace 145 func newKeyspaceSession(clusterHostName, keyspace, username, password, sslCert string, clusterConsistency gocql.Consistency, clusterTimeout time.Duration) (*gocql.Session, error) { 146 log.Infof("Creating new cassandra session for cluster hostname: %s and keyspace: %s", clusterHostName, keyspace) 147 cluster := gocql.NewCluster(clusterHostName) 148 cluster.Keyspace = keyspace 149 cluster.Timeout = clusterTimeout 150 if username != "" { 151 cluster.Authenticator = gocql.PasswordAuthenticator{ 152 Username: username, 153 Password: password, 154 } 155 } 156 if sslCert != "" { 157 cluster.SslOpts = &gocql.SslOptions{ 158 CaPath: sslCert, 159 } 160 } 161 cluster.Consistency = clusterConsistency 162 return cluster.CreateSession() 163 } 164 165 // createAppKeyspaceIfRequired creates the keyspace for the app if it doesn't exist 166 func createAppKeyspaceIfRequired(clusterHostName, systemKeyspace, appKeyspace string) error { 167 // Getting the schema file if exist 168 stmtList, err := getStmtsFromFile(path.Join(schemaPath, schemaFileName)) 169 if err != nil { 170 return err 171 } 172 if stmtList == nil { // Didn't fail but returned nil, probably the file does not exist 173 return nil 174 } 175 176 log.Info("about to create a session with a 5 minute timeout to allow for all schema creation") 177 session, err := newKeyspaceSession(clusterHostName, systemKeyspace, username, password, sslCert, clusterConsistency, 5*time.Minute) 178 if err != nil { 179 return err 180 } 181 currentKeyspace := systemKeyspace 182 183 var sessionList []*gocql.Session 184 defer func() { 185 for _, s := range sessionList { 186 if s != nil && !s.Closed() { 187 s.Close() 188 } 189 } 190 }() 191 192 log.Debugf("Creating new keyspace if required: %s", appKeyspace) 193 194 for _, stmt := range stmtList { 195 log.Debugf("Executing statement: %s", stmt) 196 // New session for use statement 197 newKeyspace, isCaseSensitive := getKeyspaceNameFromUseStmt(stmt) 198 if newKeyspace != "" { 199 if (isCaseSensitive && newKeyspace != currentKeyspace) || (!isCaseSensitive && 200 strings.ToLower(newKeyspace) != strings.ToLower(currentKeyspace)) { 201 log.Infof("about to create a session with a 5 minute timeout to set keyspace: %s", newKeyspace) 202 session, err = newKeyspaceSession(clusterHostName, newKeyspace, username, password, sslCert, clusterConsistency, 5*time.Minute) //5 minutes 203 if err != nil { 204 return err 205 } 206 currentKeyspace = newKeyspace 207 sessionList = append(sessionList, session) 208 log.Debugf("Changed to new keyspace: %s", newKeyspace) 209 } 210 continue 211 } 212 213 // execute statement 214 err = session.Query(stmt).Exec() 215 if err != nil { 216 log.Errorf("statement error: %v", err) 217 return err 218 } 219 log.Debug("Statement executed") 220 } 221 222 log.Debugf("app keyspace set to: %s", appKeyspace) 223 return nil 224 } 225 226 // getStmtsFromFile extracts CQL statements from the file 227 func getStmtsFromFile(fileName string, ) ([]string, error) { 228 // Verify first if the file exist 229 if _, err := os.Stat(fileName); err != nil { 230 if os.IsNotExist(err) { // Does not exist 231 log.Warnf("no schema file [%s] found initializing Cassandra.", fileName) 232 return nil, nil 233 } 234 } 235 236 content, err := ioutil.ReadFile(fileName) 237 if err != nil { 238 return nil, err 239 } 240 241 pattern := regexp.MustCompile(`(?ms)([^"';]*?)("(?:[^"]|"")*"|'(?:[^']|'')*'|\$\$.*?\$\$|(/\*.*?\*/)|((?:--|//).*?$)|;\n?)`) 242 243 var stmtList []string 244 245 i := 0 246 contentLength := len(content) 247 var stmt bytes.Buffer 248 for i < contentLength { 249 subIndexes := pattern.FindSubmatchIndex(content[i:]) 250 if len(subIndexes) > 0 { 251 end := subIndexes[1] 252 stmt.Write(getMatch(content, i, subIndexes, 2, 3)) 253 stmtTail := getMatch(content, i, subIndexes, 4, 5) 254 comment := getMatch(content, i, subIndexes, 6, 7) 255 lineComment := getMatch(content, i, subIndexes, 8, 9) 256 if comment == nil && lineComment == nil { 257 if stmtTail != nil && string(bytes.TrimSpace(stmtTail)) == ";" { 258 stmtList = append(stmtList, stmt.String()) 259 stmt.Reset() 260 } else { 261 stmt.Write(stmtTail) 262 } 263 } 264 i = i + end 265 } else { 266 break 267 } 268 } 269 270 return stmtList, nil 271 272 } 273 274 // getMatch returns the matched substring if there's a match, nil otherwise 275 func getMatch(src []byte, base int, match []int, start int, end int, ) []byte { 276 if match[start] >= 0 { 277 return src[base+match[start] : base+match[end]] 278 } else { 279 return nil 280 } 281 } 282 283 // getKeyspaceNameFromUseStmt return keyspace name for use statement 284 func getKeyspaceNameFromUseStmt(stmt string, ) (string, bool) { 285 pattern := regexp.MustCompile(`(?ms)[Uu][Ss][Ee]\s+("(?:[^"]|"")+"|\w+)`) 286 if pattern.MatchString(stmt) { 287 match := pattern.FindStringSubmatch(stmt) 288 if len(match) > 1 { 289 keyspace := match[1] 290 caseSensitive := false 291 if strings.HasPrefix(keyspace, "\"") && strings.HasSuffix(keyspace, "\"") { 292 keyspace = strings.Trim(keyspace, "\"") 293 caseSensitive = true 294 } 295 return keyspace, caseSensitive 296 } 297 } 298 return "", false 299 } 300 301 // Loop is a loop that tries to get a connection until a timeout is reached 302 // 303 // Params: 304 // timeout: timeout to get the connection 305 // initializer : initializer to start the session 306 // connectionHost : name of host for the connection 307 // 308 // Returns a session Holder to store the session, or an error if the timeout was reached 309 func loop(timeout time.Duration, initializer Initializer, connectionHost string, ) (Holder, error) { 310 log.Debugf("Connection loop to connect to %s, timeout to use: %s", connectionHost, timeout) 311 ticker := time.NewTicker(1 * time.Second) 312 defer ticker.Stop() 313 314 timeoutExceeded := time.After(timeout) 315 for { 316 select { 317 case <-timeoutExceeded: 318 return nil, fmt.Errorf("connection to %s failed after %s timeout", connectionHost, timeout) 319 320 case <-ticker.C: 321 log.Infof("Trying to connect to: %s", connectionHost) 322 connectionHolder, err := initializer.NewSession() 323 if err == nil { 324 log.Infof("Successful connection to: %s", connectionHost) 325 return connectionHolder, nil 326 } 327 log.Infof("Trying to connect to %s, failed attempt: %v", connectionHost, err) 328 } 329 } 330 331 } 332 333 // getenv get a string value from an environment variable 334 // or return the given default value if the environment variable is not set 335 // 336 // Params: 337 // envVariable : environment variable 338 // defaultValue : value to return if environment variable is not set 339 // 340 // Returns the string value for the specified variable 341 func getenv(envVariable string, defaultValue string) string { 342 343 log.Debugf("Setting value for: %s", envVariable) 344 returnValue := defaultValue 345 log.Debugf("Default value for %s : %s", envVariable, defaultValue) 346 envStr := os.Getenv(envVariable) 347 if envStr != "" { 348 returnValue = envStr 349 log.Debugf("Init value for %s set to: %s", envVariable, envStr) 350 } 351 352 return returnValue 353 } 354 355 // getenvnolog get a string value from an environment variable 356 // or return the given default value if the environment variable is not set 357 // 358 // Params: 359 // envVariable : environment variable 360 // defaultValue : value to return if environment variable is not set 361 // 362 // Returns the string value for the specified variable 363 func getenvnolog(envVariable string, defaultValue string) string { 364 365 log.Debugf("Setting value for: %s", envVariable) 366 returnValue := defaultValue 367 log.Debugf("Default value for %s : %s", envVariable, defaultValue) 368 envStr := os.Getenv(envVariable) 369 if envStr != "" { 370 returnValue = envStr 371 } 372 373 return returnValue 374 } 375 376 func checkConsistency(envVar string) gocql.Consistency { 377 switch strings.ToLower(envVar) { 378 case "any": 379 log.Debugf("consistency set to any") 380 return gocql.Any 381 case "one": 382 log.Debugf("consistency set to one") 383 return gocql.One 384 case "two": 385 log.Debugf("consistency set to two") 386 return gocql.Two 387 case "three": 388 log.Debugf("consistency set to three") 389 return gocql.Three 390 case "quorum": 391 log.Debugf("consistency set to quorum") 392 return gocql.Quorum 393 case "all": 394 log.Debugf("consistency set to all") 395 return gocql.All 396 case "localquorum": 397 log.Debugf("consistency set to local quorum") 398 return gocql.LocalQuorum 399 case "eachquorum": 400 log.Debugf("consistency set to each quorum") 401 return gocql.EachQuorum 402 case "localone": 403 log.Debugf("consistency set to local one") 404 return gocql.LocalOne 405 default: 406 log.Debugf("consistency set to %s", clusterConsistency.String()) 407 return clusterConsistency 408 } 409 }