github.com/XiaoMi/Gaea@v1.2.5/proxy/server/client_conn.go (about)

     1  // Copyright 2019 The Gaea Authors. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package server
    16  
    17  import (
    18  	"fmt"
    19  	"github.com/XiaoMi/Gaea/log"
    20  	"github.com/XiaoMi/Gaea/mysql"
    21  	"strings"
    22  )
    23  
    24  // ClientConn session client connection
    25  type ClientConn struct {
    26  	*mysql.Conn
    27  
    28  	salt []byte
    29  
    30  	manager *Manager
    31  
    32  	capability uint32
    33  
    34  	namespace string // TODO: remove it when refactor is done
    35  
    36  	proxy *Server
    37  }
    38  
    39  // HandshakeResponseInfo handshake response information
    40  type HandshakeResponseInfo struct {
    41  	CollationID  mysql.CollationID
    42  	User         string
    43  	AuthResponse []byte
    44  	Salt         []byte
    45  	Database     string
    46  	AuthPlugin   string
    47  }
    48  
    49  // NewClientConn constructor of ClientConn
    50  func NewClientConn(c *mysql.Conn, manager *Manager) *ClientConn {
    51  	salt, _ := mysql.RandomBuf(20)
    52  	return &ClientConn{
    53  		Conn:    c,
    54  		salt:    salt,
    55  		manager: manager,
    56  	}
    57  }
    58  
    59  func (cc *ClientConn) CompactVersion(sv string) string {
    60  	version := strings.Trim(sv, " ")
    61  	if version != "" {
    62  		v := strings.Split(sv, ".")
    63  		if len(v) < 3 {
    64  			return mysql.ServerVersion
    65  		}
    66  		return version
    67  	} else {
    68  		return mysql.ServerVersion
    69  	}
    70  }
    71  
    72  func (cc *ClientConn) writeInitialHandshakeV10() error {
    73  	ServerVersion := cc.CompactVersion(cc.proxy.ServerVersion)
    74  	length :=
    75  		1 + // protocol version
    76  			mysql.LenNullString(ServerVersion) +
    77  			4 + // connection ID
    78  			8 + // first part of salt data
    79  			1 + // filler byte
    80  			2 + // capability flags (lower 2 bytes)
    81  			1 + // character set
    82  			2 + // status flag
    83  			2 + // capability flags (upper 2 bytes)
    84  			1 + // length of auth plugin data
    85  			10 + // reserved (0)
    86  			13 // auth-plugin-data
    87  	// mysql.LenNullString(mysql.MysqlNativePassword) // auth-plugin-name
    88  	if cc.proxy.AuthPlugin != "" {
    89  		length += mysql.LenNullString(cc.proxy.AuthPlugin)
    90  	}
    91  
    92  	data := cc.StartEphemeralPacket(length)
    93  	pos := 0
    94  
    95  	// Protocol version.
    96  	pos = mysql.WriteByte(data, pos, mysql.ProtocolVersion)
    97  
    98  	// Copy server version.
    99  	// server version data with terminate character 0x00, type: string[NUL].
   100  	pos = mysql.WriteNullString(data, pos, ServerVersion)
   101  
   102  	// Add connectionID in.
   103  	// connection id type: 4 bytes.
   104  	pos = mysql.WriteUint32(data, pos, cc.GetConnectionID())
   105  
   106  	// auth-plugin-data-part-1 type: string[8].
   107  	pos += copy(data[pos:], cc.salt[:8])
   108  
   109  	// One filler byte, always 0.
   110  	pos = mysql.WriteByte(data, pos, 0)
   111  
   112  	// Lower part of the capability flags, lower 2 bytes.
   113  	pos = mysql.WriteUint16(data, pos, uint16(DefaultCapability))
   114  
   115  	// Character set.
   116  	pos = mysql.WriteByte(data, pos, byte(mysql.DefaultCollationID))
   117  
   118  	// Status flag.
   119  	pos = mysql.WriteUint16(data, pos, initClientConnStatus)
   120  
   121  	// Upper part of the capability flags.
   122  	pos = mysql.WriteUint16(data, pos, uint16(DefaultCapability>>16))
   123  
   124  	// Length of auth plugin data.
   125  	// Always 21 (8 + 13).
   126  	pos = mysql.WriteByte(data, pos, 21)
   127  
   128  	// Reserved 10 bytes: all 0
   129  	pos = mysql.WriteZeroes(data, pos, 10)
   130  
   131  	// Second part of auth plugin data.
   132  	pos += copy(data[pos:], cc.salt[8:])
   133  	data[pos] = 0
   134  	pos++
   135  	//authentication plugin
   136  	if cc.proxy.AuthPlugin != "" {
   137  		pos += copy(data[pos:], cc.proxy.AuthPlugin)
   138  		data[pos] = 0
   139  		pos++
   140  	}
   141  
   142  	// Copy authPluginName. We always start with mysql_native_password.
   143  	// pos = mysql.WriteNullString(data, pos, mysql.MysqlNativePassword)
   144  
   145  	// Sanity check.
   146  	if pos != len(data) {
   147  		return fmt.Errorf("error building Handshake packet: got %v bytes expected %v", pos, len(data))
   148  	}
   149  
   150  	if err := cc.WriteEphemeralPacket(); err != nil {
   151  		return err
   152  	}
   153  
   154  	return nil
   155  }
   156  
   157  func (cc *ClientConn) readHandshakeResponse() (HandshakeResponseInfo, error) {
   158  	info := HandshakeResponseInfo{}
   159  	info.Salt = cc.salt
   160  
   161  	data, err := cc.ReadEphemeralPacketDirect()
   162  	defer cc.RecycleReadPacket()
   163  	if err != nil {
   164  		return info, err
   165  	}
   166  
   167  	pos := 0
   168  
   169  	// Client flags, 4 bytes.
   170  	var ok bool
   171  	var capability uint32
   172  	capability, pos, ok = mysql.ReadUint32(data, pos)
   173  	if !ok {
   174  		return info, fmt.Errorf("readHandshakeResponse: can't read client flags")
   175  	}
   176  	if capability&mysql.ClientProtocol41 == 0 {
   177  		return info, fmt.Errorf("readHandshakeResponse: only support protocol 4.1")
   178  	}
   179  
   180  	cc.capability = capability
   181  	// Max packet size. Don't do anything with this now.
   182  	_, pos, ok = mysql.ReadUint32(data, pos)
   183  	if !ok {
   184  		return info, fmt.Errorf("readHandshakeResponse: can't read maxPacketSize")
   185  	}
   186  
   187  	// Character set
   188  	collationID, pos, ok := mysql.ReadByte(data, pos)
   189  	if !ok {
   190  		return info, fmt.Errorf("readHandshakeResponse: can't read characterSet")
   191  	}
   192  	info.CollationID = mysql.CollationID(collationID)
   193  
   194  	// reserved 23 zero bytes, skipped
   195  	pos += 23
   196  
   197  	// username
   198  	var user string
   199  	user, pos, ok = mysql.ReadNullString(data, pos)
   200  	if !ok {
   201  		return info, fmt.Errorf("readHandshakeResponse: can't read username")
   202  	}
   203  	info.User = user
   204  
   205  	// TODO auth-response can have three forms.
   206  	var authResponse []byte
   207  	var l uint64
   208  	l, pos, _, ok = mysql.ReadLenEncInt(data, pos)
   209  	if !ok {
   210  		return info, fmt.Errorf("readHandshakeResponse: can't read auth-response variable length")
   211  	}
   212  
   213  	if capability&mysql.ClientPluginAuthLenencClientData > 0 || capability&mysql.ClientSecureConnection > 0 {
   214  		authResponse, pos, ok = mysql.ReadBytesCopy(data, pos, int(l))
   215  	} else {
   216  		authResponse, pos, ok = mysql.ReadNullByte(data, pos)
   217  	}
   218  	if !ok {
   219  		return info, fmt.Errorf("readHandshakeResponse: can't read auth-response")
   220  	}
   221  
   222  	info.AuthResponse = authResponse
   223  
   224  	// check if with database
   225  	if capability&mysql.ClientConnectWithDB > 0 {
   226  		var db string
   227  		db, pos, ok = mysql.ReadNullString(data, pos)
   228  		if !ok {
   229  			return info, fmt.Errorf("readHandshakeResponse: can't read db")
   230  		}
   231  		info.Database = db
   232  	}
   233  	if capability&mysql.ClientPluginAuth > 0 {
   234  		var authPlugin string
   235  		authPlugin, pos, ok = mysql.ReadNullString(data, pos)
   236  		if ok && (authPlugin != cc.proxy.AuthPlugin) {
   237  			info.AuthPlugin = cc.proxy.AuthPlugin
   238  			cc.RecycleReadPacket()
   239  			cc.WriteAuthSwitchRequest(info.AuthPlugin)
   240  			// readAuthSwitchRequestResponse
   241  			info.AuthResponse, err = cc.ReadEphemeralPacketDirect()
   242  			if err != nil {
   243  				return info, fmt.Errorf("readHandshakeResponse: can't read auth switch response")
   244  			}
   245  		}
   246  	}
   247  
   248  	// TODO auth plugin name态client conn attrs .etc
   249  	return info, nil
   250  }
   251  
   252  func (cc *ClientConn) writeOK(status uint16) error {
   253  	err := cc.WriteOKPacket(0, 0, status, 0)
   254  	if err != nil {
   255  		log.Warn("write ok packet failed, %v", err)
   256  		return err
   257  	}
   258  	return nil
   259  }
   260  
   261  func (cc *ClientConn) writeOKResult(status uint16, r *mysql.Result) error {
   262  	if r.Resultset == nil {
   263  		return cc.WriteOKPacket(r.AffectedRows, r.InsertID, status, 0)
   264  	}
   265  	return cc.writeResultset(status, r.Resultset)
   266  }
   267  
   268  func (cc *ClientConn) writeEOFPacket(status uint16) error {
   269  	err := cc.WriteEOFPacket(status, 0)
   270  	if err != nil {
   271  		log.Warn("write eof packet failed, %v", err)
   272  		return err
   273  	}
   274  	return nil
   275  }
   276  
   277  func (cc *ClientConn) writeErrorPacket(err error) error {
   278  	e := cc.WriteErrorPacketFromError(err)
   279  	if e != nil {
   280  		log.Warn("write error packet failed, %v", err)
   281  		return e
   282  	}
   283  	return nil
   284  }
   285  
   286  func (cc *ClientConn) writeColumnCount(count uint64) error {
   287  	length := mysql.LenEncIntSize(count)
   288  	data := cc.StartEphemeralPacket(length)
   289  	cc.manager.GetStatisticManager().AddWriteFlowCount(cc.namespace, length)
   290  	mysql.WriteLenEncInt(data, 0, count)
   291  	return cc.WriteEphemeralPacket()
   292  }
   293  
   294  func (cc *ClientConn) writeRow(row []byte) error {
   295  	length := len(row)
   296  	data := cc.StartEphemeralPacket(length)
   297  	pos := 0
   298  	copy(data[pos:], row)
   299  	cc.manager.GetStatisticManager().AddWriteFlowCount(cc.namespace, length)
   300  	return cc.WriteEphemeralPacket()
   301  }
   302  
   303  // https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset
   304  func (cc *ClientConn) writeResultset(status uint16, r *mysql.Resultset) error {
   305  	var err error
   306  	cc.StartWriterBuffering()
   307  
   308  	// write column count
   309  	columnCount := uint64(len(r.Fields))
   310  	err = cc.writeColumnCount(columnCount)
   311  	if err != nil {
   312  		return err
   313  	}
   314  
   315  	// write columns
   316  	err = cc.writeFieldList(status, r.Fields)
   317  	if err != nil {
   318  		return err
   319  	}
   320  
   321  	// write rows data
   322  	// resultset row, NULL is sent as 0xfb, everything else is converted into a string and is sent as Protocol::LengthEncodedString
   323  	for _, v := range r.RowDatas {
   324  		err = cc.writeRow(v)
   325  		if err != nil {
   326  			return err
   327  		}
   328  	}
   329  
   330  	err = cc.writeEOFPacket(status)
   331  	if err != nil {
   332  		return err
   333  	}
   334  
   335  	err = cc.Flush()
   336  	if err != nil {
   337  		return err
   338  	}
   339  
   340  	return nil
   341  }
   342  
   343  func (cc *ClientConn) writeFieldList(status uint16, fs []*mysql.Field) error {
   344  	var err error
   345  	for _, f := range fs {
   346  		err = cc.writeColumnDefinition(f)
   347  		if err != nil {
   348  			return err
   349  		}
   350  	}
   351  
   352  	err = cc.writeEOFPacket(status)
   353  	return err
   354  }
   355  
   356  func (cc *ClientConn) writeColumnDefinition(field *mysql.Field) error {
   357  	schemaLen := uint64(len(field.Schema))
   358  	tableLen := uint64(len(field.Table))
   359  	orgTableLen := uint64(len(field.OrgTable))
   360  	nameLen := uint64(len(field.Name))
   361  	orgNameLen := uint64(len(field.OrgName))
   362  	length := 4 + // lenEncStringSize("def")
   363  		mysql.LenEncIntSize(schemaLen) +
   364  		len(field.Schema) +
   365  		mysql.LenEncIntSize(tableLen) +
   366  		len(field.Table) +
   367  		mysql.LenEncIntSize(orgTableLen) +
   368  		len(field.OrgTable) +
   369  		mysql.LenEncIntSize(nameLen) +
   370  		len(field.Name) +
   371  		mysql.LenEncIntSize(orgNameLen) +
   372  		len(field.OrgName) +
   373  		1 + // length of fixed length fields
   374  		2 + // character set
   375  		4 + // column length
   376  		1 + // type
   377  		2 + // flags
   378  		1 + // decimals
   379  		2 // filler
   380  	if field.DefaultValue != nil {
   381  		length += mysql.LenEncIntSize(uint64(len(field.DefaultValue))) + len(field.DefaultValue)
   382  	}
   383  
   384  	data := cc.StartEphemeralPacket(length)
   385  	pos := 0
   386  	pos = mysql.WriteLenEncString(data, pos, "def") // Always the same.
   387  
   388  	pos = mysql.WriteLenEncInt(data, pos, schemaLen)
   389  	copy(data[pos:], field.Schema)
   390  	pos += len(field.Schema)
   391  
   392  	pos = mysql.WriteLenEncInt(data, pos, tableLen)
   393  	copy(data[pos:], field.Table)
   394  	pos += len(field.Table)
   395  
   396  	pos = mysql.WriteLenEncInt(data, pos, orgTableLen)
   397  	copy(data[pos:], field.OrgTable)
   398  	pos += len(field.OrgTable)
   399  
   400  	pos = mysql.WriteLenEncInt(data, pos, nameLen)
   401  	copy(data[pos:], field.Name)
   402  	pos += len(field.Name)
   403  
   404  	pos = mysql.WriteLenEncInt(data, pos, orgNameLen)
   405  	copy(data[pos:], field.OrgName)
   406  	pos += len(field.OrgName)
   407  
   408  	pos = mysql.WriteByte(data, pos, 0x0c)
   409  	pos = mysql.WriteUint16(data, pos, field.Charset)
   410  	pos = mysql.WriteUint32(data, pos, field.ColumnLength)
   411  	pos = mysql.WriteByte(data, pos, byte(field.Type))
   412  	pos = mysql.WriteUint16(data, pos, field.Flag)
   413  	pos = mysql.WriteByte(data, pos, byte(field.Decimal))
   414  	pos = mysql.WriteUint16(data, pos, uint16(0x0000))
   415  
   416  	if field.DefaultValue != nil {
   417  		pos = mysql.WriteLenEncInt(data, pos, field.DefaultValueLength)
   418  		copy(data[pos:], field.DefaultValue)
   419  		pos += len(field.DefaultValue)
   420  	}
   421  	if pos != len(data) {
   422  		return fmt.Errorf("internal error: packing of column definition used %v bytes instead of %v", pos, len(data))
   423  	}
   424  	cc.manager.GetStatisticManager().AddWriteFlowCount(cc.namespace, len(data))
   425  
   426  	return cc.WriteEphemeralPacket()
   427  }
   428  
   429  // writePrepareResponse write prepare response
   430  func (cc *ClientConn) writePrepareResponse(status uint16, s *Stmt) error {
   431  	var err error
   432  	length := 1 + // status
   433  		4 + // statement-id
   434  		2 + // number of columns
   435  		2 + // number of params
   436  		1 + // filler
   437  		2 // number of warnings
   438  	data := cc.StartEphemeralPacket(length)
   439  	pos := 0
   440  	// status ok
   441  	pos = mysql.WriteByte(data, pos, 0)
   442  	// stmt id
   443  	pos = mysql.WriteUint32(data, pos, s.id)
   444  	// number columns
   445  	pos = mysql.WriteUint16(data, pos, uint16(s.columnCount))
   446  	// number params
   447  	pos = mysql.WriteUint16(data, pos, uint16(s.paramCount))
   448  	// filler [00]
   449  	pos = mysql.WriteByte(data, pos, 0)
   450  	// number of warnings
   451  	pos = mysql.WriteUint16(data, pos, 0)
   452  	if pos != length {
   453  		return fmt.Errorf("internal error packet row: got %v bytes but expected %v", pos, length)
   454  	}
   455  
   456  	err = cc.WriteEphemeralPacket()
   457  	if err != nil {
   458  		return err
   459  	}
   460  
   461  	if s.paramCount > 0 {
   462  		for i := 0; i < s.paramCount; i++ {
   463  			err = cc.writeColumnDefinition(p)
   464  			if err != nil {
   465  				return err
   466  			}
   467  		}
   468  		err = cc.writeEOFPacket(status)
   469  		return err
   470  	}
   471  
   472  	if s.columnCount > 0 {
   473  		for i := 0; i < s.columnCount; i++ {
   474  			err = cc.writeColumnDefinition(c)
   475  			if err != nil {
   476  				return err
   477  			}
   478  		}
   479  		err = cc.writeEOFPacket(status)
   480  		return err
   481  	}
   482  
   483  	return nil
   484  }
   485  
   486  func (cc *ClientConn) WriteAuthSwitchRequest(authMethod string) error {
   487  	l := 1 + len(authMethod) + 1 + len(cc.salt) + 1
   488  	data := cc.StartEphemeralPacket(l)
   489  	pos := 0
   490  	pos = mysql.WriteByte(data, pos, mysql.AuthSwitchHeader)
   491  	pos = mysql.WriteNullString(data, pos, authMethod)
   492  	pos = mysql.WriteBytes(data, pos, cc.salt)
   493  	mysql.WriteByte(data, pos, 0)
   494  	return cc.WriteEphemeralPacket()
   495  }