go.temporal.io/server@v1.23.0/common/persistence/sql/sqlplugin/postgresql/session/session.go (about)

     1  // The MIT License
     2  //
     3  // Copyright (c) 2020 Temporal Technologies Inc.  All rights reserved.
     4  //
     5  // Copyright (c) 2020 Uber Technologies, Inc.
     6  //
     7  // Permission is hereby granted, free of charge, to any person obtaining a copy
     8  // of this software and associated documentation files (the "Software"), to deal
     9  // in the Software without restriction, including without limitation the rights
    10  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    11  // copies of the Software, and to permit persons to whom the Software is
    12  // furnished to do so, subject to the following conditions:
    13  //
    14  // The above copyright notice and this permission notice shall be included in
    15  // all copies or substantial portions of the Software.
    16  //
    17  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    18  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    19  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    20  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    21  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    22  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    23  // THE SOFTWARE.
    24  
    25  package session
    26  
    27  import (
    28  	"fmt"
    29  	"net/url"
    30  	"strings"
    31  
    32  	"github.com/iancoleman/strcase"
    33  	"github.com/jmoiron/sqlx"
    34  
    35  	"go.temporal.io/server/common/config"
    36  	"go.temporal.io/server/common/persistence/sql/sqlplugin/postgresql/driver"
    37  	"go.temporal.io/server/common/resolver"
    38  )
    39  
    40  const (
    41  	dsnFmt = "postgres://%v:%v@%v/%v?%v"
    42  )
    43  
    44  const (
    45  	sslMode        = "sslmode"
    46  	sslModeNoop    = "disable"
    47  	sslModeRequire = "require"
    48  	sslModeFull    = "verify-full"
    49  
    50  	sslCA   = "sslrootcert"
    51  	sslKey  = "sslkey"
    52  	sslCert = "sslcert"
    53  )
    54  
    55  type Session struct {
    56  	*sqlx.DB
    57  }
    58  
    59  func NewSession(
    60  	cfg *config.SQL,
    61  	d driver.Driver,
    62  	resolver resolver.ServiceResolver,
    63  ) (*Session, error) {
    64  	db, err := createConnection(cfg, d, resolver)
    65  	if err != nil {
    66  		return nil, err
    67  	}
    68  	return &Session{DB: db}, nil
    69  }
    70  
    71  func (s *Session) Close() {
    72  	if s.DB != nil {
    73  		_ = s.DB.Close()
    74  	}
    75  }
    76  
    77  func createConnection(
    78  	cfg *config.SQL,
    79  	d driver.Driver,
    80  	resolver resolver.ServiceResolver,
    81  ) (*sqlx.DB, error) {
    82  	db, err := d.CreateConnection(buildDSN(cfg, resolver))
    83  	if err != nil {
    84  		return nil, err
    85  	}
    86  	if cfg.MaxConns > 0 {
    87  		db.SetMaxOpenConns(cfg.MaxConns)
    88  	}
    89  	if cfg.MaxIdleConns > 0 {
    90  		db.SetMaxIdleConns(cfg.MaxIdleConns)
    91  	}
    92  	if cfg.MaxConnLifetime > 0 {
    93  		db.SetConnMaxLifetime(cfg.MaxConnLifetime)
    94  	}
    95  
    96  	// Maps struct names in CamelCase to snake without need for db struct tags.
    97  	db.MapperFunc(strcase.ToSnake)
    98  	return db, nil
    99  }
   100  
   101  func buildDSN(
   102  	cfg *config.SQL,
   103  	r resolver.ServiceResolver,
   104  ) string {
   105  	tlsAttrs := buildDSNAttr(cfg).Encode()
   106  	resolvedAddr := r.Resolve(cfg.ConnectAddr)[0]
   107  	dsn := fmt.Sprintf(
   108  		dsnFmt,
   109  		cfg.User,
   110  		url.QueryEscape(cfg.Password),
   111  		resolvedAddr,
   112  		cfg.DatabaseName,
   113  		tlsAttrs,
   114  	)
   115  	return dsn
   116  }
   117  
   118  func buildDSNAttr(cfg *config.SQL) url.Values {
   119  	parameters := url.Values{}
   120  	if cfg.TLS != nil && cfg.TLS.Enabled {
   121  		if !cfg.TLS.EnableHostVerification {
   122  			parameters.Set(sslMode, sslModeRequire)
   123  		} else {
   124  			parameters.Set(sslMode, sslModeFull)
   125  		}
   126  
   127  		if cfg.TLS.CaFile != "" {
   128  			parameters.Set(sslCA, cfg.TLS.CaFile)
   129  		}
   130  		if cfg.TLS.KeyFile != "" && cfg.TLS.CertFile != "" {
   131  			parameters.Set(sslKey, cfg.TLS.KeyFile)
   132  			parameters.Set(sslCert, cfg.TLS.CertFile)
   133  		}
   134  	} else {
   135  		parameters.Set(sslMode, sslModeNoop)
   136  	}
   137  
   138  	for k, v := range cfg.ConnectAttributes {
   139  		key := strings.TrimSpace(k)
   140  		value := strings.TrimSpace(v)
   141  		if parameters.Get(key) != "" {
   142  			panic(fmt.Sprintf("duplicate connection attr: %v:%v, %v:%v",
   143  				key,
   144  				parameters.Get(key),
   145  				key, value,
   146  			))
   147  		}
   148  		parameters.Set(key, value)
   149  	}
   150  	return parameters
   151  }