github.com/prebid/prebid-server/v2@v2.18.0/stored_requests/backends/db_provider/postgres_dbprovider.go (about) 1 package db_provider 2 3 import ( 4 "bytes" 5 "context" 6 "database/sql" 7 "errors" 8 "fmt" 9 "net/url" 10 "strconv" 11 "strings" 12 13 "github.com/prebid/prebid-server/v2/config" 14 ) 15 16 type PostgresDbProvider struct { 17 cfg config.DatabaseConnection 18 db *sql.DB 19 } 20 21 func (provider *PostgresDbProvider) Config() config.DatabaseConnection { 22 return provider.cfg 23 } 24 25 func (provider *PostgresDbProvider) Open() error { 26 connStr, err := provider.ConnString() 27 if err != nil { 28 return err 29 } 30 31 db, err := sql.Open(provider.cfg.Driver, connStr) 32 if err != nil { 33 return err 34 } 35 36 provider.db = db 37 return nil 38 } 39 40 func (provider *PostgresDbProvider) Close() error { 41 if provider.db != nil { 42 db := provider.db 43 provider.db = nil 44 return db.Close() 45 } 46 47 return nil 48 } 49 50 func (provider *PostgresDbProvider) Ping() error { 51 return provider.db.Ping() 52 } 53 54 func (provider *PostgresDbProvider) ConnString() (string, error) { 55 buffer := bytes.NewBuffer(nil) 56 buffer.WriteString("postgresql://") 57 58 if provider.cfg.Username != "" { 59 buffer.WriteString(provider.cfg.Username) 60 if provider.cfg.Password != "" { 61 buffer.WriteString(":") 62 buffer.WriteString(url.QueryEscape(provider.cfg.Password)) 63 } 64 buffer.WriteString("@") 65 } 66 67 if provider.cfg.Host != "" { 68 buffer.WriteString(provider.cfg.Host) 69 } 70 71 if provider.cfg.Port > 0 { 72 buffer.WriteString(":") 73 buffer.WriteString(strconv.Itoa(provider.cfg.Port)) 74 } 75 76 if provider.cfg.Database != "" { 77 buffer.WriteString("/") 78 buffer.WriteString(provider.cfg.Database) 79 } 80 81 queryStr, err := provider.generateQueryString() 82 if err != nil { 83 return "", err 84 } 85 86 if queryStr != "" { 87 buffer.WriteString("?") 88 buffer.WriteString(queryStr) 89 } 90 91 return buffer.String(), nil 92 } 93 94 func (provider *PostgresDbProvider) generateQueryString() (string, error) { 95 isTlsInConfigStruct := provider.cfg.TLS.RootCert != "" || 96 provider.cfg.TLS.ClientCert != "" || 97 provider.cfg.TLS.ClientKey != "" 98 99 isTlsInQueryString := strings.Contains(provider.cfg.QueryString, "sslrootcert=") || 100 strings.Contains(provider.cfg.QueryString, "sslcert=") || 101 strings.Contains(provider.cfg.QueryString, "sslkey=") 102 103 if isTlsInConfigStruct && isTlsInQueryString { 104 return "", errors.New("TLS cert information must either be specified in the TLS object or the query string but not both.") 105 } 106 107 sslmode := "disable" 108 sslrootcert := "" 109 sslcert := "" 110 sslkey := "" 111 queryString := "" 112 113 if provider.cfg.TLS.RootCert != "" { 114 sslmode = "verify-ca" 115 sslrootcert = fmt.Sprintf("&sslrootcert=%s", provider.cfg.TLS.RootCert) 116 117 if provider.cfg.TLS.ClientCert != "" && provider.cfg.TLS.ClientKey != "" { 118 sslmode = "verify-full" 119 sslcert = fmt.Sprintf("&sslcert=%s", provider.cfg.TLS.ClientCert) 120 sslkey = fmt.Sprintf("&sslkey=%s", provider.cfg.TLS.ClientKey) 121 } 122 } 123 sslmode = fmt.Sprintf("&sslmode=%s", sslmode) 124 125 if len(provider.cfg.QueryString) != 0 { 126 queryString = fmt.Sprintf("&%s", provider.cfg.QueryString) 127 128 if strings.Contains(provider.cfg.QueryString, "sslmode=") { 129 sslmode = "" 130 } 131 } 132 133 params := strings.Join([]string{sslmode, sslrootcert, sslcert, sslkey, queryString}, "") 134 return params[1:], nil 135 } 136 137 func (provider *PostgresDbProvider) PrepareQuery(template string, params ...QueryParam) (query string, args []interface{}) { 138 query = template 139 args = []interface{}{} 140 141 for _, param := range params { 142 switch v := param.Value.(type) { 143 case []interface{}: 144 idList := v 145 idListStr := provider.createIdList(len(args), len(idList)) 146 args = append(args, idList...) 147 query = strings.Replace(query, "$"+param.Name, idListStr, -1) 148 default: 149 args = append(args, param.Value) 150 query = strings.Replace(query, "$"+param.Name, fmt.Sprintf("$%d", len(args)), -1) 151 } 152 } 153 return 154 } 155 156 func (provider *PostgresDbProvider) QueryContext(ctx context.Context, template string, params ...QueryParam) (*sql.Rows, error) { 157 query, args := provider.PrepareQuery(template, params...) 158 return provider.db.QueryContext(ctx, query, args...) 159 } 160 161 func (provider *PostgresDbProvider) createIdList(numSoFar int, numArgs int) string { 162 // Any empty list like "()" is illegal in Postgres. A (NULL) is the next best thing, 163 // though, since `id IN (NULL)` is valid for all "id" column types, and evaluates to an empty set. 164 // 165 // The query plan also suggests that it's basically free: 166 // 167 // explain SELECT id, requestData FROM stored_requests WHERE id in $ID_LIST; 168 // 169 // QUERY PLAN 170 // ------------------------------------------- 171 // Result (cost=0.00..0.00 rows=0 width=16) 172 // One-Time Filter: false 173 // (2 rows) 174 if numArgs == 0 { 175 return "(NULL)" 176 } 177 178 final := bytes.NewBuffer(make([]byte, 0, 2+4*numArgs)) 179 final.WriteString("(") 180 for i := numSoFar + 1; i < numSoFar+numArgs; i++ { 181 final.WriteString("$") 182 final.WriteString(strconv.Itoa(i)) 183 final.WriteString(", ") 184 } 185 final.WriteString("$") 186 final.WriteString(strconv.Itoa(numSoFar + numArgs)) 187 final.WriteString(")") 188 189 return final.String() 190 }