github.com/prebid/prebid-server/v2@v2.18.0/stored_requests/backends/db_provider/mysql_dbprovider.go (about) 1 package db_provider 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/tls" 7 "crypto/x509" 8 "database/sql" 9 "errors" 10 "fmt" 11 "os" 12 "regexp" 13 "sort" 14 "strconv" 15 "strings" 16 17 "github.com/go-sql-driver/mysql" 18 "github.com/prebid/prebid-server/v2/config" 19 ) 20 21 const customTLSKey = "prebid-tls" 22 23 type MySqlDbProvider struct { 24 cfg config.DatabaseConnection 25 db *sql.DB 26 } 27 28 func (provider *MySqlDbProvider) Config() config.DatabaseConnection { 29 return provider.cfg 30 } 31 32 func (provider *MySqlDbProvider) Open() error { 33 connStr, err := provider.ConnString() 34 if err != nil { 35 return err 36 } 37 38 db, err := sql.Open(provider.cfg.Driver, connStr) 39 if err != nil { 40 return err 41 } 42 43 provider.db = db 44 return nil 45 } 46 47 func (provider *MySqlDbProvider) Close() error { 48 if provider.db != nil { 49 db := provider.db 50 provider.db = nil 51 return db.Close() 52 } 53 54 return nil 55 } 56 57 func (provider *MySqlDbProvider) Ping() error { 58 return provider.db.Ping() 59 } 60 61 func (provider *MySqlDbProvider) ConnString() (string, error) { 62 buffer := bytes.NewBuffer(nil) 63 64 if provider.cfg.Username != "" { 65 buffer.WriteString(provider.cfg.Username) 66 if provider.cfg.Password != "" { 67 buffer.WriteString(":") 68 buffer.WriteString(provider.cfg.Password) 69 } 70 buffer.WriteString("@") 71 } 72 73 buffer.WriteString("tcp(") 74 if provider.cfg.Host != "" { 75 buffer.WriteString(provider.cfg.Host) 76 } 77 78 if provider.cfg.Port > 0 { 79 buffer.WriteString(":") 80 buffer.WriteString(strconv.Itoa(provider.cfg.Port)) 81 } 82 buffer.WriteString(")") 83 84 buffer.WriteString("/") 85 86 if provider.cfg.Database != "" { 87 buffer.WriteString(provider.cfg.Database) 88 } 89 90 queryStr := provider.generateQueryString() 91 92 if provider.cfg.TLS.RootCert != "" { 93 if err := setupTLSConfig(provider); err != nil { 94 return "", err 95 } 96 } 97 98 if queryStr != "" { 99 buffer.WriteString("?") 100 buffer.WriteString(queryStr) 101 } 102 103 return buffer.String(), nil 104 } 105 106 func setupTLSConfig(provider *MySqlDbProvider) error { 107 rootCertPool := x509.NewCertPool() 108 109 pem, err := os.ReadFile(provider.cfg.TLS.RootCert) 110 if err != nil { 111 return err 112 } 113 114 if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { 115 return fmt.Errorf("failed to parse certificate: %s", provider.cfg.TLS.RootCert) 116 } 117 118 var clientCert []tls.Certificate 119 if provider.cfg.TLS.ClientCert != "" && provider.cfg.TLS.ClientKey != "" { 120 clientCert = make([]tls.Certificate, 0, 1) 121 certs, err := tls.LoadX509KeyPair(provider.cfg.TLS.ClientCert, provider.cfg.TLS.ClientKey) 122 if err != nil { 123 return err 124 } 125 126 clientCert = append(clientCert, certs) 127 } 128 129 mysql.RegisterTLSConfig(provider.getTLSKey(), &tls.Config{ 130 RootCAs: rootCertPool, 131 Certificates: clientCert, 132 InsecureSkipVerify: true, 133 VerifyPeerCertificate: verifyPeerCertFunc(rootCertPool), 134 }) 135 136 return nil 137 } 138 139 // verifyPeerCertFunc returns a function that verifies the peer certificate is 140 // in the cert pool. 141 func verifyPeerCertFunc(pool *x509.CertPool) func([][]byte, [][]*x509.Certificate) error { 142 return func(rawCerts [][]byte, _ [][]*x509.Certificate) error { 143 if len(rawCerts) == 0 { 144 return errors.New("no certificates available to verify") 145 } 146 147 cert, err := x509.ParseCertificate(rawCerts[0]) 148 if err != nil { 149 return err 150 } 151 152 opts := x509.VerifyOptions{Roots: pool} 153 if _, err = cert.Verify(opts); err != nil { 154 return err 155 } 156 return nil 157 } 158 } 159 160 func (provider *MySqlDbProvider) generateQueryString() string { 161 tls := "" 162 163 if provider.cfg.TLS.RootCert != "" { 164 tls = provider.getTLSKey() 165 } 166 167 if tls != "" { 168 if len(provider.cfg.QueryString) == 0 { 169 return "tls=" + tls 170 } 171 if !strings.Contains(provider.cfg.QueryString, "tls=") { 172 return "tls=" + tls + "&" + provider.cfg.QueryString 173 } 174 } 175 176 return provider.cfg.QueryString 177 } 178 179 func (provider *MySqlDbProvider) getTLSKey() string { 180 pairs := strings.Split(provider.cfg.QueryString, "&") 181 182 for _, pair := range pairs { 183 if strings.HasPrefix(pair, "tls=") { 184 return strings.Split(pair, "=")[1] 185 } 186 } 187 188 return customTLSKey 189 } 190 191 func (provider *MySqlDbProvider) PrepareQuery(template string, params ...QueryParam) (query string, args []interface{}) { 192 query = template 193 args = []interface{}{} 194 195 type occurrence struct { 196 startIndex int 197 param QueryParam 198 } 199 occurrences := []occurrence{} 200 201 for _, param := range params { 202 re := regexp.MustCompile("\\$" + param.Name) 203 matches := re.FindAllIndex([]byte(query), -1) 204 for _, match := range matches { 205 occurrences = append(occurrences, 206 occurrence{ 207 startIndex: match[0], 208 param: param, 209 }) 210 } 211 } 212 sort.Slice(occurrences, func(i, j int) bool { 213 return occurrences[i].startIndex < occurrences[j].startIndex 214 }) 215 216 for _, occurrence := range occurrences { 217 switch occurrence.param.Value.(type) { 218 case []interface{}: 219 idList := occurrence.param.Value.([]interface{}) 220 args = append(args, idList...) 221 default: 222 args = append(args, occurrence.param.Value) 223 } 224 } 225 226 for _, param := range params { 227 switch param.Value.(type) { 228 case []interface{}: 229 len := len(param.Value.([]interface{})) 230 idList := provider.createIdList(len) 231 query = strings.Replace(query, "$"+param.Name, idList, -1) 232 default: 233 query = strings.Replace(query, "$"+param.Name, "?", -1) 234 } 235 } 236 return 237 } 238 239 func (provider *MySqlDbProvider) QueryContext(ctx context.Context, template string, params ...QueryParam) (*sql.Rows, error) { 240 query, args := provider.PrepareQuery(template, params...) 241 return provider.db.QueryContext(ctx, query, args...) 242 } 243 244 func (provider *MySqlDbProvider) createIdList(numArgs int) string { 245 // Any empty list like "()" is illegal in MySql. A (NULL) is the next best thing, 246 // though, since `id IN (NULL)` is valid for all "id" column types, and evaluates to an empty set. 247 if numArgs == 0 { 248 return "(NULL)" 249 } 250 251 result := bytes.NewBuffer(make([]byte, 0, 2+3*numArgs)) 252 result.WriteString("(") 253 for i := 1; i < numArgs; i++ { 254 result.WriteString("?") 255 result.WriteString(", ") 256 } 257 result.WriteString("?") 258 result.WriteString(")") 259 260 return result.String() 261 }