github.com/prebid/prebid-server/v2@v2.18.0/stored_requests/backends/db_provider/mysql_dbprovider.go (about)

     1  package db_provider
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/tls"
     7  	"crypto/x509"
     8  	"database/sql"
     9  	"errors"
    10  	"fmt"
    11  	"os"
    12  	"regexp"
    13  	"sort"
    14  	"strconv"
    15  	"strings"
    16  
    17  	"github.com/go-sql-driver/mysql"
    18  	"github.com/prebid/prebid-server/v2/config"
    19  )
    20  
    21  const customTLSKey = "prebid-tls"
    22  
    23  type MySqlDbProvider struct {
    24  	cfg config.DatabaseConnection
    25  	db  *sql.DB
    26  }
    27  
    28  func (provider *MySqlDbProvider) Config() config.DatabaseConnection {
    29  	return provider.cfg
    30  }
    31  
    32  func (provider *MySqlDbProvider) Open() error {
    33  	connStr, err := provider.ConnString()
    34  	if err != nil {
    35  		return err
    36  	}
    37  
    38  	db, err := sql.Open(provider.cfg.Driver, connStr)
    39  	if err != nil {
    40  		return err
    41  	}
    42  
    43  	provider.db = db
    44  	return nil
    45  }
    46  
    47  func (provider *MySqlDbProvider) Close() error {
    48  	if provider.db != nil {
    49  		db := provider.db
    50  		provider.db = nil
    51  		return db.Close()
    52  	}
    53  
    54  	return nil
    55  }
    56  
    57  func (provider *MySqlDbProvider) Ping() error {
    58  	return provider.db.Ping()
    59  }
    60  
    61  func (provider *MySqlDbProvider) ConnString() (string, error) {
    62  	buffer := bytes.NewBuffer(nil)
    63  
    64  	if provider.cfg.Username != "" {
    65  		buffer.WriteString(provider.cfg.Username)
    66  		if provider.cfg.Password != "" {
    67  			buffer.WriteString(":")
    68  			buffer.WriteString(provider.cfg.Password)
    69  		}
    70  		buffer.WriteString("@")
    71  	}
    72  
    73  	buffer.WriteString("tcp(")
    74  	if provider.cfg.Host != "" {
    75  		buffer.WriteString(provider.cfg.Host)
    76  	}
    77  
    78  	if provider.cfg.Port > 0 {
    79  		buffer.WriteString(":")
    80  		buffer.WriteString(strconv.Itoa(provider.cfg.Port))
    81  	}
    82  	buffer.WriteString(")")
    83  
    84  	buffer.WriteString("/")
    85  
    86  	if provider.cfg.Database != "" {
    87  		buffer.WriteString(provider.cfg.Database)
    88  	}
    89  
    90  	queryStr := provider.generateQueryString()
    91  
    92  	if provider.cfg.TLS.RootCert != "" {
    93  		if err := setupTLSConfig(provider); err != nil {
    94  			return "", err
    95  		}
    96  	}
    97  
    98  	if queryStr != "" {
    99  		buffer.WriteString("?")
   100  		buffer.WriteString(queryStr)
   101  	}
   102  
   103  	return buffer.String(), nil
   104  }
   105  
   106  func setupTLSConfig(provider *MySqlDbProvider) error {
   107  	rootCertPool := x509.NewCertPool()
   108  
   109  	pem, err := os.ReadFile(provider.cfg.TLS.RootCert)
   110  	if err != nil {
   111  		return err
   112  	}
   113  
   114  	if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
   115  		return fmt.Errorf("failed to parse certificate: %s", provider.cfg.TLS.RootCert)
   116  	}
   117  
   118  	var clientCert []tls.Certificate
   119  	if provider.cfg.TLS.ClientCert != "" && provider.cfg.TLS.ClientKey != "" {
   120  		clientCert = make([]tls.Certificate, 0, 1)
   121  		certs, err := tls.LoadX509KeyPair(provider.cfg.TLS.ClientCert, provider.cfg.TLS.ClientKey)
   122  		if err != nil {
   123  			return err
   124  		}
   125  
   126  		clientCert = append(clientCert, certs)
   127  	}
   128  
   129  	mysql.RegisterTLSConfig(provider.getTLSKey(), &tls.Config{
   130  		RootCAs:               rootCertPool,
   131  		Certificates:          clientCert,
   132  		InsecureSkipVerify:    true,
   133  		VerifyPeerCertificate: verifyPeerCertFunc(rootCertPool),
   134  	})
   135  
   136  	return nil
   137  }
   138  
   139  // verifyPeerCertFunc returns a function that verifies the peer certificate is
   140  // in the cert pool.
   141  func verifyPeerCertFunc(pool *x509.CertPool) func([][]byte, [][]*x509.Certificate) error {
   142  	return func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
   143  		if len(rawCerts) == 0 {
   144  			return errors.New("no certificates available to verify")
   145  		}
   146  
   147  		cert, err := x509.ParseCertificate(rawCerts[0])
   148  		if err != nil {
   149  			return err
   150  		}
   151  
   152  		opts := x509.VerifyOptions{Roots: pool}
   153  		if _, err = cert.Verify(opts); err != nil {
   154  			return err
   155  		}
   156  		return nil
   157  	}
   158  }
   159  
   160  func (provider *MySqlDbProvider) generateQueryString() string {
   161  	tls := ""
   162  
   163  	if provider.cfg.TLS.RootCert != "" {
   164  		tls = provider.getTLSKey()
   165  	}
   166  
   167  	if tls != "" {
   168  		if len(provider.cfg.QueryString) == 0 {
   169  			return "tls=" + tls
   170  		}
   171  		if !strings.Contains(provider.cfg.QueryString, "tls=") {
   172  			return "tls=" + tls + "&" + provider.cfg.QueryString
   173  		}
   174  	}
   175  
   176  	return provider.cfg.QueryString
   177  }
   178  
   179  func (provider *MySqlDbProvider) getTLSKey() string {
   180  	pairs := strings.Split(provider.cfg.QueryString, "&")
   181  
   182  	for _, pair := range pairs {
   183  		if strings.HasPrefix(pair, "tls=") {
   184  			return strings.Split(pair, "=")[1]
   185  		}
   186  	}
   187  
   188  	return customTLSKey
   189  }
   190  
   191  func (provider *MySqlDbProvider) PrepareQuery(template string, params ...QueryParam) (query string, args []interface{}) {
   192  	query = template
   193  	args = []interface{}{}
   194  
   195  	type occurrence struct {
   196  		startIndex int
   197  		param      QueryParam
   198  	}
   199  	occurrences := []occurrence{}
   200  
   201  	for _, param := range params {
   202  		re := regexp.MustCompile("\\$" + param.Name)
   203  		matches := re.FindAllIndex([]byte(query), -1)
   204  		for _, match := range matches {
   205  			occurrences = append(occurrences,
   206  				occurrence{
   207  					startIndex: match[0],
   208  					param:      param,
   209  				})
   210  		}
   211  	}
   212  	sort.Slice(occurrences, func(i, j int) bool {
   213  		return occurrences[i].startIndex < occurrences[j].startIndex
   214  	})
   215  
   216  	for _, occurrence := range occurrences {
   217  		switch occurrence.param.Value.(type) {
   218  		case []interface{}:
   219  			idList := occurrence.param.Value.([]interface{})
   220  			args = append(args, idList...)
   221  		default:
   222  			args = append(args, occurrence.param.Value)
   223  		}
   224  	}
   225  
   226  	for _, param := range params {
   227  		switch param.Value.(type) {
   228  		case []interface{}:
   229  			len := len(param.Value.([]interface{}))
   230  			idList := provider.createIdList(len)
   231  			query = strings.Replace(query, "$"+param.Name, idList, -1)
   232  		default:
   233  			query = strings.Replace(query, "$"+param.Name, "?", -1)
   234  		}
   235  	}
   236  	return
   237  }
   238  
   239  func (provider *MySqlDbProvider) QueryContext(ctx context.Context, template string, params ...QueryParam) (*sql.Rows, error) {
   240  	query, args := provider.PrepareQuery(template, params...)
   241  	return provider.db.QueryContext(ctx, query, args...)
   242  }
   243  
   244  func (provider *MySqlDbProvider) createIdList(numArgs int) string {
   245  	// Any empty list like "()" is illegal in MySql. A (NULL) is the next best thing,
   246  	// though, since `id IN (NULL)` is valid for all "id" column types, and evaluates to an empty set.
   247  	if numArgs == 0 {
   248  		return "(NULL)"
   249  	}
   250  
   251  	result := bytes.NewBuffer(make([]byte, 0, 2+3*numArgs))
   252  	result.WriteString("(")
   253  	for i := 1; i < numArgs; i++ {
   254  		result.WriteString("?")
   255  		result.WriteString(", ")
   256  	}
   257  	result.WriteString("?")
   258  	result.WriteString(")")
   259  
   260  	return result.String()
   261  }