github.com/XiaoMi/Gaea@v1.2.5/backend/direct_connection.go (about)

     1  // Copyright 2019 The Gaea Authors. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package backend
    16  
    17  import (
    18  	"bytes"
    19  	"encoding/binary"
    20  	"errors"
    21  	"fmt"
    22  	"net"
    23  	"strings"
    24  
    25  	sqlerr "github.com/XiaoMi/Gaea/core/errors"
    26  	"github.com/XiaoMi/Gaea/log"
    27  	"github.com/XiaoMi/Gaea/mysql"
    28  	"github.com/XiaoMi/Gaea/util/sync2"
    29  )
    30  
    31  // DirectConnection means connection to backend mysql
    32  type DirectConnection struct {
    33  	conn *mysql.Conn
    34  
    35  	addr     string
    36  	user     string
    37  	password string
    38  	db       string
    39  
    40  	capability uint32
    41  
    42  	sessionVariables *mysql.SessionVariables
    43  
    44  	status uint16
    45  
    46  	collation mysql.CollationID
    47  	charset   string
    48  	salt      []byte
    49  
    50  	defaultCollation mysql.CollationID
    51  	defaultCharset   string
    52  
    53  	pkgErr error
    54  	closed sync2.AtomicBool
    55  }
    56  
    57  // NewDirectConnection return direct and authorised connection to mysql with real net connection
    58  func NewDirectConnection(addr string, user string, password string, db string, charset string, collationID mysql.CollationID) (*DirectConnection, error) {
    59  	dc := &DirectConnection{
    60  		addr:             addr,
    61  		user:             user,
    62  		password:         password,
    63  		db:               db,
    64  		charset:          charset,
    65  		collation:        collationID,
    66  		defaultCharset:   charset,
    67  		defaultCollation: collationID,
    68  		closed:           sync2.NewAtomicBool(false),
    69  		sessionVariables: mysql.NewSessionVariables(),
    70  	}
    71  	err := dc.connect()
    72  	return dc, err
    73  }
    74  
    75  // connect means real connection to backend mysql after authorization
    76  func (dc *DirectConnection) connect() error {
    77  	if dc.conn != nil {
    78  		dc.conn.Close()
    79  	}
    80  
    81  	typ := "tcp"
    82  	if strings.Contains(dc.addr, "/") {
    83  		typ = "unix"
    84  	}
    85  
    86  	netConn, err := net.Dial(typ, dc.addr)
    87  	if err != nil {
    88  		return err
    89  	}
    90  
    91  	tcpConn := netConn.(*net.TCPConn)
    92  	// SetNoDelay controls whether the operating system should delay packet transmission
    93  	// in hopes of sending fewer packets (Nagle's algorithm).
    94  	// The default is true (no delay),
    95  	// meaning that data is sent as soon as possible after a Write.
    96  	tcpConn.SetNoDelay(true)
    97  	tcpConn.SetKeepAlive(true)
    98  	dc.conn = mysql.NewConn(tcpConn)
    99  
   100  	// step1: read handshake requirements
   101  	if err := dc.readInitialHandshake(); err != nil {
   102  		dc.conn.Close()
   103  		return err
   104  	}
   105  
   106  	// step2: write handshake response
   107  	if err := dc.writeHandshakeResponse41(); err != nil {
   108  		dc.conn.Close()
   109  
   110  		return err
   111  	}
   112  
   113  	response, err := dc.readPacket()
   114  	if err != nil {
   115  		dc.conn.Close()
   116  		return err
   117  	}
   118  
   119  	switch response[0] {
   120  	case mysql.OKHeader:
   121  	default:
   122  		return errors.New("dc connection handshake failed with mysql")
   123  	}
   124  
   125  	// we must always use autocommit
   126  	if !dc.IsAutoCommit() {
   127  		if _, err := dc.exec("set autocommit = 1", 0); err != nil {
   128  			dc.conn.Close()
   129  
   130  			return err
   131  		}
   132  	}
   133  
   134  	return nil
   135  }
   136  
   137  // Close close connection to backend mysql and reset conn structure
   138  func (dc *DirectConnection) Close() {
   139  	if dc.conn != nil {
   140  		dc.conn.Close()
   141  	}
   142  
   143  	dc.conn = nil
   144  	dc.salt = nil
   145  	dc.pkgErr = nil
   146  	dc.closed.Set(true)
   147  
   148  	return
   149  }
   150  
   151  // IsClosed check if connection closed
   152  func (dc *DirectConnection) IsClosed() bool {
   153  	return dc.closed.Get()
   154  }
   155  
   156  // readPacket doesn't use EphemeralBuffer
   157  func (dc *DirectConnection) readPacket() ([]byte, error) {
   158  	data, err := dc.conn.ReadPacket()
   159  	dc.pkgErr = err
   160  	return data, err
   161  }
   162  
   163  // writePacket doesn't use EphemeralBuffer
   164  func (dc *DirectConnection) writePacket(data []byte) error {
   165  	err := dc.conn.WritePacket(data)
   166  	if err != nil && strings.Contains(err.Error(), "broken pipe") {
   167  		// retry 3 times, close dc's conn、reset dc's stats and reconnect
   168  		for i := 0; i < 3; i++ {
   169  			dc.Close()
   170  			e := dc.connect()
   171  			if e == nil { // no need to write data again
   172  				break
   173  			}
   174  		}
   175  
   176  	}
   177  	return err
   178  }
   179  
   180  // writeEphemeralPacket
   181  func (dc *DirectConnection) writeEphemeralPacket() error {
   182  	err := dc.conn.WriteEphemeralPacket()
   183  	if err != nil && strings.Contains(err.Error(), "broken pipe") {
   184  		// retry 3 times, close dc's conn、reset dc's stats and reconnect
   185  		for i := 0; i < 3; i++ {
   186  			dc.Close()
   187  			e := dc.connect()
   188  			if e == nil { // no need to write data again and ephemeral buffer is recycled
   189  				break
   190  			}
   191  		}
   192  	}
   193  	return err
   194  }
   195  
   196  func (dc *DirectConnection) readInitialHandshake() error {
   197  	data, err := dc.readPacket()
   198  	if err != nil {
   199  		return err
   200  	}
   201  
   202  	if data[0] == mysql.ErrHeader {
   203  		return errors.New("read initial handshake error")
   204  	}
   205  
   206  	if data[0] < mysql.MinProtocolVersion {
   207  		return fmt.Errorf("invalid protocol version %d, must >= 10", data[0])
   208  	}
   209  
   210  	//skip mysql version
   211  	//mysql version end with 0x00
   212  	pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1
   213  
   214  	// get connection id
   215  	dc.conn.ConnectionID = binary.LittleEndian.Uint32(data[pos : pos+4])
   216  
   217  	pos += 4
   218  
   219  	dc.salt = append(dc.salt, data[pos:pos+8]...)
   220  
   221  	//skip filter
   222  	pos += 8 + 1
   223  
   224  	//capability lower 2 bytes
   225  	dc.capability = uint32(binary.LittleEndian.Uint16(data[pos : pos+2]))
   226  
   227  	pos += 2
   228  
   229  	if len(data) > pos {
   230  		//skip server charset
   231  		//c.charset = data[pos]
   232  		pos++
   233  
   234  		dc.status = binary.LittleEndian.Uint16(data[pos : pos+2])
   235  		pos += 2
   236  
   237  		dc.capability = uint32(binary.LittleEndian.Uint16(data[pos:pos+2]))<<16 | dc.capability
   238  
   239  		pos += 2
   240  
   241  		//skip auth data len or [00]
   242  		//skip reserved (all [00])
   243  		pos += 10 + 1
   244  
   245  		// The documentation is ambiguous about the length.
   246  		// The official Python library uses the fixed length 12
   247  		// mysql-proxy also use 12
   248  		// which is not documented but seems to work.
   249  		dc.salt = append(dc.salt, data[pos:pos+12]...)
   250  	}
   251  
   252  	return nil
   253  }
   254  
   255  // writeHandshakeResponse41 writes the handshake response.
   256  func (dc *DirectConnection) writeHandshakeResponse41() error {
   257  	// Adjust client capability flags based on server support
   258  	capability := mysql.ClientProtocol41 | mysql.ClientSecureConnection |
   259  		mysql.ClientLongPassword | mysql.ClientTransactions | mysql.ClientLongFlag
   260  	capability &= dc.capability
   261  
   262  	//we only support secure connection
   263  	auth := mysql.CalcPassword(dc.salt, []byte(dc.password))
   264  
   265  	length := 4 + // Client capability flags
   266  		4 + // Max-packet size.
   267  		1 + // Character set.
   268  		23 + // Reserved.
   269  		mysql.LenNullString(dc.user) + // user
   270  		1 +
   271  		len(auth)
   272  
   273  	if len(dc.db) > 0 {
   274  		capability |= mysql.ClientConnectWithDB
   275  		length += mysql.LenNullString(dc.db)
   276  	}
   277  
   278  	dc.capability = capability
   279  
   280  	data := make([]byte, length, length)
   281  	pos := 0
   282  
   283  	// Client capability flags.
   284  	pos = mysql.WriteUint32(data, pos, capability)
   285  
   286  	// Max-packet size, always 0. See doc.go.
   287  	pos = mysql.WriteZeroes(data, pos, 4)
   288  
   289  	// Character set.
   290  	pos = mysql.WriteByte(data, pos, byte(dc.collation))
   291  
   292  	// 23 reserved bytes, all 0.
   293  	pos = mysql.WriteZeroes(data, pos, 23)
   294  
   295  	// user type: null terminated string
   296  	pos = mysql.WriteNullString(data, pos, dc.user)
   297  
   298  	// auth [length encoded integer]
   299  	data[pos] = byte(len(auth))
   300  	pos++
   301  	pos += copy(data[pos:], auth)
   302  
   303  	// db type: null terminated string
   304  	if len(dc.db) > 0 {
   305  		pos = mysql.WriteNullString(data, pos, dc.db)
   306  	}
   307  
   308  	if err := dc.writePacket(data); err != nil {
   309  		return err
   310  	}
   311  
   312  	return nil
   313  }
   314  
   315  // writeComInitDB changes the default database to use.
   316  // Client -> Server.DirectConnection
   317  // Returns SQLError(CRServerGone) if it can't.
   318  func (dc *DirectConnection) writeComInitDB(db string) error {
   319  	dc.conn.SetSequence(0)
   320  	data := make([]byte, len(db)+1, len(db)+1)
   321  	data[0] = mysql.ComInitDB
   322  	copy(data[1:], db)
   323  	if err := dc.writePacket(data); err != nil {
   324  		return err
   325  	}
   326  	return nil
   327  }
   328  
   329  // writeComQuery send ComQuery request use EphemeralBuffer
   330  func (dc *DirectConnection) writeComQuery(sql string) error {
   331  	dc.conn.SetSequence(0)
   332  	data := dc.conn.StartEphemeralPacket(len(sql) + 1)
   333  	data[0] = mysql.ComQuery
   334  	copy(data[1:], sql)
   335  	if err := dc.writeEphemeralPacket(); err != nil {
   336  		return err
   337  	}
   338  	return nil
   339  }
   340  
   341  func (dc *DirectConnection) writeComFieldList(table string, wildcard string) error {
   342  	dc.conn.SetSequence(0)
   343  	length := 1 +
   344  		mysql.LenNullString(table) +
   345  		mysql.LenNullString(wildcard)
   346  
   347  	data := make([]byte, length, length)
   348  	pos := 0
   349  
   350  	pos = mysql.WriteByte(data, 0, mysql.ComFieldList)
   351  	pos = mysql.WriteNullString(data, pos, table)
   352  	pos = mysql.WriteNullString(data, pos, wildcard)
   353  
   354  	if err := dc.writePacket(data); err != nil {
   355  		return err
   356  	}
   357  
   358  	return nil
   359  }
   360  
   361  // Ping implements mysql ping command.
   362  func (dc *DirectConnection) Ping() error {
   363  	dc.conn.SetSequence(0)
   364  	if err := dc.writePacket([]byte{mysql.ComPing}); err != nil {
   365  		return err
   366  	}
   367  	data, err := dc.readPacket()
   368  	if err != nil {
   369  		return err
   370  	}
   371  	switch data[0] {
   372  	case mysql.OKHeader:
   373  		return nil
   374  	case mysql.ErrHeader:
   375  		return errors.New("dc connection ping failed")
   376  	}
   377  	return fmt.Errorf("unexpected packet type: %d", data[0])
   378  }
   379  
   380  // UseDB send ComInitDB to backend mysql
   381  func (dc *DirectConnection) UseDB(dbName string) error {
   382  	dc.conn.SetSequence(0)
   383  	if dc.db == dbName || len(dbName) == 0 {
   384  		return nil
   385  	}
   386  
   387  	if err := dc.writeComInitDB(dbName); err != nil {
   388  		return err
   389  	}
   390  
   391  	if r, err := dc.readPacket(); err != nil {
   392  		return err
   393  	} else if !mysql.IsOKPacket(r) {
   394  		return errors.New("dc connection use db failed")
   395  	}
   396  
   397  	dc.db = dbName
   398  	return nil
   399  }
   400  
   401  // GetDB return database name
   402  func (dc *DirectConnection) GetDB() string {
   403  	return dc.db
   404  }
   405  
   406  // GetAddr return addr of backend mysql
   407  func (dc *DirectConnection) GetAddr() string {
   408  	return dc.addr
   409  }
   410  
   411  // Execute send ComQuery or ComStmtPrepare/ComStmtExecute/ComStmtClose to backend mysql
   412  func (dc *DirectConnection) Execute(sql string, maxRows int) (*mysql.Result, error) {
   413  	return dc.exec(sql, maxRows)
   414  }
   415  
   416  // Begin send ComQuery with 'begin' to backend mysql to start transaction
   417  func (dc *DirectConnection) Begin() error {
   418  	_, err := dc.exec("begin", 0)
   419  	return err
   420  }
   421  
   422  // Commit send ComQuery with 'commit' to backend mysql to commit transaction
   423  func (dc *DirectConnection) Commit() error {
   424  	_, err := dc.exec("commit", 0)
   425  	return err
   426  }
   427  
   428  // Rollback send ComQuery with 'rollback' to backend mysql to rollback transaction
   429  func (dc *DirectConnection) Rollback() error {
   430  	_, err := dc.exec("rollback", 0)
   431  	return err
   432  }
   433  
   434  // SetAutoCommit trun on/off autocommit
   435  func (dc *DirectConnection) SetAutoCommit(v uint8) error {
   436  	if v == 0 {
   437  		if _, err := dc.exec("set autocommit = 0", 0); err != nil {
   438  			dc.conn.Close()
   439  
   440  			return err
   441  		}
   442  	} else {
   443  		if _, err := dc.exec("set autocommit = 1", 0); err != nil {
   444  			dc.conn.Close()
   445  
   446  			return err
   447  		}
   448  	}
   449  	return nil
   450  }
   451  
   452  // SetCharset set charset of connection to backend mysql
   453  func (dc *DirectConnection) SetCharset(charset string, collation mysql.CollationID) ( /*changed*/ bool, error) {
   454  	charset = strings.Trim(charset, "\"'`")
   455  
   456  	if collation == 0 || collation > 247 {
   457  		collation = mysql.CollationNames[mysql.Charsets[charset]]
   458  	}
   459  
   460  	if dc.charset == charset && dc.collation == collation {
   461  		return false, nil
   462  	}
   463  
   464  	_, ok := mysql.CharsetIds[charset]
   465  	if !ok {
   466  		return false, fmt.Errorf("invalid charset %s", charset)
   467  	}
   468  
   469  	_, ok = mysql.Collations[collation]
   470  	if !ok {
   471  		return false, fmt.Errorf("invalid collation %d", collation)
   472  	}
   473  
   474  	dc.collation = collation
   475  	dc.charset = charset
   476  	return true, nil
   477  }
   478  
   479  // ResetConnection reset connection stattus, include transaction、autocommit、charset、sql_mode .etc
   480  func (dc *DirectConnection) ResetConnection() error {
   481  	if dc.IsInTransaction() {
   482  		log.Debug("get transaction connection from pool, addr: %s, user: %s, db: %s, status: %d", dc.addr, dc.user, dc.db, dc.status)
   483  		if err := dc.Rollback(); err != nil {
   484  			log.Warn("rollback in reset connection error, addr: %s, user: %s, db: %s, status: %d, err: %v", dc.addr, dc.user, dc.db, dc.status, err)
   485  			return err
   486  		}
   487  	}
   488  
   489  	if !dc.IsAutoCommit() {
   490  		log.Debug("get autocommit = 0 connection from pool, addr: %s, user: %s, db: %s, status: %d", dc.addr, dc.user, dc.db, dc.status)
   491  		if err := dc.SetAutoCommit(1); err != nil {
   492  			log.Warn("set autocommit = 1 in reset connection error, addr: %s, user: %s, db: %s, status: %d, err: %v", dc.addr, dc.user, dc.db, dc.status, err)
   493  			return err
   494  		}
   495  	}
   496  
   497  	return nil
   498  }
   499  
   500  // SetSessionVariables set direction variables according to Session
   501  func (dc *DirectConnection) SetSessionVariables(frontend *mysql.SessionVariables) (bool, error) {
   502  	return dc.sessionVariables.SetEqualsWith(frontend)
   503  }
   504  
   505  // WriteSetStatement execute sql
   506  func (dc *DirectConnection) WriteSetStatement() error {
   507  	var setVariableSQL bytes.Buffer
   508  	collation, ok := mysql.Collations[dc.collation]
   509  	if !ok {
   510  		return fmt.Errorf("invalid collationId: %v", dc.collation)
   511  	}
   512  	appendSetCharset(&setVariableSQL, dc.charset, collation)
   513  
   514  	for _, v := range dc.sessionVariables.GetAll() {
   515  		appendSetVariable(&setVariableSQL, v.Name(), v.Get())
   516  	}
   517  
   518  	for _, v := range dc.sessionVariables.GetUnusedAndClear() {
   519  		appendSetVariableToDefault(&setVariableSQL, v.Name())
   520  	}
   521  
   522  	setSQL := setVariableSQL.String()
   523  	if setSQL == "" {
   524  		return nil
   525  	}
   526  	if _, err := dc.exec(setSQL, 0); err != nil {
   527  		return err
   528  	}
   529  	return nil
   530  }
   531  
   532  // FieldList send ComFieldList to backend mysql
   533  func (dc *DirectConnection) FieldList(table string, wildcard string) ([]*mysql.Field, error) {
   534  	if err := dc.writeComFieldList(table, wildcard); err != nil {
   535  		return nil, err
   536  	}
   537  	fs := make([]*mysql.Field, 0, 4)
   538  	var f *mysql.Field
   539  	for {
   540  		data, err := dc.readPacket()
   541  		if err != nil {
   542  			return nil, err
   543  		}
   544  
   545  		// EOF Packet
   546  		if dc.isEOFPacket(data) {
   547  			return fs, nil
   548  		}
   549  
   550  		if data[0] == mysql.ErrHeader {
   551  			return nil, dc.handleErrorPacket(data)
   552  		}
   553  
   554  		if f, err = mysql.FieldData(data).Parse(); err != nil {
   555  			return nil, err
   556  		}
   557  		fs = append(fs, f)
   558  	}
   559  }
   560  
   561  // execute ComQuery command
   562  func (dc *DirectConnection) exec(query string, maxRows int) (*mysql.Result, error) {
   563  	if err := dc.writeComQuery(query); err != nil {
   564  		return nil, err
   565  	}
   566  
   567  	return dc.readResult(false, maxRows)
   568  }
   569  
   570  // read resultset from mysql
   571  func (dc *DirectConnection) readResultSet(data []byte, binary bool, maxRows int) (*mysql.Result, error) {
   572  	result := &mysql.Result{
   573  		Status:       0,
   574  		InsertID:     0,
   575  		AffectedRows: 0,
   576  
   577  		Resultset: &mysql.Resultset{},
   578  	}
   579  
   580  	// column count
   581  	pos := 0
   582  	count, pos, _, _ := mysql.ReadLenEncInt(data, pos)
   583  
   584  	if pos-len(data) != 0 {
   585  		return nil, mysql.ErrMalformPacket
   586  	}
   587  
   588  	result.Fields = make([]*mysql.Field, count)
   589  	result.FieldNames = make(map[string]int, count)
   590  
   591  	if err := dc.readResultColumns(result); err != nil {
   592  		return nil, err
   593  	}
   594  
   595  	if err := dc.readResultRows(result, binary, maxRows); err != nil {
   596  		return nil, err
   597  	}
   598  
   599  	return result, nil
   600  }
   601  
   602  // readResultColumns read column information
   603  func (dc *DirectConnection) readResultColumns(result *mysql.Result) (err error) {
   604  	var i = 0
   605  	var data []byte
   606  
   607  	for {
   608  		data, err = dc.readPacket()
   609  		if err != nil {
   610  			return
   611  		}
   612  
   613  		// EOF Packet
   614  		if dc.isEOFPacket(data) {
   615  			if dc.capability&mysql.ClientProtocol41 > 0 {
   616  				//result.Warnings = binary.LittleEndian.Uint16(data[1:])
   617  				//todo add strict_mode, warning will be treat as error
   618  				result.Status = binary.LittleEndian.Uint16(data[3:])
   619  				dc.status = result.Status
   620  			}
   621  
   622  			if i != len(result.Fields) {
   623  				err = mysql.ErrMalformPacket
   624  			}
   625  
   626  			return
   627  		}
   628  
   629  		if data[0] == mysql.ErrHeader {
   630  			return dc.handleErrorPacket(data)
   631  		}
   632  
   633  		result.Fields[i], err = mysql.FieldData(data).Parse()
   634  		if err != nil {
   635  			return
   636  		}
   637  
   638  		result.FieldNames[string(result.Fields[i].Name)] = i
   639  
   640  		i++
   641  	}
   642  }
   643  
   644  // readResultRows read result rows
   645  func (dc *DirectConnection) readResultRows(result *mysql.Result, isBinary bool, maxRows int) (err error) {
   646  	var data []byte
   647  
   648  	for {
   649  		data, err = dc.readPacket()
   650  		if err != nil {
   651  			return
   652  		}
   653  
   654  		// EOF Packet
   655  		if dc.isEOFPacket(data) {
   656  			if dc.capability&mysql.ClientProtocol41 > 0 {
   657  				//result.Warnings = binary.LittleEndian.Uint16(data[1:])
   658  				//todo add strict_mode, warning will be treat as error
   659  				result.Status = binary.LittleEndian.Uint16(data[3:])
   660  				dc.status = result.Status
   661  			}
   662  
   663  			break
   664  		}
   665  
   666  		if data[0] == mysql.ErrHeader {
   667  			return dc.handleErrorPacket(data)
   668  		}
   669  
   670  		result.RowDatas = append(result.RowDatas, data)
   671  		if maxRows > 0 && len(result.RowDatas) >= maxRows {
   672  			if err := dc.drainResults(); err != nil {
   673  				return fmt.Errorf("%v %d, drain error: %v", sqlerr.ErrRowsLimitExceeded, maxRows, err)
   674  			}
   675  			return fmt.Errorf("%v %d", sqlerr.ErrRowsLimitExceeded, maxRows)
   676  		}
   677  	}
   678  
   679  	result.Values = make([][]interface{}, len(result.RowDatas))
   680  	for i := range result.Values {
   681  		result.Values[i], err = result.RowDatas[i].Parse(result.Fields, isBinary)
   682  		if err != nil {
   683  			return err
   684  		}
   685  	}
   686  
   687  	return nil
   688  }
   689  
   690  // drainResults will read all packets for a result set and ignore them.
   691  func (dc *DirectConnection) drainResults() error {
   692  	for {
   693  		data, err := dc.conn.ReadEphemeralPacket()
   694  		if err != nil {
   695  			dc.conn.RecycleReadPacket()
   696  			return err
   697  		}
   698  
   699  		if dc.isEOFPacket(data) {
   700  			dc.conn.RecycleReadPacket()
   701  			return nil
   702  		} else if data[0] == mysql.ErrHeader {
   703  			err := dc.handleErrorPacket(data)
   704  			dc.conn.RecycleReadPacket()
   705  			return err
   706  		}
   707  		dc.conn.RecycleReadPacket()
   708  	}
   709  }
   710  
   711  func (dc *DirectConnection) isEOFPacket(data []byte) bool {
   712  	return data[0] == mysql.EOFHeader && len(data) <= 5
   713  }
   714  
   715  func (dc *DirectConnection) handleOKPacket(data []byte) (*mysql.Result, error) {
   716  	var pos = 1
   717  
   718  	r := new(mysql.Result)
   719  
   720  	r.AffectedRows, pos, _, _ = mysql.ReadLenEncInt(data, pos)
   721  	r.InsertID, pos, _, _ = mysql.ReadLenEncInt(data, pos)
   722  
   723  	if dc.capability&mysql.ClientProtocol41 > 0 {
   724  		r.Status = binary.LittleEndian.Uint16(data[pos:])
   725  		dc.status = r.Status
   726  		pos += 2
   727  
   728  		// TODO strict_mode, check warnings as error
   729  		// Warnings := binary.LittleEndian.Uint16(data[pos:])
   730  		// pos += 2
   731  	} else if dc.capability&mysql.ClientTransactions > 0 {
   732  		r.Status = binary.LittleEndian.Uint16(data[pos:])
   733  		dc.status = r.Status
   734  		pos += 2
   735  	}
   736  
   737  	//info
   738  	return r, nil
   739  }
   740  
   741  func (dc *DirectConnection) handleErrorPacket(data []byte) error {
   742  	e := new(mysql.SQLError)
   743  
   744  	var pos = 1
   745  
   746  	e.Code = binary.LittleEndian.Uint16(data[pos:])
   747  	pos += 2
   748  
   749  	if dc.capability&mysql.ClientProtocol41 > 0 {
   750  		// skip '#'
   751  		pos++
   752  		e.State = string(data[pos : pos+5])
   753  		pos += 5
   754  	}
   755  
   756  	e.Message = string(data[pos:])
   757  
   758  	return e
   759  }
   760  
   761  func (dc *DirectConnection) readResult(binary bool, maxRows int) (*mysql.Result, error) {
   762  	data, err := dc.readPacket()
   763  	if err != nil {
   764  		return nil, err
   765  	}
   766  	if data[0] == mysql.OKHeader {
   767  		return dc.handleOKPacket(data)
   768  	} else if data[0] == mysql.ErrHeader {
   769  		return nil, dc.handleErrorPacket(data)
   770  	} else if data[0] == mysql.LocalInFileHeader {
   771  		return nil, mysql.ErrMalformPacket
   772  	}
   773  
   774  	return dc.readResultSet(data, binary, maxRows)
   775  }
   776  
   777  // IsAutoCommit check if autocommit
   778  func (dc *DirectConnection) IsAutoCommit() bool {
   779  	return dc.status&mysql.ServerStatusAutocommit > 0
   780  }
   781  
   782  // IsInTransaction check if in transaction
   783  func (dc *DirectConnection) IsInTransaction() bool {
   784  	return dc.status&mysql.ServerStatusInTrans > 0
   785  }
   786  
   787  // GetCharset return charset of specific connection
   788  func (dc *DirectConnection) GetCharset() string {
   789  	return dc.charset
   790  }
   791  
   792  func appendSetCharset(buf *bytes.Buffer, charset string, collation string) {
   793  	if buf.Len() != 0 {
   794  		buf.WriteString(",")
   795  	} else {
   796  		buf.WriteString("SET NAMES '")
   797  	}
   798  	buf.WriteString(charset)
   799  	buf.WriteString("' COLLATE '")
   800  	buf.WriteString(collation)
   801  	buf.WriteString("'")
   802  }
   803  
   804  func appendSetVariable(buf *bytes.Buffer, key string, value interface{}) {
   805  	if buf.Len() != 0 {
   806  		buf.WriteString(",")
   807  	} else {
   808  		buf.WriteString("SET ")
   809  	}
   810  	buf.WriteString(key)
   811  	buf.WriteString(" = ")
   812  	switch v := value.(type) {
   813  	case string:
   814  		if strings.ToLower(v) == mysql.KeywordDefault {
   815  			buf.WriteString(v)
   816  		} else {
   817  			buf.WriteString("'")
   818  			buf.WriteString(v)
   819  			buf.WriteString("'")
   820  		}
   821  	case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
   822  		buf.WriteString(fmt.Sprintf("%d", v))
   823  	default:
   824  		buf.WriteString("'")
   825  		buf.WriteString(fmt.Sprintf("%v", v))
   826  		buf.WriteString("'")
   827  	}
   828  }
   829  
   830  func appendSetVariableToDefault(buf *bytes.Buffer, key string) {
   831  	if buf.Len() != 0 {
   832  		buf.WriteString(",")
   833  	} else {
   834  		buf.WriteString("SET ")
   835  	}
   836  	buf.WriteString(key)
   837  	buf.WriteString(" = ")
   838  	buf.WriteString("DEFAULT")
   839  }