go.temporal.io/server@v1.23.0/common/persistence/nosql/nosqlplugin/cassandra/gocql/client.go (about)

     1  // The MIT License
     2  //
     3  // Copyright (c) 2020 Temporal Technologies Inc.  All rights reserved.
     4  //
     5  // Copyright (c) 2020 Uber Technologies, Inc.
     6  //
     7  // Permission is hereby granted, free of charge, to any person obtaining a copy
     8  // of this software and associated documentation files (the "Software"), to deal
     9  // in the Software without restriction, including without limitation the rights
    10  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    11  // copies of the Software, and to permit persons to whom the Software is
    12  // furnished to do so, subject to the following conditions:
    13  //
    14  // The above copyright notice and this permission notice shall be included in
    15  // all copies or substantial portions of the Software.
    16  //
    17  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    18  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    19  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    20  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    21  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    22  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    23  // THE SOFTWARE.
    24  
    25  package gocql
    26  
    27  import (
    28  	"crypto/tls"
    29  	"crypto/x509"
    30  	"encoding/base64"
    31  	"errors"
    32  	"fmt"
    33  	"os"
    34  	"strings"
    35  	"time"
    36  
    37  	"github.com/gocql/gocql"
    38  
    39  	"go.temporal.io/server/common/auth"
    40  	"go.temporal.io/server/common/config"
    41  	"go.temporal.io/server/common/debug"
    42  	"go.temporal.io/server/common/persistence/nosql/nosqlplugin/cassandra/translator"
    43  	"go.temporal.io/server/common/resolver"
    44  )
    45  
    46  func NewCassandraCluster(
    47  	cfg config.Cassandra,
    48  	resolver resolver.ServiceResolver,
    49  ) (*gocql.ClusterConfig, error) {
    50  	var resolvedHosts []string
    51  	for _, host := range parseHosts(cfg.Hosts) {
    52  		resolvedHosts = append(resolvedHosts, resolver.Resolve(host)...)
    53  	}
    54  
    55  	cluster := gocql.NewCluster(resolvedHosts...)
    56  	if err := ConfigureCassandraCluster(cfg, cluster); err != nil {
    57  		return nil, err
    58  	}
    59  
    60  	return cluster, nil
    61  }
    62  
    63  // Modifies the input cluster config in place.
    64  //
    65  //nolint:revive // cognitive complexity 61 (> max enabled 25)
    66  func ConfigureCassandraCluster(cfg config.Cassandra, cluster *gocql.ClusterConfig) error {
    67  	cluster.ProtoVersion = 4
    68  	if cfg.Port > 0 {
    69  		cluster.Port = cfg.Port
    70  	}
    71  	if cfg.User != "" && cfg.Password != "" {
    72  		cluster.Authenticator = gocql.PasswordAuthenticator{
    73  			Username: cfg.User,
    74  			Password: cfg.Password,
    75  		}
    76  	}
    77  	if cfg.Keyspace != "" {
    78  		cluster.Keyspace = cfg.Keyspace
    79  	}
    80  	if cfg.Datacenter != "" {
    81  		cluster.HostFilter = gocql.DataCentreHostFilter(cfg.Datacenter)
    82  	}
    83  	if cfg.TLS != nil && cfg.TLS.Enabled {
    84  		if cfg.TLS.CertData != "" && cfg.TLS.CertFile != "" {
    85  			return errors.New("only one of certData or certFile properties should be specified")
    86  		}
    87  
    88  		if cfg.TLS.KeyData != "" && cfg.TLS.KeyFile != "" {
    89  			return errors.New("only one of keyData or keyFile properties should be specified")
    90  		}
    91  
    92  		if cfg.TLS.CaData != "" && cfg.TLS.CaFile != "" {
    93  			return errors.New("only one of caData or caFile properties should be specified")
    94  		}
    95  
    96  		cluster.SslOpts = &gocql.SslOptions{
    97  			CaPath:                 cfg.TLS.CaFile,
    98  			EnableHostVerification: cfg.TLS.EnableHostVerification,
    99  			Config:                 auth.NewTLSConfigForServer(cfg.TLS.ServerName, cfg.TLS.EnableHostVerification),
   100  		}
   101  
   102  		var certBytes []byte
   103  		var keyBytes []byte
   104  		var err error
   105  
   106  		if cfg.TLS.CertFile != "" {
   107  			certBytes, err = os.ReadFile(cfg.TLS.CertFile)
   108  			if err != nil {
   109  				return fmt.Errorf("error reading client certificate file: %w", err)
   110  			}
   111  		} else if cfg.TLS.CertData != "" {
   112  			certBytes, err = base64.StdEncoding.DecodeString(cfg.TLS.CertData)
   113  			if err != nil {
   114  				return fmt.Errorf("client certificate could not be decoded: %w", err)
   115  			}
   116  		}
   117  
   118  		if cfg.TLS.KeyFile != "" {
   119  			keyBytes, err = os.ReadFile(cfg.TLS.KeyFile)
   120  			if err != nil {
   121  				return fmt.Errorf("error reading client certificate private key file: %w", err)
   122  			}
   123  		} else if cfg.TLS.KeyData != "" {
   124  			keyBytes, err = base64.StdEncoding.DecodeString(cfg.TLS.KeyData)
   125  			if err != nil {
   126  				return fmt.Errorf("client certificate private key could not be decoded: %w", err)
   127  			}
   128  		}
   129  
   130  		if len(certBytes) > 0 {
   131  			clientCert, err := tls.X509KeyPair(certBytes, keyBytes)
   132  			if err != nil {
   133  				return fmt.Errorf("unable to generate x509 key pair: %w", err)
   134  			}
   135  
   136  			cluster.SslOpts.Certificates = []tls.Certificate{clientCert}
   137  		}
   138  
   139  		if cfg.TLS.CaData != "" {
   140  			cluster.SslOpts.RootCAs = x509.NewCertPool()
   141  			pem, err := base64.StdEncoding.DecodeString(cfg.TLS.CaData)
   142  			if err != nil {
   143  				return fmt.Errorf("caData could not be decoded: %w", err)
   144  			}
   145  			if !cluster.SslOpts.RootCAs.AppendCertsFromPEM(pem) {
   146  				return errors.New("failed to load decoded CA Cert as PEM")
   147  			}
   148  		}
   149  	}
   150  
   151  	if cfg.MaxConns > 0 {
   152  		cluster.NumConns = cfg.MaxConns
   153  	}
   154  
   155  	if cfg.ConnectTimeout > 0 {
   156  		cluster.Timeout = cfg.ConnectTimeout
   157  		cluster.ConnectTimeout = cfg.ConnectTimeout
   158  	} else {
   159  		cluster.Timeout = 10 * time.Second * debug.TimeoutMultiplier
   160  		cluster.ConnectTimeout = 10 * time.Second * debug.TimeoutMultiplier
   161  	}
   162  
   163  	cluster.ProtoVersion = 4
   164  	cluster.Consistency = cfg.Consistency.GetConsistency()
   165  	cluster.SerialConsistency = cfg.Consistency.GetSerialConsistency()
   166  	cluster.DisableInitialHostLookup = cfg.DisableInitialHostLookup
   167  
   168  	cluster.ReconnectionPolicy = &gocql.ExponentialReconnectionPolicy{
   169  		MaxRetries:      30,
   170  		InitialInterval: time.Second,
   171  		MaxInterval:     10 * time.Second,
   172  	}
   173  
   174  	cluster.PoolConfig.HostSelectionPolicy = gocql.TokenAwareHostPolicy(gocql.RoundRobinHostPolicy())
   175  
   176  	if cfg.AddressTranslator != nil && cfg.AddressTranslator.Translator != "" {
   177  		addressTranslator, err := translator.LookupTranslator(cfg.AddressTranslator.Translator)
   178  		if err != nil {
   179  			return err
   180  		}
   181  		cluster.AddressTranslator, err = addressTranslator.GetTranslator(&cfg)
   182  		if err != nil {
   183  			return err
   184  		}
   185  	}
   186  
   187  	return nil
   188  }
   189  
   190  // parseHosts returns parses a list of hosts separated by comma
   191  func parseHosts(input string) []string {
   192  	var hosts []string
   193  	for _, h := range strings.Split(input, ",") {
   194  		if host := strings.TrimSpace(h); len(host) > 0 {
   195  			hosts = append(hosts, host)
   196  		}
   197  	}
   198  	return hosts
   199  }