github.com/XiaoMi/Gaea@v1.2.5/mysql/conn.go (about)

     1  /*
     2  Copyright 2017 Google Inc.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreedto in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  // Copyright 2019 The Gaea Authors. All Rights Reserved.
    18  //
    19  // Licensed under the Apache License, Version 2.0 (the "License");
    20  // you may not use this file except in compliance with the License.
    21  // You may obtain a copy of the License at
    22  //
    23  //     http://www.apache.org/licenses/LICENSE-2.0
    24  //
    25  // Unless required by applicable law or agreed to in writing, software
    26  // distributed under the License is distributed on an "AS IS" BASIS,
    27  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    28  // See the License for the specific language governing permissions and
    29  // limitations under the License.
    30  
    31  package mysql
    32  
    33  import (
    34  	"bufio"
    35  	"errors"
    36  	"fmt"
    37  	"io"
    38  	"net"
    39  	"strings"
    40  	"sync"
    41  
    42  	"github.com/XiaoMi/Gaea/util/bucketpool"
    43  	"github.com/XiaoMi/Gaea/util/sync2"
    44  )
    45  
    46  const (
    47  	// connBufferSize is how much we buffer for reading and
    48  	// writing. It is also how much we allocate for ephemeral buffers.
    49  	connBufferSize = 128
    50  
    51  	// MaxPacketSize is the maximum payload length of a packet(16MB)
    52  	// the server supports.
    53  	MaxPacketSize = (1 << 24) - 1
    54  )
    55  
    56  // Constants for how ephemeral buffers were used for reading / writing.
    57  const (
    58  	// ephemeralUnused means the ephemeral buffer is not in use at this
    59  	// moment. This is the default value, and is checked so we don't
    60  	// read or write a packet while one is already used.
    61  	ephemeralUnused = iota
    62  
    63  	// ephemeralWrite means we currently in process of writing from  currentEphemeralBuffer
    64  	ephemeralWrite
    65  
    66  	// ephemeralRead means we currently in process of reading into currentEphemeralBuffer
    67  	ephemeralRead
    68  )
    69  
    70  // Conn is a connection between a client and a server, using the MySQL
    71  // binary protocol. It is built on top of an existing net.Conn, that
    72  // has already been established.
    73  //
    74  // Use Connect on the client side to create a connection.
    75  // Use NewListener to create a server side and listen for connections.
    76  type Conn struct {
    77  	// conn is the underlying network connection.
    78  	// Calling Close() on the Conn will close this connection.
    79  	// If there are any ongoing reads or writes, they may get interrupted.
    80  	conn net.Conn
    81  
    82  	// ConnectionID is set:
    83  	// - at Connect() time for clients, with the value returned by
    84  	// the server.
    85  	// - at accept time for the server.
    86  	ConnectionID uint32
    87  
    88  	// closed is set to true when Close() is called on the connection.
    89  	closed sync2.AtomicBool
    90  
    91  	// Packet encoding variables.
    92  	bufferedReader *bufio.Reader
    93  	bufferedWriter *bufio.Writer
    94  	sequence       uint8
    95  
    96  	// Keep track of how and of the buffer we allocated for an
    97  	// ephemeral packet on the read and write sides.
    98  	// These fields are used by:
    99  	// - StartEphemeralPacket / writeEphemeralPacket methods for writes.
   100  	// - ReadEphemeralPacket / RecycleReadPacket methods for reads.
   101  	currentEphemeralPolicy int
   102  	// currentEphemeralBuffer for tracking allocated temporary buffer for writes and reads respectively.
   103  	// It can be allocated from bufPool or heap and should be recycled in the same manner.
   104  	currentEphemeralBuffer *[]byte
   105  }
   106  
   107  // bufPool is used to allocate and free buffers in an efficient way.
   108  var bufPool = bucketpool.New(connBufferSize, MaxPacketSize)
   109  
   110  // writersPool is used for pooling bufio.Writer objects.
   111  var writersPool = sync.Pool{New: func() interface{} { return bufio.NewWriterSize(nil, connBufferSize) }}
   112  
   113  // NewConn is an internal method to create a Conn. Used by client and server
   114  // side for common creation code.
   115  func NewConn(conn net.Conn) *Conn {
   116  	return &Conn{
   117  		conn:           conn,
   118  		closed:         sync2.NewAtomicBool(false),
   119  		bufferedReader: bufio.NewReaderSize(conn, connBufferSize),
   120  	}
   121  }
   122  
   123  // StartWriterBuffering starts using buffered writes. This should
   124  // be terminated by a call to flush.
   125  func (c *Conn) StartWriterBuffering() {
   126  	c.bufferedWriter = writersPool.Get().(*bufio.Writer)
   127  	c.bufferedWriter.Reset(c.conn)
   128  }
   129  
   130  // Flush flushes the written data to the socket.
   131  // This must be called to terminate startBuffering.
   132  func (c *Conn) Flush() error {
   133  	if c.bufferedWriter == nil {
   134  		return nil
   135  	}
   136  
   137  	defer func() {
   138  		c.bufferedWriter.Reset(nil)
   139  		writersPool.Put(c.bufferedWriter)
   140  		c.bufferedWriter = nil
   141  	}()
   142  
   143  	return c.bufferedWriter.Flush()
   144  }
   145  
   146  // getWriter returns the current writer. It may be either
   147  // the original connection or a wrapper.
   148  func (c *Conn) getWriter() io.Writer {
   149  	if c.bufferedWriter != nil {
   150  		return c.bufferedWriter
   151  	}
   152  	return c.conn
   153  }
   154  
   155  // getReader returns reader for connection. It can be *bufio.Reader or net.Conn
   156  // depending on which buffer size was passed to newServerConn.
   157  func (c *Conn) getReader() io.Reader {
   158  	if c.bufferedReader != nil {
   159  		return c.bufferedReader
   160  	}
   161  	return c.conn
   162  }
   163  
   164  func (c *Conn) readHeaderFrom(r io.Reader) (int, error) {
   165  	var header [4]byte
   166  	// Note io.ReadFull will return two different types of errors:
   167  	// 1. if the socket is already closed, and the go runtime knows it,
   168  	//   then ReadFull will return an error (different than EOF),
   169  	//   someting like 'read: connection reset by peer'.
   170  	// 2. if the socket is not closed while we start the read,
   171  	//   but gets closed after the read is started, we'll get io.EOF.
   172  	if _, err := io.ReadFull(r, header[:]); err != nil {
   173  		// The special casing of propagating io.EOF up
   174  		// is used by the server side only, to suppress an error
   175  		// message if a client just disconnects.
   176  		if err == io.EOF {
   177  			return 0, err
   178  		}
   179  		if strings.HasSuffix(err.Error(), "read: connection reset by peer") {
   180  			return 0, io.EOF
   181  		}
   182  		return 0, fmt.Errorf("io.ReadFull(header size) failed: %v", err)
   183  	}
   184  
   185  	sequence := uint8(header[3])
   186  	if sequence != c.sequence {
   187  		return 0, fmt.Errorf("invalid sequence, expected %v got %v", c.sequence, sequence)
   188  	}
   189  
   190  	c.sequence++
   191  
   192  	return int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16), nil
   193  }
   194  
   195  // ReadEphemeralPacket attempts to read a packet into buffer from sync.Pool.  Do
   196  // not use this method if the contents of the packet needs to be kept
   197  // after the next ReadEphemeralPacket.
   198  //
   199  // Note if the connection is closed already, an error will be
   200  // returned, and it may not be io.EOF. If the connection closes while
   201  // we are stuck waiting for data, an error will also be returned, and
   202  // it most likely will be io.EOF.
   203  func (c *Conn) ReadEphemeralPacket() ([]byte, error) {
   204  	if c.currentEphemeralPolicy != ephemeralUnused {
   205  		panic(fmt.Errorf("ReadEphemeralPacket: unexpected currentEphemeralPolicy: %v", c.currentEphemeralPolicy))
   206  	}
   207  	c.currentEphemeralPolicy = ephemeralRead
   208  
   209  	r := c.getReader()
   210  	length, err := c.readHeaderFrom(r)
   211  	if err != nil {
   212  		return nil, err
   213  	}
   214  
   215  	if length == 0 {
   216  		// This can be caused by the packet after a packet of
   217  		// exactly size MaxPacketSize.
   218  		return nil, nil
   219  	}
   220  
   221  	// Use the bufPool.
   222  	if length < MaxPacketSize {
   223  		c.currentEphemeralBuffer = bufPool.Get(length)
   224  		if _, err := io.ReadFull(r, *c.currentEphemeralBuffer); err != nil {
   225  			return nil, fmt.Errorf("io.ReadFull(packet body of length %v) failed: %v", length, err)
   226  		}
   227  		return *c.currentEphemeralBuffer, nil
   228  	}
   229  
   230  	// Much slower path, revert to allocating everything from scratch.
   231  	// We're going to concatenate a lot of data anyway, can't really
   232  	// optimize this code path easily.
   233  	data := make([]byte, length)
   234  	if _, err := io.ReadFull(r, data); err != nil {
   235  		return nil, fmt.Errorf("io.ReadFull(packet body of length %v) failed: %v", length, err)
   236  	}
   237  	for {
   238  		next, err := c.readOnePacket()
   239  		if err != nil {
   240  			return nil, err
   241  		}
   242  
   243  		if len(next) == 0 {
   244  			// Again, the packet after a packet of exactly size MaxPacketSize.
   245  			break
   246  		}
   247  
   248  		data = append(data, next...)
   249  		if len(next) < MaxPacketSize {
   250  			break
   251  		}
   252  	}
   253  
   254  	return data, nil
   255  }
   256  
   257  // ReadEphemeralPacketDirect attempts to read a packet from the socket directly.
   258  // It needs to be used for the first handshake packet the server receives,
   259  // so we do't buffer the SSL negotiation packet. As a shortcut, only
   260  // packets smaller than MaxPacketSize can be read here.
   261  // This function usually shouldn't be used - use ReadEphemeralPacket.
   262  func (c *Conn) ReadEphemeralPacketDirect() ([]byte, error) {
   263  	if c.currentEphemeralPolicy != ephemeralUnused {
   264  		panic(fmt.Errorf("ReadEphemeralPacketDirect: unexpected currentEphemeralPolicy: %v", c.currentEphemeralPolicy))
   265  	}
   266  	c.currentEphemeralPolicy = ephemeralRead
   267  
   268  	var r io.Reader = c.conn
   269  	length, err := c.readHeaderFrom(r)
   270  	if err != nil {
   271  		return nil, err
   272  	}
   273  
   274  	if length == 0 {
   275  		// This can be caused by the packet after a packet of
   276  		// exactly size MaxPacketSize.
   277  		return nil, nil
   278  	}
   279  
   280  	if length < MaxPacketSize {
   281  		c.currentEphemeralBuffer = bufPool.Get(length)
   282  		if _, err := io.ReadFull(r, *c.currentEphemeralBuffer); err != nil {
   283  			return nil, fmt.Errorf("io.ReadFull(packet body of length %v) failed: %v", length, err)
   284  		}
   285  		return *c.currentEphemeralBuffer, nil
   286  	}
   287  
   288  	return nil, fmt.Errorf("ReadEphemeralPacketDirect doesn't support more than one packet")
   289  }
   290  
   291  // RecycleReadPacket recycles the read packet. It needs to be called
   292  // after ReadEphemeralPacket was called.
   293  func (c *Conn) RecycleReadPacket() {
   294  	if c.currentEphemeralPolicy != ephemeralRead {
   295  		// Programming error.
   296  		panic(fmt.Errorf("trying to call RecycleReadPacket while currentEphemeralPolicy is %d", c.currentEphemeralPolicy))
   297  	}
   298  	if c.currentEphemeralBuffer != nil {
   299  		// We are using the pool, put the buffer back in.
   300  		bufPool.Put(c.currentEphemeralBuffer)
   301  		c.currentEphemeralBuffer = nil
   302  	}
   303  	c.currentEphemeralPolicy = ephemeralUnused
   304  }
   305  
   306  // readOnePacket reads a single packet into a newly allocated buffer.
   307  func (c *Conn) readOnePacket() ([]byte, error) {
   308  	r := c.getReader()
   309  	length, err := c.readHeaderFrom(r)
   310  	if err != nil {
   311  		return nil, err
   312  	}
   313  	if length == 0 {
   314  		// This can be caused by the packet after a packet of
   315  		// exactly size MaxPacketSize.
   316  		return nil, nil
   317  	}
   318  
   319  	data := make([]byte, length)
   320  	if _, err := io.ReadFull(r, data); err != nil {
   321  		return nil, fmt.Errorf("io.ReadFull(packet body of length %v) failed: %v", length, err)
   322  	}
   323  	return data, nil
   324  }
   325  
   326  // readPacket reads a packet from the underlying connection.
   327  // It re-assembles packets that span more than one message.
   328  // This method returns a generic error, not a SQLError.
   329  func (c *Conn) readPacket() ([]byte, error) {
   330  	// Optimize for a single packet case.
   331  	data, err := c.readOnePacket()
   332  	if err != nil {
   333  		return nil, err
   334  	}
   335  
   336  	// This is a single packet.
   337  	if len(data) < MaxPacketSize {
   338  		return data, nil
   339  	}
   340  
   341  	// There is more than one packet, read them all.
   342  	for {
   343  		next, err := c.readOnePacket()
   344  		if err != nil {
   345  			return nil, err
   346  		}
   347  
   348  		if len(next) == 0 {
   349  			// Again, the packet after a packet of exactly size MaxPacketSize.
   350  			break
   351  		}
   352  
   353  		data = append(data, next...)
   354  		if len(next) < MaxPacketSize {
   355  			break
   356  		}
   357  	}
   358  
   359  	return data, nil
   360  }
   361  
   362  // ReadPacket reads a packet from the underlying connection.
   363  // it is the public API version, that returns a SQLError.
   364  // The memory for the packet is always allocated, and it is owned by the caller
   365  // after this function returns.
   366  func (c *Conn) ReadPacket() ([]byte, error) {
   367  	result, err := c.readPacket()
   368  	if err != nil {
   369  		return nil, err
   370  	}
   371  	return result, err
   372  }
   373  
   374  // WritePacket writes a packet, possibly cutting it into multiple
   375  // chunks.  Note this is not very efficient, as the client probably
   376  // has to build the []byte and that makes a memory copy.
   377  // Try to use StartEphemeralPacket/writeEphemeralPacket instead.
   378  //
   379  // This method returns a generic error, not a SQLError.
   380  func (c *Conn) WritePacket(data []byte) error {
   381  	index := 0
   382  	length := len(data)
   383  
   384  	w := c.getWriter()
   385  
   386  	for {
   387  		// Packet length is capped to MaxPacketSize.
   388  		packetLength := length
   389  		if packetLength > MaxPacketSize {
   390  			packetLength = MaxPacketSize
   391  		}
   392  
   393  		// Compute and write the header.
   394  		var header [4]byte
   395  		header[0] = byte(packetLength)
   396  		header[1] = byte(packetLength >> 8)
   397  		header[2] = byte(packetLength >> 16)
   398  		header[3] = c.sequence
   399  		if n, err := w.Write(header[:]); err != nil {
   400  			return fmt.Errorf("Write(header) failed: %v", err)
   401  		} else if n != 4 {
   402  			return fmt.Errorf("Write(header) returned a short write: %v < 4", n)
   403  		}
   404  
   405  		// Write the body.
   406  		if n, err := w.Write(data[index : index+packetLength]); err != nil {
   407  			return fmt.Errorf("Write(packet) failed: %v", err)
   408  		} else if n != packetLength {
   409  			return fmt.Errorf("Write(packet) returned a short write: %v < %v", n, packetLength)
   410  		}
   411  
   412  		// Update our state.
   413  		c.sequence++
   414  		length -= packetLength
   415  		if length == 0 {
   416  			if packetLength == MaxPacketSize {
   417  				// The packet we just sent had exactly
   418  				// MaxPacketSize size, we need to
   419  				// sent a zero-size packet too.
   420  				header[0] = 0
   421  				header[1] = 0
   422  				header[2] = 0
   423  				header[3] = c.sequence
   424  				if n, err := w.Write(header[:]); err != nil {
   425  					return fmt.Errorf("Write(empty header) failed: %v", err)
   426  				} else if n != 4 {
   427  					return fmt.Errorf("Write(empty header) returned a short write: %v < 4", n)
   428  				}
   429  				c.sequence++
   430  			}
   431  			return nil
   432  		}
   433  		index += packetLength
   434  	}
   435  }
   436  
   437  // StartEphemeralPacket get []byte from pool
   438  func (c *Conn) StartEphemeralPacket(length int) []byte {
   439  	if c.currentEphemeralPolicy != ephemeralUnused {
   440  		panic("StartEphemeralPacket cannot be used while a packet is already started.")
   441  	}
   442  
   443  	c.currentEphemeralPolicy = ephemeralWrite
   444  	// get buffer from pool or it'll be allocated if length is too big
   445  	c.currentEphemeralBuffer = bufPool.Get(length)
   446  	return *c.currentEphemeralBuffer
   447  }
   448  
   449  // WriteEphemeralPacket writes the packet that was allocated by
   450  // StartEphemeralPacket.
   451  func (c *Conn) WriteEphemeralPacket() error {
   452  	defer c.recycleWritePacket()
   453  
   454  	switch c.currentEphemeralPolicy {
   455  	case ephemeralWrite:
   456  		if err := c.WritePacket(*c.currentEphemeralBuffer); err != nil {
   457  			return fmt.Errorf("Conn %v: %v", c.GetConnectionID(), err)
   458  		}
   459  	case ephemeralUnused, ephemeralRead:
   460  		// Programming error.
   461  		panic(fmt.Errorf("Conn %v: trying to call writeEphemeralPacket while currentEphemeralPolicy is %v", c.GetConnectionID(), c.currentEphemeralPolicy))
   462  	}
   463  
   464  	return nil
   465  }
   466  
   467  // recycleWritePacket recycles the write packet. It needs to be called
   468  // after writeEphemeralPacket was called.
   469  func (c *Conn) recycleWritePacket() {
   470  	if c.currentEphemeralPolicy != ephemeralWrite {
   471  		// Programming error.
   472  		panic(fmt.Errorf("trying to call recycleWritePacket while currentEphemeralPolicy is %d", c.currentEphemeralPolicy))
   473  	}
   474  	// Release our reference so the buffer can be gced
   475  	bufPool.Put(c.currentEphemeralBuffer)
   476  	c.currentEphemeralBuffer = nil
   477  	c.currentEphemeralPolicy = ephemeralUnused
   478  }
   479  
   480  // writeComQuit writes a Quit message for the server, to indicate we
   481  // want to close the connection.
   482  // Client -> Server.
   483  // Returns SQLError(CRServerGone) if it can't.
   484  func (c *Conn) writeComQuit() error {
   485  	// This is a new command, need to reset the sequence.
   486  	c.sequence = 0
   487  
   488  	data := c.StartEphemeralPacket(1)
   489  	data[0] = ComQuit
   490  	if err := c.WriteEphemeralPacket(); err != nil {
   491  		return err
   492  	}
   493  	return nil
   494  }
   495  
   496  // RemoteAddr returns the underlying socket RemoteAddr().
   497  func (c *Conn) RemoteAddr() net.Addr {
   498  	return c.conn.RemoteAddr()
   499  }
   500  
   501  // GetConnectionID returns the MySQL connection ID for this connection.
   502  func (c *Conn) GetConnectionID() uint32 {
   503  	return c.ConnectionID
   504  }
   505  
   506  // SetConnectionID set connection id of conn.
   507  func (c *Conn) SetConnectionID(connectionID uint32) {
   508  	c.ConnectionID = connectionID
   509  }
   510  
   511  // SetSequence set sequence of conn
   512  func (c *Conn) SetSequence(sequence uint8) {
   513  	c.sequence = sequence
   514  }
   515  
   516  // GetSequence return sequence of conn
   517  func (c *Conn) GetSequence() uint8 {
   518  	return c.sequence
   519  }
   520  
   521  // Ident returns a useful identification string for error logging
   522  func (c *Conn) String() string {
   523  	return fmt.Sprintf("client %v (%s)", c.ConnectionID, c.RemoteAddr().String())
   524  }
   525  
   526  // Close closes the connection. It can be called from a different go
   527  // routine to interrupt the current connection.
   528  func (c *Conn) Close() {
   529  	if c.closed.CompareAndSwap(false, true) {
   530  		c.conn.Close()
   531  	}
   532  }
   533  
   534  // IsClosed returns true if this connection was ever closed by the
   535  // Close() method.  Note if the other side closes the connection, but
   536  // Close() wasn't called, this will return false.
   537  func (c *Conn) IsClosed() bool {
   538  	return c.closed.Get()
   539  }
   540  
   541  //
   542  // Packet writing methods, for generic packets.
   543  //
   544  
   545  // WriteOKPacket writes an OK packet.
   546  // Server -> Client.
   547  // This method returns a generic error, not a SQLError.
   548  func (c *Conn) WriteOKPacket(affectedRows, lastInsertID uint64, flags uint16, warnings uint16) error {
   549  	length := 1 + // OKHeader
   550  		LenEncIntSize(affectedRows) +
   551  		LenEncIntSize(lastInsertID) +
   552  		2 + // flags
   553  		2 // warnings
   554  	data := c.StartEphemeralPacket(length)
   555  	pos := 0
   556  	pos = WriteByte(data, pos, OKHeader)
   557  	pos = WriteLenEncInt(data, pos, affectedRows)
   558  	pos = WriteLenEncInt(data, pos, lastInsertID)
   559  	pos = WriteUint16(data, pos, flags)
   560  	pos = WriteUint16(data, pos, warnings)
   561  
   562  	return c.WriteEphemeralPacket()
   563  }
   564  
   565  // WriteOKPacketWithEOFHeader writes an OK packet with an EOF header.
   566  // This is used at the end of a result set if
   567  // CapabilityClientDeprecateEOF is set.
   568  // Server -> Client.
   569  // This method returns a generic error, not a SQLError.
   570  func (c *Conn) WriteOKPacketWithEOFHeader(affectedRows, lastInsertID uint64, flags uint16, warnings uint16) error {
   571  	length := 1 + // EOFHeader
   572  		LenEncIntSize(affectedRows) +
   573  		LenEncIntSize(lastInsertID) +
   574  		2 + // flags
   575  		2 // warnings
   576  	data := c.StartEphemeralPacket(length)
   577  	pos := 0
   578  	pos = WriteByte(data, pos, EOFHeader)
   579  	pos = WriteLenEncInt(data, pos, affectedRows)
   580  	pos = WriteLenEncInt(data, pos, lastInsertID)
   581  	pos = WriteUint16(data, pos, flags)
   582  	pos = WriteUint16(data, pos, warnings)
   583  
   584  	return c.WriteEphemeralPacket()
   585  }
   586  
   587  // WriteErrorPacket writes an error packet.
   588  // Server -> Client.
   589  // This method returns a generic error, not a SQLError.
   590  func (c *Conn) WriteErrorPacket(errorCode uint16, sqlState string, format string, args ...interface{}) error {
   591  	errorMessage := fmt.Sprintf(format, args...)
   592  	length := 1 + 2 + 1 + 5 + len(errorMessage)
   593  	data := c.StartEphemeralPacket(length)
   594  	pos := 0
   595  	pos = WriteByte(data, pos, ErrHeader)
   596  	pos = WriteUint16(data, pos, errorCode)
   597  	pos = WriteByte(data, pos, '#')
   598  	if sqlState == "" {
   599  		sqlState = DefaultMySQLState
   600  	}
   601  	if len(sqlState) != 5 {
   602  		panic("sqlState has to be 5 characters long")
   603  	}
   604  	pos = writeEOFString(data, pos, sqlState)
   605  	pos = writeEOFString(data, pos, errorMessage)
   606  
   607  	return c.WriteEphemeralPacket()
   608  }
   609  
   610  // WriteErrorPacketFromError writes an error packet, from a regular error.
   611  // See writeErrorPacket for other info.
   612  func (c *Conn) WriteErrorPacketFromError(err error) error {
   613  	if se, ok := err.(*SQLError); ok {
   614  		return c.WriteErrorPacket(se.SQLCode(), se.SQLState(), "%v", se.Message)
   615  	}
   616  
   617  	return c.WriteErrorPacket(ErrUnknown, DefaultMySQLState, "unknown error: %v", err)
   618  }
   619  
   620  // WriteEOFPacket writes an EOF packet, through the buffer, and
   621  // doesn't flush (as it is used as part of a query result).
   622  func (c *Conn) WriteEOFPacket(flags uint16, warnings uint16) error {
   623  	length := 5
   624  	data := c.StartEphemeralPacket(length)
   625  	pos := 0
   626  	pos = WriteByte(data, pos, EOFHeader)
   627  	pos = WriteUint16(data, pos, warnings)
   628  	pos = WriteUint16(data, pos, flags)
   629  
   630  	return c.WriteEphemeralPacket()
   631  }
   632  
   633  //
   634  // Packet parsing methods, for generic packets.
   635  //
   636  
   637  // isEOFHeader determines whether or not a data packet is a "true" EOF. DO NOT blindly compare the
   638  // first byte of a packet to EOFHeader as you might do for other packet types, as 0xfe is overloaded
   639  // as a first byte.
   640  //
   641  // Per https://dev.mysql.com/doc/internals/en/packet-EOF_Packet.html, a packet starting with 0xfe
   642  // but having length >= 9 (on top of 4 byte header) is not a true EOF but a LengthEncodedInteger
   643  // (typically preceding a LengthEncodedString). Thus, all EOF checks must validate the payload size
   644  // before exiting.
   645  //
   646  // More specifically, an EOF packet can have 3 different lengths (1, 5, 7) depending on the client
   647  // flags that are set. 7 comes from server versions of 5.7.5 or greater where ClientDeprecateEOF is
   648  // set (i.e. uses an OK packet starting with 0xfe instead of 0x00 to signal EOF). Regardless, 8 is
   649  // an upper bound otherwise it would be ambiguous w.r.t. LengthEncodedIntegers.
   650  //
   651  // More docs here:
   652  // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_response_packets.html
   653  func isEOFHeader(data []byte) bool {
   654  	return data[0] == EOFHeader && len(data) < 9
   655  }
   656  
   657  // parseEOFHeader returns the warning count and a boolean to indicate if there
   658  // are more results to receive.
   659  //
   660  // Note: This is only valid on actual EOF packets and not on OK packets with the EOF
   661  // type code set, i.e. should not be used if ClientDeprecateEOF is set.
   662  func parseEOFHeader(data []byte) (warnings uint16, more bool, err error) {
   663  	// The warning count is in position 2 & 3
   664  	warnings, _, ok := ReadUint16(data, 1)
   665  
   666  	// The status flag is in position 4 & 5
   667  	statusFlags, _, ok := ReadUint16(data, 3)
   668  	if !ok {
   669  		return 0, false, fmt.Errorf("invalid EOF packet statusFlags: %v", data)
   670  	}
   671  	return warnings, (statusFlags & ServerMoreResultsExists) != 0, nil
   672  }
   673  
   674  func parseOKHeader(data []byte) (uint64, uint64, uint16, uint16, error) {
   675  	// We already read the type.
   676  	pos := 1
   677  
   678  	// Affected rows.
   679  	affectedRows, pos, _, ok := ReadLenEncInt(data, pos)
   680  	if !ok {
   681  		return 0, 0, 0, 0, fmt.Errorf("invalid OK packet affectedRows: %v", data)
   682  	}
   683  
   684  	// Last Insert ID.
   685  	lastInsertID, pos, _, ok := ReadLenEncInt(data, pos)
   686  	if !ok {
   687  		return 0, 0, 0, 0, fmt.Errorf("invalid OK packet lastInsertID: %v", data)
   688  	}
   689  
   690  	// Status flags.
   691  	statusFlags, pos, ok := ReadUint16(data, pos)
   692  	if !ok {
   693  		return 0, 0, 0, 0, fmt.Errorf("invalid OK packet statusFlags: %v", data)
   694  	}
   695  
   696  	// Warnings.
   697  	warnings, pos, ok := ReadUint16(data, pos)
   698  	if !ok {
   699  		return 0, 0, 0, 0, fmt.Errorf("invalid OK packet warnings: %v", data)
   700  	}
   701  
   702  	return affectedRows, lastInsertID, statusFlags, warnings, nil
   703  }
   704  
   705  // IsErrorPacket determines whether or not the packet is an error packet. Mostly here for
   706  // consistency with isEOFHeader
   707  func IsErrorPacket(data []byte) bool {
   708  	return data[0] == ErrHeader
   709  }
   710  
   711  // IsOKPacket determines whether or not the packet is an ok packet.
   712  func IsOKPacket(data []byte) bool {
   713  	return data[0] == OKHeader
   714  }
   715  
   716  // ParseErrorPacket parses the error packet and returns a SQLError.
   717  func ParseErrorPacket(data []byte) error {
   718  	// We already read the type.
   719  	pos := 1
   720  
   721  	// Error code is 2 bytes.
   722  	code, pos, ok := ReadUint16(data, pos)
   723  	if !ok {
   724  		return errors.New("invalid error packet code")
   725  	}
   726  
   727  	// '#' marker of the SQL state is 1 byte. Ignored.
   728  	pos++
   729  
   730  	// SQL state can be calculated
   731  	_, pos, ok = ReadBytes(data, pos, 5)
   732  	if !ok {
   733  		return errors.New("invalid error packet sqlState")
   734  	}
   735  
   736  	// Human readable error message is the rest.
   737  	msg := string(data[pos:])
   738  
   739  	return NewError(code, msg)
   740  }