vitess.io/vitess@v0.16.2/go/cmd/vtclient/vtclient.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package main 18 19 import ( 20 "context" 21 "database/sql" 22 "encoding/json" 23 "errors" 24 "flag" 25 "fmt" 26 "io" 27 "math/rand" 28 "os" 29 "sort" 30 "sync" 31 "time" 32 33 "github.com/olekukonko/tablewriter" 34 "github.com/spf13/pflag" 35 36 "vitess.io/vitess/go/acl" 37 "vitess.io/vitess/go/vt/concurrency" 38 "vitess.io/vitess/go/vt/grpccommon" 39 "vitess.io/vitess/go/vt/log" 40 "vitess.io/vitess/go/vt/logutil" 41 "vitess.io/vitess/go/vt/servenv" 42 "vitess.io/vitess/go/vt/sqlparser" 43 "vitess.io/vitess/go/vt/vitessdriver" 44 "vitess.io/vitess/go/vt/vterrors" 45 "vitess.io/vitess/go/vt/vtgate/vtgateconn" 46 47 vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" 48 49 // Include deprecation warnings for soon-to-be-unsupported flag invocations. 50 _flag "vitess.io/vitess/go/internal/flag" 51 ) 52 53 var ( 54 usage = ` 55 vtclient connects to a vtgate server using the standard go driver API. 56 Version 3 of the API is used, we do not send any hint to the server. 57 58 For query bound variables, we assume place-holders in the query string 59 in the form of :v1, :v2, etc. 60 61 Examples: 62 63 $ vtclient --server vtgate:15991 "SELECT * FROM messages" 64 65 $ vtclient --server vtgate:15991 --target '@primary' --bind_variables '[ 12345, 1, "msg 12345" ]' "INSERT INTO messages (page,time_created_ns,message) VALUES (:v1, :v2, :v3)" 66 67 ` 68 server string 69 streaming bool 70 targetString string 71 jsonOutput bool 72 useRandom bool 73 bindVariables *bindvars 74 75 timeout = 30 * time.Second 76 parallel = 1 77 count = 1 78 minSeqID = 0 79 maxSeqID = 0 80 qps = 0 81 ) 82 83 var ( 84 seqChan = make(chan int, 10) 85 ) 86 87 func init() { 88 _flag.SetUsage(flag.CommandLine, _flag.UsageOptions{ 89 Epilogue: func(w io.Writer) { fmt.Fprint(w, usage) }, 90 }) 91 } 92 93 func registerFlags(fs *pflag.FlagSet) { 94 fs.StringVar(&server, "server", server, "vtgate server to connect to") 95 fs.DurationVar(&timeout, "timeout", timeout, "timeout for queries") 96 fs.BoolVar(&streaming, "streaming", streaming, "use a streaming query") 97 fs.StringVar(&targetString, "target", targetString, "keyspace:shard@tablet_type") 98 fs.BoolVar(&jsonOutput, "json", jsonOutput, "Output JSON instead of human-readable table") 99 fs.IntVar(¶llel, "parallel", parallel, "DMLs only: Number of threads executing the same query in parallel. Useful for simple load testing.") 100 fs.IntVar(&count, "count", count, "DMLs only: Number of times each thread executes the query. Useful for simple, sustained load testing.") 101 fs.IntVar(&minSeqID, "min_sequence_id", minSeqID, "min sequence ID to generate. When max_sequence_id > min_sequence_id, for each query, a number is generated in [min_sequence_id, max_sequence_id) and attached to the end of the bind variables.") 102 fs.IntVar(&maxSeqID, "max_sequence_id", maxSeqID, "max sequence ID.") 103 fs.BoolVar(&useRandom, "use_random_sequence", useRandom, "use random sequence for generating [min_sequence_id, max_sequence_id)") 104 fs.IntVar(&qps, "qps", qps, "queries per second to throttle each thread at.") 105 106 acl.RegisterFlags(fs) 107 108 bindVariables = newBindvars(fs, "bind_variables", "bind variables as a json list") 109 } 110 111 type bindvars []any 112 113 func (bv *bindvars) String() string { 114 b, err := json.Marshal(bv) 115 if err != nil { 116 return err.Error() 117 } 118 return string(b) 119 } 120 121 func (bv *bindvars) Set(s string) (err error) { 122 err = json.Unmarshal([]byte(s), &bv) 123 if err != nil { 124 return err 125 } 126 // json reads all numbers as float64 127 // So, we just ditch floats for bindvars 128 for i, v := range *bv { 129 if f, ok := v.(float64); ok { 130 if f > 0 { 131 (*bv)[i] = uint64(f) 132 } else { 133 (*bv)[i] = int64(f) 134 } 135 } 136 } 137 138 return nil 139 } 140 141 // For internal flag compatibility 142 func (bv *bindvars) Get() any { 143 return bv 144 } 145 146 // Type is part of the pflag.Value interface. bindvars.Set() expects all numbers as float64. 147 func (bv *bindvars) Type() string { 148 return "float64" 149 } 150 151 func newBindvars(fs *pflag.FlagSet, name, usage string) *bindvars { 152 var bv bindvars 153 fs.Var(&bv, name, usage) 154 return &bv 155 } 156 157 func main() { 158 defer logutil.Flush() 159 160 qr, err := run() 161 if jsonOutput && qr != nil { 162 data, err := json.MarshalIndent(qr, "", " ") 163 if err != nil { 164 log.Exitf("cannot marshal data: %v", err) 165 } 166 fmt.Print(string(data)) 167 return 168 } 169 170 qr.print() 171 172 if err != nil { 173 log.Exit(err) 174 } 175 } 176 177 func run() (*results, error) { 178 fs := pflag.NewFlagSet("vtclient", pflag.ExitOnError) 179 grpccommon.RegisterFlags(fs) 180 log.RegisterFlags(fs) 181 logutil.RegisterFlags(fs) 182 servenv.RegisterMySQLServerFlags(fs) 183 registerFlags(fs) 184 _flag.Parse(fs) 185 args := _flag.Args() 186 187 logutil.PurgeLogs() 188 189 if len(args) == 0 { 190 pflag.Usage() 191 return nil, errors.New("no arguments provided. See usage above") 192 } 193 if len(args) > 1 { 194 return nil, errors.New("no additional arguments after the query allowed") 195 } 196 197 if maxSeqID > minSeqID { 198 go func() { 199 if useRandom { 200 rand.Seed(time.Now().UnixNano()) 201 for { 202 seqChan <- rand.Intn(maxSeqID-minSeqID) + minSeqID 203 } 204 } else { 205 for i := minSeqID; i < maxSeqID; i++ { 206 seqChan <- i 207 } 208 } 209 }() 210 } 211 212 c := vitessdriver.Configuration{ 213 Protocol: vtgateconn.GetVTGateProtocol(), 214 Address: server, 215 Target: targetString, 216 Streaming: streaming, 217 } 218 db, err := vitessdriver.OpenWithConfiguration(c) 219 if err != nil { 220 return nil, fmt.Errorf("client error: %v", err) 221 } 222 223 log.Infof("Sending the query...") 224 225 ctx, cancel := context.WithTimeout(context.Background(), timeout) 226 defer cancel() 227 return execMulti(ctx, db, args[0]) 228 } 229 230 func prepareBindVariables() []any { 231 bv := make([]any, 0, len(*bindVariables)+1) 232 bv = append(bv, (*bindVariables)...) 233 if maxSeqID > minSeqID { 234 bv = append(bv, <-seqChan) 235 } 236 return bv 237 } 238 239 func execMulti(ctx context.Context, db *sql.DB, sql string) (*results, error) { 240 all := newResults() 241 ec := concurrency.FirstErrorRecorder{} 242 wg := sync.WaitGroup{} 243 isDML := sqlparser.IsDML(sql) 244 245 isThrottled := qps > 0 246 247 start := time.Now() 248 for i := 0; i < parallel; i++ { 249 wg.Add(1) 250 251 go func() { 252 defer wg.Done() 253 254 var ticker *time.Ticker 255 if isThrottled { 256 tickDuration := time.Second / time.Duration(qps) 257 ticker = time.NewTicker(tickDuration) 258 } 259 260 for j := 0; j < count; j++ { 261 var qr *results 262 var err error 263 if isDML { 264 qr, err = execDml(ctx, db, sql) 265 } else { 266 qr, err = execNonDml(ctx, db, sql) 267 } 268 if count == 1 && parallel == 1 { 269 all = qr 270 } else { 271 all.merge(qr) 272 if err != nil { 273 all.recordError(err) 274 } 275 } 276 if err != nil { 277 ec.RecordError(err) 278 // We keep going and do not return early purpose. 279 } 280 281 if ticker != nil { 282 <-ticker.C 283 } 284 } 285 }() 286 } 287 wg.Wait() 288 if all != nil { 289 all.duration = time.Since(start) 290 } 291 292 return all, ec.Error() 293 } 294 295 func execDml(ctx context.Context, db *sql.DB, sql string) (*results, error) { 296 start := time.Now() 297 tx, err := db.Begin() 298 if err != nil { 299 return nil, vterrors.Wrap(err, "BEGIN failed") 300 } 301 302 result, err := tx.ExecContext(ctx, sql, []any(prepareBindVariables())...) 303 if err != nil { 304 return nil, vterrors.Wrap(err, "failed to execute DML") 305 } 306 307 err = tx.Commit() 308 if err != nil { 309 return nil, vterrors.Wrap(err, "COMMIT failed") 310 } 311 312 rowsAffected, _ := result.RowsAffected() 313 lastInsertID, _ := result.LastInsertId() 314 return &results{ 315 rowsAffected: rowsAffected, 316 lastInsertID: lastInsertID, 317 duration: time.Since(start), 318 }, nil 319 } 320 321 func execNonDml(ctx context.Context, db *sql.DB, sql string) (*results, error) { 322 start := time.Now() 323 rows, err := db.QueryContext(ctx, sql, []any(prepareBindVariables())...) 324 if err != nil { 325 return nil, vterrors.Wrap(err, "client error") 326 } 327 defer rows.Close() 328 329 // get the headers 330 var qr results 331 cols, err := rows.Columns() 332 if err != nil { 333 return nil, vterrors.Wrap(err, "client error") 334 } 335 qr.Fields = cols 336 337 // get the rows 338 for rows.Next() { 339 row := make([]any, len(cols)) 340 for i := range row { 341 var col string 342 row[i] = &col 343 } 344 if err := rows.Scan(row...); err != nil { 345 return nil, vterrors.Wrap(err, "client error") 346 } 347 348 // unpack []*string into []string 349 vals := make([]string, 0, len(row)) 350 for _, value := range row { 351 vals = append(vals, *(value.(*string))) 352 } 353 qr.Rows = append(qr.Rows, vals) 354 } 355 qr.rowsAffected = int64(len(qr.Rows)) 356 357 if err := rows.Err(); err != nil { 358 return nil, vterrors.Wrap(err, "Vitess returned an error") 359 } 360 361 qr.duration = time.Since(start) 362 return &qr, nil 363 } 364 365 type results struct { 366 mu sync.Mutex 367 Fields []string `json:"fields"` 368 Rows [][]string `json:"rows"` 369 rowsAffected int64 370 lastInsertID int64 371 duration time.Duration 372 cumulativeDuration time.Duration 373 374 // Multi DML mode: Track total error count, error count per code and the first error. 375 totalErrorCount int 376 errorCount map[vtrpcpb.Code]int 377 firstError map[vtrpcpb.Code]error 378 } 379 380 func newResults() *results { 381 return &results{ 382 errorCount: make(map[vtrpcpb.Code]int), 383 firstError: make(map[vtrpcpb.Code]error), 384 } 385 } 386 387 // merge aggregates "other" into "r". 388 // This is only used for executing DMLs concurrently and repeatedly. 389 // Therefore, "Fields" and "Rows" are not merged. 390 func (r *results) merge(other *results) { 391 if other == nil { 392 return 393 } 394 395 r.mu.Lock() 396 defer r.mu.Unlock() 397 398 r.rowsAffected += other.rowsAffected 399 if other.lastInsertID > r.lastInsertID { 400 r.lastInsertID = other.lastInsertID 401 } 402 r.cumulativeDuration += other.duration 403 } 404 405 func (r *results) recordError(err error) { 406 r.mu.Lock() 407 defer r.mu.Unlock() 408 409 r.totalErrorCount++ 410 code := vterrors.Code(err) 411 r.errorCount[code]++ 412 413 if r.errorCount[code] == 1 { 414 r.firstError[code] = err 415 } 416 } 417 418 func (r *results) print() { 419 if r == nil { 420 return 421 } 422 423 table := tablewriter.NewWriter(os.Stdout) 424 table.SetHeader(r.Fields) 425 table.SetAutoFormatHeaders(false) 426 table.AppendBulk(r.Rows) 427 table.Render() 428 fmt.Printf("%v row(s) affected (%v, cum: %v)\n", r.rowsAffected, r.duration, r.cumulativeDuration) 429 if r.lastInsertID != 0 { 430 fmt.Printf("Last insert ID: %v\n", r.lastInsertID) 431 } 432 433 if r.totalErrorCount == 0 { 434 return 435 } 436 437 fmt.Printf("%d error(s) were returned. Number of errors by error code:\n\n", r.totalErrorCount) 438 // Sort different error codes by count (descending). 439 type errorCounts struct { 440 code vtrpcpb.Code 441 count int 442 } 443 var counts []errorCounts 444 for code, count := range r.errorCount { 445 counts = append(counts, errorCounts{code, count}) 446 } 447 sort.Slice(counts, func(i, j int) bool { return counts[i].count >= counts[j].count }) 448 for _, c := range counts { 449 fmt.Printf("%- 30v= % 5d\n", c.code, c.count) 450 } 451 452 fmt.Printf("\nFirst error per code:\n\n") 453 for code, err := range r.firstError { 454 fmt.Printf("Code: %v\nError: %v\n\n", code, err) 455 } 456 }