go.temporal.io/server@v1.23.0/common/persistence/nosql/nosqlplugin/cassandra/gocql/client.go (about) 1 // The MIT License 2 // 3 // Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. 4 // 5 // Copyright (c) 2020 Uber Technologies, Inc. 6 // 7 // Permission is hereby granted, free of charge, to any person obtaining a copy 8 // of this software and associated documentation files (the "Software"), to deal 9 // in the Software without restriction, including without limitation the rights 10 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 // copies of the Software, and to permit persons to whom the Software is 12 // furnished to do so, subject to the following conditions: 13 // 14 // The above copyright notice and this permission notice shall be included in 15 // all copies or substantial portions of the Software. 16 // 17 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 23 // THE SOFTWARE. 24 25 package gocql 26 27 import ( 28 "crypto/tls" 29 "crypto/x509" 30 "encoding/base64" 31 "errors" 32 "fmt" 33 "os" 34 "strings" 35 "time" 36 37 "github.com/gocql/gocql" 38 39 "go.temporal.io/server/common/auth" 40 "go.temporal.io/server/common/config" 41 "go.temporal.io/server/common/debug" 42 "go.temporal.io/server/common/persistence/nosql/nosqlplugin/cassandra/translator" 43 "go.temporal.io/server/common/resolver" 44 ) 45 46 func NewCassandraCluster( 47 cfg config.Cassandra, 48 resolver resolver.ServiceResolver, 49 ) (*gocql.ClusterConfig, error) { 50 var resolvedHosts []string 51 for _, host := range parseHosts(cfg.Hosts) { 52 resolvedHosts = append(resolvedHosts, resolver.Resolve(host)...) 53 } 54 55 cluster := gocql.NewCluster(resolvedHosts...) 56 if err := ConfigureCassandraCluster(cfg, cluster); err != nil { 57 return nil, err 58 } 59 60 return cluster, nil 61 } 62 63 // Modifies the input cluster config in place. 64 // 65 //nolint:revive // cognitive complexity 61 (> max enabled 25) 66 func ConfigureCassandraCluster(cfg config.Cassandra, cluster *gocql.ClusterConfig) error { 67 cluster.ProtoVersion = 4 68 if cfg.Port > 0 { 69 cluster.Port = cfg.Port 70 } 71 if cfg.User != "" && cfg.Password != "" { 72 cluster.Authenticator = gocql.PasswordAuthenticator{ 73 Username: cfg.User, 74 Password: cfg.Password, 75 } 76 } 77 if cfg.Keyspace != "" { 78 cluster.Keyspace = cfg.Keyspace 79 } 80 if cfg.Datacenter != "" { 81 cluster.HostFilter = gocql.DataCentreHostFilter(cfg.Datacenter) 82 } 83 if cfg.TLS != nil && cfg.TLS.Enabled { 84 if cfg.TLS.CertData != "" && cfg.TLS.CertFile != "" { 85 return errors.New("only one of certData or certFile properties should be specified") 86 } 87 88 if cfg.TLS.KeyData != "" && cfg.TLS.KeyFile != "" { 89 return errors.New("only one of keyData or keyFile properties should be specified") 90 } 91 92 if cfg.TLS.CaData != "" && cfg.TLS.CaFile != "" { 93 return errors.New("only one of caData or caFile properties should be specified") 94 } 95 96 cluster.SslOpts = &gocql.SslOptions{ 97 CaPath: cfg.TLS.CaFile, 98 EnableHostVerification: cfg.TLS.EnableHostVerification, 99 Config: auth.NewTLSConfigForServer(cfg.TLS.ServerName, cfg.TLS.EnableHostVerification), 100 } 101 102 var certBytes []byte 103 var keyBytes []byte 104 var err error 105 106 if cfg.TLS.CertFile != "" { 107 certBytes, err = os.ReadFile(cfg.TLS.CertFile) 108 if err != nil { 109 return fmt.Errorf("error reading client certificate file: %w", err) 110 } 111 } else if cfg.TLS.CertData != "" { 112 certBytes, err = base64.StdEncoding.DecodeString(cfg.TLS.CertData) 113 if err != nil { 114 return fmt.Errorf("client certificate could not be decoded: %w", err) 115 } 116 } 117 118 if cfg.TLS.KeyFile != "" { 119 keyBytes, err = os.ReadFile(cfg.TLS.KeyFile) 120 if err != nil { 121 return fmt.Errorf("error reading client certificate private key file: %w", err) 122 } 123 } else if cfg.TLS.KeyData != "" { 124 keyBytes, err = base64.StdEncoding.DecodeString(cfg.TLS.KeyData) 125 if err != nil { 126 return fmt.Errorf("client certificate private key could not be decoded: %w", err) 127 } 128 } 129 130 if len(certBytes) > 0 { 131 clientCert, err := tls.X509KeyPair(certBytes, keyBytes) 132 if err != nil { 133 return fmt.Errorf("unable to generate x509 key pair: %w", err) 134 } 135 136 cluster.SslOpts.Certificates = []tls.Certificate{clientCert} 137 } 138 139 if cfg.TLS.CaData != "" { 140 cluster.SslOpts.RootCAs = x509.NewCertPool() 141 pem, err := base64.StdEncoding.DecodeString(cfg.TLS.CaData) 142 if err != nil { 143 return fmt.Errorf("caData could not be decoded: %w", err) 144 } 145 if !cluster.SslOpts.RootCAs.AppendCertsFromPEM(pem) { 146 return errors.New("failed to load decoded CA Cert as PEM") 147 } 148 } 149 } 150 151 if cfg.MaxConns > 0 { 152 cluster.NumConns = cfg.MaxConns 153 } 154 155 if cfg.ConnectTimeout > 0 { 156 cluster.Timeout = cfg.ConnectTimeout 157 cluster.ConnectTimeout = cfg.ConnectTimeout 158 } else { 159 cluster.Timeout = 10 * time.Second * debug.TimeoutMultiplier 160 cluster.ConnectTimeout = 10 * time.Second * debug.TimeoutMultiplier 161 } 162 163 cluster.ProtoVersion = 4 164 cluster.Consistency = cfg.Consistency.GetConsistency() 165 cluster.SerialConsistency = cfg.Consistency.GetSerialConsistency() 166 cluster.DisableInitialHostLookup = cfg.DisableInitialHostLookup 167 168 cluster.ReconnectionPolicy = &gocql.ExponentialReconnectionPolicy{ 169 MaxRetries: 30, 170 InitialInterval: time.Second, 171 MaxInterval: 10 * time.Second, 172 } 173 174 cluster.PoolConfig.HostSelectionPolicy = gocql.TokenAwareHostPolicy(gocql.RoundRobinHostPolicy()) 175 176 if cfg.AddressTranslator != nil && cfg.AddressTranslator.Translator != "" { 177 addressTranslator, err := translator.LookupTranslator(cfg.AddressTranslator.Translator) 178 if err != nil { 179 return err 180 } 181 cluster.AddressTranslator, err = addressTranslator.GetTranslator(&cfg) 182 if err != nil { 183 return err 184 } 185 } 186 187 return nil 188 } 189 190 // parseHosts returns parses a list of hosts separated by comma 191 func parseHosts(input string) []string { 192 var hosts []string 193 for _, h := range strings.Split(input, ",") { 194 if host := strings.TrimSpace(h); len(host) > 0 { 195 hosts = append(hosts, host) 196 } 197 } 198 return hosts 199 }