go.temporal.io/server@v1.23.0/common/persistence/sql/sqlplugin/mysql/session/session.go (about)

     1  // The MIT License
     2  //
     3  // Copyright (c) 2020 Temporal Technologies Inc.  All rights reserved.
     4  //
     5  // Copyright (c) 2020 Uber Technologies, Inc.
     6  //
     7  // Permission is hereby granted, free of charge, to any person obtaining a copy
     8  // of this software and associated documentation files (the "Software"), to deal
     9  // in the Software without restriction, including without limitation the rights
    10  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    11  // copies of the Software, and to permit persons to whom the Software is
    12  // furnished to do so, subject to the following conditions:
    13  //
    14  // The above copyright notice and this permission notice shall be included in
    15  // all copies or substantial portions of the Software.
    16  //
    17  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    18  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    19  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    20  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    21  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    22  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    23  // THE SOFTWARE.
    24  
    25  package session
    26  
    27  import (
    28  	"crypto/tls"
    29  	"crypto/x509"
    30  	"fmt"
    31  	"os"
    32  	"strings"
    33  
    34  	"github.com/go-sql-driver/mysql"
    35  	"github.com/iancoleman/strcase"
    36  	"github.com/jmoiron/sqlx"
    37  
    38  	"go.temporal.io/server/common/auth"
    39  	"go.temporal.io/server/common/config"
    40  	"go.temporal.io/server/common/resolver"
    41  )
    42  
    43  const (
    44  	driverName = "mysql"
    45  
    46  	isolationLevelAttrName       = "transaction_isolation"
    47  	isolationLevelAttrNameLegacy = "tx_isolation"
    48  	defaultIsolationLevel        = "'READ-COMMITTED'"
    49  	// customTLSName is the name used if a custom tls configuration is created
    50  	customTLSName = "tls-custom"
    51  )
    52  
    53  var dsnAttrOverrides = map[string]string{
    54  	"parseTime":       "true",
    55  	"clientFoundRows": "true",
    56  }
    57  
    58  type Session struct {
    59  	*sqlx.DB
    60  }
    61  
    62  func NewSession(
    63  	cfg *config.SQL,
    64  	resolver resolver.ServiceResolver,
    65  ) (*Session, error) {
    66  	db, err := createConnection(cfg, resolver)
    67  	if err != nil {
    68  		return nil, err
    69  	}
    70  	return &Session{DB: db}, nil
    71  }
    72  
    73  func (s *Session) Close() {
    74  	if s.DB != nil {
    75  		_ = s.DB.Close()
    76  	}
    77  }
    78  
    79  func createConnection(
    80  	cfg *config.SQL,
    81  	resolver resolver.ServiceResolver,
    82  ) (*sqlx.DB, error) {
    83  	err := registerTLSConfig(cfg)
    84  	if err != nil {
    85  		return nil, err
    86  	}
    87  
    88  	db, err := sqlx.Connect(driverName, buildDSN(cfg, resolver))
    89  	if err != nil {
    90  		return nil, err
    91  	}
    92  	if cfg.MaxConns > 0 {
    93  		db.SetMaxOpenConns(cfg.MaxConns)
    94  	}
    95  	if cfg.MaxIdleConns > 0 {
    96  		db.SetMaxIdleConns(cfg.MaxIdleConns)
    97  	}
    98  	if cfg.MaxConnLifetime > 0 {
    99  		db.SetConnMaxLifetime(cfg.MaxConnLifetime)
   100  	}
   101  
   102  	// Maps struct names in CamelCase to snake without need for db struct tags.
   103  	db.MapperFunc(strcase.ToSnake)
   104  	return db, nil
   105  }
   106  
   107  func buildDSN(cfg *config.SQL, r resolver.ServiceResolver) string {
   108  	mysqlConfig := mysql.NewConfig()
   109  
   110  	mysqlConfig.User = cfg.User
   111  	mysqlConfig.Passwd = cfg.Password
   112  	mysqlConfig.Addr = r.Resolve(cfg.ConnectAddr)[0]
   113  	mysqlConfig.DBName = cfg.DatabaseName
   114  	mysqlConfig.Net = cfg.ConnectProtocol
   115  	mysqlConfig.Params = buildDSNAttrs(cfg)
   116  
   117  	// https://github.com/go-sql-driver/mysql/blob/v1.5.0/dsn.go#L104-L106
   118  	// https://github.com/go-sql-driver/mysql/blob/v1.5.0/dsn.go#L182-L189
   119  	if mysqlConfig.Net == "" {
   120  		mysqlConfig.Net = "tcp"
   121  	}
   122  
   123  	// https://github.com/go-sql-driver/mysql#rejectreadonly
   124  	// https://github.com/temporalio/temporal/issues/1703
   125  	mysqlConfig.RejectReadOnly = true
   126  
   127  	return mysqlConfig.FormatDSN()
   128  }
   129  
   130  func buildDSNAttrs(cfg *config.SQL) map[string]string {
   131  	attrs := make(map[string]string, len(dsnAttrOverrides)+len(cfg.ConnectAttributes)+1)
   132  	for k, v := range cfg.ConnectAttributes {
   133  		k1, v1 := sanitizeAttr(k, v)
   134  		attrs[k1] = v1
   135  	}
   136  
   137  	// only override isolation level if not specified
   138  	if !hasAttr(attrs, isolationLevelAttrName) &&
   139  		!hasAttr(attrs, isolationLevelAttrNameLegacy) {
   140  		attrs[isolationLevelAttrName] = defaultIsolationLevel
   141  	}
   142  
   143  	// these attrs are always overriden
   144  	for k, v := range dsnAttrOverrides {
   145  		attrs[k] = v
   146  	}
   147  
   148  	return attrs
   149  }
   150  
   151  func hasAttr(attrs map[string]string, key string) bool {
   152  	_, ok := attrs[key]
   153  	return ok
   154  }
   155  
   156  func sanitizeAttr(inkey string, invalue string) (string, string) {
   157  	key := strings.ToLower(strings.TrimSpace(inkey))
   158  	value := strings.ToLower(strings.TrimSpace(invalue))
   159  	switch key {
   160  	case isolationLevelAttrName, isolationLevelAttrNameLegacy:
   161  		if value[0] != '\'' { // mysql sys variable values must be enclosed in single quotes
   162  			value = "'" + value + "'"
   163  		}
   164  		return key, value
   165  	default:
   166  		return inkey, invalue
   167  	}
   168  }
   169  
   170  func registerTLSConfig(cfg *config.SQL) error {
   171  	if cfg.TLS == nil || !cfg.TLS.Enabled {
   172  		return nil
   173  	}
   174  
   175  	// TODO: create a way to set MinVersion and CipherSuites via cfg.
   176  	tlsConfig := auth.NewTLSConfigForServer(cfg.TLS.ServerName, cfg.TLS.EnableHostVerification)
   177  
   178  	if cfg.TLS.CaFile != "" {
   179  		rootCertPool := x509.NewCertPool()
   180  		pem, err := os.ReadFile(cfg.TLS.CaFile)
   181  		if err != nil {
   182  			return fmt.Errorf("failed to load CA files: %v", err)
   183  		}
   184  		if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
   185  			return fmt.Errorf("failed to append CA file")
   186  		}
   187  		tlsConfig.RootCAs = rootCertPool
   188  	}
   189  
   190  	if cfg.TLS.CertFile != "" && cfg.TLS.KeyFile != "" {
   191  		clientCert := make([]tls.Certificate, 0, 1)
   192  		certs, err := tls.LoadX509KeyPair(
   193  			cfg.TLS.CertFile,
   194  			cfg.TLS.KeyFile,
   195  		)
   196  		if err != nil {
   197  			return fmt.Errorf("failed to load tls x509 key pair: %v", err)
   198  		}
   199  		clientCert = append(clientCert, certs)
   200  		tlsConfig.Certificates = clientCert
   201  	}
   202  
   203  	// In order to use the TLS configuration you need to register it. Once registered you use it by specifying
   204  	// `tls` in the connect attributes.
   205  	err := mysql.RegisterTLSConfig(customTLSName, tlsConfig)
   206  	if err != nil {
   207  		return fmt.Errorf("failed to register tls config: %v", err)
   208  	}
   209  
   210  	if cfg.ConnectAttributes == nil {
   211  		cfg.ConnectAttributes = map[string]string{}
   212  	}
   213  
   214  	// If no `tls` connect attribute is provided then we override it to our newly registered tls config automatically.
   215  	// This allows users to simply provide a tls config without needing to remember to also set the connect attribute
   216  	if cfg.ConnectAttributes["tls"] == "" {
   217  		cfg.ConnectAttributes["tls"] = customTLSName
   218  	}
   219  
   220  	return nil
   221  }