github.com/hellofresh/janus@v0.0.0-20230925145208-ce8de8183c67/cassandra/wrapper/init.go (about)

     1  package wrapper
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io/ioutil"
     7  	"os"
     8  	"path"
     9  	"regexp"
    10  	"strings"
    11  	"time"
    12  
    13  	"github.com/gocql/gocql"
    14  	log "github.com/sirupsen/logrus"
    15  )
    16  
    17  // Schema file to create keyspace if required
    18  const (
    19  	schemaDefaultPath     = "/usr/local/bin"
    20  	schemaDefaultFileName = "schema.sql"
    21  	defaultCassandraClusterConsistency = gocql.Quorum
    22  	defaultUsername = ""
    23  	defaultPassword = ""
    24  	defaultSSLCert = ""
    25  
    26  	envCassandraClusterConsistency = "CLUSTER_CONSISTENCY"
    27  	envCassandraSchemaPath     = "CASSANDRA_SCHEMA_PATH"
    28  	envCassandraSchemaFileName = "CASSANDRA_SCHEMA_FILE_NAME"
    29  	envUsername = "CASSANDRA_USERNAME"
    30  	envPassword = "CASSANDRA_PASSWORD"
    31  	envSSLCert = "CASSANDRA_SSL_CERT"
    32  )
    33  
    34  var schemaPath = "/usr/local/bin"
    35  var schemaFileName = "schema.sql"
    36  var clusterConsistency = gocql.Quorum
    37  var username = ""
    38  var password = ""
    39  var sslCert = ""
    40  
    41  // Package level initialization.
    42  //
    43  // init functions are automatically executed when the programs starts
    44  func init() {
    45  
    46  	// reading and setting up environment variables
    47  	schemaPath = getenv(envCassandraSchemaPath, schemaDefaultPath)
    48  	schemaFileName = getenv(envCassandraSchemaFileName, schemaDefaultFileName)
    49  	clusterConsistency = checkConsistency(getenv(envCassandraClusterConsistency, defaultCassandraClusterConsistency.String()))
    50  	username = getenvnolog(envUsername, defaultUsername)
    51  	password = getenvnolog(envPassword, defaultPassword)
    52  	sslCert = getenvnolog(envSSLCert, defaultSSLCert)
    53  
    54  	log.Debugf("Got schema path: %s", schemaPath)
    55  	log.Debugf("Got schema file name: %s", schemaFileName)
    56  	log.Debugf("Got cluster consistency: %s", clusterConsistency)
    57  	log.Debugf("Got username: %s", username)
    58  }
    59  
    60  // sessionInitializer is an initializer for a cassandra session
    61  type sessionInitializer struct {
    62  	clusterHostName string
    63  	clusterHostUsername string
    64  	clusterHostPassword string
    65  	clusterHostSSLCert string
    66  	keyspace        string
    67  	consistency gocql.Consistency
    68  }
    69  
    70  // sessionHolder stores a cassandra session
    71  type sessionHolder struct {
    72  	session SessionInterface
    73  }
    74  
    75  // New return a cassandra session Initializer
    76  func New(clusterHostName, keyspace string) Initializer {
    77  	log.Debugf("in new")
    78  
    79  	return sessionInitializer{
    80  		clusterHostName:     clusterHostName,
    81  		clusterHostUsername: username,
    82  		clusterHostPassword: password,
    83  		clusterHostSSLCert:  sslCert,
    84  		keyspace:            keyspace,
    85  		consistency:         clusterConsistency,
    86  	}
    87  }
    88  
    89  // Initialize waits for a Cassandra session, initializes Cassandra keyspace and creates tables if required.
    90  // NOTE: Needs to be called only once on the app startup, won't fail if it is called multiple times but is not necessary.
    91  //
    92  // Params:
    93  //	clusterHostName: Cassandra cluster host
    94  //	systemKeyspace: System keyspace
    95  //	appKeyspace: Application keyspace
    96  //	connectionTimeout: timeout to get the connection
    97  func Initialize(clusterHostName, systemKeyspace, appKeyspace string, connectionTimeout time.Duration) {
    98  	log.Debug("Setting up cassandra db")
    99  	connectionHolder, err := loop(connectionTimeout, New(clusterHostName, systemKeyspace), "cassandra-db")
   100  	if err != nil {
   101  		log.Fatalf("error connecting to Cassandra db: %v", err)
   102  		panic(err)
   103  	}
   104  	defer connectionHolder.CloseSession()
   105  
   106  	log.Debug("Setting up cassandra keyspace")
   107  	err = createAppKeyspaceIfRequired(clusterHostName, systemKeyspace, appKeyspace)
   108  	if err != nil {
   109  		log.Fatalf("error creating keyspace for Cassandra db: %v", err)
   110  		panic(err)
   111  	}
   112  
   113  	log.Info("Cassandra keyspace has been set up")
   114  }
   115  
   116  // NewSession starts a new cassandra session for the given keyspace
   117  // NOTE: It is responsibility of the caller to close this new session.
   118  //
   119  // Returns a session Holder for the session, or an error if can't start the session
   120  func (i sessionInitializer) NewSession() (Holder, error) {
   121  	session, err := newKeyspaceSession(i.clusterHostName, i.keyspace,
   122  		i.clusterHostUsername, i.clusterHostPassword, i.clusterHostSSLCert, i.consistency,
   123  		600*time.Millisecond)
   124  	if err != nil {
   125  		log.Errorf("error starting Cassandra session for the cluster hostname: %s and keyspace: %s - %v",
   126  			i.clusterHostName, i.keyspace, err)
   127  		return nil, err
   128  	}
   129  	sessionRetry := sessionRetry{session}
   130  	connectionHolder := sessionHolder{sessionRetry}
   131  	return connectionHolder, nil
   132  }
   133  
   134  // GetSession returns the stored cassandra session
   135  func (holder sessionHolder) GetSession() SessionInterface {
   136  	return holder.session
   137  }
   138  
   139  // CloseSession closes the cassandra session
   140  func (holder sessionHolder) CloseSession() {
   141  	holder.session.Close()
   142  }
   143  
   144  // newKeyspaceSession returns a new session for the given keyspace
   145  func newKeyspaceSession(clusterHostName, keyspace, username, password, sslCert string, clusterConsistency gocql.Consistency, clusterTimeout time.Duration) (*gocql.Session, error) {
   146  	log.Infof("Creating new cassandra session for cluster hostname: %s and keyspace: %s", clusterHostName, keyspace)
   147  	cluster := gocql.NewCluster(clusterHostName)
   148  	cluster.Keyspace = keyspace
   149  	cluster.Timeout = clusterTimeout
   150  	if username != "" {
   151  		cluster.Authenticator = gocql.PasswordAuthenticator{
   152  			Username: username,
   153  			Password: password,
   154  		}
   155  	}
   156  	if sslCert != "" {
   157  		cluster.SslOpts = &gocql.SslOptions{
   158  			CaPath: sslCert,
   159  		}
   160  	}
   161  	cluster.Consistency = clusterConsistency
   162  	return cluster.CreateSession()
   163  }
   164  
   165  // createAppKeyspaceIfRequired creates the keyspace for the app if it doesn't exist
   166  func createAppKeyspaceIfRequired(clusterHostName, systemKeyspace, appKeyspace string) error {
   167  	// Getting the schema file if exist
   168  	stmtList, err := getStmtsFromFile(path.Join(schemaPath, schemaFileName))
   169  	if err != nil {
   170  		return err
   171  	}
   172  	if stmtList == nil { // Didn't fail but returned nil, probably the file does not exist
   173  		return nil
   174  	}
   175  
   176  	log.Info("about to create a session with a 5 minute timeout to allow for all schema creation")
   177  	session, err := newKeyspaceSession(clusterHostName, systemKeyspace, username, password, sslCert, clusterConsistency, 5*time.Minute)
   178  	if err != nil {
   179  		return err
   180  	}
   181  	currentKeyspace := systemKeyspace
   182  
   183  	var sessionList []*gocql.Session
   184  	defer func() {
   185  		for _, s := range sessionList {
   186  			if s != nil && !s.Closed() {
   187  				s.Close()
   188  			}
   189  		}
   190  	}()
   191  
   192  	log.Debugf("Creating new keyspace if required: %s", appKeyspace)
   193  
   194  	for _, stmt := range stmtList {
   195  		log.Debugf("Executing statement: %s", stmt)
   196  		// New session for use statement
   197  		newKeyspace, isCaseSensitive := getKeyspaceNameFromUseStmt(stmt)
   198  		if newKeyspace != "" {
   199  			if (isCaseSensitive && newKeyspace != currentKeyspace) || (!isCaseSensitive &&
   200  				strings.ToLower(newKeyspace) != strings.ToLower(currentKeyspace)) {
   201  				log.Infof("about to create a session with a 5 minute timeout to set keyspace: %s", newKeyspace)
   202  				session, err = newKeyspaceSession(clusterHostName, newKeyspace, username, password, sslCert, clusterConsistency, 5*time.Minute) //5 minutes
   203  				if err != nil {
   204  					return err
   205  				}
   206  				currentKeyspace = newKeyspace
   207  				sessionList = append(sessionList, session)
   208  				log.Debugf("Changed to new keyspace: %s", newKeyspace)
   209  			}
   210  			continue
   211  		}
   212  
   213  		// execute statement
   214  		err = session.Query(stmt).Exec()
   215  		if err != nil {
   216  			log.Errorf("statement error: %v", err)
   217  			return err
   218  		}
   219  		log.Debug("Statement executed")
   220  	}
   221  
   222  	log.Debugf("app keyspace set to: %s", appKeyspace)
   223  	return nil
   224  }
   225  
   226  // getStmtsFromFile extracts CQL statements from the file
   227  func getStmtsFromFile(fileName string, ) ([]string, error) {
   228  	// Verify first if the file exist
   229  	if _, err := os.Stat(fileName); err != nil {
   230  		if os.IsNotExist(err) { // Does not exist
   231  			log.Warnf("no schema file [%s] found initializing Cassandra.", fileName)
   232  			return nil, nil
   233  		}
   234  	}
   235  
   236  	content, err := ioutil.ReadFile(fileName)
   237  	if err != nil {
   238  		return nil, err
   239  	}
   240  
   241  	pattern := regexp.MustCompile(`(?ms)([^"';]*?)("(?:[^"]|"")*"|'(?:[^']|'')*'|\$\$.*?\$\$|(/\*.*?\*/)|((?:--|//).*?$)|;\n?)`)
   242  
   243  	var stmtList []string
   244  
   245  	i := 0
   246  	contentLength := len(content)
   247  	var stmt bytes.Buffer
   248  	for i < contentLength {
   249  		subIndexes := pattern.FindSubmatchIndex(content[i:])
   250  		if len(subIndexes) > 0 {
   251  			end := subIndexes[1]
   252  			stmt.Write(getMatch(content, i, subIndexes, 2, 3))
   253  			stmtTail := getMatch(content, i, subIndexes, 4, 5)
   254  			comment := getMatch(content, i, subIndexes, 6, 7)
   255  			lineComment := getMatch(content, i, subIndexes, 8, 9)
   256  			if comment == nil && lineComment == nil {
   257  				if stmtTail != nil && string(bytes.TrimSpace(stmtTail)) == ";" {
   258  					stmtList = append(stmtList, stmt.String())
   259  					stmt.Reset()
   260  				} else {
   261  					stmt.Write(stmtTail)
   262  				}
   263  			}
   264  			i = i + end
   265  		} else {
   266  			break
   267  		}
   268  	}
   269  
   270  	return stmtList, nil
   271  
   272  }
   273  
   274  // getMatch returns the matched substring if there's a match, nil otherwise
   275  func getMatch(src []byte, base int, match []int, start int, end int, ) []byte {
   276  	if match[start] >= 0 {
   277  		return src[base+match[start] : base+match[end]]
   278  	} else {
   279  		return nil
   280  	}
   281  }
   282  
   283  // getKeyspaceNameFromUseStmt return keyspace name for use statement
   284  func getKeyspaceNameFromUseStmt(stmt string, ) (string, bool) {
   285  	pattern := regexp.MustCompile(`(?ms)[Uu][Ss][Ee]\s+("(?:[^"]|"")+"|\w+)`)
   286  	if pattern.MatchString(stmt) {
   287  		match := pattern.FindStringSubmatch(stmt)
   288  		if len(match) > 1 {
   289  			keyspace := match[1]
   290  			caseSensitive := false
   291  			if strings.HasPrefix(keyspace, "\"") && strings.HasSuffix(keyspace, "\"") {
   292  				keyspace = strings.Trim(keyspace, "\"")
   293  				caseSensitive = true
   294  			}
   295  			return keyspace, caseSensitive
   296  		}
   297  	}
   298  	return "", false
   299  }
   300  
   301  // Loop is a loop that tries to get a connection until a timeout is reached
   302  //
   303  // Params:
   304  //	 timeout: timeout to get the connection
   305  //	 initializer : initializer to start the session
   306  //	 connectionHost : name of host for the connection
   307  //
   308  // Returns a session Holder to store the session, or an error if the timeout was reached
   309  func loop(timeout time.Duration, initializer Initializer, connectionHost string, ) (Holder, error) {
   310  	log.Debugf("Connection loop to connect to %s, timeout to use: %s", connectionHost, timeout)
   311  	ticker := time.NewTicker(1 * time.Second)
   312  	defer ticker.Stop()
   313  
   314  	timeoutExceeded := time.After(timeout)
   315  	for {
   316  		select {
   317  		case <-timeoutExceeded:
   318  			return nil, fmt.Errorf("connection to %s failed after %s timeout", connectionHost, timeout)
   319  
   320  		case <-ticker.C:
   321  			log.Infof("Trying to connect to: %s", connectionHost)
   322  			connectionHolder, err := initializer.NewSession()
   323  			if err == nil {
   324  				log.Infof("Successful connection to: %s", connectionHost)
   325  				return connectionHolder, nil
   326  			}
   327  			log.Infof("Trying to connect to %s, failed attempt: %v", connectionHost, err)
   328  		}
   329  	}
   330  
   331  }
   332  
   333  // getenv get a string value from an environment variable
   334  // or return the given default value if the environment variable is not set
   335  //
   336  // Params:
   337  //  envVariable : environment variable
   338  //  defaultValue : value to return if environment variable is not set
   339  //
   340  // Returns the string value for the specified variable
   341  func getenv(envVariable string, defaultValue string) string {
   342  
   343  	log.Debugf("Setting value for: %s", envVariable)
   344  	returnValue := defaultValue
   345  	log.Debugf("Default value for %s : %s", envVariable, defaultValue)
   346  	envStr := os.Getenv(envVariable)
   347  	if envStr != "" {
   348  		returnValue = envStr
   349  		log.Debugf("Init value for %s set to: %s", envVariable, envStr)
   350  	}
   351  
   352  	return returnValue
   353  }
   354  
   355  // getenvnolog get a string value from an environment variable
   356  // or return the given default value if the environment variable is not set
   357  //
   358  // Params:
   359  //  envVariable : environment variable
   360  //  defaultValue : value to return if environment variable is not set
   361  //
   362  // Returns the string value for the specified variable
   363  func getenvnolog(envVariable string, defaultValue string) string {
   364  
   365  	log.Debugf("Setting value for: %s", envVariable)
   366  	returnValue := defaultValue
   367  	log.Debugf("Default value for %s : %s", envVariable, defaultValue)
   368  	envStr := os.Getenv(envVariable)
   369  	if envStr != "" {
   370  		returnValue = envStr
   371  	}
   372  
   373  	return returnValue
   374  }
   375  
   376  func checkConsistency(envVar string) gocql.Consistency {
   377  	switch strings.ToLower(envVar) {
   378  	case "any":
   379  		log.Debugf("consistency set to any")
   380  		return gocql.Any
   381  	case "one":
   382  		log.Debugf("consistency set to one")
   383  		return gocql.One
   384  	case "two":
   385  		log.Debugf("consistency set to two")
   386  		return gocql.Two
   387  	case "three":
   388  		log.Debugf("consistency set to three")
   389  		return gocql.Three
   390  	case "quorum":
   391  		log.Debugf("consistency set to quorum")
   392  		return gocql.Quorum
   393  	case "all":
   394  		log.Debugf("consistency set to all")
   395  		return gocql.All
   396  	case "localquorum":
   397  		log.Debugf("consistency set to local quorum")
   398  		return gocql.LocalQuorum
   399  	case "eachquorum":
   400  		log.Debugf("consistency set to each quorum")
   401  		return gocql.EachQuorum
   402  	case "localone":
   403  		log.Debugf("consistency set to local one")
   404  		return gocql.LocalOne
   405  	default:
   406  		log.Debugf("consistency set to %s", clusterConsistency.String())
   407  		return clusterConsistency
   408  	}
   409  }