vitess.io/vitess@v0.16.2/go/cmd/vtclient/vtclient.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package main
    18  
    19  import (
    20  	"context"
    21  	"database/sql"
    22  	"encoding/json"
    23  	"errors"
    24  	"flag"
    25  	"fmt"
    26  	"io"
    27  	"math/rand"
    28  	"os"
    29  	"sort"
    30  	"sync"
    31  	"time"
    32  
    33  	"github.com/olekukonko/tablewriter"
    34  	"github.com/spf13/pflag"
    35  
    36  	"vitess.io/vitess/go/acl"
    37  	"vitess.io/vitess/go/vt/concurrency"
    38  	"vitess.io/vitess/go/vt/grpccommon"
    39  	"vitess.io/vitess/go/vt/log"
    40  	"vitess.io/vitess/go/vt/logutil"
    41  	"vitess.io/vitess/go/vt/servenv"
    42  	"vitess.io/vitess/go/vt/sqlparser"
    43  	"vitess.io/vitess/go/vt/vitessdriver"
    44  	"vitess.io/vitess/go/vt/vterrors"
    45  	"vitess.io/vitess/go/vt/vtgate/vtgateconn"
    46  
    47  	vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
    48  
    49  	// Include deprecation warnings for soon-to-be-unsupported flag invocations.
    50  	_flag "vitess.io/vitess/go/internal/flag"
    51  )
    52  
    53  var (
    54  	usage = `
    55  vtclient connects to a vtgate server using the standard go driver API.
    56  Version 3 of the API is used, we do not send any hint to the server.
    57  
    58  For query bound variables, we assume place-holders in the query string
    59  in the form of :v1, :v2, etc.
    60  
    61  Examples:
    62  
    63    $ vtclient --server vtgate:15991 "SELECT * FROM messages"
    64  
    65    $ vtclient --server vtgate:15991 --target '@primary' --bind_variables '[ 12345, 1, "msg 12345" ]' "INSERT INTO messages (page,time_created_ns,message) VALUES (:v1, :v2, :v3)"
    66  
    67  `
    68  	server        string
    69  	streaming     bool
    70  	targetString  string
    71  	jsonOutput    bool
    72  	useRandom     bool
    73  	bindVariables *bindvars
    74  
    75  	timeout  = 30 * time.Second
    76  	parallel = 1
    77  	count    = 1
    78  	minSeqID = 0
    79  	maxSeqID = 0
    80  	qps      = 0
    81  )
    82  
    83  var (
    84  	seqChan = make(chan int, 10)
    85  )
    86  
    87  func init() {
    88  	_flag.SetUsage(flag.CommandLine, _flag.UsageOptions{
    89  		Epilogue: func(w io.Writer) { fmt.Fprint(w, usage) },
    90  	})
    91  }
    92  
    93  func registerFlags(fs *pflag.FlagSet) {
    94  	fs.StringVar(&server, "server", server, "vtgate server to connect to")
    95  	fs.DurationVar(&timeout, "timeout", timeout, "timeout for queries")
    96  	fs.BoolVar(&streaming, "streaming", streaming, "use a streaming query")
    97  	fs.StringVar(&targetString, "target", targetString, "keyspace:shard@tablet_type")
    98  	fs.BoolVar(&jsonOutput, "json", jsonOutput, "Output JSON instead of human-readable table")
    99  	fs.IntVar(&parallel, "parallel", parallel, "DMLs only: Number of threads executing the same query in parallel. Useful for simple load testing.")
   100  	fs.IntVar(&count, "count", count, "DMLs only: Number of times each thread executes the query. Useful for simple, sustained load testing.")
   101  	fs.IntVar(&minSeqID, "min_sequence_id", minSeqID, "min sequence ID to generate. When max_sequence_id > min_sequence_id, for each query, a number is generated in [min_sequence_id, max_sequence_id) and attached to the end of the bind variables.")
   102  	fs.IntVar(&maxSeqID, "max_sequence_id", maxSeqID, "max sequence ID.")
   103  	fs.BoolVar(&useRandom, "use_random_sequence", useRandom, "use random sequence for generating [min_sequence_id, max_sequence_id)")
   104  	fs.IntVar(&qps, "qps", qps, "queries per second to throttle each thread at.")
   105  
   106  	acl.RegisterFlags(fs)
   107  
   108  	bindVariables = newBindvars(fs, "bind_variables", "bind variables as a json list")
   109  }
   110  
   111  type bindvars []any
   112  
   113  func (bv *bindvars) String() string {
   114  	b, err := json.Marshal(bv)
   115  	if err != nil {
   116  		return err.Error()
   117  	}
   118  	return string(b)
   119  }
   120  
   121  func (bv *bindvars) Set(s string) (err error) {
   122  	err = json.Unmarshal([]byte(s), &bv)
   123  	if err != nil {
   124  		return err
   125  	}
   126  	// json reads all numbers as float64
   127  	// So, we just ditch floats for bindvars
   128  	for i, v := range *bv {
   129  		if f, ok := v.(float64); ok {
   130  			if f > 0 {
   131  				(*bv)[i] = uint64(f)
   132  			} else {
   133  				(*bv)[i] = int64(f)
   134  			}
   135  		}
   136  	}
   137  
   138  	return nil
   139  }
   140  
   141  // For internal flag compatibility
   142  func (bv *bindvars) Get() any {
   143  	return bv
   144  }
   145  
   146  // Type is part of the pflag.Value interface. bindvars.Set() expects all numbers as float64.
   147  func (bv *bindvars) Type() string {
   148  	return "float64"
   149  }
   150  
   151  func newBindvars(fs *pflag.FlagSet, name, usage string) *bindvars {
   152  	var bv bindvars
   153  	fs.Var(&bv, name, usage)
   154  	return &bv
   155  }
   156  
   157  func main() {
   158  	defer logutil.Flush()
   159  
   160  	qr, err := run()
   161  	if jsonOutput && qr != nil {
   162  		data, err := json.MarshalIndent(qr, "", "  ")
   163  		if err != nil {
   164  			log.Exitf("cannot marshal data: %v", err)
   165  		}
   166  		fmt.Print(string(data))
   167  		return
   168  	}
   169  
   170  	qr.print()
   171  
   172  	if err != nil {
   173  		log.Exit(err)
   174  	}
   175  }
   176  
   177  func run() (*results, error) {
   178  	fs := pflag.NewFlagSet("vtclient", pflag.ExitOnError)
   179  	grpccommon.RegisterFlags(fs)
   180  	log.RegisterFlags(fs)
   181  	logutil.RegisterFlags(fs)
   182  	servenv.RegisterMySQLServerFlags(fs)
   183  	registerFlags(fs)
   184  	_flag.Parse(fs)
   185  	args := _flag.Args()
   186  
   187  	logutil.PurgeLogs()
   188  
   189  	if len(args) == 0 {
   190  		pflag.Usage()
   191  		return nil, errors.New("no arguments provided. See usage above")
   192  	}
   193  	if len(args) > 1 {
   194  		return nil, errors.New("no additional arguments after the query allowed")
   195  	}
   196  
   197  	if maxSeqID > minSeqID {
   198  		go func() {
   199  			if useRandom {
   200  				rand.Seed(time.Now().UnixNano())
   201  				for {
   202  					seqChan <- rand.Intn(maxSeqID-minSeqID) + minSeqID
   203  				}
   204  			} else {
   205  				for i := minSeqID; i < maxSeqID; i++ {
   206  					seqChan <- i
   207  				}
   208  			}
   209  		}()
   210  	}
   211  
   212  	c := vitessdriver.Configuration{
   213  		Protocol:  vtgateconn.GetVTGateProtocol(),
   214  		Address:   server,
   215  		Target:    targetString,
   216  		Streaming: streaming,
   217  	}
   218  	db, err := vitessdriver.OpenWithConfiguration(c)
   219  	if err != nil {
   220  		return nil, fmt.Errorf("client error: %v", err)
   221  	}
   222  
   223  	log.Infof("Sending the query...")
   224  
   225  	ctx, cancel := context.WithTimeout(context.Background(), timeout)
   226  	defer cancel()
   227  	return execMulti(ctx, db, args[0])
   228  }
   229  
   230  func prepareBindVariables() []any {
   231  	bv := make([]any, 0, len(*bindVariables)+1)
   232  	bv = append(bv, (*bindVariables)...)
   233  	if maxSeqID > minSeqID {
   234  		bv = append(bv, <-seqChan)
   235  	}
   236  	return bv
   237  }
   238  
   239  func execMulti(ctx context.Context, db *sql.DB, sql string) (*results, error) {
   240  	all := newResults()
   241  	ec := concurrency.FirstErrorRecorder{}
   242  	wg := sync.WaitGroup{}
   243  	isDML := sqlparser.IsDML(sql)
   244  
   245  	isThrottled := qps > 0
   246  
   247  	start := time.Now()
   248  	for i := 0; i < parallel; i++ {
   249  		wg.Add(1)
   250  
   251  		go func() {
   252  			defer wg.Done()
   253  
   254  			var ticker *time.Ticker
   255  			if isThrottled {
   256  				tickDuration := time.Second / time.Duration(qps)
   257  				ticker = time.NewTicker(tickDuration)
   258  			}
   259  
   260  			for j := 0; j < count; j++ {
   261  				var qr *results
   262  				var err error
   263  				if isDML {
   264  					qr, err = execDml(ctx, db, sql)
   265  				} else {
   266  					qr, err = execNonDml(ctx, db, sql)
   267  				}
   268  				if count == 1 && parallel == 1 {
   269  					all = qr
   270  				} else {
   271  					all.merge(qr)
   272  					if err != nil {
   273  						all.recordError(err)
   274  					}
   275  				}
   276  				if err != nil {
   277  					ec.RecordError(err)
   278  					// We keep going and do not return early purpose.
   279  				}
   280  
   281  				if ticker != nil {
   282  					<-ticker.C
   283  				}
   284  			}
   285  		}()
   286  	}
   287  	wg.Wait()
   288  	if all != nil {
   289  		all.duration = time.Since(start)
   290  	}
   291  
   292  	return all, ec.Error()
   293  }
   294  
   295  func execDml(ctx context.Context, db *sql.DB, sql string) (*results, error) {
   296  	start := time.Now()
   297  	tx, err := db.Begin()
   298  	if err != nil {
   299  		return nil, vterrors.Wrap(err, "BEGIN failed")
   300  	}
   301  
   302  	result, err := tx.ExecContext(ctx, sql, []any(prepareBindVariables())...)
   303  	if err != nil {
   304  		return nil, vterrors.Wrap(err, "failed to execute DML")
   305  	}
   306  
   307  	err = tx.Commit()
   308  	if err != nil {
   309  		return nil, vterrors.Wrap(err, "COMMIT failed")
   310  	}
   311  
   312  	rowsAffected, _ := result.RowsAffected()
   313  	lastInsertID, _ := result.LastInsertId()
   314  	return &results{
   315  		rowsAffected: rowsAffected,
   316  		lastInsertID: lastInsertID,
   317  		duration:     time.Since(start),
   318  	}, nil
   319  }
   320  
   321  func execNonDml(ctx context.Context, db *sql.DB, sql string) (*results, error) {
   322  	start := time.Now()
   323  	rows, err := db.QueryContext(ctx, sql, []any(prepareBindVariables())...)
   324  	if err != nil {
   325  		return nil, vterrors.Wrap(err, "client error")
   326  	}
   327  	defer rows.Close()
   328  
   329  	// get the headers
   330  	var qr results
   331  	cols, err := rows.Columns()
   332  	if err != nil {
   333  		return nil, vterrors.Wrap(err, "client error")
   334  	}
   335  	qr.Fields = cols
   336  
   337  	// get the rows
   338  	for rows.Next() {
   339  		row := make([]any, len(cols))
   340  		for i := range row {
   341  			var col string
   342  			row[i] = &col
   343  		}
   344  		if err := rows.Scan(row...); err != nil {
   345  			return nil, vterrors.Wrap(err, "client error")
   346  		}
   347  
   348  		// unpack []*string into []string
   349  		vals := make([]string, 0, len(row))
   350  		for _, value := range row {
   351  			vals = append(vals, *(value.(*string)))
   352  		}
   353  		qr.Rows = append(qr.Rows, vals)
   354  	}
   355  	qr.rowsAffected = int64(len(qr.Rows))
   356  
   357  	if err := rows.Err(); err != nil {
   358  		return nil, vterrors.Wrap(err, "Vitess returned an error")
   359  	}
   360  
   361  	qr.duration = time.Since(start)
   362  	return &qr, nil
   363  }
   364  
   365  type results struct {
   366  	mu                 sync.Mutex
   367  	Fields             []string   `json:"fields"`
   368  	Rows               [][]string `json:"rows"`
   369  	rowsAffected       int64
   370  	lastInsertID       int64
   371  	duration           time.Duration
   372  	cumulativeDuration time.Duration
   373  
   374  	// Multi DML mode: Track total error count, error count per code and the first error.
   375  	totalErrorCount int
   376  	errorCount      map[vtrpcpb.Code]int
   377  	firstError      map[vtrpcpb.Code]error
   378  }
   379  
   380  func newResults() *results {
   381  	return &results{
   382  		errorCount: make(map[vtrpcpb.Code]int),
   383  		firstError: make(map[vtrpcpb.Code]error),
   384  	}
   385  }
   386  
   387  // merge aggregates "other" into "r".
   388  // This is only used for executing DMLs concurrently and repeatedly.
   389  // Therefore, "Fields" and "Rows" are not merged.
   390  func (r *results) merge(other *results) {
   391  	if other == nil {
   392  		return
   393  	}
   394  
   395  	r.mu.Lock()
   396  	defer r.mu.Unlock()
   397  
   398  	r.rowsAffected += other.rowsAffected
   399  	if other.lastInsertID > r.lastInsertID {
   400  		r.lastInsertID = other.lastInsertID
   401  	}
   402  	r.cumulativeDuration += other.duration
   403  }
   404  
   405  func (r *results) recordError(err error) {
   406  	r.mu.Lock()
   407  	defer r.mu.Unlock()
   408  
   409  	r.totalErrorCount++
   410  	code := vterrors.Code(err)
   411  	r.errorCount[code]++
   412  
   413  	if r.errorCount[code] == 1 {
   414  		r.firstError[code] = err
   415  	}
   416  }
   417  
   418  func (r *results) print() {
   419  	if r == nil {
   420  		return
   421  	}
   422  
   423  	table := tablewriter.NewWriter(os.Stdout)
   424  	table.SetHeader(r.Fields)
   425  	table.SetAutoFormatHeaders(false)
   426  	table.AppendBulk(r.Rows)
   427  	table.Render()
   428  	fmt.Printf("%v row(s) affected (%v, cum: %v)\n", r.rowsAffected, r.duration, r.cumulativeDuration)
   429  	if r.lastInsertID != 0 {
   430  		fmt.Printf("Last insert ID: %v\n", r.lastInsertID)
   431  	}
   432  
   433  	if r.totalErrorCount == 0 {
   434  		return
   435  	}
   436  
   437  	fmt.Printf("%d error(s) were returned. Number of errors by error code:\n\n", r.totalErrorCount)
   438  	// Sort different error codes by count (descending).
   439  	type errorCounts struct {
   440  		code  vtrpcpb.Code
   441  		count int
   442  	}
   443  	var counts []errorCounts
   444  	for code, count := range r.errorCount {
   445  		counts = append(counts, errorCounts{code, count})
   446  	}
   447  	sort.Slice(counts, func(i, j int) bool { return counts[i].count >= counts[j].count })
   448  	for _, c := range counts {
   449  		fmt.Printf("%- 30v= % 5d\n", c.code, c.count)
   450  	}
   451  
   452  	fmt.Printf("\nFirst error per code:\n\n")
   453  	for code, err := range r.firstError {
   454  		fmt.Printf("Code:  %v\nError: %v\n\n", code, err)
   455  	}
   456  }