github.com/square/finch@v0.0.0-20240412205204-6530c03e2b96/dbconn/factory.go (about)

     1  // Copyright 2024 Block, Inc.
     2  
     3  // Package dbconn provides a Factory that makes *sql.DB connections to MySQL.
     4  package dbconn
     5  
     6  import (
     7  	"database/sql"
     8  	"fmt"
     9  	"io/fs"
    10  	"io/ioutil"
    11  	"os"
    12  	"os/exec"
    13  	"regexp"
    14  	"strings"
    15  
    16  	"github.com/go-sql-driver/mysql"
    17  
    18  	"github.com/square/finch"
    19  	"github.com/square/finch/aws"
    20  	"github.com/square/finch/config"
    21  )
    22  
    23  // rdsAddr matches Amazon RDS hostnames with optional :port suffix.
    24  // It's used to automatically load the Amazon RDS CA and enable TLS,
    25  // unless config.aws.disable-auto-tls is true.
    26  var rdsAddr = regexp.MustCompile(`rds\.amazonaws\.com(:\d+)?$`)
    27  
    28  // portSuffix matches optional :port suffix on addresses. It's used to
    29  // strip the port suffix before passing the hostname to LoadTLS.
    30  var portSuffix = regexp.MustCompile(`:\d+$`)
    31  
    32  var f = &factory{}
    33  
    34  type factory struct {
    35  	cfg config.MySQL
    36  	dsn string
    37  }
    38  
    39  func SetConfig(cfg config.MySQL) {
    40  	f.cfg = cfg
    41  	f.dsn = ""
    42  }
    43  
    44  func Make() (*sql.DB, string, error) {
    45  	// Parse MySQL params and set DSN on first call. There's only 1 DSN for
    46  	// all clients, so this only needs to be done once.
    47  	if f.dsn == "" {
    48  		if err := f.setDSN(); err != nil {
    49  			return nil, "", err
    50  		}
    51  	}
    52  	finch.Debug("dsn: %s", RedactedDSN(f.dsn))
    53  
    54  	// Make new sql.DB (conn pool) for each client group; see the call to
    55  	// this func in workload/workload.go.
    56  	db, err := sql.Open("mysql", f.dsn)
    57  	if err != nil {
    58  		return nil, "", err
    59  	}
    60  	return db, RedactedDSN(f.dsn), nil
    61  }
    62  
    63  func (f *factory) setDSN() error {
    64  	// --dsn or mysql.dsn (in that order) overrides all
    65  	if f.cfg.DSN != "" {
    66  		f.dsn = f.cfg.DSN
    67  		return nil
    68  	}
    69  
    70  	// ----------------------------------------------------------------------
    71  	// my.cnf
    72  
    73  	// Set values in cfg from values in my.cnf. This does not overwrite any
    74  	// values in cfg already set. For exmaple, if username is specified in both,
    75  	// the default my.cnf username is ignored and the explicit cfg.Username is
    76  	// kept/used.
    77  	if f.cfg.MyCnf != "" {
    78  		finch.Debug("read mycnf %s", f.cfg.MyCnf)
    79  		def, err := ParseMyCnf(f.cfg.MyCnf)
    80  		if err != nil {
    81  			return err
    82  		}
    83  		f.cfg.With(def)
    84  	}
    85  
    86  	// ----------------------------------------------------------------------
    87  	// TCP or Unix socket
    88  
    89  	net := ""
    90  	addr := ""
    91  	if f.cfg.Socket != "" {
    92  		net = "unix"
    93  		addr = f.cfg.Socket
    94  	} else {
    95  		net = "tcp"
    96  		if f.cfg.Hostname == "" {
    97  			f.cfg.Hostname = "127.0.0.1"
    98  		}
    99  		addr = f.cfg.Hostname
   100  	}
   101  
   102  	// ----------------------------------------------------------------------
   103  	// Load TLS
   104  
   105  	params := []string{"parseTime=true"}
   106  
   107  	// Go says "either ServerName or InsecureSkipVerify must be specified".
   108  	// This is a pathological case: socket and TLS but no hostname to verify
   109  	// and user didn't explicitly set skip-verify=true. So we set this latter
   110  	// automatically because Go will certainly error if we don't.
   111  	if net == "unix" && f.cfg.TLS.Set() && f.cfg.Hostname == "" && !config.True(f.cfg.TLS.SkipVerify) {
   112  		b := true
   113  		f.cfg.TLS.SkipVerify = &b
   114  		finch.Debug("auto-enabled skip-verify on socket with TLS but no hostname")
   115  	}
   116  
   117  	// Load and register TLS, if any
   118  	tlsConfig, err := f.cfg.TLS.LoadTLS(portSuffix.ReplaceAllString(f.cfg.Hostname, ""))
   119  	if err != nil {
   120  		return err
   121  	}
   122  	if tlsConfig != nil {
   123  		mysql.RegisterTLSConfig("benchmark", tlsConfig)
   124  		params = append(params, "tls=benchmark")
   125  		finch.Debug("TLS enabled")
   126  	}
   127  
   128  	// Use built-in Amazon RDS CA
   129  	if rdsAddr.MatchString(addr) && !config.True(f.cfg.DisableAutoTLS) && tlsConfig == nil {
   130  		finch.Debug("auto AWS TLS: hostname has suffix .rds.amazonaws.com")
   131  		aws.RegisterRDSCA() // safe to call multiple times
   132  		params = append(params, "tls=rds")
   133  	}
   134  
   135  	// ----------------------------------------------------------------------
   136  	// Credentials (user:pass)
   137  
   138  	var password = f.cfg.Password
   139  	if f.cfg.PasswordFile != "" {
   140  		bytes, err := ioutil.ReadFile(f.cfg.PasswordFile)
   141  		if err != nil {
   142  			return err
   143  		}
   144  		password = string(bytes)
   145  	}
   146  
   147  	if f.cfg.Username == "" {
   148  		f.cfg.Username = "finch" // default username
   149  		finch.Debug("using default MySQL username")
   150  		if f.cfg.Password == "" && password == "" {
   151  			finch.Debug("using default MySQL password")
   152  			password = "amazing"
   153  		}
   154  	}
   155  	cred := f.cfg.Username
   156  	if password != "" {
   157  		cred += ":" + password
   158  	}
   159  
   160  	// ----------------------------------------------------------------------
   161  	// Set DSN
   162  
   163  	f.dsn = fmt.Sprintf("%s@%s(%s)/%s", cred, net, addr, f.cfg.Db)
   164  	if len(params) > 0 {
   165  		f.dsn += "?" + strings.Join(params, "&")
   166  	}
   167  
   168  	return nil
   169  }
   170  
   171  const (
   172  	default_mysql_socket  = "/tmp/mysql.sock"
   173  	default_distro_socket = "/var/lib/mysql/mysql.sock"
   174  )
   175  
   176  func Sockets() []string {
   177  	sockets := []string{}
   178  	seen := map[string]bool{}
   179  	for _, socket := range strings.Split(socketList(), "\n") {
   180  		socket = strings.TrimSpace(socket)
   181  		if socket == "" {
   182  			continue
   183  		}
   184  		if seen[socket] {
   185  			continue
   186  		}
   187  		seen[socket] = true
   188  		if !isSocket(socket) {
   189  			continue
   190  		}
   191  		sockets = append(sockets, socket)
   192  	}
   193  
   194  	if len(sockets) == 0 {
   195  		finch.Debug("no sockets, using defaults")
   196  		if isSocket(default_mysql_socket) {
   197  			sockets = append(sockets, default_mysql_socket)
   198  		}
   199  		if isSocket(default_distro_socket) {
   200  			sockets = append(sockets, default_distro_socket)
   201  		}
   202  	}
   203  
   204  	finch.Debug("sockets: %v", sockets)
   205  	return sockets
   206  }
   207  
   208  func socketList() string {
   209  	cmd := exec.Command("sh", "-c", "netstat -f unix | grep mysql | grep -v mysqlx | awk '{print $NF}'")
   210  	output, err := cmd.Output()
   211  	if err != nil {
   212  		finch.Debug(err.Error())
   213  	}
   214  	return string(output)
   215  }
   216  
   217  func isSocket(file string) bool {
   218  	fi, err := os.Stat(file)
   219  	if err != nil {
   220  		return false
   221  	}
   222  	return fi.Mode()&fs.ModeSocket != 0
   223  }
   224  
   225  func RedactedDSN(dsn string) string {
   226  	redactedPassword, err := mysql.ParseDSN(dsn)
   227  	if err != nil { // ok to ignore
   228  		finch.Debug("mysql.ParseDSN error: %s", err)
   229  	}
   230  	redactedPassword.Passwd = "..."
   231  	return redactedPassword.FormatDSN()
   232  
   233  }
   234  
   235  // Suppress error messages from Go MySQL driver like "[mysql] 2023/07/27 15:59:30 packets.go:37: unexpected EOF"
   236  type null struct{}
   237  
   238  func (n null) Print(v ...interface{}) {}
   239  
   240  func init() {
   241  	mysql.SetLogger(null{})
   242  }