github.com/insionng/yougam@v0.0.0-20170714101924-2bc18d833463/libraries/go-sql-driver/mysql/packets.go (about)

     1  // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
     2  //
     3  // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
     4  //
     5  // This Source Code Form is subject to the terms of the Mozilla Public
     6  // License, v. 2.0. If a copy of the MPL was not distributed with this file,
     7  // You can obtain one at http://mozilla.org/MPL/2.0/.
     8  
     9  package mysql
    10  
    11  import (
    12  	"bytes"
    13  	"crypto/tls"
    14  	"database/sql/driver"
    15  	"encoding/binary"
    16  	"errors"
    17  	"fmt"
    18  	"io"
    19  	"math"
    20  	"time"
    21  )
    22  
    23  // Packets documentation:
    24  // http://dev.mysql.com/doc/internals/en/client-server-protocol.html
    25  
    26  // Read packet to buffer 'data'
    27  func (mc *mysqlConn) readPacket() ([]byte, error) {
    28  	var payload []byte
    29  	for {
    30  		// Read packet header
    31  		data, err := mc.buf.readNext(4)
    32  		if err != nil {
    33  			errLog.Print(err)
    34  			mc.Close()
    35  			return nil, driver.ErrBadConn
    36  		}
    37  
    38  		// Packet Length [24 bit]
    39  		pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16)
    40  
    41  		if pktLen < 1 {
    42  			errLog.Print(ErrMalformPkt)
    43  			mc.Close()
    44  			return nil, driver.ErrBadConn
    45  		}
    46  
    47  		// Check Packet Sync [8 bit]
    48  		if data[3] != mc.sequence {
    49  			if data[3] > mc.sequence {
    50  				return nil, ErrPktSyncMul
    51  			}
    52  			return nil, ErrPktSync
    53  		}
    54  		mc.sequence++
    55  
    56  		// Read packet body [pktLen bytes]
    57  		data, err = mc.buf.readNext(pktLen)
    58  		if err != nil {
    59  			errLog.Print(err)
    60  			mc.Close()
    61  			return nil, driver.ErrBadConn
    62  		}
    63  
    64  		isLastPacket := (pktLen < maxPacketSize)
    65  
    66  		// Zero allocations for non-splitting packets
    67  		if isLastPacket && payload == nil {
    68  			return data, nil
    69  		}
    70  
    71  		payload = append(payload, data...)
    72  
    73  		if isLastPacket {
    74  			return payload, nil
    75  		}
    76  	}
    77  }
    78  
    79  // Write packet buffer 'data'
    80  func (mc *mysqlConn) writePacket(data []byte) error {
    81  	pktLen := len(data) - 4
    82  
    83  	if pktLen > mc.maxPacketAllowed {
    84  		return ErrPktTooLarge
    85  	}
    86  
    87  	for {
    88  		var size int
    89  		if pktLen >= maxPacketSize {
    90  			data[0] = 0xff
    91  			data[1] = 0xff
    92  			data[2] = 0xff
    93  			size = maxPacketSize
    94  		} else {
    95  			data[0] = byte(pktLen)
    96  			data[1] = byte(pktLen >> 8)
    97  			data[2] = byte(pktLen >> 16)
    98  			size = pktLen
    99  		}
   100  		data[3] = mc.sequence
   101  
   102  		// Write packet
   103  		if mc.writeTimeout > 0 {
   104  			if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil {
   105  				return err
   106  			}
   107  		}
   108  
   109  		n, err := mc.netConn.Write(data[:4+size])
   110  		if err == nil && n == 4+size {
   111  			mc.sequence++
   112  			if size != maxPacketSize {
   113  				return nil
   114  			}
   115  			pktLen -= size
   116  			data = data[size:]
   117  			continue
   118  		}
   119  
   120  		// Handle error
   121  		if err == nil { // n != len(data)
   122  			errLog.Print(ErrMalformPkt)
   123  		} else {
   124  			errLog.Print(err)
   125  		}
   126  		return driver.ErrBadConn
   127  	}
   128  }
   129  
   130  /******************************************************************************
   131  *                           Initialisation Process                            *
   132  ******************************************************************************/
   133  
   134  // Handshake Initialization Packet
   135  // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
   136  func (mc *mysqlConn) readInitPacket() ([]byte, error) {
   137  	data, err := mc.readPacket()
   138  	if err != nil {
   139  		return nil, err
   140  	}
   141  
   142  	if data[0] == iERR {
   143  		return nil, mc.handleErrorPacket(data)
   144  	}
   145  
   146  	// protocol version [1 byte]
   147  	if data[0] < minProtocolVersion {
   148  		return nil, fmt.Errorf(
   149  			"unsupported protocol version %d. Version %d or higher is required",
   150  			data[0],
   151  			minProtocolVersion,
   152  		)
   153  	}
   154  
   155  	// server version [null terminated string]
   156  	// connection id [4 bytes]
   157  	pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4
   158  
   159  	// first part of the password cipher [8 bytes]
   160  	cipher := data[pos : pos+8]
   161  
   162  	// (filler) always 0x00 [1 byte]
   163  	pos += 8 + 1
   164  
   165  	// capability flags (lower 2 bytes) [2 bytes]
   166  	mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
   167  	if mc.flags&clientProtocol41 == 0 {
   168  		return nil, ErrOldProtocol
   169  	}
   170  	if mc.flags&clientSSL == 0 && mc.cfg.tls != nil {
   171  		return nil, ErrNoTLS
   172  	}
   173  	pos += 2
   174  
   175  	if len(data) > pos {
   176  		// character set [1 byte]
   177  		// status flags [2 bytes]
   178  		// capability flags (upper 2 bytes) [2 bytes]
   179  		// length of auth-plugin-data [1 byte]
   180  		// reserved (all [00]) [10 bytes]
   181  		pos += 1 + 2 + 2 + 1 + 10
   182  
   183  		// second part of the password cipher [mininum 13 bytes],
   184  		// where len=MAX(13, length of auth-plugin-data - 8)
   185  		//
   186  		// The web documentation is ambiguous about the length. However,
   187  		// according to mysql-5.7/sql/auth/sql_authentication.cc line 538,
   188  		// the 13th byte is "\0 byte, terminating the second part of
   189  		// a scramble". So the second part of the password cipher is
   190  		// a NULL terminated string that's at least 13 bytes with the
   191  		// last byte being NULL.
   192  		//
   193  		// The official Python library uses the fixed length 12
   194  		// which seems to work but technically could have a hidden bug.
   195  		cipher = append(cipher, data[pos:pos+12]...)
   196  
   197  		// TODO: Verify string termination
   198  		// EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2)
   199  		// \NUL otherwise
   200  		//
   201  		//if data[len(data)-1] == 0 {
   202  		//	return
   203  		//}
   204  		//return ErrMalformPkt
   205  
   206  		// make a memory safe copy of the cipher slice
   207  		var b [20]byte
   208  		copy(b[:], cipher)
   209  		return b[:], nil
   210  	}
   211  
   212  	// make a memory safe copy of the cipher slice
   213  	var b [8]byte
   214  	copy(b[:], cipher)
   215  	return b[:], nil
   216  }
   217  
   218  // Client Authentication Packet
   219  // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
   220  func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
   221  	// Adjust client flags based on server support
   222  	clientFlags := clientProtocol41 |
   223  		clientSecureConn |
   224  		clientLongPassword |
   225  		clientTransactions |
   226  		clientLocalFiles |
   227  		clientPluginAuth |
   228  		clientMultiResults |
   229  		mc.flags&clientLongFlag
   230  
   231  	if mc.cfg.ClientFoundRows {
   232  		clientFlags |= clientFoundRows
   233  	}
   234  
   235  	// To enable TLS / SSL
   236  	if mc.cfg.tls != nil {
   237  		clientFlags |= clientSSL
   238  	}
   239  
   240  	if mc.cfg.MultiStatements {
   241  		clientFlags |= clientMultiStatements
   242  	}
   243  
   244  	// User Password
   245  	scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd))
   246  
   247  	pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + 1 + len(scrambleBuff) + 21 + 1
   248  
   249  	// To specify a db name
   250  	if n := len(mc.cfg.DBName); n > 0 {
   251  		clientFlags |= clientConnectWithDB
   252  		pktLen += n + 1
   253  	}
   254  
   255  	// Calculate packet length and get buffer with that size
   256  	data := mc.buf.takeSmallBuffer(pktLen + 4)
   257  	if data == nil {
   258  		// can not take the buffer. Something must be wrong with the connection
   259  		errLog.Print(ErrBusyBuffer)
   260  		return driver.ErrBadConn
   261  	}
   262  
   263  	// ClientFlags [32 bit]
   264  	data[4] = byte(clientFlags)
   265  	data[5] = byte(clientFlags >> 8)
   266  	data[6] = byte(clientFlags >> 16)
   267  	data[7] = byte(clientFlags >> 24)
   268  
   269  	// MaxPacketSize [32 bit] (none)
   270  	data[8] = 0x00
   271  	data[9] = 0x00
   272  	data[10] = 0x00
   273  	data[11] = 0x00
   274  
   275  	// Charset [1 byte]
   276  	var found bool
   277  	data[12], found = collations[mc.cfg.Collation]
   278  	if !found {
   279  		// Note possibility for false negatives:
   280  		// could be triggered  although the collation is valid if the
   281  		// collations map does not contain entries the server supports.
   282  		return errors.New("unknown collation")
   283  	}
   284  
   285  	// SSL Connection Request Packet
   286  	// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
   287  	if mc.cfg.tls != nil {
   288  		// Send TLS / SSL request packet
   289  		if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil {
   290  			return err
   291  		}
   292  
   293  		// Switch to TLS
   294  		tlsConn := tls.Client(mc.netConn, mc.cfg.tls)
   295  		if err := tlsConn.Handshake(); err != nil {
   296  			return err
   297  		}
   298  		mc.netConn = tlsConn
   299  		mc.buf.nc = tlsConn
   300  	}
   301  
   302  	// Filler [23 bytes] (all 0x00)
   303  	pos := 13
   304  	for ; pos < 13+23; pos++ {
   305  		data[pos] = 0
   306  	}
   307  
   308  	// User [null terminated string]
   309  	if len(mc.cfg.User) > 0 {
   310  		pos += copy(data[pos:], mc.cfg.User)
   311  	}
   312  	data[pos] = 0x00
   313  	pos++
   314  
   315  	// ScrambleBuffer [length encoded integer]
   316  	data[pos] = byte(len(scrambleBuff))
   317  	pos += 1 + copy(data[pos+1:], scrambleBuff)
   318  
   319  	// Databasename [null terminated string]
   320  	if len(mc.cfg.DBName) > 0 {
   321  		pos += copy(data[pos:], mc.cfg.DBName)
   322  		data[pos] = 0x00
   323  		pos++
   324  	}
   325  
   326  	// Assume native client during response
   327  	pos += copy(data[pos:], "mysql_native_password")
   328  	data[pos] = 0x00
   329  
   330  	// Send Auth packet
   331  	return mc.writePacket(data)
   332  }
   333  
   334  //  Client old authentication packet
   335  // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
   336  func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error {
   337  	// User password
   338  	scrambleBuff := scrambleOldPassword(cipher, []byte(mc.cfg.Passwd))
   339  
   340  	// Calculate the packet length and add a tailing 0
   341  	pktLen := len(scrambleBuff) + 1
   342  	data := mc.buf.takeSmallBuffer(4 + pktLen)
   343  	if data == nil {
   344  		// can not take the buffer. Something must be wrong with the connection
   345  		errLog.Print(ErrBusyBuffer)
   346  		return driver.ErrBadConn
   347  	}
   348  
   349  	// Add the scrambled password [null terminated string]
   350  	copy(data[4:], scrambleBuff)
   351  	data[4+pktLen-1] = 0x00
   352  
   353  	return mc.writePacket(data)
   354  }
   355  
   356  //  Client clear text authentication packet
   357  // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
   358  func (mc *mysqlConn) writeClearAuthPacket() error {
   359  	// Calculate the packet length and add a tailing 0
   360  	pktLen := len(mc.cfg.Passwd) + 1
   361  	data := mc.buf.takeSmallBuffer(4 + pktLen)
   362  	if data == nil {
   363  		// can not take the buffer. Something must be wrong with the connection
   364  		errLog.Print(ErrBusyBuffer)
   365  		return driver.ErrBadConn
   366  	}
   367  
   368  	// Add the clear password [null terminated string]
   369  	copy(data[4:], mc.cfg.Passwd)
   370  	data[4+pktLen-1] = 0x00
   371  
   372  	return mc.writePacket(data)
   373  }
   374  
   375  /******************************************************************************
   376  *                             Command Packets                                 *
   377  ******************************************************************************/
   378  
   379  func (mc *mysqlConn) writeCommandPacket(command byte) error {
   380  	// Reset Packet Sequence
   381  	mc.sequence = 0
   382  
   383  	data := mc.buf.takeSmallBuffer(4 + 1)
   384  	if data == nil {
   385  		// can not take the buffer. Something must be wrong with the connection
   386  		errLog.Print(ErrBusyBuffer)
   387  		return driver.ErrBadConn
   388  	}
   389  
   390  	// Add command byte
   391  	data[4] = command
   392  
   393  	// Send CMD packet
   394  	return mc.writePacket(data)
   395  }
   396  
   397  func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
   398  	// Reset Packet Sequence
   399  	mc.sequence = 0
   400  
   401  	pktLen := 1 + len(arg)
   402  	data := mc.buf.takeBuffer(pktLen + 4)
   403  	if data == nil {
   404  		// can not take the buffer. Something must be wrong with the connection
   405  		errLog.Print(ErrBusyBuffer)
   406  		return driver.ErrBadConn
   407  	}
   408  
   409  	// Add command byte
   410  	data[4] = command
   411  
   412  	// Add arg
   413  	copy(data[5:], arg)
   414  
   415  	// Send CMD packet
   416  	return mc.writePacket(data)
   417  }
   418  
   419  func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
   420  	// Reset Packet Sequence
   421  	mc.sequence = 0
   422  
   423  	data := mc.buf.takeSmallBuffer(4 + 1 + 4)
   424  	if data == nil {
   425  		// can not take the buffer. Something must be wrong with the connection
   426  		errLog.Print(ErrBusyBuffer)
   427  		return driver.ErrBadConn
   428  	}
   429  
   430  	// Add command byte
   431  	data[4] = command
   432  
   433  	// Add arg [32 bit]
   434  	data[5] = byte(arg)
   435  	data[6] = byte(arg >> 8)
   436  	data[7] = byte(arg >> 16)
   437  	data[8] = byte(arg >> 24)
   438  
   439  	// Send CMD packet
   440  	return mc.writePacket(data)
   441  }
   442  
   443  /******************************************************************************
   444  *                              Result Packets                                 *
   445  ******************************************************************************/
   446  
   447  // Returns error if Packet is not an 'Result OK'-Packet
   448  func (mc *mysqlConn) readResultOK() error {
   449  	data, err := mc.readPacket()
   450  	if err == nil {
   451  		// packet indicator
   452  		switch data[0] {
   453  
   454  		case iOK:
   455  			return mc.handleOkPacket(data)
   456  
   457  		case iEOF:
   458  			if len(data) > 1 {
   459  				plugin := string(data[1:bytes.IndexByte(data, 0x00)])
   460  				if plugin == "mysql_old_password" {
   461  					// using old_passwords
   462  					return ErrOldPassword
   463  				} else if plugin == "mysql_clear_password" {
   464  					// using clear text password
   465  					return ErrCleartextPassword
   466  				} else {
   467  					return ErrUnknownPlugin
   468  				}
   469  			} else {
   470  				return ErrOldPassword
   471  			}
   472  
   473  		default: // Error otherwise
   474  			return mc.handleErrorPacket(data)
   475  		}
   476  	}
   477  	return err
   478  }
   479  
   480  // Result Set Header Packet
   481  // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset
   482  func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {
   483  	data, err := mc.readPacket()
   484  	if err == nil {
   485  		switch data[0] {
   486  
   487  		case iOK:
   488  			return 0, mc.handleOkPacket(data)
   489  
   490  		case iERR:
   491  			return 0, mc.handleErrorPacket(data)
   492  
   493  		case iLocalInFile:
   494  			return 0, mc.handleInFileRequest(string(data[1:]))
   495  		}
   496  
   497  		// column count
   498  		num, _, n := readLengthEncodedInteger(data)
   499  		if n-len(data) == 0 {
   500  			return int(num), nil
   501  		}
   502  
   503  		return 0, ErrMalformPkt
   504  	}
   505  	return 0, err
   506  }
   507  
   508  // Error Packet
   509  // http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-ERR_Packet
   510  func (mc *mysqlConn) handleErrorPacket(data []byte) error {
   511  	if data[0] != iERR {
   512  		return ErrMalformPkt
   513  	}
   514  
   515  	// 0xff [1 byte]
   516  
   517  	// Error Number [16 bit uint]
   518  	errno := binary.LittleEndian.Uint16(data[1:3])
   519  
   520  	pos := 3
   521  
   522  	// SQL State [optional: # + 5bytes string]
   523  	if data[3] == 0x23 {
   524  		//sqlstate := string(data[4 : 4+5])
   525  		pos = 9
   526  	}
   527  
   528  	// Error Message [string]
   529  	return &MySQLError{
   530  		Number:  errno,
   531  		Message: string(data[pos:]),
   532  	}
   533  }
   534  
   535  func readStatus(b []byte) statusFlag {
   536  	return statusFlag(b[0]) | statusFlag(b[1])<<8
   537  }
   538  
   539  // Ok Packet
   540  // http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet
   541  func (mc *mysqlConn) handleOkPacket(data []byte) error {
   542  	var n, m int
   543  
   544  	// 0x00 [1 byte]
   545  
   546  	// Affected rows [Length Coded Binary]
   547  	mc.affectedRows, _, n = readLengthEncodedInteger(data[1:])
   548  
   549  	// Insert id [Length Coded Binary]
   550  	mc.insertId, _, m = readLengthEncodedInteger(data[1+n:])
   551  
   552  	// server_status [2 bytes]
   553  	mc.status = readStatus(data[1+n+m : 1+n+m+2])
   554  	if err := mc.discardResults(); err != nil {
   555  		return err
   556  	}
   557  
   558  	// warning count [2 bytes]
   559  	if !mc.strict {
   560  		return nil
   561  	}
   562  
   563  	pos := 1 + n + m + 2
   564  	if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 {
   565  		return mc.getWarnings()
   566  	}
   567  	return nil
   568  }
   569  
   570  // Read Packets as Field Packets until EOF-Packet or an Error appears
   571  // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41
   572  func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
   573  	columns := make([]mysqlField, count)
   574  
   575  	for i := 0; ; i++ {
   576  		data, err := mc.readPacket()
   577  		if err != nil {
   578  			return nil, err
   579  		}
   580  
   581  		// EOF Packet
   582  		if data[0] == iEOF && (len(data) == 5 || len(data) == 1) {
   583  			if i == count {
   584  				return columns, nil
   585  			}
   586  			return nil, fmt.Errorf("column count mismatch n:%d len:%d", count, len(columns))
   587  		}
   588  
   589  		// Catalog
   590  		pos, err := skipLengthEncodedString(data)
   591  		if err != nil {
   592  			return nil, err
   593  		}
   594  
   595  		// Database [len coded string]
   596  		n, err := skipLengthEncodedString(data[pos:])
   597  		if err != nil {
   598  			return nil, err
   599  		}
   600  		pos += n
   601  
   602  		// Table [len coded string]
   603  		if mc.cfg.ColumnsWithAlias {
   604  			tableName, _, n, err := readLengthEncodedString(data[pos:])
   605  			if err != nil {
   606  				return nil, err
   607  			}
   608  			pos += n
   609  			columns[i].tableName = string(tableName)
   610  		} else {
   611  			n, err = skipLengthEncodedString(data[pos:])
   612  			if err != nil {
   613  				return nil, err
   614  			}
   615  			pos += n
   616  		}
   617  
   618  		// Original table [len coded string]
   619  		n, err = skipLengthEncodedString(data[pos:])
   620  		if err != nil {
   621  			return nil, err
   622  		}
   623  		pos += n
   624  
   625  		// Name [len coded string]
   626  		name, _, n, err := readLengthEncodedString(data[pos:])
   627  		if err != nil {
   628  			return nil, err
   629  		}
   630  		columns[i].name = string(name)
   631  		pos += n
   632  
   633  		// Original name [len coded string]
   634  		n, err = skipLengthEncodedString(data[pos:])
   635  		if err != nil {
   636  			return nil, err
   637  		}
   638  
   639  		// Filler [uint8]
   640  		// Charset [charset, collation uint8]
   641  		// Length [uint32]
   642  		pos += n + 1 + 2 + 4
   643  
   644  		// Field type [uint8]
   645  		columns[i].fieldType = data[pos]
   646  		pos++
   647  
   648  		// Flags [uint16]
   649  		columns[i].flags = fieldFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
   650  		pos += 2
   651  
   652  		// Decimals [uint8]
   653  		columns[i].decimals = data[pos]
   654  		//pos++
   655  
   656  		// Default value [len coded binary]
   657  		//if pos < len(data) {
   658  		//	defaultVal, _, err = bytesToLengthCodedBinary(data[pos:])
   659  		//}
   660  	}
   661  }
   662  
   663  // Read Packets as Field Packets until EOF-Packet or an Error appears
   664  // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow
   665  func (rows *textRows) readRow(dest []driver.Value) error {
   666  	mc := rows.mc
   667  
   668  	data, err := mc.readPacket()
   669  	if err != nil {
   670  		return err
   671  	}
   672  
   673  	// EOF Packet
   674  	if data[0] == iEOF && len(data) == 5 {
   675  		// server_status [2 bytes]
   676  		rows.mc.status = readStatus(data[3:])
   677  		if err := rows.mc.discardResults(); err != nil {
   678  			return err
   679  		}
   680  		rows.mc = nil
   681  		return io.EOF
   682  	}
   683  	if data[0] == iERR {
   684  		rows.mc = nil
   685  		return mc.handleErrorPacket(data)
   686  	}
   687  
   688  	// RowSet Packet
   689  	var n int
   690  	var isNull bool
   691  	pos := 0
   692  
   693  	for i := range dest {
   694  		// Read bytes and convert to string
   695  		dest[i], isNull, n, err = readLengthEncodedString(data[pos:])
   696  		pos += n
   697  		if err == nil {
   698  			if !isNull {
   699  				if !mc.parseTime {
   700  					continue
   701  				} else {
   702  					switch rows.columns[i].fieldType {
   703  					case fieldTypeTimestamp, fieldTypeDateTime,
   704  						fieldTypeDate, fieldTypeNewDate:
   705  						dest[i], err = parseDateTime(
   706  							string(dest[i].([]byte)),
   707  							mc.cfg.Loc,
   708  						)
   709  						if err == nil {
   710  							continue
   711  						}
   712  					default:
   713  						continue
   714  					}
   715  				}
   716  
   717  			} else {
   718  				dest[i] = nil
   719  				continue
   720  			}
   721  		}
   722  		return err // err != nil
   723  	}
   724  
   725  	return nil
   726  }
   727  
   728  // Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read
   729  func (mc *mysqlConn) readUntilEOF() error {
   730  	for {
   731  		data, err := mc.readPacket()
   732  
   733  		// No Err and no EOF Packet
   734  		if err == nil && data[0] != iEOF {
   735  			continue
   736  		}
   737  		if err == nil && data[0] == iEOF && len(data) == 5 {
   738  			mc.status = readStatus(data[3:])
   739  		}
   740  
   741  		return err // Err or EOF
   742  	}
   743  }
   744  
   745  /******************************************************************************
   746  *                           Prepared Statements                               *
   747  ******************************************************************************/
   748  
   749  // Prepare Result Packets
   750  // http://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html
   751  func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) {
   752  	data, err := stmt.mc.readPacket()
   753  	if err == nil {
   754  		// packet indicator [1 byte]
   755  		if data[0] != iOK {
   756  			return 0, stmt.mc.handleErrorPacket(data)
   757  		}
   758  
   759  		// statement id [4 bytes]
   760  		stmt.id = binary.LittleEndian.Uint32(data[1:5])
   761  
   762  		// Column count [16 bit uint]
   763  		columnCount := binary.LittleEndian.Uint16(data[5:7])
   764  
   765  		// Param count [16 bit uint]
   766  		stmt.paramCount = int(binary.LittleEndian.Uint16(data[7:9]))
   767  
   768  		// Reserved [8 bit]
   769  
   770  		// Warning count [16 bit uint]
   771  		if !stmt.mc.strict {
   772  			return columnCount, nil
   773  		}
   774  
   775  		// Check for warnings count > 0, only available in MySQL > 4.1
   776  		if len(data) >= 12 && binary.LittleEndian.Uint16(data[10:12]) > 0 {
   777  			return columnCount, stmt.mc.getWarnings()
   778  		}
   779  		return columnCount, nil
   780  	}
   781  	return 0, err
   782  }
   783  
   784  // http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html
   785  func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
   786  	maxLen := stmt.mc.maxPacketAllowed - 1
   787  	pktLen := maxLen
   788  
   789  	// After the header (bytes 0-3) follows before the data:
   790  	// 1 byte command
   791  	// 4 bytes stmtID
   792  	// 2 bytes paramID
   793  	const dataOffset = 1 + 4 + 2
   794  
   795  	// Can not use the write buffer since
   796  	// a) the buffer is too small
   797  	// b) it is in use
   798  	data := make([]byte, 4+1+4+2+len(arg))
   799  
   800  	copy(data[4+dataOffset:], arg)
   801  
   802  	for argLen := len(arg); argLen > 0; argLen -= pktLen - dataOffset {
   803  		if dataOffset+argLen < maxLen {
   804  			pktLen = dataOffset + argLen
   805  		}
   806  
   807  		stmt.mc.sequence = 0
   808  		// Add command byte [1 byte]
   809  		data[4] = comStmtSendLongData
   810  
   811  		// Add stmtID [32 bit]
   812  		data[5] = byte(stmt.id)
   813  		data[6] = byte(stmt.id >> 8)
   814  		data[7] = byte(stmt.id >> 16)
   815  		data[8] = byte(stmt.id >> 24)
   816  
   817  		// Add paramID [16 bit]
   818  		data[9] = byte(paramID)
   819  		data[10] = byte(paramID >> 8)
   820  
   821  		// Send CMD packet
   822  		err := stmt.mc.writePacket(data[:4+pktLen])
   823  		if err == nil {
   824  			data = data[pktLen-dataOffset:]
   825  			continue
   826  		}
   827  		return err
   828  
   829  	}
   830  
   831  	// Reset Packet Sequence
   832  	stmt.mc.sequence = 0
   833  	return nil
   834  }
   835  
   836  // Execute Prepared Statement
   837  // http://dev.mysql.com/doc/internals/en/com-stmt-execute.html
   838  func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
   839  	if len(args) != stmt.paramCount {
   840  		return fmt.Errorf(
   841  			"argument count mismatch (got: %d; has: %d)",
   842  			len(args),
   843  			stmt.paramCount,
   844  		)
   845  	}
   846  
   847  	const minPktLen = 4 + 1 + 4 + 1 + 4
   848  	mc := stmt.mc
   849  
   850  	// Reset packet-sequence
   851  	mc.sequence = 0
   852  
   853  	var data []byte
   854  
   855  	if len(args) == 0 {
   856  		data = mc.buf.takeBuffer(minPktLen)
   857  	} else {
   858  		data = mc.buf.takeCompleteBuffer()
   859  	}
   860  	if data == nil {
   861  		// can not take the buffer. Something must be wrong with the connection
   862  		errLog.Print(ErrBusyBuffer)
   863  		return driver.ErrBadConn
   864  	}
   865  
   866  	// command [1 byte]
   867  	data[4] = comStmtExecute
   868  
   869  	// statement_id [4 bytes]
   870  	data[5] = byte(stmt.id)
   871  	data[6] = byte(stmt.id >> 8)
   872  	data[7] = byte(stmt.id >> 16)
   873  	data[8] = byte(stmt.id >> 24)
   874  
   875  	// flags (0: CURSOR_TYPE_NO_CURSOR) [1 byte]
   876  	data[9] = 0x00
   877  
   878  	// iteration_count (uint32(1)) [4 bytes]
   879  	data[10] = 0x01
   880  	data[11] = 0x00
   881  	data[12] = 0x00
   882  	data[13] = 0x00
   883  
   884  	if len(args) > 0 {
   885  		pos := minPktLen
   886  
   887  		var nullMask []byte
   888  		if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= len(data) {
   889  			// buffer has to be extended but we don't know by how much so
   890  			// we depend on append after all data with known sizes fit.
   891  			// We stop at that because we deal with a lot of columns here
   892  			// which makes the required allocation size hard to guess.
   893  			tmp := make([]byte, pos+maskLen+typesLen)
   894  			copy(tmp[:pos], data[:pos])
   895  			data = tmp
   896  			nullMask = data[pos : pos+maskLen]
   897  			pos += maskLen
   898  		} else {
   899  			nullMask = data[pos : pos+maskLen]
   900  			for i := 0; i < maskLen; i++ {
   901  				nullMask[i] = 0
   902  			}
   903  			pos += maskLen
   904  		}
   905  
   906  		// newParameterBoundFlag 1 [1 byte]
   907  		data[pos] = 0x01
   908  		pos++
   909  
   910  		// type of each parameter [len(args)*2 bytes]
   911  		paramTypes := data[pos:]
   912  		pos += len(args) * 2
   913  
   914  		// value of each parameter [n bytes]
   915  		paramValues := data[pos:pos]
   916  		valuesCap := cap(paramValues)
   917  
   918  		for i, arg := range args {
   919  			// build NULL-bitmap
   920  			if arg == nil {
   921  				nullMask[i/8] |= 1 << (uint(i) & 7)
   922  				paramTypes[i+i] = fieldTypeNULL
   923  				paramTypes[i+i+1] = 0x00
   924  				continue
   925  			}
   926  
   927  			// cache types and values
   928  			switch v := arg.(type) {
   929  			case int64:
   930  				paramTypes[i+i] = fieldTypeLongLong
   931  				paramTypes[i+i+1] = 0x00
   932  
   933  				if cap(paramValues)-len(paramValues)-8 >= 0 {
   934  					paramValues = paramValues[:len(paramValues)+8]
   935  					binary.LittleEndian.PutUint64(
   936  						paramValues[len(paramValues)-8:],
   937  						uint64(v),
   938  					)
   939  				} else {
   940  					paramValues = append(paramValues,
   941  						uint64ToBytes(uint64(v))...,
   942  					)
   943  				}
   944  
   945  			case float64:
   946  				paramTypes[i+i] = fieldTypeDouble
   947  				paramTypes[i+i+1] = 0x00
   948  
   949  				if cap(paramValues)-len(paramValues)-8 >= 0 {
   950  					paramValues = paramValues[:len(paramValues)+8]
   951  					binary.LittleEndian.PutUint64(
   952  						paramValues[len(paramValues)-8:],
   953  						math.Float64bits(v),
   954  					)
   955  				} else {
   956  					paramValues = append(paramValues,
   957  						uint64ToBytes(math.Float64bits(v))...,
   958  					)
   959  				}
   960  
   961  			case bool:
   962  				paramTypes[i+i] = fieldTypeTiny
   963  				paramTypes[i+i+1] = 0x00
   964  
   965  				if v {
   966  					paramValues = append(paramValues, 0x01)
   967  				} else {
   968  					paramValues = append(paramValues, 0x00)
   969  				}
   970  
   971  			case []byte:
   972  				// Common case (non-nil value) first
   973  				if v != nil {
   974  					paramTypes[i+i] = fieldTypeString
   975  					paramTypes[i+i+1] = 0x00
   976  
   977  					if len(v) < mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 {
   978  						paramValues = appendLengthEncodedInteger(paramValues,
   979  							uint64(len(v)),
   980  						)
   981  						paramValues = append(paramValues, v...)
   982  					} else {
   983  						if err := stmt.writeCommandLongData(i, v); err != nil {
   984  							return err
   985  						}
   986  					}
   987  					continue
   988  				}
   989  
   990  				// Handle []byte(nil) as a NULL value
   991  				nullMask[i/8] |= 1 << (uint(i) & 7)
   992  				paramTypes[i+i] = fieldTypeNULL
   993  				paramTypes[i+i+1] = 0x00
   994  
   995  			case string:
   996  				paramTypes[i+i] = fieldTypeString
   997  				paramTypes[i+i+1] = 0x00
   998  
   999  				if len(v) < mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 {
  1000  					paramValues = appendLengthEncodedInteger(paramValues,
  1001  						uint64(len(v)),
  1002  					)
  1003  					paramValues = append(paramValues, v...)
  1004  				} else {
  1005  					if err := stmt.writeCommandLongData(i, []byte(v)); err != nil {
  1006  						return err
  1007  					}
  1008  				}
  1009  
  1010  			case time.Time:
  1011  				paramTypes[i+i] = fieldTypeString
  1012  				paramTypes[i+i+1] = 0x00
  1013  
  1014  				var val []byte
  1015  				if v.IsZero() {
  1016  					val = []byte("0000-00-00")
  1017  				} else {
  1018  					val = []byte(v.In(mc.cfg.Loc).Format(timeFormat))
  1019  				}
  1020  
  1021  				paramValues = appendLengthEncodedInteger(paramValues,
  1022  					uint64(len(val)),
  1023  				)
  1024  				paramValues = append(paramValues, val...)
  1025  
  1026  			default:
  1027  				return fmt.Errorf("can not convert type: %T", arg)
  1028  			}
  1029  		}
  1030  
  1031  		// Check if param values exceeded the available buffer
  1032  		// In that case we must build the data packet with the new values buffer
  1033  		if valuesCap != cap(paramValues) {
  1034  			data = append(data[:pos], paramValues...)
  1035  			mc.buf.buf = data
  1036  		}
  1037  
  1038  		pos += len(paramValues)
  1039  		data = data[:pos]
  1040  	}
  1041  
  1042  	return mc.writePacket(data)
  1043  }
  1044  
  1045  func (mc *mysqlConn) discardResults() error {
  1046  	for mc.status&statusMoreResultsExists != 0 {
  1047  		resLen, err := mc.readResultSetHeaderPacket()
  1048  		if err != nil {
  1049  			return err
  1050  		}
  1051  		if resLen > 0 {
  1052  			// columns
  1053  			if err := mc.readUntilEOF(); err != nil {
  1054  				return err
  1055  			}
  1056  			// rows
  1057  			if err := mc.readUntilEOF(); err != nil {
  1058  				return err
  1059  			}
  1060  		} else {
  1061  			mc.status &^= statusMoreResultsExists
  1062  		}
  1063  	}
  1064  	return nil
  1065  }
  1066  
  1067  // http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html
  1068  func (rows *binaryRows) readRow(dest []driver.Value) error {
  1069  	data, err := rows.mc.readPacket()
  1070  	if err != nil {
  1071  		return err
  1072  	}
  1073  
  1074  	// packet indicator [1 byte]
  1075  	if data[0] != iOK {
  1076  		// EOF Packet
  1077  		if data[0] == iEOF && len(data) == 5 {
  1078  			rows.mc.status = readStatus(data[3:])
  1079  			if err := rows.mc.discardResults(); err != nil {
  1080  				return err
  1081  			}
  1082  			rows.mc = nil
  1083  			return io.EOF
  1084  		}
  1085  		rows.mc = nil
  1086  
  1087  		// Error otherwise
  1088  		return rows.mc.handleErrorPacket(data)
  1089  	}
  1090  
  1091  	// NULL-bitmap,  [(column-count + 7 + 2) / 8 bytes]
  1092  	pos := 1 + (len(dest)+7+2)>>3
  1093  	nullMask := data[1:pos]
  1094  
  1095  	for i := range dest {
  1096  		// Field is NULL
  1097  		// (byte >> bit-pos) % 2 == 1
  1098  		if ((nullMask[(i+2)>>3] >> uint((i+2)&7)) & 1) == 1 {
  1099  			dest[i] = nil
  1100  			continue
  1101  		}
  1102  
  1103  		// Convert to byte-coded string
  1104  		switch rows.columns[i].fieldType {
  1105  		case fieldTypeNULL:
  1106  			dest[i] = nil
  1107  			continue
  1108  
  1109  		// Numeric Types
  1110  		case fieldTypeTiny:
  1111  			if rows.columns[i].flags&flagUnsigned != 0 {
  1112  				dest[i] = int64(data[pos])
  1113  			} else {
  1114  				dest[i] = int64(int8(data[pos]))
  1115  			}
  1116  			pos++
  1117  			continue
  1118  
  1119  		case fieldTypeShort, fieldTypeYear:
  1120  			if rows.columns[i].flags&flagUnsigned != 0 {
  1121  				dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2]))
  1122  			} else {
  1123  				dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2])))
  1124  			}
  1125  			pos += 2
  1126  			continue
  1127  
  1128  		case fieldTypeInt24, fieldTypeLong:
  1129  			if rows.columns[i].flags&flagUnsigned != 0 {
  1130  				dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4]))
  1131  			} else {
  1132  				dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4])))
  1133  			}
  1134  			pos += 4
  1135  			continue
  1136  
  1137  		case fieldTypeLongLong:
  1138  			if rows.columns[i].flags&flagUnsigned != 0 {
  1139  				val := binary.LittleEndian.Uint64(data[pos : pos+8])
  1140  				if val > math.MaxInt64 {
  1141  					dest[i] = uint64ToString(val)
  1142  				} else {
  1143  					dest[i] = int64(val)
  1144  				}
  1145  			} else {
  1146  				dest[i] = int64(binary.LittleEndian.Uint64(data[pos : pos+8]))
  1147  			}
  1148  			pos += 8
  1149  			continue
  1150  
  1151  		case fieldTypeFloat:
  1152  			dest[i] = float64(math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4])))
  1153  			pos += 4
  1154  			continue
  1155  
  1156  		case fieldTypeDouble:
  1157  			dest[i] = math.Float64frombits(binary.LittleEndian.Uint64(data[pos : pos+8]))
  1158  			pos += 8
  1159  			continue
  1160  
  1161  		// Length coded Binary Strings
  1162  		case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar,
  1163  			fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB,
  1164  			fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB,
  1165  			fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON:
  1166  			var isNull bool
  1167  			var n int
  1168  			dest[i], isNull, n, err = readLengthEncodedString(data[pos:])
  1169  			pos += n
  1170  			if err == nil {
  1171  				if !isNull {
  1172  					continue
  1173  				} else {
  1174  					dest[i] = nil
  1175  					continue
  1176  				}
  1177  			}
  1178  			return err
  1179  
  1180  		case
  1181  			fieldTypeDate, fieldTypeNewDate, // Date YYYY-MM-DD
  1182  			fieldTypeTime,                         // Time [-][H]HH:MM:SS[.fractal]
  1183  			fieldTypeTimestamp, fieldTypeDateTime: // Timestamp YYYY-MM-DD HH:MM:SS[.fractal]
  1184  
  1185  			num, isNull, n := readLengthEncodedInteger(data[pos:])
  1186  			pos += n
  1187  
  1188  			switch {
  1189  			case isNull:
  1190  				dest[i] = nil
  1191  				continue
  1192  			case rows.columns[i].fieldType == fieldTypeTime:
  1193  				// database/sql does not support an equivalent to TIME, return a string
  1194  				var dstlen uint8
  1195  				switch decimals := rows.columns[i].decimals; decimals {
  1196  				case 0x00, 0x1f:
  1197  					dstlen = 8
  1198  				case 1, 2, 3, 4, 5, 6:
  1199  					dstlen = 8 + 1 + decimals
  1200  				default:
  1201  					return fmt.Errorf(
  1202  						"protocol error, illegal decimals value %d",
  1203  						rows.columns[i].decimals,
  1204  					)
  1205  				}
  1206  				dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, true)
  1207  			case rows.mc.parseTime:
  1208  				dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc)
  1209  			default:
  1210  				var dstlen uint8
  1211  				if rows.columns[i].fieldType == fieldTypeDate {
  1212  					dstlen = 10
  1213  				} else {
  1214  					switch decimals := rows.columns[i].decimals; decimals {
  1215  					case 0x00, 0x1f:
  1216  						dstlen = 19
  1217  					case 1, 2, 3, 4, 5, 6:
  1218  						dstlen = 19 + 1 + decimals
  1219  					default:
  1220  						return fmt.Errorf(
  1221  							"protocol error, illegal decimals value %d",
  1222  							rows.columns[i].decimals,
  1223  						)
  1224  					}
  1225  				}
  1226  				dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, false)
  1227  			}
  1228  
  1229  			if err == nil {
  1230  				pos += int(num)
  1231  				continue
  1232  			} else {
  1233  				return err
  1234  			}
  1235  
  1236  		// Please report if this happens!
  1237  		default:
  1238  			return fmt.Errorf("unknown field type %d", rows.columns[i].fieldType)
  1239  		}
  1240  	}
  1241  
  1242  	return nil
  1243  }