github.com/square/finch@v0.0.0-20240412205204-6530c03e2b96/client/client.go (about)

     1  // Copyright 2024 Block, Inc.
     2  
     3  package client
     4  
     5  import (
     6  	"context"
     7  	"database/sql"
     8  	"errors"
     9  	"fmt"
    10  	"log"
    11  	"runtime"
    12  	"sync/atomic"
    13  	"time"
    14  
    15  	myerr "github.com/go-mysql/errors"
    16  
    17  	"github.com/square/finch"
    18  	"github.com/square/finch/data"
    19  	"github.com/square/finch/stats"
    20  	"github.com/square/finch/trx"
    21  )
    22  
    23  var (
    24  	ConnectTimeout   = 500 * time.Millisecond
    25  	ConnectRetryWait = 200 * time.Millisecond
    26  )
    27  
    28  // Client executes SQL statements. Each client is created in workload.Allocator.Clients
    29  // and run in Stage.Run. Client.Init must be called once before calling Client.Run once.
    30  type Client struct {
    31  	// Required args
    32  	DB         *sql.DB `deep:"-"`
    33  	Data       []StatementData
    34  	DoneChan   chan *Client
    35  	RunLevel   finch.RunLevel
    36  	Statements []*trx.Statement
    37  	Stats      []*stats.Trx `deep:"-"`
    38  
    39  	// Optional, usually from stage config
    40  	DefaultDb        string
    41  	IterExecGroup    uint32
    42  	IterExecGroupPtr *uint32
    43  	IterClients      uint32
    44  	IterClientsPtr   *uint32
    45  	Iter             uint
    46  	QPS              <-chan bool
    47  	TPS              <-chan bool
    48  
    49  	// Retrun value to DoneChane
    50  	Error Error
    51  
    52  	// --
    53  	ps     []*sql.Stmt
    54  	values [][]interface{}
    55  	conn   *sql.Conn
    56  }
    57  
    58  type Error struct {
    59  	Err         error
    60  	StatementNo int
    61  }
    62  
    63  type StatementData struct {
    64  	Inputs      []data.ValueFunc `deep:"-"` // input to query
    65  	Outputs     []interface{}    `deep:"-"` // output from query; values are data.Generator
    66  	InsertId    data.Generator   `deep:"-"`
    67  	TrxBoundary byte
    68  }
    69  
    70  func (c *Client) Init() error {
    71  	c.ps = make([]*sql.Stmt, len(c.Statements))
    72  	c.values = make([][]interface{}, len(c.Statements))
    73  	for i, s := range c.Statements {
    74  		if len(s.Inputs) > 0 {
    75  			c.values[i] = make([]interface{}, len(s.Inputs))
    76  		}
    77  	}
    78  	c.Error = Error{}
    79  	return nil
    80  }
    81  
    82  func (c *Client) Connect(ctx context.Context, cerr error, stmtNo int, trxActive bool) error {
    83  	if ctx.Err() != nil { // finch terminated (CTRL-C)?
    84  		return ctx.Err()
    85  	}
    86  
    87  	// @todo: handled errors aren't printed, can't tell what went wrong
    88  	//        when errors col != 0
    89  
    90  	silent := false
    91  	// Connect called due to error on query execution?
    92  	if cerr != nil {
    93  		errFlags, handled := finch.MySQLErrorHandling[myerr.MySQLErrorCode(cerr)]
    94  		if c.Statements[stmtNo].DDL && !handled {
    95  			return fmt.Errorf("DDL: %s", cerr)
    96  		}
    97  		if handled {
    98  			if errFlags&finch.Eabort != 0 {
    99  				return cerr // stop client
   100  			}
   101  			if errFlags&finch.Erollback != 0 && trxActive {
   102  				finch.Debug("%s: rollback", c.RunLevel.ClientId())
   103  				if _, err := c.conn.ExecContext(ctx, "ROLLBACK"); err != nil {
   104  					return fmt.Errorf("ROLLBACK failed: %s (on err: %s) (query: %s)", err, cerr, c.Statements[stmtNo].Query)
   105  				}
   106  			}
   107  			if errFlags&finch.Econtinue != 0 {
   108  				return nil // keep conn, next iter, keep executing
   109  			}
   110  		}
   111  		silent = (errFlags&finch.Esilent != 0) // log the error (here and below)? uhandled errors are logged
   112  		if !silent {
   113  			log.Printf("Client %s reconnect on error: %s (%s)", c.RunLevel.ClientId(), cerr, c.Statements[stmtNo].Query)
   114  		}
   115  	}
   116  
   117  	if c.conn != nil {
   118  		c.conn.Close()
   119  		c.conn = nil
   120  		time.Sleep(ConnectRetryWait)
   121  	}
   122  
   123  	t0 := time.Now()
   124  	for ctx.Err() == nil {
   125  		ctxConn, cancel := context.WithTimeout(ctx, ConnectTimeout)
   126  		c.conn, _ = c.DB.Conn(ctxConn)
   127  		cancel()
   128  		if c.conn != nil {
   129  			break // success
   130  		}
   131  		time.Sleep(ConnectRetryWait)
   132  	}
   133  
   134  	if ctx.Err() != nil { // finch terminated (CTRL-C)?
   135  		return ctx.Err()
   136  	}
   137  
   138  	if cerr != nil && !silent {
   139  		log.Printf("Client %s reconnected in %.3fs", c.RunLevel.ClientId(), time.Now().Sub(t0).Seconds())
   140  	}
   141  
   142  	if c.DefaultDb != "" {
   143  		_, err := c.conn.ExecContext(ctx, "USE `"+c.DefaultDb+"`")
   144  		if err != nil {
   145  			return err
   146  		}
   147  	}
   148  
   149  	var err error
   150  	for i, s := range c.Statements {
   151  		if !s.Prepare {
   152  			continue
   153  		}
   154  		if c.ps[i] != nil {
   155  			continue // prepare multi
   156  		}
   157  		c.ps[i], err = c.conn.PrepareContext(ctx, s.Query)
   158  		if err != nil {
   159  			c.Error.StatementNo = i
   160  			return fmt.Errorf("prepare: %s", err)
   161  		}
   162  
   163  		// If s.PrepareMulti = 3, it means this ps should be used for 3 statments
   164  		// including this one, so copy it into the next 2 statements. If = 0, this
   165  		// loop doesn't run becuase j = 1; j < 0 is immediately false.
   166  		for j := 1; j < s.PrepareMulti; j++ {
   167  			c.ps[i+j] = c.ps[i]
   168  		}
   169  	}
   170  	return nil
   171  }
   172  
   173  func (c *Client) Run(ctxExec context.Context) {
   174  	finch.Debug("run client %s: %d stmts, iter %d/%d/%d", c.RunLevel.ClientId(), len(c.Statements), c.IterExecGroup, c.IterClients, c.Iter)
   175  	var err error
   176  	defer func() {
   177  		if r := recover(); r != nil {
   178  			b := make([]byte, 4096)
   179  			n := runtime.Stack(b, false)
   180  			err = fmt.Errorf("PANIC: %v\n%s", r, string(b[0:n]))
   181  		}
   182  		for i := range c.ps {
   183  			if c.ps[i] == nil {
   184  				continue
   185  			}
   186  			c.ps[i].Close()
   187  		}
   188  		if c.conn != nil {
   189  			c.conn.Close()
   190  		}
   191  		// Context cancellation is not an error it's runtime elapsing or CTRL-C
   192  		if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
   193  			c.Error.Err = err
   194  		}
   195  		c.DoneChan <- c
   196  	}()
   197  
   198  	if err = c.Connect(ctxExec, nil, -1, false); err != nil {
   199  		return
   200  	}
   201  
   202  	var rc data.RunCount
   203  	rc[data.CONN] = 1 // first MySQL connection ^
   204  
   205  	// Not counts but passed with RunCount in case a data.Generator wants to know
   206  	rc[data.CLIENT] = c.RunLevel.Client
   207  	rc[data.CLIENT_GROUP] = c.RunLevel.ClientGroup
   208  	rc[data.EXEC_GROUP] = c.RunLevel.ExecGroup
   209  	rc[data.STAGE] = c.RunLevel.Stage
   210  
   211  	var rows *sql.Rows
   212  	var res sql.Result
   213  	var t time.Time
   214  
   215  	// trxNo indexes into c.Stats and resets to 0 on each iteration. Remember:
   216  	// these are finch trx (files), not MySQL trx, so trx boundaries mark the
   217  	// beginning and end of a finch trx (file). User is expected to make finch
   218  	// trx boundaries meaningful.
   219  	trxNo := -1
   220  	trxActive := false
   221  
   222  	//
   223  	// CRITICAL LOOP: no debug or superfluous function calls
   224  	//
   225  ITER:
   226  	for {
   227  		if c.IterExecGroup > 0 && atomic.AddUint32(c.IterExecGroupPtr, 1) > c.IterExecGroup {
   228  			return
   229  		}
   230  		if c.IterClients > 0 && atomic.AddUint32(c.IterClientsPtr, 1) > c.IterClients {
   231  			return
   232  		}
   233  		if c.Iter > 0 && rc[data.ITER] == c.Iter {
   234  			return
   235  		}
   236  		rc[data.ITER] += 1
   237  		trxNo = -1
   238  		trxActive = false
   239  
   240  		for i := range c.Statements {
   241  			// Idle time
   242  			if c.Statements[i].Idle != 0 {
   243  				time.Sleep(c.Statements[i].Idle)
   244  				continue
   245  			}
   246  
   247  			// Is this query the start of a new (finch) trx file? This is not
   248  			// a MySQL trx (either BEGIN or implicit). It marks finch trx scope
   249  			// "trx" is a trx file in the config assigned to this client.
   250  			if c.Data[i].TrxBoundary&trx.BEGIN != 0 {
   251  				rc[data.TRX] += 1
   252  				trxNo += 1
   253  				trxActive = true
   254  			} else if c.Data[i].TrxBoundary&trx.END != 0 {
   255  				trxActive = false
   256  			}
   257  
   258  			// If BEGIN, check TPS rate limiter
   259  			if c.TPS != nil && c.Statements[i].Begin {
   260  				<-c.TPS
   261  			}
   262  
   263  			// If query, check QPS
   264  			if c.QPS != nil {
   265  				<-c.QPS
   266  			}
   267  
   268  			// Generate new data values for this query. A single data generator
   269  			// can return multiple values, so d makes copy() append, else copy()
   270  			// would start at [0:] each time
   271  			rc[data.STATEMENT] += 1
   272  			d := 0
   273  			for _, f := range c.Data[i].Inputs {
   274  				d += copy(c.values[i][d:], f(rc))
   275  			}
   276  
   277  			if c.Statements[i].ResultSet {
   278  				//
   279  				// SELECT
   280  				//
   281  				t = time.Now()
   282  				if c.ps[i] != nil {
   283  					rows, err = c.ps[i].QueryContext(ctxExec, c.values[i]...)
   284  				} else {
   285  					rows, err = c.conn.QueryContext(ctxExec, fmt.Sprintf(c.Statements[i].Query, c.values[i]...))
   286  				}
   287  				if c.Stats[trxNo] != nil {
   288  					c.Stats[trxNo].Record(stats.READ, time.Now().Sub(t).Microseconds())
   289  				}
   290  				if err != nil {
   291  					goto ERROR
   292  				}
   293  				if c.Data[i].Outputs != nil {
   294  					// @todo what if no row match? This loop won't happen,
   295  					// and the column generator won't be called, which will
   296  					// make it return nil later when used as input to another
   297  					// query.
   298  					for rows.Next() {
   299  						if err = rows.Scan(c.Data[i].Outputs...); err != nil {
   300  							rows.Close()
   301  							goto ERROR
   302  						}
   303  					}
   304  				}
   305  				rows.Close()
   306  			} else {
   307  				//
   308  				// Write or query without result set (e.g. BEGIN, SET, etc.)
   309  				//
   310  				if c.Statements[i].Limit != nil { // limit rows -------------
   311  					if !c.Statements[i].Limit.More(c.conn) {
   312  						return // chan closed = no more writes
   313  					}
   314  				}
   315  				t = time.Now()
   316  				if c.ps[i] != nil { // exec ---------------------------------
   317  					res, err = c.ps[i].ExecContext(ctxExec, c.values[i]...)
   318  				} else {
   319  					res, err = c.conn.ExecContext(ctxExec, fmt.Sprintf(c.Statements[i].Query, c.values[i]...))
   320  				}
   321  				if c.Stats[trxNo] != nil { // record stats ------------------
   322  					switch {
   323  					case c.Statements[i].Write:
   324  						c.Stats[trxNo].Record(stats.WRITE, time.Now().Sub(t).Microseconds())
   325  					case c.Statements[i].Commit:
   326  						c.Stats[trxNo].Record(stats.COMMIT, time.Now().Sub(t).Microseconds())
   327  					default:
   328  						// BEGIN, SET, and other statements that aren't reads or writes
   329  						// but count and response time will be included in total
   330  						c.Stats[trxNo].Record(stats.TOTAL, time.Now().Sub(t).Microseconds())
   331  					}
   332  				}
   333  				if err != nil { // handle err, if any -----------------------
   334  					goto ERROR
   335  				}
   336  				if c.Statements[i].Limit != nil { // limit rows -------------
   337  					n, _ := res.RowsAffected()
   338  					c.Statements[i].Limit.Affected(n)
   339  				}
   340  				if c.Data[i].InsertId != nil { // insert ID -----------------
   341  					id, _ := res.LastInsertId()
   342  					c.Data[i].InsertId.Scan(id)
   343  				}
   344  			} // execute
   345  			continue // next query
   346  
   347  		ERROR:
   348  			if c.Stats[trxNo] != nil && ctxExec.Err() == nil {
   349  				c.Stats[trxNo].Error(myerr.MySQLErrorCode(err))
   350  			}
   351  			if err = c.Connect(ctxExec, err, i, trxActive); err != nil {
   352  				c.Error.StatementNo = i
   353  				return // unrecoverable error or runtime elapsed (context timeout/cancel)
   354  			}
   355  			rc[data.CONN] += 1 // reconnected or recovered after query error
   356  			continue ITER
   357  		} // statements
   358  	} // iterations
   359  }