go.temporal.io/server@v1.23.0/common/persistence/sql/sqlplugin/mysql/session/session.go (about) 1 // The MIT License 2 // 3 // Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. 4 // 5 // Copyright (c) 2020 Uber Technologies, Inc. 6 // 7 // Permission is hereby granted, free of charge, to any person obtaining a copy 8 // of this software and associated documentation files (the "Software"), to deal 9 // in the Software without restriction, including without limitation the rights 10 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 // copies of the Software, and to permit persons to whom the Software is 12 // furnished to do so, subject to the following conditions: 13 // 14 // The above copyright notice and this permission notice shall be included in 15 // all copies or substantial portions of the Software. 16 // 17 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 23 // THE SOFTWARE. 24 25 package session 26 27 import ( 28 "crypto/tls" 29 "crypto/x509" 30 "fmt" 31 "os" 32 "strings" 33 34 "github.com/go-sql-driver/mysql" 35 "github.com/iancoleman/strcase" 36 "github.com/jmoiron/sqlx" 37 38 "go.temporal.io/server/common/auth" 39 "go.temporal.io/server/common/config" 40 "go.temporal.io/server/common/resolver" 41 ) 42 43 const ( 44 driverName = "mysql" 45 46 isolationLevelAttrName = "transaction_isolation" 47 isolationLevelAttrNameLegacy = "tx_isolation" 48 defaultIsolationLevel = "'READ-COMMITTED'" 49 // customTLSName is the name used if a custom tls configuration is created 50 customTLSName = "tls-custom" 51 ) 52 53 var dsnAttrOverrides = map[string]string{ 54 "parseTime": "true", 55 "clientFoundRows": "true", 56 } 57 58 type Session struct { 59 *sqlx.DB 60 } 61 62 func NewSession( 63 cfg *config.SQL, 64 resolver resolver.ServiceResolver, 65 ) (*Session, error) { 66 db, err := createConnection(cfg, resolver) 67 if err != nil { 68 return nil, err 69 } 70 return &Session{DB: db}, nil 71 } 72 73 func (s *Session) Close() { 74 if s.DB != nil { 75 _ = s.DB.Close() 76 } 77 } 78 79 func createConnection( 80 cfg *config.SQL, 81 resolver resolver.ServiceResolver, 82 ) (*sqlx.DB, error) { 83 err := registerTLSConfig(cfg) 84 if err != nil { 85 return nil, err 86 } 87 88 db, err := sqlx.Connect(driverName, buildDSN(cfg, resolver)) 89 if err != nil { 90 return nil, err 91 } 92 if cfg.MaxConns > 0 { 93 db.SetMaxOpenConns(cfg.MaxConns) 94 } 95 if cfg.MaxIdleConns > 0 { 96 db.SetMaxIdleConns(cfg.MaxIdleConns) 97 } 98 if cfg.MaxConnLifetime > 0 { 99 db.SetConnMaxLifetime(cfg.MaxConnLifetime) 100 } 101 102 // Maps struct names in CamelCase to snake without need for db struct tags. 103 db.MapperFunc(strcase.ToSnake) 104 return db, nil 105 } 106 107 func buildDSN(cfg *config.SQL, r resolver.ServiceResolver) string { 108 mysqlConfig := mysql.NewConfig() 109 110 mysqlConfig.User = cfg.User 111 mysqlConfig.Passwd = cfg.Password 112 mysqlConfig.Addr = r.Resolve(cfg.ConnectAddr)[0] 113 mysqlConfig.DBName = cfg.DatabaseName 114 mysqlConfig.Net = cfg.ConnectProtocol 115 mysqlConfig.Params = buildDSNAttrs(cfg) 116 117 // https://github.com/go-sql-driver/mysql/blob/v1.5.0/dsn.go#L104-L106 118 // https://github.com/go-sql-driver/mysql/blob/v1.5.0/dsn.go#L182-L189 119 if mysqlConfig.Net == "" { 120 mysqlConfig.Net = "tcp" 121 } 122 123 // https://github.com/go-sql-driver/mysql#rejectreadonly 124 // https://github.com/temporalio/temporal/issues/1703 125 mysqlConfig.RejectReadOnly = true 126 127 return mysqlConfig.FormatDSN() 128 } 129 130 func buildDSNAttrs(cfg *config.SQL) map[string]string { 131 attrs := make(map[string]string, len(dsnAttrOverrides)+len(cfg.ConnectAttributes)+1) 132 for k, v := range cfg.ConnectAttributes { 133 k1, v1 := sanitizeAttr(k, v) 134 attrs[k1] = v1 135 } 136 137 // only override isolation level if not specified 138 if !hasAttr(attrs, isolationLevelAttrName) && 139 !hasAttr(attrs, isolationLevelAttrNameLegacy) { 140 attrs[isolationLevelAttrName] = defaultIsolationLevel 141 } 142 143 // these attrs are always overriden 144 for k, v := range dsnAttrOverrides { 145 attrs[k] = v 146 } 147 148 return attrs 149 } 150 151 func hasAttr(attrs map[string]string, key string) bool { 152 _, ok := attrs[key] 153 return ok 154 } 155 156 func sanitizeAttr(inkey string, invalue string) (string, string) { 157 key := strings.ToLower(strings.TrimSpace(inkey)) 158 value := strings.ToLower(strings.TrimSpace(invalue)) 159 switch key { 160 case isolationLevelAttrName, isolationLevelAttrNameLegacy: 161 if value[0] != '\'' { // mysql sys variable values must be enclosed in single quotes 162 value = "'" + value + "'" 163 } 164 return key, value 165 default: 166 return inkey, invalue 167 } 168 } 169 170 func registerTLSConfig(cfg *config.SQL) error { 171 if cfg.TLS == nil || !cfg.TLS.Enabled { 172 return nil 173 } 174 175 // TODO: create a way to set MinVersion and CipherSuites via cfg. 176 tlsConfig := auth.NewTLSConfigForServer(cfg.TLS.ServerName, cfg.TLS.EnableHostVerification) 177 178 if cfg.TLS.CaFile != "" { 179 rootCertPool := x509.NewCertPool() 180 pem, err := os.ReadFile(cfg.TLS.CaFile) 181 if err != nil { 182 return fmt.Errorf("failed to load CA files: %v", err) 183 } 184 if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { 185 return fmt.Errorf("failed to append CA file") 186 } 187 tlsConfig.RootCAs = rootCertPool 188 } 189 190 if cfg.TLS.CertFile != "" && cfg.TLS.KeyFile != "" { 191 clientCert := make([]tls.Certificate, 0, 1) 192 certs, err := tls.LoadX509KeyPair( 193 cfg.TLS.CertFile, 194 cfg.TLS.KeyFile, 195 ) 196 if err != nil { 197 return fmt.Errorf("failed to load tls x509 key pair: %v", err) 198 } 199 clientCert = append(clientCert, certs) 200 tlsConfig.Certificates = clientCert 201 } 202 203 // In order to use the TLS configuration you need to register it. Once registered you use it by specifying 204 // `tls` in the connect attributes. 205 err := mysql.RegisterTLSConfig(customTLSName, tlsConfig) 206 if err != nil { 207 return fmt.Errorf("failed to register tls config: %v", err) 208 } 209 210 if cfg.ConnectAttributes == nil { 211 cfg.ConnectAttributes = map[string]string{} 212 } 213 214 // If no `tls` connect attribute is provided then we override it to our newly registered tls config automatically. 215 // This allows users to simply provide a tls config without needing to remember to also set the connect attribute 216 if cfg.ConnectAttributes["tls"] == "" { 217 cfg.ConnectAttributes["tls"] = customTLSName 218 } 219 220 return nil 221 }