github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/workload/connection.go (about)

     1  // Copyright 2018 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package workload
    12  
    13  import (
    14  	"fmt"
    15  	"net/url"
    16  	"runtime"
    17  	"strings"
    18  
    19  	"github.com/spf13/pflag"
    20  )
    21  
    22  // ConnFlags is helper of common flags that are relevant to QueryLoads.
    23  type ConnFlags struct {
    24  	*pflag.FlagSet
    25  	DBOverride  string
    26  	Concurrency int
    27  	// Method for issuing queries; see SQLRunner.
    28  	Method string
    29  }
    30  
    31  // NewConnFlags returns an initialized ConnFlags.
    32  func NewConnFlags(genFlags *Flags) *ConnFlags {
    33  	c := &ConnFlags{}
    34  	c.FlagSet = pflag.NewFlagSet(`conn`, pflag.ContinueOnError)
    35  	c.StringVar(&c.DBOverride, `db`, ``,
    36  		`Override for the SQL database to use. If empty, defaults to the generator name`)
    37  	c.IntVar(&c.Concurrency, `concurrency`, 2*runtime.NumCPU(),
    38  		`Number of concurrent workers`)
    39  	c.StringVar(&c.Method, `method`, `prepare`, `SQL issue method (prepare, noprepare, simple)`)
    40  	genFlags.AddFlagSet(c.FlagSet)
    41  	if genFlags.Meta == nil {
    42  		genFlags.Meta = make(map[string]FlagMeta)
    43  	}
    44  	genFlags.Meta[`db`] = FlagMeta{RuntimeOnly: true}
    45  	genFlags.Meta[`concurrency`] = FlagMeta{RuntimeOnly: true}
    46  	genFlags.Meta[`method`] = FlagMeta{RuntimeOnly: true}
    47  	return c
    48  }
    49  
    50  // SanitizeUrls verifies that the give SQL connection strings have the correct
    51  // SQL database set, rewriting them in place if necessary. This database name is
    52  // returned.
    53  func SanitizeUrls(gen Generator, dbOverride string, urls []string) (string, error) {
    54  	dbName := gen.Meta().Name
    55  	if dbOverride != `` {
    56  		dbName = dbOverride
    57  	}
    58  	for i := range urls {
    59  		parsed, err := url.Parse(urls[i])
    60  		if err != nil {
    61  			return "", err
    62  		}
    63  		if d := strings.TrimPrefix(parsed.Path, `/`); d != `` && d != dbName {
    64  			return "", fmt.Errorf(`%s specifies database %q, but database %q is expected`,
    65  				urls[i], d, dbName)
    66  		}
    67  		parsed.Path = dbName
    68  
    69  		q := parsed.Query()
    70  		q.Set("application_name", gen.Meta().Name)
    71  		parsed.RawQuery = q.Encode()
    72  
    73  		switch parsed.Scheme {
    74  		case "postgres", "postgresql":
    75  			urls[i] = parsed.String()
    76  		default:
    77  			return ``, fmt.Errorf(`unsupported scheme: %s`, parsed.Scheme)
    78  		}
    79  	}
    80  	return dbName, nil
    81  }