github.com/prebid/prebid-server/v2@v2.18.0/stored_requests/backends/db_provider/postgres_dbprovider.go (about)

     1  package db_provider
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"database/sql"
     7  	"errors"
     8  	"fmt"
     9  	"net/url"
    10  	"strconv"
    11  	"strings"
    12  
    13  	"github.com/prebid/prebid-server/v2/config"
    14  )
    15  
    16  type PostgresDbProvider struct {
    17  	cfg config.DatabaseConnection
    18  	db  *sql.DB
    19  }
    20  
    21  func (provider *PostgresDbProvider) Config() config.DatabaseConnection {
    22  	return provider.cfg
    23  }
    24  
    25  func (provider *PostgresDbProvider) Open() error {
    26  	connStr, err := provider.ConnString()
    27  	if err != nil {
    28  		return err
    29  	}
    30  
    31  	db, err := sql.Open(provider.cfg.Driver, connStr)
    32  	if err != nil {
    33  		return err
    34  	}
    35  
    36  	provider.db = db
    37  	return nil
    38  }
    39  
    40  func (provider *PostgresDbProvider) Close() error {
    41  	if provider.db != nil {
    42  		db := provider.db
    43  		provider.db = nil
    44  		return db.Close()
    45  	}
    46  
    47  	return nil
    48  }
    49  
    50  func (provider *PostgresDbProvider) Ping() error {
    51  	return provider.db.Ping()
    52  }
    53  
    54  func (provider *PostgresDbProvider) ConnString() (string, error) {
    55  	buffer := bytes.NewBuffer(nil)
    56  	buffer.WriteString("postgresql://")
    57  
    58  	if provider.cfg.Username != "" {
    59  		buffer.WriteString(provider.cfg.Username)
    60  		if provider.cfg.Password != "" {
    61  			buffer.WriteString(":")
    62  			buffer.WriteString(url.QueryEscape(provider.cfg.Password))
    63  		}
    64  		buffer.WriteString("@")
    65  	}
    66  
    67  	if provider.cfg.Host != "" {
    68  		buffer.WriteString(provider.cfg.Host)
    69  	}
    70  
    71  	if provider.cfg.Port > 0 {
    72  		buffer.WriteString(":")
    73  		buffer.WriteString(strconv.Itoa(provider.cfg.Port))
    74  	}
    75  
    76  	if provider.cfg.Database != "" {
    77  		buffer.WriteString("/")
    78  		buffer.WriteString(provider.cfg.Database)
    79  	}
    80  
    81  	queryStr, err := provider.generateQueryString()
    82  	if err != nil {
    83  		return "", err
    84  	}
    85  
    86  	if queryStr != "" {
    87  		buffer.WriteString("?")
    88  		buffer.WriteString(queryStr)
    89  	}
    90  
    91  	return buffer.String(), nil
    92  }
    93  
    94  func (provider *PostgresDbProvider) generateQueryString() (string, error) {
    95  	isTlsInConfigStruct := provider.cfg.TLS.RootCert != "" ||
    96  		provider.cfg.TLS.ClientCert != "" ||
    97  		provider.cfg.TLS.ClientKey != ""
    98  
    99  	isTlsInQueryString := strings.Contains(provider.cfg.QueryString, "sslrootcert=") ||
   100  		strings.Contains(provider.cfg.QueryString, "sslcert=") ||
   101  		strings.Contains(provider.cfg.QueryString, "sslkey=")
   102  
   103  	if isTlsInConfigStruct && isTlsInQueryString {
   104  		return "", errors.New("TLS cert information must either be specified in the TLS object or the query string but not both.")
   105  	}
   106  
   107  	sslmode := "disable"
   108  	sslrootcert := ""
   109  	sslcert := ""
   110  	sslkey := ""
   111  	queryString := ""
   112  
   113  	if provider.cfg.TLS.RootCert != "" {
   114  		sslmode = "verify-ca"
   115  		sslrootcert = fmt.Sprintf("&sslrootcert=%s", provider.cfg.TLS.RootCert)
   116  
   117  		if provider.cfg.TLS.ClientCert != "" && provider.cfg.TLS.ClientKey != "" {
   118  			sslmode = "verify-full"
   119  			sslcert = fmt.Sprintf("&sslcert=%s", provider.cfg.TLS.ClientCert)
   120  			sslkey = fmt.Sprintf("&sslkey=%s", provider.cfg.TLS.ClientKey)
   121  		}
   122  	}
   123  	sslmode = fmt.Sprintf("&sslmode=%s", sslmode)
   124  
   125  	if len(provider.cfg.QueryString) != 0 {
   126  		queryString = fmt.Sprintf("&%s", provider.cfg.QueryString)
   127  
   128  		if strings.Contains(provider.cfg.QueryString, "sslmode=") {
   129  			sslmode = ""
   130  		}
   131  	}
   132  
   133  	params := strings.Join([]string{sslmode, sslrootcert, sslcert, sslkey, queryString}, "")
   134  	return params[1:], nil
   135  }
   136  
   137  func (provider *PostgresDbProvider) PrepareQuery(template string, params ...QueryParam) (query string, args []interface{}) {
   138  	query = template
   139  	args = []interface{}{}
   140  
   141  	for _, param := range params {
   142  		switch v := param.Value.(type) {
   143  		case []interface{}:
   144  			idList := v
   145  			idListStr := provider.createIdList(len(args), len(idList))
   146  			args = append(args, idList...)
   147  			query = strings.Replace(query, "$"+param.Name, idListStr, -1)
   148  		default:
   149  			args = append(args, param.Value)
   150  			query = strings.Replace(query, "$"+param.Name, fmt.Sprintf("$%d", len(args)), -1)
   151  		}
   152  	}
   153  	return
   154  }
   155  
   156  func (provider *PostgresDbProvider) QueryContext(ctx context.Context, template string, params ...QueryParam) (*sql.Rows, error) {
   157  	query, args := provider.PrepareQuery(template, params...)
   158  	return provider.db.QueryContext(ctx, query, args...)
   159  }
   160  
   161  func (provider *PostgresDbProvider) createIdList(numSoFar int, numArgs int) string {
   162  	// Any empty list like "()" is illegal in Postgres. A (NULL) is the next best thing,
   163  	// though, since `id IN (NULL)` is valid for all "id" column types, and evaluates to an empty set.
   164  	//
   165  	// The query plan also suggests that it's basically free:
   166  	//
   167  	// explain SELECT id, requestData FROM stored_requests WHERE id in $ID_LIST;
   168  	//
   169  	// QUERY PLAN
   170  	// -------------------------------------------
   171  	// Result  (cost=0.00..0.00 rows=0 width=16)
   172  	//	 One-Time Filter: false
   173  	// (2 rows)
   174  	if numArgs == 0 {
   175  		return "(NULL)"
   176  	}
   177  
   178  	final := bytes.NewBuffer(make([]byte, 0, 2+4*numArgs))
   179  	final.WriteString("(")
   180  	for i := numSoFar + 1; i < numSoFar+numArgs; i++ {
   181  		final.WriteString("$")
   182  		final.WriteString(strconv.Itoa(i))
   183  		final.WriteString(", ")
   184  	}
   185  	final.WriteString("$")
   186  	final.WriteString(strconv.Itoa(numSoFar + numArgs))
   187  	final.WriteString(")")
   188  
   189  	return final.String()
   190  }