github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/workload/connection.go (about) 1 // Copyright 2018 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 package workload 12 13 import ( 14 "fmt" 15 "net/url" 16 "runtime" 17 "strings" 18 19 "github.com/spf13/pflag" 20 ) 21 22 // ConnFlags is helper of common flags that are relevant to QueryLoads. 23 type ConnFlags struct { 24 *pflag.FlagSet 25 DBOverride string 26 Concurrency int 27 // Method for issuing queries; see SQLRunner. 28 Method string 29 } 30 31 // NewConnFlags returns an initialized ConnFlags. 32 func NewConnFlags(genFlags *Flags) *ConnFlags { 33 c := &ConnFlags{} 34 c.FlagSet = pflag.NewFlagSet(`conn`, pflag.ContinueOnError) 35 c.StringVar(&c.DBOverride, `db`, ``, 36 `Override for the SQL database to use. If empty, defaults to the generator name`) 37 c.IntVar(&c.Concurrency, `concurrency`, 2*runtime.NumCPU(), 38 `Number of concurrent workers`) 39 c.StringVar(&c.Method, `method`, `prepare`, `SQL issue method (prepare, noprepare, simple)`) 40 genFlags.AddFlagSet(c.FlagSet) 41 if genFlags.Meta == nil { 42 genFlags.Meta = make(map[string]FlagMeta) 43 } 44 genFlags.Meta[`db`] = FlagMeta{RuntimeOnly: true} 45 genFlags.Meta[`concurrency`] = FlagMeta{RuntimeOnly: true} 46 genFlags.Meta[`method`] = FlagMeta{RuntimeOnly: true} 47 return c 48 } 49 50 // SanitizeUrls verifies that the give SQL connection strings have the correct 51 // SQL database set, rewriting them in place if necessary. This database name is 52 // returned. 53 func SanitizeUrls(gen Generator, dbOverride string, urls []string) (string, error) { 54 dbName := gen.Meta().Name 55 if dbOverride != `` { 56 dbName = dbOverride 57 } 58 for i := range urls { 59 parsed, err := url.Parse(urls[i]) 60 if err != nil { 61 return "", err 62 } 63 if d := strings.TrimPrefix(parsed.Path, `/`); d != `` && d != dbName { 64 return "", fmt.Errorf(`%s specifies database %q, but database %q is expected`, 65 urls[i], d, dbName) 66 } 67 parsed.Path = dbName 68 69 q := parsed.Query() 70 q.Set("application_name", gen.Meta().Name) 71 parsed.RawQuery = q.Encode() 72 73 switch parsed.Scheme { 74 case "postgres", "postgresql": 75 urls[i] = parsed.String() 76 default: 77 return ``, fmt.Errorf(`unsupported scheme: %s`, parsed.Scheme) 78 } 79 } 80 return dbName, nil 81 }