vitess.io/vitess@v0.16.2/go/mysql/conn.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package mysql
    18  
    19  import (
    20  	"bufio"
    21  	"crypto/tls"
    22  	"crypto/x509"
    23  	"errors"
    24  	"fmt"
    25  	"io"
    26  	"net"
    27  	"strings"
    28  	"sync"
    29  	"time"
    30  
    31  	"vitess.io/vitess/go/mysql/collations"
    32  
    33  	"vitess.io/vitess/go/sqlescape"
    34  
    35  	"vitess.io/vitess/go/bucketpool"
    36  	"vitess.io/vitess/go/sqltypes"
    37  	"vitess.io/vitess/go/sync2"
    38  	"vitess.io/vitess/go/vt/log"
    39  	querypb "vitess.io/vitess/go/vt/proto/query"
    40  	vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
    41  	"vitess.io/vitess/go/vt/sqlparser"
    42  	"vitess.io/vitess/go/vt/vterrors"
    43  )
    44  
    45  const (
    46  	// connBufferSize is how much we buffer for reading and
    47  	// writing. It is also how much we allocate for ephemeral buffers.
    48  	connBufferSize = 16 * 1024
    49  
    50  	// packetHeaderSize is the 4 bytes of header per MySQL packet
    51  	// sent over
    52  	packetHeaderSize = 4
    53  )
    54  
    55  // Constants for how ephemeral buffers were used for reading / writing.
    56  const (
    57  	// ephemeralUnused means the ephemeral buffer is not in use at this
    58  	// moment. This is the default value, and is checked so we don't
    59  	// read or write a packet while one is already used.
    60  	ephemeralUnused = iota
    61  
    62  	// ephemeralWrite means we currently in process of writing from  currentEphemeralBuffer
    63  	ephemeralWrite
    64  
    65  	// ephemeralRead means we currently in process of reading into currentEphemeralBuffer
    66  	ephemeralRead
    67  )
    68  
    69  // A Getter has a Get()
    70  type Getter interface {
    71  	Get() *querypb.VTGateCallerID
    72  }
    73  
    74  // Conn is a connection between a client and a server, using the MySQL
    75  // binary protocol. It is built on top of an existing net.Conn, that
    76  // has already been established.
    77  //
    78  // Use Connect on the client side to create a connection.
    79  // Use NewListener to create a server side and listen for connections.
    80  type Conn struct {
    81  	// fields contains the fields definitions for an on-going
    82  	// streaming query. It is set by ExecuteStreamFetch, and
    83  	// cleared by the last FetchNext().  It is nil if no streaming
    84  	// query is in progress.  If the streaming query returned no
    85  	// fields, this is set to an empty array (but not nil).
    86  	fields []*querypb.Field
    87  
    88  	// salt is sent by the server during initial handshake to be used for authentication
    89  	salt []byte
    90  
    91  	// authPluginName is the name of server's authentication plugin.
    92  	// It is set during the initial handshake.
    93  	authPluginName AuthMethodDescription
    94  
    95  	// schemaName is the default database name to use. It is set
    96  	// during handshake, and by ComInitDb packets. Both client and
    97  	// servers maintain it. This member is private because it's
    98  	// non-authoritative: the client can change the schema name
    99  	// through the 'USE' statement, which will bypass this variable.
   100  	schemaName string
   101  
   102  	// ClientData is a place where an application can store any
   103  	// connection-related data. Mostly used on the server side, to
   104  	// avoid maps indexed by ConnectionID for instance.
   105  	ClientData any
   106  
   107  	// conn is the underlying network connection.
   108  	// Calling Close() on the Conn will close this connection.
   109  	// If there are any ongoing reads or writes, they may get interrupted.
   110  	conn net.Conn
   111  
   112  	// flavor contains the auto-detected flavor for this client
   113  	// connection. It is unused for server-side connections.
   114  	flavor flavor
   115  
   116  	// ServerVersion is set during Connect with the server
   117  	// version.  It is not changed afterwards. It is unused for
   118  	// server-side connections.
   119  	ServerVersion string
   120  
   121  	// User is the name used by the client to connect.
   122  	// It is set during the initial handshake.
   123  	User string // For server-side connections, listener points to the server object.
   124  
   125  	// UserData is custom data returned by the AuthServer module.
   126  	// It is set during the initial handshake.
   127  	UserData Getter
   128  
   129  	bufferedReader *bufio.Reader
   130  	flushTimer     *time.Timer
   131  	header         [packetHeaderSize]byte
   132  
   133  	// Keep track of how and of the buffer we allocated for an
   134  	// ephemeral packet on the read and write sides.
   135  	// These fields are used by:
   136  	// - startEphemeralPacketWithHeader / writeEphemeralPacket methods for writes.
   137  	// - readEphemeralPacket / recycleReadPacket methods for reads.
   138  	currentEphemeralPolicy int
   139  	// currentEphemeralBuffer for tracking allocated temporary buffer for writes and reads respectively.
   140  	// It can be allocated from bufPool or heap and should be recycled in the same manner.
   141  	currentEphemeralBuffer *[]byte
   142  
   143  	listener *Listener
   144  
   145  	// Buffered writing has a timer which flushes on inactivity.
   146  	bufferedWriter *bufio.Writer
   147  
   148  	// PrepareData is the map to use a prepared statement.
   149  	PrepareData map[uint32]*PrepareData
   150  
   151  	// protects the bufferedWriter and bufferedReader
   152  	bufMu sync.Mutex
   153  
   154  	// Capabilities is the current set of features this connection
   155  	// is using.  It is the features that are both supported by
   156  	// the client and the server, and currently in use.
   157  	// It is set during the initial handshake.
   158  	//
   159  	// It is only used for CapabilityClientDeprecateEOF
   160  	// and CapabilityClientFoundRows.
   161  	Capabilities uint32
   162  
   163  	// closed is set to true when Close() is called on the connection.
   164  	closed sync2.AtomicBool
   165  
   166  	// ConnectionID is set:
   167  	// - at Connect() time for clients, with the value returned by
   168  	// the server.
   169  	// - at accept time for the server.
   170  	ConnectionID uint32
   171  
   172  	// StatementID is the prepared statement ID.
   173  	StatementID uint32
   174  
   175  	// StatusFlags are the status flags we will base our returned flags on.
   176  	// This is a bit field, with values documented in constants.go.
   177  	// An interesting value here would be ServerStatusAutocommit.
   178  	// It is only used by the server. These flags can be changed
   179  	// by Handler methods.
   180  	StatusFlags uint16
   181  
   182  	// CharacterSet is the charset for this connection, as negotiated
   183  	// in our handshake with the server. Note that although the MySQL protocol lists this
   184  	// as a "character set", the returned byte value is actually a Collation ID,
   185  	// and hence it's casted as such here.
   186  	// If the user has specified a custom Collation in the ConnParams for this
   187  	// connection, once the CharacterSet has been negotiated, we will override
   188  	// it via SQL and update this field accordingly.
   189  	CharacterSet collations.ID
   190  
   191  	// Packet encoding variables.
   192  	sequence uint8
   193  
   194  	// ExpectSemiSyncIndicator is applicable when the connection is used for replication (ComBinlogDump).
   195  	// When 'true', events are assumed to be padded with 2-byte semi-sync information
   196  	// See https://dev.mysql.com/doc/internals/en/semi-sync-binlog-event.html
   197  	ExpectSemiSyncIndicator bool
   198  
   199  	// enableQueryInfo controls whether we parse the INFO field in QUERY_OK packets
   200  	// See: ConnParams.EnableQueryInfo
   201  	enableQueryInfo bool
   202  }
   203  
   204  // splitStatementFunciton is the function that is used to split the statement in case of a multi-statement query.
   205  var splitStatementFunction = sqlparser.SplitStatementToPieces
   206  
   207  // PrepareData is a buffer used for store prepare statement meta data
   208  type PrepareData struct {
   209  	ParamsType  []int32
   210  	ColumnNames []string
   211  	PrepareStmt string
   212  	BindVars    map[string]*querypb.BindVariable
   213  	StatementID uint32
   214  	ParamsCount uint16
   215  }
   216  
   217  // execResult is an enum signifying the result of executing a query
   218  type execResult byte
   219  
   220  const (
   221  	execSuccess execResult = iota
   222  	execErr
   223  	connErr
   224  )
   225  
   226  // bufPool is used to allocate and free buffers in an efficient way.
   227  var bufPool = bucketpool.New(connBufferSize, MaxPacketSize)
   228  
   229  // writersPool is used for pooling bufio.Writer objects.
   230  var writersPool = sync.Pool{New: func() any { return bufio.NewWriterSize(nil, connBufferSize) }}
   231  
   232  var readersPool = sync.Pool{New: func() any { return bufio.NewReaderSize(nil, connBufferSize) }}
   233  
   234  // newConn is an internal method to create a Conn. Used by client and server
   235  // side for common creation code.
   236  func newConn(conn net.Conn) *Conn {
   237  	return &Conn{
   238  		conn:           conn,
   239  		closed:         sync2.NewAtomicBool(false),
   240  		bufferedReader: bufio.NewReaderSize(conn, connBufferSize),
   241  	}
   242  }
   243  
   244  // newServerConn should be used to create server connections.
   245  //
   246  // It stashes a reference to the listener to be able to determine if
   247  // the server is shutting down, and has the ability to control buffer
   248  // size for reads.
   249  func newServerConn(conn net.Conn, listener *Listener) *Conn {
   250  	c := &Conn{
   251  		conn:        conn,
   252  		listener:    listener,
   253  		closed:      sync2.NewAtomicBool(false),
   254  		PrepareData: make(map[uint32]*PrepareData),
   255  	}
   256  
   257  	if listener.connReadBufferSize > 0 {
   258  		var buf *bufio.Reader
   259  		if listener.connBufferPooling {
   260  			buf = readersPool.Get().(*bufio.Reader)
   261  			buf.Reset(conn)
   262  		} else {
   263  			buf = bufio.NewReaderSize(conn, listener.connReadBufferSize)
   264  		}
   265  
   266  		c.bufferedReader = buf
   267  	}
   268  
   269  	return c
   270  }
   271  
   272  // startWriterBuffering starts using buffered writes. This should
   273  // be terminated by a call to endWriteBuffering.
   274  func (c *Conn) startWriterBuffering() {
   275  	c.bufMu.Lock()
   276  	defer c.bufMu.Unlock()
   277  
   278  	c.bufferedWriter = writersPool.Get().(*bufio.Writer)
   279  	c.bufferedWriter.Reset(c.conn)
   280  }
   281  
   282  // endWriterBuffering must be called to terminate startWriteBuffering.
   283  func (c *Conn) endWriterBuffering() error {
   284  	c.bufMu.Lock()
   285  	defer c.bufMu.Unlock()
   286  
   287  	if c.bufferedWriter == nil {
   288  		return nil
   289  	}
   290  
   291  	defer func() {
   292  		c.bufferedWriter.Reset(nil)
   293  		writersPool.Put(c.bufferedWriter)
   294  		c.bufferedWriter = nil
   295  	}()
   296  
   297  	c.stopFlushTimer()
   298  	return c.bufferedWriter.Flush()
   299  }
   300  
   301  func (c *Conn) returnReader() {
   302  	if c.bufferedReader == nil {
   303  		return
   304  	}
   305  	c.bufferedReader.Reset(nil)
   306  	readersPool.Put(c.bufferedReader)
   307  }
   308  
   309  // getWriter returns the current writer. It may be either
   310  // the original connection or a wrapper. The returned unget
   311  // function must be invoked after the writing is finished.
   312  // In buffered mode, the unget starts a timer to flush any
   313  // buffered data.
   314  func (c *Conn) getWriter() (w io.Writer, unget func()) {
   315  	c.bufMu.Lock()
   316  	if c.bufferedWriter != nil {
   317  		return c.bufferedWriter, func() {
   318  			c.startFlushTimer()
   319  			c.bufMu.Unlock()
   320  		}
   321  	}
   322  	c.bufMu.Unlock()
   323  	return c.conn, func() {}
   324  }
   325  
   326  // startFlushTimer must be called while holding lock on bufMu.
   327  func (c *Conn) startFlushTimer() {
   328  	c.stopFlushTimer()
   329  	c.flushTimer = time.AfterFunc(mysqlServerFlushDelay, func() {
   330  		c.bufMu.Lock()
   331  		defer c.bufMu.Unlock()
   332  
   333  		if c.bufferedWriter == nil {
   334  			return
   335  		}
   336  		c.stopFlushTimer()
   337  		c.bufferedWriter.Flush()
   338  	})
   339  }
   340  
   341  // stopFlushTimer must be called while holding lock on bufMu.
   342  func (c *Conn) stopFlushTimer() {
   343  	if c.flushTimer != nil {
   344  		c.flushTimer.Stop()
   345  		c.flushTimer = nil
   346  	}
   347  }
   348  
   349  // getReader returns reader for connection. It can be *bufio.Reader or net.Conn
   350  // depending on which buffer size was passed to newServerConn.
   351  func (c *Conn) getReader() io.Reader {
   352  	if c.bufferedReader != nil {
   353  		return c.bufferedReader
   354  	}
   355  	return c.conn
   356  }
   357  
   358  func (c *Conn) readHeaderFrom(r io.Reader) (int, error) {
   359  	// Note io.ReadFull will return two different types of errors:
   360  	// 1. if the socket is already closed, and the go runtime knows it,
   361  	//   then ReadFull will return an error (different than EOF),
   362  	//   something like 'read: connection reset by peer'.
   363  	// 2. if the socket is not closed while we start the read,
   364  	//   but gets closed after the read is started, we'll get io.EOF.
   365  	if _, err := io.ReadFull(r, c.header[:]); err != nil {
   366  		// The special casing of propagating io.EOF up
   367  		// is used by the server side only, to suppress an error
   368  		// message if a client just disconnects.
   369  		if err == io.EOF {
   370  			return 0, err
   371  		}
   372  		if strings.HasSuffix(err.Error(), "read: connection reset by peer") {
   373  			return 0, io.EOF
   374  		}
   375  		return 0, vterrors.Wrapf(err, "io.ReadFull(header size) failed")
   376  	}
   377  
   378  	sequence := uint8(c.header[3])
   379  	if sequence != c.sequence {
   380  		return 0, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid sequence, expected %v got %v", c.sequence, sequence)
   381  	}
   382  
   383  	c.sequence++
   384  
   385  	return int(uint32(c.header[0]) | uint32(c.header[1])<<8 | uint32(c.header[2])<<16), nil
   386  }
   387  
   388  // readEphemeralPacket attempts to read a packet into buffer from sync.Pool.  Do
   389  // not use this method if the contents of the packet needs to be kept
   390  // after the next readEphemeralPacket.
   391  //
   392  // Note if the connection is closed already, an error will be
   393  // returned, and it may not be io.EOF. If the connection closes while
   394  // we are stuck waiting for data, an error will also be returned, and
   395  // it most likely will be io.EOF.
   396  func (c *Conn) readEphemeralPacket() ([]byte, error) {
   397  	if c.currentEphemeralPolicy != ephemeralUnused {
   398  		panic(vterrors.Errorf(vtrpcpb.Code_INTERNAL, "readEphemeralPacket: unexpected currentEphemeralPolicy: %v", c.currentEphemeralPolicy))
   399  	}
   400  
   401  	r := c.getReader()
   402  
   403  	length, err := c.readHeaderFrom(r)
   404  	if err != nil {
   405  		return nil, err
   406  	}
   407  
   408  	c.currentEphemeralPolicy = ephemeralRead
   409  	if length == 0 {
   410  		// This can be caused by the packet after a packet of
   411  		// exactly size MaxPacketSize.
   412  		return nil, nil
   413  	}
   414  
   415  	// Use the bufPool.
   416  	if length < MaxPacketSize {
   417  		c.currentEphemeralBuffer = bufPool.Get(length)
   418  		if _, err := io.ReadFull(r, *c.currentEphemeralBuffer); err != nil {
   419  			return nil, vterrors.Wrapf(err, "io.ReadFull(packet body of length %v) failed", length)
   420  		}
   421  		return *c.currentEphemeralBuffer, nil
   422  	}
   423  
   424  	// Much slower path, revert to allocating everything from scratch.
   425  	// We're going to concatenate a lot of data anyway, can't really
   426  	// optimize this code path easily.
   427  	data := make([]byte, length)
   428  	if _, err := io.ReadFull(r, data); err != nil {
   429  		return nil, vterrors.Wrapf(err, "io.ReadFull(packet body of length %v) failed", length)
   430  	}
   431  	for {
   432  		next, err := c.readOnePacket()
   433  		if err != nil {
   434  			return nil, err
   435  		}
   436  
   437  		if len(next) == 0 {
   438  			// Again, the packet after a packet of exactly size MaxPacketSize.
   439  			break
   440  		}
   441  
   442  		data = append(data, next...)
   443  		if len(next) < MaxPacketSize {
   444  			break
   445  		}
   446  	}
   447  
   448  	return data, nil
   449  }
   450  
   451  // readEphemeralPacketDirect attempts to read a packet from the socket directly.
   452  // It needs to be used for the first handshake packet the server receives,
   453  // so we do't buffer the SSL negotiation packet. As a shortcut, only
   454  // packets smaller than MaxPacketSize can be read here.
   455  // This function usually shouldn't be used - use readEphemeralPacket.
   456  func (c *Conn) readEphemeralPacketDirect() ([]byte, error) {
   457  	if c.currentEphemeralPolicy != ephemeralUnused {
   458  		panic(vterrors.Errorf(vtrpcpb.Code_INTERNAL, "readEphemeralPacketDirect: unexpected currentEphemeralPolicy: %v", c.currentEphemeralPolicy))
   459  	}
   460  
   461  	var r io.Reader = c.conn
   462  
   463  	length, err := c.readHeaderFrom(r)
   464  	if err != nil {
   465  		return nil, err
   466  	}
   467  
   468  	c.currentEphemeralPolicy = ephemeralRead
   469  	if length == 0 {
   470  		// This can be caused by the packet after a packet of
   471  		// exactly size MaxPacketSize.
   472  		return nil, nil
   473  	}
   474  
   475  	if length < MaxPacketSize {
   476  		c.currentEphemeralBuffer = bufPool.Get(length)
   477  		if _, err := io.ReadFull(r, *c.currentEphemeralBuffer); err != nil {
   478  			return nil, vterrors.Wrapf(err, "io.ReadFull(packet body of length %v) failed", length)
   479  		}
   480  		return *c.currentEphemeralBuffer, nil
   481  	}
   482  
   483  	return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "readEphemeralPacketDirect doesn't support more than one packet")
   484  }
   485  
   486  // recycleReadPacket recycles the read packet. It needs to be called
   487  // after readEphemeralPacket was called.
   488  func (c *Conn) recycleReadPacket() {
   489  	if c.currentEphemeralPolicy != ephemeralRead {
   490  		// Programming error.
   491  		panic(vterrors.Errorf(vtrpcpb.Code_INTERNAL, "trying to call recycleReadPacket while currentEphemeralPolicy is %d", c.currentEphemeralPolicy))
   492  	}
   493  	if c.currentEphemeralBuffer != nil {
   494  		// We are using the pool, put the buffer back in.
   495  		bufPool.Put(c.currentEphemeralBuffer)
   496  		c.currentEphemeralBuffer = nil
   497  	}
   498  	c.currentEphemeralPolicy = ephemeralUnused
   499  }
   500  
   501  // readOnePacket reads a single packet into a newly allocated buffer.
   502  func (c *Conn) readOnePacket() ([]byte, error) {
   503  	r := c.getReader()
   504  	length, err := c.readHeaderFrom(r)
   505  	if err != nil {
   506  		return nil, err
   507  	}
   508  	if length == 0 {
   509  		// This can be caused by the packet after a packet of
   510  		// exactly size MaxPacketSize.
   511  		return nil, nil
   512  	}
   513  
   514  	data := make([]byte, length)
   515  	if _, err := io.ReadFull(r, data); err != nil {
   516  		return nil, vterrors.Wrapf(err, "io.ReadFull(packet body of length %v) failed", length)
   517  	}
   518  	return data, nil
   519  }
   520  
   521  // readPacket reads a packet from the underlying connection.
   522  // It re-assembles packets that span more than one message.
   523  // This method returns a generic error, not a SQLError.
   524  func (c *Conn) readPacket() ([]byte, error) {
   525  	// Optimize for a single packet case.
   526  	data, err := c.readOnePacket()
   527  	if err != nil {
   528  		return nil, err
   529  	}
   530  
   531  	// This is a single packet.
   532  	if len(data) < MaxPacketSize {
   533  		return data, nil
   534  	}
   535  
   536  	// There is more than one packet, read them all.
   537  	for {
   538  		next, err := c.readOnePacket()
   539  		if err != nil {
   540  			return nil, err
   541  		}
   542  
   543  		if len(next) == 0 {
   544  			// Again, the packet after a packet of exactly size MaxPacketSize.
   545  			break
   546  		}
   547  
   548  		data = append(data, next...)
   549  		if len(next) < MaxPacketSize {
   550  			break
   551  		}
   552  	}
   553  
   554  	return data, nil
   555  }
   556  
   557  // ReadPacket reads a packet from the underlying connection.
   558  // it is the public API version, that returns a SQLError.
   559  // The memory for the packet is always allocated, and it is owned by the caller
   560  // after this function returns.
   561  func (c *Conn) ReadPacket() ([]byte, error) {
   562  	result, err := c.readPacket()
   563  	if err != nil {
   564  		return nil, NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err)
   565  	}
   566  	return result, err
   567  }
   568  
   569  // writePacket writes a packet, possibly cutting it into multiple
   570  // chunks.  Note this is not very efficient, as the client probably
   571  // has to build the []byte and that makes a memory copy.
   572  // Try to use startEphemeralPacketWithHeader/writeEphemeralPacket instead.
   573  //
   574  // This method returns a generic error, not a SQLError.
   575  func (c *Conn) writePacket(data []byte) error {
   576  	index := 0
   577  	dataLength := len(data) - packetHeaderSize
   578  
   579  	w, unget := c.getWriter()
   580  	defer unget()
   581  
   582  	var header [packetHeaderSize]byte
   583  	for {
   584  		// toBeSent is capped to MaxPacketSize.
   585  		toBeSent := dataLength
   586  		if toBeSent > MaxPacketSize {
   587  			toBeSent = MaxPacketSize
   588  		}
   589  
   590  		// save the first 4 bytes of the payload, we will overwrite them with the
   591  		// header below
   592  		copy(header[0:packetHeaderSize], data[index:index+packetHeaderSize])
   593  
   594  		// Compute and write the header.
   595  		data[index] = byte(toBeSent)
   596  		data[index+1] = byte(toBeSent >> 8)
   597  		data[index+2] = byte(toBeSent >> 16)
   598  		data[index+3] = c.sequence
   599  
   600  		// Write the body.
   601  		if n, err := w.Write(data[index : index+toBeSent+packetHeaderSize]); err != nil {
   602  			return vterrors.Wrapf(err, "Write(packet) failed")
   603  		} else if n != (toBeSent + packetHeaderSize) {
   604  			return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "Write(packet) returned a short write: %v < %v", n, (toBeSent + packetHeaderSize))
   605  		}
   606  
   607  		// restore the first 4 bytes once the network send is done
   608  		copy(data[index:index+packetHeaderSize], header[0:packetHeaderSize])
   609  
   610  		// Update our state.
   611  		c.sequence++
   612  		dataLength -= toBeSent
   613  		if dataLength == 0 {
   614  			if toBeSent == MaxPacketSize {
   615  				// The packet we just sent had exactly
   616  				// MaxPacketSize size, we need to
   617  				// sent a zero-size packet too.
   618  				header[0] = 0
   619  				header[1] = 0
   620  				header[2] = 0
   621  				header[3] = c.sequence
   622  				if n, err := w.Write(header[:]); err != nil {
   623  					return vterrors.Wrapf(err, "Write(empty header) failed")
   624  				} else if n != packetHeaderSize {
   625  					return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "Write(empty header) returned a short write: %v < 4", n)
   626  				}
   627  				c.sequence++
   628  			}
   629  			return nil
   630  		}
   631  		index += toBeSent
   632  	}
   633  }
   634  
   635  func (c *Conn) startEphemeralPacketWithHeader(length int) ([]byte, int) {
   636  	if c.currentEphemeralPolicy != ephemeralUnused {
   637  		panic("startEphemeralPacketWithHeader cannot be used while a packet is already started.")
   638  	}
   639  
   640  	c.currentEphemeralPolicy = ephemeralWrite
   641  	// get buffer from pool or it'll be allocated if length is too big
   642  	c.currentEphemeralBuffer = bufPool.Get(length + packetHeaderSize)
   643  	return *c.currentEphemeralBuffer, packetHeaderSize
   644  }
   645  
   646  // writeEphemeralPacket writes the packet that was allocated by
   647  // startEphemeralPacketWithHeader.
   648  func (c *Conn) writeEphemeralPacket() error {
   649  	defer c.recycleWritePacket()
   650  
   651  	switch c.currentEphemeralPolicy {
   652  	case ephemeralWrite:
   653  		if err := c.writePacket(*c.currentEphemeralBuffer); err != nil {
   654  			return vterrors.Wrapf(err, "conn %v", c.ID())
   655  		}
   656  	case ephemeralUnused, ephemeralRead:
   657  		// Programming error.
   658  		panic(vterrors.Errorf(vtrpcpb.Code_INTERNAL, "conn %v: trying to call writeEphemeralPacket while currentEphemeralPolicy is %v", c.ID(), c.currentEphemeralPolicy))
   659  	}
   660  
   661  	return nil
   662  }
   663  
   664  // recycleWritePacket recycles the write packet. It needs to be called
   665  // after writeEphemeralPacket was called.
   666  func (c *Conn) recycleWritePacket() {
   667  	if c.currentEphemeralPolicy != ephemeralWrite {
   668  		// Programming error.
   669  		panic(vterrors.Errorf(vtrpcpb.Code_INTERNAL, "trying to call recycleWritePacket while currentEphemeralPolicy is %d", c.currentEphemeralPolicy))
   670  	}
   671  	// Release our reference so the buffer can be gced
   672  	bufPool.Put(c.currentEphemeralBuffer)
   673  	c.currentEphemeralBuffer = nil
   674  	c.currentEphemeralPolicy = ephemeralUnused
   675  }
   676  
   677  // writeComQuit writes a Quit message for the server, to indicate we
   678  // want to close the connection.
   679  // Client -> Server.
   680  // Returns SQLError(CRServerGone) if it can't.
   681  func (c *Conn) writeComQuit() error {
   682  	// This is a new command, need to reset the sequence.
   683  	c.sequence = 0
   684  
   685  	data, pos := c.startEphemeralPacketWithHeader(1)
   686  	data[pos] = ComQuit
   687  	if err := c.writeEphemeralPacket(); err != nil {
   688  		return NewSQLError(CRServerGone, SSUnknownSQLState, err.Error())
   689  	}
   690  	return nil
   691  }
   692  
   693  // RemoteAddr returns the underlying socket RemoteAddr().
   694  func (c *Conn) RemoteAddr() net.Addr {
   695  	return c.conn.RemoteAddr()
   696  }
   697  
   698  // ID returns the MySQL connection ID for this connection.
   699  func (c *Conn) ID() int64 {
   700  	return int64(c.ConnectionID)
   701  }
   702  
   703  // Ident returns a useful identification string for error logging
   704  func (c *Conn) String() string {
   705  	return fmt.Sprintf("client %v (%s)", c.ConnectionID, c.RemoteAddr().String())
   706  }
   707  
   708  // Close closes the connection. It can be called from a different go
   709  // routine to interrupt the current connection.
   710  func (c *Conn) Close() {
   711  	if c.closed.CompareAndSwap(false, true) {
   712  		c.conn.Close()
   713  	}
   714  }
   715  
   716  // IsClosed returns true if this connection was ever closed by the
   717  // Close() method.  Note if the other side closes the connection, but
   718  // Close() wasn't called, this will return false.
   719  func (c *Conn) IsClosed() bool {
   720  	return c.closed.Get()
   721  }
   722  
   723  //
   724  // Packet writing methods, for generic packets.
   725  //
   726  
   727  // writeOKPacket writes an OK packet.
   728  // Server -> Client.
   729  // This method returns a generic error, not a SQLError.
   730  func (c *Conn) writeOKPacket(packetOk *PacketOK) error {
   731  	return c.writeOKPacketWithHeader(packetOk, OKPacket)
   732  }
   733  
   734  // writeOKPacketWithEOFHeader writes an OK packet with an EOF header.
   735  // This is used at the end of a result set if
   736  // CapabilityClientDeprecateEOF is set.
   737  // Server -> Client.
   738  // This method returns a generic error, not a SQLError.
   739  func (c *Conn) writeOKPacketWithEOFHeader(packetOk *PacketOK) error {
   740  	return c.writeOKPacketWithHeader(packetOk, EOFPacket)
   741  }
   742  
   743  // writeOKPacketWithEOFHeader writes an OK packet with an EOF header.
   744  // This is used at the end of a result set if
   745  // CapabilityClientDeprecateEOF is set.
   746  // Server -> Client.
   747  // This method returns a generic error, not a SQLError.
   748  func (c *Conn) writeOKPacketWithHeader(packetOk *PacketOK, headerType byte) error {
   749  	length := 1 + // OKPacket
   750  		lenEncIntSize(packetOk.affectedRows) +
   751  		lenEncIntSize(packetOk.lastInsertID)
   752  	// assuming CapabilityClientProtocol41
   753  	length += 4 // status_flags + warnings
   754  
   755  	var gtidData []byte
   756  	if c.Capabilities&CapabilityClientSessionTrack == CapabilityClientSessionTrack {
   757  		length += lenEncStringSize(packetOk.info) // info
   758  		if packetOk.statusFlags&ServerSessionStateChanged == ServerSessionStateChanged {
   759  			gtidData = getLenEncString([]byte(packetOk.sessionStateData))
   760  			gtidData = append([]byte{0x00}, gtidData...)
   761  			gtidData = getLenEncString(gtidData)
   762  			gtidData = append([]byte{0x03}, gtidData...)
   763  			gtidData = append(getLenEncInt(uint64(len(gtidData))), gtidData...)
   764  			length += len(gtidData)
   765  		}
   766  	} else {
   767  		length += len(packetOk.info) // info
   768  	}
   769  
   770  	bytes, pos := c.startEphemeralPacketWithHeader(length)
   771  	data := &coder{data: bytes, pos: pos}
   772  	data.writeByte(headerType) //header - OK or EOF
   773  	data.writeLenEncInt(packetOk.affectedRows)
   774  	data.writeLenEncInt(packetOk.lastInsertID)
   775  	data.writeUint16(packetOk.statusFlags)
   776  	data.writeUint16(packetOk.warnings)
   777  	if c.Capabilities&CapabilityClientSessionTrack == CapabilityClientSessionTrack {
   778  		data.writeLenEncString(packetOk.info)
   779  		if packetOk.statusFlags&ServerSessionStateChanged == ServerSessionStateChanged {
   780  			data.writeEOFString(string(gtidData))
   781  		}
   782  	} else {
   783  		data.writeEOFString(packetOk.info)
   784  	}
   785  	return c.writeEphemeralPacket()
   786  }
   787  
   788  func getLenEncString(value []byte) []byte {
   789  	data := getLenEncInt(uint64(len(value)))
   790  	return append(data, value...)
   791  }
   792  
   793  func getLenEncInt(i uint64) []byte {
   794  	var data []byte
   795  	switch {
   796  	case i < 251:
   797  		data = append(data, byte(i))
   798  	case i < 1<<16:
   799  		data = append(data, 0xfc)
   800  		data = append(data, byte(i))
   801  		data = append(data, byte(i>>8))
   802  	case i < 1<<24:
   803  		data = append(data, 0xfd)
   804  		data = append(data, byte(i))
   805  		data = append(data, byte(i>>8))
   806  		data = append(data, byte(i>>16))
   807  	default:
   808  		data = append(data, 0xfe)
   809  		data = append(data, byte(i))
   810  		data = append(data, byte(i>>8))
   811  		data = append(data, byte(i>>16))
   812  		data = append(data, byte(i>>24))
   813  		data = append(data, byte(i>>32))
   814  		data = append(data, byte(i>>40))
   815  		data = append(data, byte(i>>48))
   816  		data = append(data, byte(i>>56))
   817  	}
   818  	return data
   819  }
   820  
   821  func (c *Conn) WriteErrorAndLog(format string, args ...interface{}) bool {
   822  	return c.writeErrorAndLog(ERUnknownComError, SSNetError, format, args...)
   823  }
   824  
   825  func (c *Conn) writeErrorAndLog(errorCode uint16, sqlState string, format string, args ...any) bool {
   826  	if err := c.writeErrorPacket(errorCode, sqlState, format, args...); err != nil {
   827  		log.Errorf("Error writing error to %s: %v", c, err)
   828  		return false
   829  	}
   830  	return true
   831  }
   832  
   833  func (c *Conn) writeErrorPacketFromErrorAndLog(err error) bool {
   834  	werr := c.writeErrorPacketFromError(err)
   835  	if werr != nil {
   836  		log.Errorf("Error writing error to %s: %v", c, werr)
   837  		return false
   838  	}
   839  	return true
   840  }
   841  
   842  // writeErrorPacket writes an error packet.
   843  // Server -> Client.
   844  // This method returns a generic error, not a SQLError.
   845  func (c *Conn) writeErrorPacket(errorCode uint16, sqlState string, format string, args ...any) error {
   846  	errorMessage := fmt.Sprintf(format, args...)
   847  	length := 1 + 2 + 1 + 5 + len(errorMessage)
   848  	data, pos := c.startEphemeralPacketWithHeader(length)
   849  	pos = writeByte(data, pos, ErrPacket)
   850  	pos = writeUint16(data, pos, errorCode)
   851  	pos = writeByte(data, pos, '#')
   852  	if sqlState == "" {
   853  		sqlState = SSUnknownSQLState
   854  	}
   855  	if len(sqlState) != 5 {
   856  		panic("sqlState has to be 5 characters long")
   857  	}
   858  	pos = writeEOFString(data, pos, sqlState)
   859  	_ = writeEOFString(data, pos, errorMessage)
   860  
   861  	return c.writeEphemeralPacket()
   862  }
   863  
   864  // writeErrorPacketFromError writes an error packet, from a regular error.
   865  // See writeErrorPacket for other info.
   866  func (c *Conn) writeErrorPacketFromError(err error) error {
   867  	if se, ok := err.(*SQLError); ok {
   868  		return c.writeErrorPacket(uint16(se.Num), se.State, "%v", se.Message)
   869  	}
   870  
   871  	return c.writeErrorPacket(ERUnknownError, SSUnknownSQLState, "unknown error: %v", err)
   872  }
   873  
   874  // writeEOFPacket writes an EOF packet, through the buffer, and
   875  // doesn't flush (as it is used as part of a query result).
   876  func (c *Conn) writeEOFPacket(flags uint16, warnings uint16) error {
   877  	length := 5
   878  	data, pos := c.startEphemeralPacketWithHeader(length)
   879  	pos = writeByte(data, pos, EOFPacket)
   880  	pos = writeUint16(data, pos, warnings)
   881  	_ = writeUint16(data, pos, flags)
   882  
   883  	return c.writeEphemeralPacket()
   884  }
   885  
   886  // handleNextCommand is called in the server loop to process
   887  // incoming packets.
   888  func (c *Conn) handleNextCommand(handler Handler) bool {
   889  	c.sequence = 0
   890  	data, err := c.readEphemeralPacket()
   891  	if err != nil {
   892  		// Don't log EOF errors. They cause too much spam.
   893  		if err != io.EOF && !strings.Contains(err.Error(), "use of closed network connection") {
   894  			log.Errorf("Error reading packet from %s: %v", c, err)
   895  		}
   896  		return false
   897  	}
   898  	if len(data) == 0 {
   899  		return false
   900  	}
   901  
   902  	switch data[0] {
   903  	case ComQuit:
   904  		c.recycleReadPacket()
   905  		return false
   906  	case ComInitDB:
   907  		db := c.parseComInitDB(data)
   908  		c.recycleReadPacket()
   909  		res := c.execQuery("use "+sqlescape.EscapeID(db), handler, false)
   910  		return res != connErr
   911  	case ComQuery:
   912  		return c.handleComQuery(handler, data)
   913  	case ComPing:
   914  		return c.handleComPing()
   915  	case ComSetOption:
   916  		return c.handleComSetOption(data)
   917  	case ComPrepare:
   918  		return c.handleComPrepare(handler, data)
   919  	case ComStmtExecute:
   920  		return c.handleComStmtExecute(handler, data)
   921  	case ComStmtSendLongData:
   922  		return c.handleComStmtSendLongData(data)
   923  	case ComStmtClose:
   924  		stmtID, ok := c.parseComStmtClose(data)
   925  		c.recycleReadPacket()
   926  		if ok {
   927  			delete(c.PrepareData, stmtID)
   928  		}
   929  	case ComStmtReset:
   930  		return c.handleComStmtReset(data)
   931  	case ComResetConnection:
   932  		c.handleComResetConnection(handler)
   933  		return true
   934  	case ComFieldList:
   935  		c.recycleReadPacket()
   936  		if !c.writeErrorAndLog(ERUnknownComError, SSNetError, "command handling not implemented yet: %v", data[0]) {
   937  			return false
   938  		}
   939  	case ComBinlogDump:
   940  		return c.handleComBinlogDump(handler, data)
   941  	case ComBinlogDumpGTID:
   942  		return c.handleComBinlogDumpGTID(handler, data)
   943  	case ComRegisterReplica:
   944  		return c.handleComRegisterReplica(handler, data)
   945  	default:
   946  		log.Errorf("Got unhandled packet (default) from %s, returning error: %v", c, data)
   947  		c.recycleReadPacket()
   948  		if !c.writeErrorAndLog(ERUnknownComError, SSNetError, "command handling not implemented yet: %v", data[0]) {
   949  			return false
   950  		}
   951  	}
   952  
   953  	return true
   954  }
   955  
   956  func (c *Conn) handleComRegisterReplica(handler Handler, data []byte) (kontinue bool) {
   957  	c.recycleReadPacket()
   958  
   959  	replicaHost, replicaPort, replicaUser, replicaPassword, err := c.parseComRegisterReplica(data)
   960  	if err != nil {
   961  		log.Errorf("conn %v: parseComRegisterReplica failed: %v", c.ID(), err)
   962  		return false
   963  	}
   964  	if err := handler.ComRegisterReplica(c, replicaHost, replicaPort, replicaUser, replicaPassword); err != nil {
   965  		c.writeErrorPacketFromError(err)
   966  		return false
   967  	}
   968  	if err := c.writeOKPacket(&PacketOK{}); err != nil {
   969  		c.writeErrorPacketFromError(err)
   970  	}
   971  	return true
   972  }
   973  
   974  func (c *Conn) handleComBinlogDump(handler Handler, data []byte) (kontinue bool) {
   975  	c.recycleReadPacket()
   976  	kontinue = true
   977  
   978  	c.startWriterBuffering()
   979  	defer func() {
   980  		if err := c.endWriterBuffering(); err != nil {
   981  			log.Errorf("conn %v: flush() failed: %v", c.ID(), err)
   982  			kontinue = false
   983  		}
   984  	}()
   985  
   986  	logfile, binlogPos, err := c.parseComBinlogDump(data)
   987  	if err != nil {
   988  		log.Errorf("conn %v: parseComBinlogDumpGTID failed: %v", c.ID(), err)
   989  		return false
   990  	}
   991  	if err := handler.ComBinlogDump(c, logfile, binlogPos); err != nil {
   992  		log.Error(err.Error())
   993  		return false
   994  	}
   995  	return kontinue
   996  }
   997  
   998  func (c *Conn) handleComBinlogDumpGTID(handler Handler, data []byte) (kontinue bool) {
   999  	c.recycleReadPacket()
  1000  	kontinue = true
  1001  
  1002  	c.startWriterBuffering()
  1003  	defer func() {
  1004  		if err := c.endWriterBuffering(); err != nil {
  1005  			log.Errorf("conn %v: flush() failed: %v", c.ID(), err)
  1006  			kontinue = false
  1007  		}
  1008  	}()
  1009  
  1010  	logFile, logPos, position, err := c.parseComBinlogDumpGTID(data)
  1011  	if err != nil {
  1012  		log.Errorf("conn %v: parseComBinlogDumpGTID failed: %v", c.ID(), err)
  1013  		return false
  1014  	}
  1015  	if err := handler.ComBinlogDumpGTID(c, logFile, logPos, position.GTIDSet); err != nil {
  1016  		log.Error(err.Error())
  1017  		return false
  1018  	}
  1019  	return kontinue
  1020  }
  1021  
  1022  func (c *Conn) handleComResetConnection(handler Handler) {
  1023  	// Clean up and reset the connection
  1024  	c.recycleReadPacket()
  1025  	handler.ComResetConnection(c)
  1026  	// Reset prepared statements
  1027  	c.PrepareData = make(map[uint32]*PrepareData)
  1028  	err := c.writeOKPacket(&PacketOK{})
  1029  	if err != nil {
  1030  		c.writeErrorPacketFromError(err)
  1031  	}
  1032  }
  1033  
  1034  func (c *Conn) handleComStmtReset(data []byte) bool {
  1035  	stmtID, ok := c.parseComStmtReset(data)
  1036  	c.recycleReadPacket()
  1037  	if !ok {
  1038  		log.Error("Got unhandled packet from client %v, returning error: %v", c.ConnectionID, data)
  1039  		if !c.writeErrorAndLog(ERUnknownComError, SSNetError, "error handling packet: %v", data) {
  1040  			return false
  1041  		}
  1042  	}
  1043  
  1044  	prepare, ok := c.PrepareData[stmtID]
  1045  	if !ok {
  1046  		log.Error("Commands were executed in an improper order from client %v, packet: %v", c.ConnectionID, data)
  1047  		if !c.writeErrorAndLog(CRCommandsOutOfSync, SSNetError, "commands were executed in an improper order: %v", data) {
  1048  			return false
  1049  		}
  1050  	}
  1051  
  1052  	if prepare.BindVars != nil {
  1053  		for k := range prepare.BindVars {
  1054  			prepare.BindVars[k] = nil
  1055  		}
  1056  	}
  1057  
  1058  	if err := c.writeOKPacket(&PacketOK{statusFlags: c.StatusFlags}); err != nil {
  1059  		log.Error("Error writing ComStmtReset OK packet to client %v: %v", c.ConnectionID, err)
  1060  		return false
  1061  	}
  1062  	return true
  1063  }
  1064  
  1065  func (c *Conn) handleComStmtSendLongData(data []byte) bool {
  1066  	stmtID, paramID, chunk, ok := c.parseComStmtSendLongData(data)
  1067  	c.recycleReadPacket()
  1068  	if !ok {
  1069  		err := fmt.Errorf("error parsing statement send long data from client %v, returning error: %v", c.ConnectionID, data)
  1070  		return c.writeErrorPacketFromErrorAndLog(err)
  1071  	}
  1072  
  1073  	prepare, ok := c.PrepareData[stmtID]
  1074  	if !ok {
  1075  		err := fmt.Errorf("got wrong statement id from client %v, statement ID(%v) is not found from record", c.ConnectionID, stmtID)
  1076  		return c.writeErrorPacketFromErrorAndLog(err)
  1077  	}
  1078  
  1079  	if prepare.BindVars == nil ||
  1080  		prepare.ParamsCount == uint16(0) ||
  1081  		paramID >= prepare.ParamsCount {
  1082  		err := fmt.Errorf("invalid parameter Number from client %v, statement: %v", c.ConnectionID, prepare.PrepareStmt)
  1083  		return c.writeErrorPacketFromErrorAndLog(err)
  1084  	}
  1085  
  1086  	key := fmt.Sprintf("v%d", paramID+1)
  1087  	if val, ok := prepare.BindVars[key]; ok {
  1088  		val.Value = append(val.Value, chunk...)
  1089  	} else {
  1090  		prepare.BindVars[key] = sqltypes.BytesBindVariable(chunk)
  1091  	}
  1092  	return true
  1093  }
  1094  
  1095  func (c *Conn) handleComStmtExecute(handler Handler, data []byte) (kontinue bool) {
  1096  	c.startWriterBuffering()
  1097  	defer func() {
  1098  		if err := c.endWriterBuffering(); err != nil {
  1099  			log.Errorf("conn %v: flush() failed: %v", c.ID(), err)
  1100  			kontinue = false
  1101  		}
  1102  	}()
  1103  	queryStart := time.Now()
  1104  	stmtID, _, err := c.parseComStmtExecute(c.PrepareData, data)
  1105  	c.recycleReadPacket()
  1106  
  1107  	if stmtID != uint32(0) {
  1108  		defer func() {
  1109  			// Allocate a new bindvar map every time since VTGate.Execute() mutates it.
  1110  			prepare := c.PrepareData[stmtID]
  1111  			prepare.BindVars = make(map[string]*querypb.BindVariable, prepare.ParamsCount)
  1112  		}()
  1113  	}
  1114  
  1115  	if err != nil {
  1116  		return c.writeErrorPacketFromErrorAndLog(err)
  1117  	}
  1118  
  1119  	fieldSent := false
  1120  	// sendFinished is set if the response should just be an OK packet.
  1121  	sendFinished := false
  1122  	prepare := c.PrepareData[stmtID]
  1123  	err = handler.ComStmtExecute(c, prepare, func(qr *sqltypes.Result) error {
  1124  		if sendFinished {
  1125  			// Failsafe: Unreachable if server is well-behaved.
  1126  			return io.EOF
  1127  		}
  1128  
  1129  		if !fieldSent {
  1130  			fieldSent = true
  1131  
  1132  			if len(qr.Fields) == 0 {
  1133  				sendFinished = true
  1134  				// We should not send any more packets after this.
  1135  				ok := PacketOK{
  1136  					affectedRows:     qr.RowsAffected,
  1137  					lastInsertID:     qr.InsertID,
  1138  					statusFlags:      c.StatusFlags,
  1139  					warnings:         0,
  1140  					info:             "",
  1141  					sessionStateData: qr.SessionStateChanges,
  1142  				}
  1143  				return c.writeOKPacket(&ok)
  1144  			}
  1145  			if err := c.writeFields(qr); err != nil {
  1146  				return err
  1147  			}
  1148  		}
  1149  
  1150  		return c.writeBinaryRows(qr)
  1151  	})
  1152  
  1153  	// If no field was sent, we expect an error.
  1154  	if !fieldSent {
  1155  		// This is just a failsafe. Should never happen.
  1156  		if err == nil || err == io.EOF {
  1157  			err = NewSQLErrorFromError(errors.New("unexpected: query ended without no results and no error"))
  1158  		}
  1159  		if !c.writeErrorPacketFromErrorAndLog(err) {
  1160  			return false
  1161  		}
  1162  	} else {
  1163  		if err != nil {
  1164  			// We can't send an error in the middle of a stream.
  1165  			// All we can do is abort the send, which will cause a 2013.
  1166  			log.Errorf("Error in the middle of a stream to %s: %v", c, err)
  1167  			return false
  1168  		}
  1169  
  1170  		// Send the end packet only sendFinished is false (results were streamed).
  1171  		// In this case the affectedRows and lastInsertID are always 0 since it
  1172  		// was a read operation.
  1173  		if !sendFinished {
  1174  			if err := c.writeEndResult(false, 0, 0, handler.WarningCount(c)); err != nil {
  1175  				log.Errorf("Error writing result to %s: %v", c, err)
  1176  				return false
  1177  			}
  1178  		}
  1179  	}
  1180  
  1181  	timings.Record(queryTimingKey, queryStart)
  1182  	return true
  1183  }
  1184  
  1185  func (c *Conn) handleComPrepare(handler Handler, data []byte) (kontinue bool) {
  1186  	c.startWriterBuffering()
  1187  	defer func() {
  1188  		if err := c.endWriterBuffering(); err != nil {
  1189  			log.Errorf("conn %v: flush() failed: %v", c.ID(), err)
  1190  			kontinue = false
  1191  		}
  1192  	}()
  1193  
  1194  	query := c.parseComPrepare(data)
  1195  	c.recycleReadPacket()
  1196  
  1197  	var queries []string
  1198  	if c.Capabilities&CapabilityClientMultiStatements != 0 {
  1199  		var err error
  1200  		queries, err = splitStatementFunction(query)
  1201  		if err != nil {
  1202  			log.Errorf("Conn %v: Error splitting query: %v", c, err)
  1203  			return c.writeErrorPacketFromErrorAndLog(err)
  1204  		}
  1205  		if len(queries) != 1 {
  1206  			log.Errorf("Conn %v: can not prepare multiple statements", c)
  1207  			return c.writeErrorPacketFromErrorAndLog(err)
  1208  		}
  1209  	} else {
  1210  		queries = []string{query}
  1211  	}
  1212  
  1213  	// Popoulate PrepareData
  1214  	c.StatementID++
  1215  	prepare := &PrepareData{
  1216  		StatementID: c.StatementID,
  1217  		PrepareStmt: queries[0],
  1218  	}
  1219  
  1220  	statement, err := sqlparser.ParseStrictDDL(query)
  1221  	if err != nil {
  1222  		log.Errorf("Conn %v: Error parsing prepared statement: %v", c, err)
  1223  		if !c.writeErrorPacketFromErrorAndLog(err) {
  1224  			return false
  1225  		}
  1226  	}
  1227  
  1228  	paramsCount := uint16(0)
  1229  	_ = sqlparser.Walk(func(node sqlparser.SQLNode) (bool, error) {
  1230  		switch node := node.(type) {
  1231  		case sqlparser.Argument:
  1232  			if strings.HasPrefix(string(node), "v") {
  1233  				paramsCount++
  1234  			}
  1235  		}
  1236  		return true, nil
  1237  	}, statement)
  1238  
  1239  	if paramsCount > 0 {
  1240  		prepare.ParamsCount = paramsCount
  1241  		prepare.ParamsType = make([]int32, paramsCount)
  1242  		prepare.BindVars = make(map[string]*querypb.BindVariable, paramsCount)
  1243  	}
  1244  
  1245  	bindVars := make(map[string]*querypb.BindVariable, paramsCount)
  1246  	for i := uint16(0); i < paramsCount; i++ {
  1247  		parameterID := fmt.Sprintf("v%d", i+1)
  1248  		bindVars[parameterID] = &querypb.BindVariable{}
  1249  	}
  1250  
  1251  	c.PrepareData[c.StatementID] = prepare
  1252  
  1253  	fld, err := handler.ComPrepare(c, queries[0], bindVars)
  1254  
  1255  	if err != nil {
  1256  		return c.writeErrorPacketFromErrorAndLog(err)
  1257  	}
  1258  
  1259  	if err := c.writePrepare(fld, c.PrepareData[c.StatementID]); err != nil {
  1260  		log.Error("Error writing prepare data to client %v: %v", c.ConnectionID, err)
  1261  		return false
  1262  	}
  1263  	return true
  1264  }
  1265  
  1266  func (c *Conn) handleComSetOption(data []byte) bool {
  1267  	operation, ok := c.parseComSetOption(data)
  1268  	c.recycleReadPacket()
  1269  	if ok {
  1270  		switch operation {
  1271  		case 0:
  1272  			c.Capabilities |= CapabilityClientMultiStatements
  1273  		case 1:
  1274  			c.Capabilities &^= CapabilityClientMultiStatements
  1275  		default:
  1276  			log.Errorf("Got unhandled packet (ComSetOption default) from client %v, returning error: %v", c.ConnectionID, data)
  1277  			if !c.writeErrorAndLog(ERUnknownComError, SSNetError, "error handling packet: %v", data) {
  1278  				return false
  1279  			}
  1280  		}
  1281  		if err := c.writeEndResult(false, 0, 0, 0); err != nil {
  1282  			log.Errorf("Error writeEndResult error %v ", err)
  1283  			return false
  1284  		}
  1285  	} else {
  1286  		log.Errorf("Got unhandled packet (ComSetOption else) from client %v, returning error: %v", c.ConnectionID, data)
  1287  		if !c.writeErrorAndLog(ERUnknownComError, SSNetError, "error handling packet: %v", data) {
  1288  			return false
  1289  		}
  1290  	}
  1291  	return true
  1292  }
  1293  
  1294  func (c *Conn) handleComPing() bool {
  1295  	c.recycleReadPacket()
  1296  	// Return error if listener was shut down and OK otherwise
  1297  	if c.listener.isShutdown() {
  1298  		if !c.writeErrorAndLog(ERServerShutdown, SSNetError, "Server shutdown in progress") {
  1299  			return false
  1300  		}
  1301  	} else {
  1302  		if err := c.writeOKPacket(&PacketOK{statusFlags: c.StatusFlags}); err != nil {
  1303  			log.Errorf("Error writing ComPing result to %s: %v", c, err)
  1304  			return false
  1305  		}
  1306  	}
  1307  	return true
  1308  }
  1309  
  1310  var errEmptyStatement = NewSQLError(EREmptyQuery, SSClientError, "Query was empty")
  1311  
  1312  func (c *Conn) handleComQuery(handler Handler, data []byte) (kontinue bool) {
  1313  	c.startWriterBuffering()
  1314  	defer func() {
  1315  		if err := c.endWriterBuffering(); err != nil {
  1316  			log.Errorf("conn %v: flush() failed: %v", c.ID(), err)
  1317  			kontinue = false
  1318  		}
  1319  	}()
  1320  
  1321  	queryStart := time.Now()
  1322  	query := c.parseComQuery(data)
  1323  	c.recycleReadPacket()
  1324  
  1325  	var queries []string
  1326  	var err error
  1327  	if c.Capabilities&CapabilityClientMultiStatements != 0 {
  1328  		queries, err = splitStatementFunction(query)
  1329  		if err != nil {
  1330  			log.Errorf("Conn %v: Error splitting query: %v", c, err)
  1331  			return c.writeErrorPacketFromErrorAndLog(err)
  1332  		}
  1333  	} else {
  1334  		queries = []string{query}
  1335  	}
  1336  
  1337  	if len(queries) == 0 {
  1338  		return c.writeErrorPacketFromErrorAndLog(errEmptyStatement)
  1339  	}
  1340  
  1341  	for index, sql := range queries {
  1342  		more := false
  1343  		if index != len(queries)-1 {
  1344  			more = true
  1345  		}
  1346  		res := c.execQuery(sql, handler, more)
  1347  		if res != execSuccess {
  1348  			return res != connErr
  1349  		}
  1350  	}
  1351  
  1352  	timings.Record(queryTimingKey, queryStart)
  1353  	return true
  1354  }
  1355  
  1356  func (c *Conn) execQuery(query string, handler Handler, more bool) execResult {
  1357  	callbackCalled := false
  1358  	// sendFinished is set if the response should just be an OK packet.
  1359  	sendFinished := false
  1360  
  1361  	err := handler.ComQuery(c, query, func(qr *sqltypes.Result) error {
  1362  		flag := c.StatusFlags
  1363  		if more {
  1364  			flag |= ServerMoreResultsExists
  1365  		}
  1366  		if sendFinished {
  1367  			// Failsafe: Unreachable if server is well-behaved.
  1368  			return io.EOF
  1369  		}
  1370  
  1371  		if !callbackCalled {
  1372  			callbackCalled = true
  1373  
  1374  			if len(qr.Fields) == 0 {
  1375  				sendFinished = true
  1376  
  1377  				// A successful callback with no fields means that this was a
  1378  				// DML or other write-only operation.
  1379  				//
  1380  				// We should not send any more packets after this, but make sure
  1381  				// to extract the affected rows and last insert id from the result
  1382  				// struct here since clients expect it.
  1383  				ok := PacketOK{
  1384  					affectedRows:     qr.RowsAffected,
  1385  					lastInsertID:     qr.InsertID,
  1386  					statusFlags:      flag,
  1387  					warnings:         handler.WarningCount(c),
  1388  					info:             "",
  1389  					sessionStateData: qr.SessionStateChanges,
  1390  				}
  1391  				return c.writeOKPacket(&ok)
  1392  			}
  1393  			if err := c.writeFields(qr); err != nil {
  1394  				return err
  1395  			}
  1396  		}
  1397  
  1398  		return c.writeRows(qr)
  1399  	})
  1400  
  1401  	// If callback was not called, we expect an error.
  1402  	if !callbackCalled {
  1403  		// This is just a failsafe. Should never happen.
  1404  		if err == nil || err == io.EOF {
  1405  			err = NewSQLErrorFromError(errors.New("unexpected: query ended without no results and no error"))
  1406  		}
  1407  		if !c.writeErrorPacketFromErrorAndLog(err) {
  1408  			return connErr
  1409  		}
  1410  		return execErr
  1411  	}
  1412  	if err != nil {
  1413  		// We can't send an error in the middle of a stream.
  1414  		// All we can do is abort the send, which will cause a 2013.
  1415  		log.Errorf("Error in the middle of a stream to %s: %v", c, err)
  1416  		return connErr
  1417  	}
  1418  
  1419  	// Send the end packet only sendFinished is false (results were streamed).
  1420  	// In this case the affectedRows and lastInsertID are always 0 since it
  1421  	// was a read operation.
  1422  	if !sendFinished {
  1423  		if err := c.writeEndResult(more, 0, 0, handler.WarningCount(c)); err != nil {
  1424  			log.Errorf("Error writing result to %s: %v", c, err)
  1425  			return connErr
  1426  		}
  1427  	}
  1428  
  1429  	return execSuccess
  1430  }
  1431  
  1432  //
  1433  // Packet parsing methods, for generic packets.
  1434  //
  1435  
  1436  // isEOFPacket determines whether a data packet is an EOF. In case the client capabilities
  1437  // do not have DEPRECATE_EOF set, DO NOT blindly compare the first byte of a packet to EOFPacket
  1438  // as you might do for other packet types, as 0xfe is overloaded as a first byte.
  1439  
  1440  // In case that DEPRECATE_EOF is set, we have really an OK packet which is always maximum a single
  1441  // packet and not multiple, but otherwise 0xfe definitely indicates it is an EOF.
  1442  //
  1443  // Per https://dev.mysql.com/doc/internals/en/packet-EOF_Packet.html, a packet starting with 0xfe
  1444  // but having length >= 9 (on top of 4 byte header)  without DEPRECATE_EOF set is not a true EOF but
  1445  // a LengthEncodedInteger (typically preceding a LengthEncodedString). Thus, all EOF checks without
  1446  // DEPRECATE_EOF must validate the payload size before exiting.
  1447  //
  1448  // More docs here:
  1449  // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_response_packets.html
  1450  func (c *Conn) isEOFPacket(data []byte) bool {
  1451  	if data[0] != EOFPacket {
  1452  		return false
  1453  	}
  1454  	if c.Capabilities&CapabilityClientDeprecateEOF == 0 {
  1455  		return len(data) < 9
  1456  	}
  1457  	return len(data) < MaxPacketSize
  1458  }
  1459  
  1460  // parseEOFPacket returns the warning count and a boolean to indicate if there
  1461  // are more results to receive.
  1462  //
  1463  // Note: This is only valid on actual EOF packets and not on OK packets with the EOF
  1464  // type code set, i.e. should not be used if ClientDeprecateEOF is set.
  1465  func parseEOFPacket(data []byte) (warnings uint16, statusFlags uint16, err error) {
  1466  	// The warning count is in position 2 & 3
  1467  	warnings, _, _ = readUint16(data, 1)
  1468  
  1469  	// The status flag is in position 4 & 5
  1470  	statusFlags, _, ok := readUint16(data, 3)
  1471  	if !ok {
  1472  		return 0, 0, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid EOF packet statusFlags: %v", data)
  1473  	}
  1474  	return warnings, statusFlags, nil
  1475  }
  1476  
  1477  // PacketOK contains the ok packet details
  1478  type PacketOK struct {
  1479  	affectedRows uint64
  1480  	lastInsertID uint64
  1481  	statusFlags  uint16
  1482  	warnings     uint16
  1483  	info         string
  1484  
  1485  	// at the moment, we only store GTID information in this field
  1486  	sessionStateData string
  1487  }
  1488  
  1489  func (c *Conn) parseOKPacket(in []byte) (*PacketOK, error) {
  1490  	data := &coder{
  1491  		data: in,
  1492  		pos:  1, // We already read the type.
  1493  	}
  1494  	packetOK := &PacketOK{}
  1495  
  1496  	fail := func(format string, args ...any) (*PacketOK, error) {
  1497  		return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, format, args...)
  1498  	}
  1499  
  1500  	// Affected rows.
  1501  	affectedRows, ok := data.readLenEncInt()
  1502  	if !ok {
  1503  		return fail("invalid OK packet affectedRows: %v", data)
  1504  	}
  1505  	packetOK.affectedRows = affectedRows
  1506  
  1507  	// Last Insert ID.
  1508  	lastInsertID, ok := data.readLenEncInt()
  1509  	if !ok {
  1510  		return fail("invalid OK packet lastInsertID: %v", data)
  1511  	}
  1512  	packetOK.lastInsertID = lastInsertID
  1513  
  1514  	// Status flags.
  1515  	statusFlags, ok := data.readUint16()
  1516  	if !ok {
  1517  		return fail("invalid OK packet statusFlags: %v", data)
  1518  	}
  1519  	packetOK.statusFlags = statusFlags
  1520  
  1521  	// assuming CapabilityClientProtocol41
  1522  	// Warnings.
  1523  	warnings, ok := data.readUint16()
  1524  	if !ok {
  1525  		return fail("invalid OK packet warnings: %v", data)
  1526  	}
  1527  	packetOK.warnings = warnings
  1528  
  1529  	// info
  1530  	info, _ := data.readLenEncInfo()
  1531  	if c.enableQueryInfo {
  1532  		packetOK.info = info
  1533  	}
  1534  
  1535  	if c.Capabilities&uint32(CapabilityClientSessionTrack) == CapabilityClientSessionTrack {
  1536  		// session tracking
  1537  		if statusFlags&ServerSessionStateChanged == ServerSessionStateChanged {
  1538  			length, ok := data.readLenEncInt()
  1539  			if !ok || length == 0 {
  1540  				// In case we have no more data or a zero length string, there's no additional information so
  1541  				// we can return the packet.
  1542  				return packetOK, nil
  1543  			}
  1544  
  1545  			// Alright, now we need to read each sub packet from the session state change.
  1546  			for {
  1547  				sscType, ok := data.readByte()
  1548  				if !ok {
  1549  					// We're done, there's no more session state parts in the packet.
  1550  					break
  1551  				}
  1552  				sessionLen, ok := data.readLenEncInt()
  1553  				if !ok {
  1554  					return fail("invalid OK packet session state change length for type %v", sscType)
  1555  				}
  1556  
  1557  				if sscType != SessionTrackGtids {
  1558  					// Still need to increase the pointer here to indicate we're consuming
  1559  					// but otherwise ignoring the rest of this packet
  1560  					data.pos = data.pos + int(sessionLen)
  1561  					continue
  1562  				}
  1563  
  1564  				// read (and ignore for now) the GTIDS encoding specification code: 1 byte
  1565  				_, ok = data.readByte()
  1566  				if !ok {
  1567  					return fail("invalid OK packet gtids type: %v", data)
  1568  				}
  1569  
  1570  				gtids, ok := data.readLenEncString()
  1571  				if !ok {
  1572  					return fail("invalid OK packet gtids: %v", data)
  1573  				}
  1574  				packetOK.sessionStateData = gtids
  1575  			}
  1576  		}
  1577  	}
  1578  
  1579  	return packetOK, nil
  1580  }
  1581  
  1582  // isErrorPacket determines whether or not the packet is an error packet. Mostly here for
  1583  // consistency with isEOFPacket
  1584  func isErrorPacket(data []byte) bool {
  1585  	return data[0] == ErrPacket
  1586  }
  1587  
  1588  // ParseErrorPacket parses the error packet and returns a SQLError.
  1589  func ParseErrorPacket(data []byte) error {
  1590  	// We already read the type.
  1591  	pos := 1
  1592  
  1593  	// Error code is 2 bytes.
  1594  	code, pos, ok := readUint16(data, pos)
  1595  	if !ok {
  1596  		return NewSQLError(CRUnknownError, SSUnknownSQLState, "invalid error packet code: %v", data)
  1597  	}
  1598  
  1599  	// '#' marker of the SQL state is 1 byte. Ignored.
  1600  	pos++
  1601  
  1602  	// SQL state is 5 bytes
  1603  	sqlState, pos, ok := readBytes(data, pos, 5)
  1604  	if !ok {
  1605  		return NewSQLError(CRUnknownError, SSUnknownSQLState, "invalid error packet sqlState: %v", data)
  1606  	}
  1607  
  1608  	// Human readable error message is the rest.
  1609  	msg := string(data[pos:])
  1610  
  1611  	return NewSQLError(int(code), string(sqlState), "%v", msg)
  1612  }
  1613  
  1614  // GetTLSClientCerts gets TLS certificates.
  1615  func (c *Conn) GetTLSClientCerts() []*x509.Certificate {
  1616  	if tlsConn, ok := c.conn.(*tls.Conn); ok {
  1617  		return tlsConn.ConnectionState().PeerCertificates
  1618  	}
  1619  	return nil
  1620  }
  1621  
  1622  // TLSEnabled returns true if this connection is using TLS.
  1623  func (c *Conn) TLSEnabled() bool {
  1624  	return c.Capabilities&CapabilityClientSSL > 0
  1625  }
  1626  
  1627  // IsUnixSocket returns true if this connection is over a Unix socket.
  1628  func (c *Conn) IsUnixSocket() bool {
  1629  	_, ok := c.listener.listener.(*net.UnixListener)
  1630  	return ok
  1631  }
  1632  
  1633  // GetRawConn returns the raw net.Conn for nefarious purposes.
  1634  func (c *Conn) GetRawConn() net.Conn {
  1635  	return c.conn
  1636  }