github.com/insionng/yougam@v0.0.0-20170714101924-2bc18d833463/libraries/pingcap/tidb/server/conn.go (about)

     1  // Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
     2  //
     3  // This Source Code Form is subject to the terms of the Mozilla Public
     4  // License, v. 2.0. If a copy of the MPL was not distributed with this file,
     5  // You can obtain one at http://mozilla.org/MPL/2.0/.
     6  
     7  // The MIT License (MIT)
     8  //
     9  // Copyright (c) 2014 wandoulabs
    10  // Copyright (c) 2014 siddontang
    11  //
    12  // Permission is hereby granted, free of charge, to any person obtaining a copy of
    13  // this software and associated documentation files (the "Software"), to deal in
    14  // the Software without restriction, including without limitation the rights to
    15  // use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
    16  // the Software, and to permit persons to whom the Software is furnished to do so,
    17  // subject to the following conditions:
    18  //
    19  // The above copyright notice and this permission notice shall be included in all
    20  // copies or substantial portions of the Software.
    21  
    22  // Copyright 2015 PingCAP, Inc.
    23  //
    24  // Licensed under the Apache License, Version 2.0 (the "License");
    25  // you may not use this file except in compliance with the License.
    26  // You may obtain a copy of the License at
    27  //
    28  //     http://www.apache.org/licenses/LICENSE-2.0
    29  //
    30  // Unless required by applicable law or agreed to in writing, software
    31  // distributed under the License is distributed on an "AS IS" BASIS,
    32  // See the License for the specific language governing permissions and
    33  // limitations under the License.
    34  
    35  package server
    36  
    37  import (
    38  	"bytes"
    39  	"encoding/binary"
    40  	"fmt"
    41  	"io"
    42  	"net"
    43  	"runtime"
    44  	"strings"
    45  	"time"
    46  
    47  	"github.com/insionng/yougam/libraries/juju/errors"
    48  	"github.com/insionng/yougam/libraries/ngaut/log"
    49  	"github.com/insionng/yougam/libraries/pingcap/tidb/mysql"
    50  	"github.com/insionng/yougam/libraries/pingcap/tidb/terror"
    51  	"github.com/insionng/yougam/libraries/pingcap/tidb/util/arena"
    52  	"github.com/insionng/yougam/libraries/pingcap/tidb/util/hack"
    53  	"github.com/insionng/yougam/libraries/pingcap/tidb/util/types"
    54  )
    55  
    56  var defaultCapability = mysql.ClientLongPassword | mysql.ClientLongFlag |
    57  	mysql.ClientConnectWithDB | mysql.ClientProtocol41 |
    58  	mysql.ClientTransactions | mysql.ClientSecureConnection | mysql.ClientFoundRows
    59  
    60  type clientConn struct {
    61  	pkg          *packetIO
    62  	conn         net.Conn
    63  	server       *Server
    64  	capability   uint32
    65  	connectionID uint32
    66  	collation    uint8
    67  	charset      string
    68  	user         string
    69  	dbname       string
    70  	salt         []byte
    71  	alloc        arena.Allocator
    72  	lastCmd      string
    73  	ctx          IContext
    74  }
    75  
    76  func (cc *clientConn) String() string {
    77  	return fmt.Sprintf("conn: %s, status: %d, charset: %s, user: %s, lastInsertId: %d",
    78  		cc.conn.RemoteAddr(), cc.ctx.Status(), cc.charset, cc.user, cc.ctx.LastInsertID(),
    79  	)
    80  }
    81  
    82  func (cc *clientConn) handshake() error {
    83  	if err := cc.writeInitialHandshake(); err != nil {
    84  		return errors.Trace(err)
    85  	}
    86  	if err := cc.readHandshakeResponse(); err != nil {
    87  		cc.writeError(err)
    88  		return errors.Trace(err)
    89  	}
    90  	data := cc.alloc.AllocWithLen(4, 32)
    91  	data = append(data, mysql.OKHeader)
    92  	data = append(data, 0, 0)
    93  	if cc.capability&mysql.ClientProtocol41 > 0 {
    94  		data = append(data, dumpUint16(mysql.ServerStatusAutocommit)...)
    95  		data = append(data, 0, 0)
    96  	}
    97  
    98  	err := cc.writePacket(data)
    99  	cc.pkg.sequence = 0
   100  	if err != nil {
   101  		return errors.Trace(err)
   102  	}
   103  
   104  	return errors.Trace(cc.flush())
   105  }
   106  
   107  func (cc *clientConn) Close() error {
   108  	cc.server.rwlock.Lock()
   109  	delete(cc.server.clients, cc.connectionID)
   110  	cc.server.rwlock.Unlock()
   111  	cc.conn.Close()
   112  	if cc.ctx != nil {
   113  		return cc.ctx.Close()
   114  	}
   115  	return nil
   116  }
   117  
   118  func (cc *clientConn) writeInitialHandshake() error {
   119  	data := make([]byte, 4, 128)
   120  
   121  	// min version 10
   122  	data = append(data, 10)
   123  	// server version[00]
   124  	data = append(data, mysql.ServerVersion...)
   125  	data = append(data, 0)
   126  	// connection id
   127  	data = append(data, byte(cc.connectionID), byte(cc.connectionID>>8), byte(cc.connectionID>>16), byte(cc.connectionID>>24))
   128  	// auth-plugin-data-part-1
   129  	data = append(data, cc.salt[0:8]...)
   130  	// filler [00]
   131  	data = append(data, 0)
   132  	// capability flag lower 2 bytes, using default capability here
   133  	data = append(data, byte(defaultCapability), byte(defaultCapability>>8))
   134  	// charset, utf-8 default
   135  	data = append(data, uint8(mysql.DefaultCollationID))
   136  	//status
   137  	data = append(data, dumpUint16(mysql.ServerStatusAutocommit)...)
   138  	// below 13 byte may not be used
   139  	// capability flag upper 2 bytes, using default capability here
   140  	data = append(data, byte(defaultCapability>>16), byte(defaultCapability>>24))
   141  	// filler [0x15], for wireshark dump, value is 0x15
   142  	data = append(data, 0x15)
   143  	// reserved 10 [00]
   144  	data = append(data, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)
   145  	// auth-plugin-data-part-2
   146  	data = append(data, cc.salt[8:]...)
   147  	// filler [00]
   148  	data = append(data, 0)
   149  	err := cc.writePacket(data)
   150  	if err != nil {
   151  		return errors.Trace(err)
   152  	}
   153  	return errors.Trace(cc.flush())
   154  }
   155  
   156  func (cc *clientConn) readPacket() ([]byte, error) {
   157  	return cc.pkg.readPacket()
   158  }
   159  
   160  func (cc *clientConn) writePacket(data []byte) error {
   161  	return cc.pkg.writePacket(data)
   162  }
   163  
   164  func (cc *clientConn) readHandshakeResponse() error {
   165  	data, err := cc.readPacket()
   166  	if err != nil {
   167  		return errors.Trace(err)
   168  	}
   169  
   170  	pos := 0
   171  	// capability
   172  	cc.capability = binary.LittleEndian.Uint32(data[:4])
   173  	pos += 4
   174  	// skip max packet size
   175  	pos += 4
   176  	// charset, skip, if you want to use another charset, use set names
   177  	cc.collation = data[pos]
   178  	pos++
   179  	// skip reserved 23[00]
   180  	pos += 23
   181  	// user name
   182  	cc.user = string(data[pos : pos+bytes.IndexByte(data[pos:], 0)])
   183  	pos += len(cc.user) + 1
   184  	// auth length and auth
   185  	authLen := int(data[pos])
   186  	pos++
   187  	auth := data[pos : pos+authLen]
   188  	pos += authLen
   189  	if cc.capability&mysql.ClientConnectWithDB > 0 {
   190  		if len(data[pos:]) > 0 {
   191  			idx := bytes.IndexByte(data[pos:], 0)
   192  			cc.dbname = string(data[pos : pos+idx])
   193  		}
   194  	}
   195  	// Open session and do auth
   196  	cc.ctx, err = cc.server.driver.OpenCtx(uint64(cc.connectionID), cc.capability, uint8(cc.collation), cc.dbname)
   197  	if err != nil {
   198  		cc.Close()
   199  		return errors.Trace(err)
   200  	}
   201  	if !cc.server.skipAuth() {
   202  		// Do Auth
   203  		addr := cc.conn.RemoteAddr().String()
   204  		host, _, err1 := net.SplitHostPort(addr)
   205  		if err1 != nil {
   206  			return errors.Trace(mysql.NewErr(mysql.ErrAccessDenied, cc.user, addr, "Yes"))
   207  		}
   208  		user := fmt.Sprintf("%s@%s", cc.user, host)
   209  		if !cc.ctx.Auth(user, auth, cc.salt) {
   210  			return errors.Trace(mysql.NewErr(mysql.ErrAccessDenied, cc.user, host, "Yes"))
   211  		}
   212  	}
   213  	return nil
   214  }
   215  
   216  func (cc *clientConn) Run() {
   217  	defer func() {
   218  		r := recover()
   219  		if r != nil {
   220  			const size = 4096
   221  			buf := make([]byte, size)
   222  			buf = buf[:runtime.Stack(buf, false)]
   223  			log.Errorf("lastCmd %s, %v, %s", cc.lastCmd, r, buf)
   224  		}
   225  		cc.Close()
   226  	}()
   227  
   228  	for {
   229  		cc.alloc.Reset()
   230  		data, err := cc.readPacket()
   231  		if err != nil {
   232  			if terror.ErrorNotEqual(err, io.EOF) {
   233  				log.Error(err)
   234  			}
   235  			return
   236  		}
   237  
   238  		if err := cc.dispatch(data); err != nil {
   239  			if terror.ErrorEqual(err, io.EOF) {
   240  				return
   241  			}
   242  			log.Errorf("dispatch error %s, %s", errors.ErrorStack(err), cc)
   243  			log.Errorf("cmd: %s", string(data[1:]))
   244  			cc.writeError(err)
   245  		}
   246  
   247  		cc.pkg.sequence = 0
   248  	}
   249  }
   250  
   251  func (cc *clientConn) dispatch(data []byte) error {
   252  	cmd := data[0]
   253  	data = data[1:]
   254  	cc.lastCmd = hack.String(data)
   255  
   256  	token := cc.server.getToken()
   257  
   258  	startTs := time.Now()
   259  	defer func() {
   260  		cc.server.releaseToken(token)
   261  		log.Debugf("[TIME_CMD] %v %d", time.Now().Sub(startTs), cmd)
   262  	}()
   263  
   264  	switch cmd {
   265  	case mysql.ComQuit:
   266  		return io.EOF
   267  	case mysql.ComQuery:
   268  		return cc.handleQuery(hack.String(data))
   269  	case mysql.ComPing:
   270  		return cc.writeOK()
   271  	case mysql.ComInitDB:
   272  		log.Debug("init db", hack.String(data))
   273  		if err := cc.useDB(hack.String(data)); err != nil {
   274  			return errors.Trace(err)
   275  		}
   276  		return cc.writeOK()
   277  	case mysql.ComFieldList:
   278  		return cc.handleFieldList(hack.String(data))
   279  	case mysql.ComStmtPrepare:
   280  		return cc.handleStmtPrepare(hack.String(data))
   281  	case mysql.ComStmtExecute:
   282  		return cc.handleStmtExecute(data)
   283  	case mysql.ComStmtClose:
   284  		return cc.handleStmtClose(data)
   285  	case mysql.ComStmtSendLongData:
   286  		return cc.handleStmtSendLongData(data)
   287  	case mysql.ComStmtReset:
   288  		return cc.handleStmtReset(data)
   289  	default:
   290  		return mysql.NewErrf(mysql.ErrUnknown, "command %d not supported now", cmd)
   291  	}
   292  }
   293  
   294  func (cc *clientConn) useDB(db string) (err error) {
   295  	_, err = cc.ctx.Execute("use " + db)
   296  	if err != nil {
   297  		return errors.Trace(err)
   298  	}
   299  	cc.dbname = db
   300  	return
   301  }
   302  
   303  func (cc *clientConn) flush() error {
   304  	return cc.pkg.flush()
   305  }
   306  
   307  func (cc *clientConn) writeOK() error {
   308  	data := cc.alloc.AllocWithLen(4, 32)
   309  	data = append(data, mysql.OKHeader)
   310  	data = append(data, dumpLengthEncodedInt(uint64(cc.ctx.AffectedRows()))...)
   311  	data = append(data, dumpLengthEncodedInt(uint64(cc.ctx.LastInsertID()))...)
   312  	if cc.capability&mysql.ClientProtocol41 > 0 {
   313  		data = append(data, dumpUint16(cc.ctx.Status())...)
   314  		data = append(data, dumpUint16(cc.ctx.WarningCount())...)
   315  	}
   316  
   317  	err := cc.writePacket(data)
   318  	if err != nil {
   319  		return errors.Trace(err)
   320  	}
   321  
   322  	return errors.Trace(cc.flush())
   323  }
   324  
   325  func (cc *clientConn) writeError(e error) error {
   326  	var (
   327  		m  *mysql.SQLError
   328  		te *terror.Error
   329  		ok bool
   330  	)
   331  	originErr := errors.Cause(e)
   332  	if te, ok = originErr.(*terror.Error); ok {
   333  		m = te.ToSQLError()
   334  	} else {
   335  		m = mysql.NewErrf(mysql.ErrUnknown, e.Error())
   336  	}
   337  
   338  	data := make([]byte, 4, 16+len(m.Message))
   339  	data = append(data, mysql.ErrHeader)
   340  	data = append(data, byte(m.Code), byte(m.Code>>8))
   341  	if cc.capability&mysql.ClientProtocol41 > 0 {
   342  		data = append(data, '#')
   343  		data = append(data, m.State...)
   344  	}
   345  
   346  	data = append(data, m.Message...)
   347  
   348  	err := cc.writePacket(data)
   349  	if err != nil {
   350  		return errors.Trace(err)
   351  	}
   352  	return errors.Trace(cc.flush())
   353  }
   354  
   355  func (cc *clientConn) writeEOF() error {
   356  	data := cc.alloc.AllocWithLen(4, 9)
   357  
   358  	data = append(data, mysql.EOFHeader)
   359  	if cc.capability&mysql.ClientProtocol41 > 0 {
   360  		data = append(data, dumpUint16(cc.ctx.WarningCount())...)
   361  		data = append(data, dumpUint16(cc.ctx.Status())...)
   362  	}
   363  
   364  	err := cc.writePacket(data)
   365  	return errors.Trace(err)
   366  }
   367  
   368  func (cc *clientConn) handleQuery(sql string) (err error) {
   369  	startTs := time.Now()
   370  	rs, err := cc.ctx.Execute(sql)
   371  	if err != nil {
   372  		return errors.Trace(err)
   373  	}
   374  	if rs != nil {
   375  		err = cc.writeResultset(rs, false)
   376  	} else {
   377  		err = cc.writeOK()
   378  	}
   379  	log.Debugf("[TIME_QUERY] %v %s", time.Now().Sub(startTs), sql)
   380  	return errors.Trace(err)
   381  }
   382  
   383  func (cc *clientConn) handleFieldList(sql string) (err error) {
   384  	parts := strings.Split(sql, "\x00")
   385  	columns, err := cc.ctx.FieldList(parts[0])
   386  	if err != nil {
   387  		return errors.Trace(err)
   388  	}
   389  	data := make([]byte, 4, 1024)
   390  	for _, v := range columns {
   391  		data = data[0:4]
   392  		data = append(data, v.Dump(cc.alloc)...)
   393  		if err := cc.writePacket(data); err != nil {
   394  			return errors.Trace(err)
   395  		}
   396  	}
   397  	if err := cc.writeEOF(); err != nil {
   398  		return errors.Trace(err)
   399  	}
   400  	return errors.Trace(cc.flush())
   401  }
   402  
   403  func (cc *clientConn) writeResultset(rs ResultSet, binary bool) error {
   404  	defer rs.Close()
   405  	// We need to call Next before we get columns.
   406  	// Otherwise, we will get incorrect columns info.
   407  	row, err := rs.Next()
   408  	if err != nil {
   409  		return errors.Trace(err)
   410  	}
   411  
   412  	columns, err := rs.Columns()
   413  	if err != nil {
   414  		return errors.Trace(err)
   415  	}
   416  	columnLen := dumpLengthEncodedInt(uint64(len(columns)))
   417  	data := cc.alloc.AllocWithLen(4, 1024)
   418  	data = append(data, columnLen...)
   419  	if err = cc.writePacket(data); err != nil {
   420  		return errors.Trace(err)
   421  	}
   422  
   423  	for _, v := range columns {
   424  		data = data[0:4]
   425  		data = append(data, v.Dump(cc.alloc)...)
   426  		if err = cc.writePacket(data); err != nil {
   427  			return errors.Trace(err)
   428  		}
   429  	}
   430  
   431  	if err = cc.writeEOF(); err != nil {
   432  		return errors.Trace(err)
   433  	}
   434  
   435  	for {
   436  		if err != nil {
   437  			return errors.Trace(err)
   438  		}
   439  		if row == nil {
   440  			break
   441  		}
   442  		data = data[0:4]
   443  		if binary {
   444  			var rowData []byte
   445  			rowData, err = dumpRowValuesBinary(cc.alloc, columns, row)
   446  			if err != nil {
   447  				return errors.Trace(err)
   448  			}
   449  			data = append(data, rowData...)
   450  		} else {
   451  			for i, value := range row {
   452  				if value.Kind() == types.KindNull {
   453  					data = append(data, 0xfb)
   454  					continue
   455  				}
   456  				var valData []byte
   457  				valData, err = dumpTextValue(columns[i].Type, value)
   458  				if err != nil {
   459  					return errors.Trace(err)
   460  				}
   461  				data = append(data, dumpLengthEncodedString(valData, cc.alloc)...)
   462  			}
   463  		}
   464  
   465  		if err = cc.writePacket(data); err != nil {
   466  			return errors.Trace(err)
   467  		}
   468  		row, err = rs.Next()
   469  	}
   470  
   471  	err = cc.writeEOF()
   472  	if err != nil {
   473  		return errors.Trace(err)
   474  	}
   475  
   476  	return errors.Trace(cc.flush())
   477  }