vitess.io/vitess@v0.16.2/go/mysql/client.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package mysql
    18  
    19  import (
    20  	"context"
    21  	"crypto/rsa"
    22  	"crypto/tls"
    23  	"crypto/x509"
    24  	"encoding/pem"
    25  	"fmt"
    26  	"net"
    27  	"time"
    28  
    29  	"vitess.io/vitess/go/mysql/collations"
    30  	vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
    31  	"vitess.io/vitess/go/vt/vterrors"
    32  	"vitess.io/vitess/go/vt/vttls"
    33  )
    34  
    35  // connectResult is used by Connect.
    36  type connectResult struct {
    37  	c   *Conn
    38  	err error
    39  }
    40  
    41  // Connect creates a connection to a server.
    42  // It then handles the initial handshake.
    43  //
    44  // If context is canceled before the end of the process, this function
    45  // will return nil, ctx.Err().
    46  //
    47  // FIXME(alainjobart) once we have more of a server side, add test cases
    48  // to cover all failure scenarios.
    49  func Connect(ctx context.Context, params *ConnParams) (*Conn, error) {
    50  	if params.ConnectTimeoutMs != 0 {
    51  		var cancel context.CancelFunc
    52  		ctx, cancel = context.WithTimeout(ctx, time.Duration(params.ConnectTimeoutMs)*time.Millisecond)
    53  		defer cancel()
    54  	}
    55  	netProto := "tcp"
    56  	addr := ""
    57  	if params.UnixSocket != "" {
    58  		netProto = "unix"
    59  		addr = params.UnixSocket
    60  	} else {
    61  		addr = net.JoinHostPort(params.Host, fmt.Sprintf("%v", params.Port))
    62  	}
    63  
    64  	// Start a background connection routine.  It first
    65  	// establishes a network connection, returns it on the channel,
    66  	// then starts the negotiation, and returns the result on the channel.
    67  	// It can send on the channel, before closing it:
    68  	// - a connectResult with an error and nothing else (when dial fails).
    69  	// - a connectResult with a *Conn and no error, then another one
    70  	//   with possibly an error.
    71  	status := make(chan connectResult)
    72  	go func() {
    73  		defer close(status)
    74  		var err error
    75  		var conn net.Conn
    76  
    77  		// Cap the Dial time with the context deadline, plus a
    78  		// few seconds. We want to reclaim resources quickly
    79  		// and not let this go routine stuck in Dial forever.
    80  		//
    81  		// We add a few seconds so we detect the context is
    82  		// Done() before timing out the Dial. That way we'll
    83  		// return the right error to the client (ctx.Err(), vs
    84  		// DialTimeout() error).
    85  		if deadline, ok := ctx.Deadline(); ok {
    86  			timeout := time.Until(deadline) + 5*time.Second
    87  			conn, err = net.DialTimeout(netProto, addr, timeout)
    88  		} else {
    89  			conn, err = net.Dial(netProto, addr)
    90  		}
    91  		if err != nil {
    92  			// If we get an error, the connection to a Unix socket
    93  			// should return a 2002, but for a TCP socket it
    94  			// should return a 2003.
    95  			if netProto == "tcp" {
    96  				status <- connectResult{
    97  					err: NewSQLError(CRConnHostError, SSUnknownSQLState, "net.Dial(%v) failed: %v", addr, err),
    98  				}
    99  			} else {
   100  				status <- connectResult{
   101  					err: NewSQLError(CRConnectionError, SSUnknownSQLState, "net.Dial(%v) to local server failed: %v", addr, err),
   102  				}
   103  			}
   104  			return
   105  		}
   106  
   107  		// Send the connection back, so the other side can close it.
   108  		c := newConn(conn)
   109  		status <- connectResult{
   110  			c: c,
   111  		}
   112  
   113  		// During the handshake, and if the context is
   114  		// canceled, the connection will be closed. That will
   115  		// make any read or write just return with an error
   116  		// right away.
   117  		status <- connectResult{
   118  			err: c.clientHandshake(params),
   119  		}
   120  	}()
   121  
   122  	// Wait on the context and the status, for the connection to happen.
   123  	var c *Conn
   124  	select {
   125  	case <-ctx.Done():
   126  		// The background routine may send us a few things,
   127  		// wait for them and terminate them properly in the
   128  		// background.
   129  		go func() {
   130  			dialCR := <-status // This one can take a while.
   131  			if dialCR.err != nil {
   132  				// Dial failed, nothing else to do.
   133  				return
   134  			}
   135  			// Dial worked, close the connection, wait for the end.
   136  			// We wait as not to leave a channel with an unread value.
   137  			dialCR.c.Close()
   138  			<-status
   139  		}()
   140  		return nil, ctx.Err()
   141  	case cr := <-status:
   142  		if cr.err != nil {
   143  			// Dial failed, no connection was ever established.
   144  			return nil, cr.err
   145  		}
   146  
   147  		// Dial worked, we have a connection. Keep going.
   148  		c = cr.c
   149  	}
   150  
   151  	// Wait for the end of the handshake.
   152  	select {
   153  	case <-ctx.Done():
   154  		// We are interrupted. Close the connection, wait for
   155  		// the handshake to finish in the background.
   156  		c.Close()
   157  		go func() {
   158  			// Since we closed the connection, this one should be fast.
   159  			// We wait as not to leave a channel with an unread value.
   160  			<-status
   161  		}()
   162  		return nil, ctx.Err()
   163  	case cr := <-status:
   164  		if cr.err != nil {
   165  			c.Close()
   166  			return nil, cr.err
   167  		}
   168  	}
   169  
   170  	return c, nil
   171  }
   172  
   173  // Ping implements mysql ping command.
   174  func (c *Conn) Ping() error {
   175  	// This is a new command, need to reset the sequence.
   176  	c.sequence = 0
   177  	data, pos := c.startEphemeralPacketWithHeader(1)
   178  	data[pos] = ComPing
   179  
   180  	if err := c.writeEphemeralPacket(); err != nil {
   181  		return NewSQLError(CRServerGone, SSUnknownSQLState, "%v", err)
   182  	}
   183  	data, err := c.readEphemeralPacket()
   184  	if err != nil {
   185  		return NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err)
   186  	}
   187  	defer c.recycleReadPacket()
   188  	switch data[0] {
   189  	case OKPacket:
   190  		return nil
   191  	case ErrPacket:
   192  		return ParseErrorPacket(data)
   193  	}
   194  	return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "unexpected packet type: %d", data[0])
   195  }
   196  
   197  // clientHandshake handles the client side of the handshake.
   198  // Note the connection can be closed while this is running.
   199  // Returns a SQLError.
   200  func (c *Conn) clientHandshake(params *ConnParams) error {
   201  	// if EnableQueryInfo is set, make sure that all queries starting with the handshake
   202  	// will actually process the INFO fields in QUERY_OK packets
   203  	if params.EnableQueryInfo {
   204  		c.enableQueryInfo = true
   205  	}
   206  
   207  	// Wait for the server initial handshake packet, and parse it.
   208  	data, err := c.readPacket()
   209  	if err != nil {
   210  		return NewSQLError(CRServerLost, "", "initial packet read failed: %v", err)
   211  	}
   212  	capabilities, salt, err := c.parseInitialHandshakePacket(data)
   213  	if err != nil {
   214  		return err
   215  	}
   216  	c.fillFlavor(params)
   217  	c.salt = salt
   218  
   219  	// Sanity check.
   220  	if capabilities&CapabilityClientProtocol41 == 0 {
   221  		return NewSQLError(CRVersionError, SSUnknownSQLState, "cannot connect to servers earlier than 4.1")
   222  	}
   223  
   224  	// Remember a subset of the capabilities, so we can use them
   225  	// later in the protocol.
   226  	c.Capabilities = 0
   227  	if !params.DisableClientDeprecateEOF {
   228  		c.Capabilities = capabilities & (CapabilityClientDeprecateEOF)
   229  	}
   230  
   231  	charset, err := collations.Local().ParseConnectionCharset(params.Charset)
   232  	if err != nil {
   233  		return err
   234  	}
   235  
   236  	// Handle switch to SSL if necessary.
   237  	if params.SslEnabled() {
   238  		// If client asked for SSL, but server doesn't support it,
   239  		// stop right here.
   240  		if params.SslRequired() && capabilities&CapabilityClientSSL == 0 {
   241  			return NewSQLError(CRSSLConnectionError, SSUnknownSQLState, "server doesn't support SSL but client asked for it")
   242  		}
   243  
   244  		// The ServerName to verify depends on what the hostname is.
   245  		// We use the params's ServerName if specified. Otherwise:
   246  		// - If using a socket, we use "localhost".
   247  		// - If it is an IP address, we need to prefix it with 'IP:'.
   248  		// - If not, we can just use it as is.
   249  		serverName := "localhost"
   250  		if params.ServerName != "" {
   251  			serverName = params.ServerName
   252  		} else if params.Host != "" {
   253  			if net.ParseIP(params.Host) != nil {
   254  				serverName = "IP:" + params.Host
   255  			} else {
   256  				serverName = params.Host
   257  			}
   258  		}
   259  
   260  		tlsVersion, err := vttls.TLSVersionToNumber(params.TLSMinVersion)
   261  		if err != nil {
   262  			return NewSQLError(CRSSLConnectionError, SSUnknownSQLState, "error parsing minimal TLS version: %v", err)
   263  		}
   264  
   265  		// Build the TLS config.
   266  		clientConfig, err := vttls.ClientConfig(params.EffectiveSslMode(), params.SslCert, params.SslKey, params.SslCa, params.SslCrl, serverName, tlsVersion)
   267  		if err != nil {
   268  			return NewSQLError(CRSSLConnectionError, SSUnknownSQLState, "error loading client cert and ca: %v", err)
   269  		}
   270  
   271  		// Send the SSLRequest packet.
   272  		if err := c.writeSSLRequest(capabilities, charset, params); err != nil {
   273  			return err
   274  		}
   275  
   276  		// Switch to SSL.
   277  		conn := tls.Client(c.conn, clientConfig)
   278  		c.conn = conn
   279  		c.bufferedReader.Reset(conn)
   280  		c.Capabilities |= CapabilityClientSSL
   281  	}
   282  
   283  	// Password encryption.
   284  	var scrambledPassword []byte
   285  	if c.authPluginName == CachingSha2Password {
   286  		scrambledPassword = ScrambleCachingSha2Password(salt, []byte(params.Pass))
   287  	} else {
   288  		scrambledPassword = ScrambleMysqlNativePassword(salt, []byte(params.Pass))
   289  	}
   290  
   291  	// Client Session Tracking Capability.
   292  	if capabilities&CapabilityClientSessionTrack == CapabilityClientSessionTrack {
   293  		// If the server also supports it, we will have enabled
   294  		// it so we also add it to our capabilities.
   295  		c.Capabilities |= CapabilityClientSessionTrack
   296  	} else if params.Flags&CapabilityClientSessionTrack == CapabilityClientSessionTrack {
   297  		// If client asked for ClientSessionTrack, but server doesn't support it,
   298  		// stop right here.
   299  		return NewSQLError(CRSSLConnectionError, SSUnknownSQLState, "server doesn't support ClientSessionTrack but client asked for it")
   300  	}
   301  
   302  	// Build and send our handshake response 41.
   303  	// Note this one will never have SSL flag on.
   304  	if err := c.writeHandshakeResponse41(capabilities, scrambledPassword, charset, params); err != nil {
   305  		return err
   306  	}
   307  
   308  	// Read the server response.
   309  	if err := c.handleAuthResponse(params); err != nil {
   310  		return err
   311  	}
   312  
   313  	// If the server didn't support DbName in its handshake, set
   314  	// it now. This is what the 'mysql' client does.
   315  	if capabilities&CapabilityClientConnectWithDB == 0 && params.DbName != "" {
   316  		// Write the packet.
   317  		if err := c.writeComInitDB(params.DbName); err != nil {
   318  			return err
   319  		}
   320  
   321  		// Wait for response, should be OK.
   322  		response, err := c.readPacket()
   323  		if err != nil {
   324  			return NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err)
   325  		}
   326  		switch response[0] {
   327  		case OKPacket:
   328  			// OK packet, we are authenticated.
   329  			return nil
   330  		case ErrPacket:
   331  			return ParseErrorPacket(response)
   332  		default:
   333  			// FIXME(alainjobart) handle extra auth cases and so on.
   334  			return NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "initial server response is asking for more information, not implemented yet: %v", response)
   335  		}
   336  	}
   337  
   338  	return nil
   339  }
   340  
   341  // parseInitialHandshakePacket parses the initial handshake from the server.
   342  // It returns a SQLError with the right code.
   343  func (c *Conn) parseInitialHandshakePacket(data []byte) (uint32, []byte, error) {
   344  	pos := 0
   345  
   346  	// Protocol version.
   347  	pver, pos, ok := readByte(data, pos)
   348  	if !ok {
   349  		return 0, nil, NewSQLError(CRVersionError, SSUnknownSQLState, "parseInitialHandshakePacket: packet has no protocol version")
   350  	}
   351  
   352  	// Server is allowed to immediately send ERR packet
   353  	if pver == ErrPacket {
   354  		errorCode, pos, _ := readUint16(data, pos)
   355  		// Normally there would be a 1-byte sql_state_marker field and a 5-byte
   356  		// sql_state field here, but docs say these will not be present in this case.
   357  		errorMsg, _, _ := readEOFString(data, pos)
   358  		return 0, nil, NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "immediate error from server errorCode=%v errorMsg=%v", errorCode, errorMsg)
   359  	}
   360  
   361  	if pver != protocolVersion {
   362  		return 0, nil, NewSQLError(CRVersionError, SSUnknownSQLState, "bad protocol version: %v", pver)
   363  	}
   364  
   365  	// Read the server version.
   366  	c.ServerVersion, pos, ok = readNullString(data, pos)
   367  	if !ok {
   368  		return 0, nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "parseInitialHandshakePacket: packet has no server version")
   369  	}
   370  
   371  	// Read the connection id.
   372  	c.ConnectionID, pos, ok = readUint32(data, pos)
   373  	if !ok {
   374  		return 0, nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "parseInitialHandshakePacket: packet has no connection id")
   375  	}
   376  
   377  	// Read the first part of the auth-plugin-data
   378  	authPluginData, pos, ok := readBytes(data, pos, 8)
   379  	if !ok {
   380  		return 0, nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "parseInitialHandshakePacket: packet has no auth-plugin-data-part-1")
   381  	}
   382  
   383  	// One byte filler, 0. We don't really care about the value.
   384  	_, pos, ok = readByte(data, pos)
   385  	if !ok {
   386  		return 0, nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "parseInitialHandshakePacket: packet has no filler")
   387  	}
   388  
   389  	// Lower 2 bytes of the capability flags.
   390  	capLower, pos, ok := readUint16(data, pos)
   391  	if !ok {
   392  		return 0, nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "parseInitialHandshakePacket: packet has no capability flags (lower 2 bytes)")
   393  	}
   394  	var capabilities = uint32(capLower)
   395  
   396  	// The packet can end here.
   397  	if pos == len(data) {
   398  		return capabilities, authPluginData, nil
   399  	}
   400  
   401  	// Character set.
   402  	characterSet, pos, ok := readByte(data, pos)
   403  	if !ok {
   404  		return 0, nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "parseInitialHandshakePacket: packet has no character set")
   405  	}
   406  	c.CharacterSet = collations.ID(characterSet)
   407  
   408  	// Status flags. Ignored.
   409  	_, pos, ok = readUint16(data, pos)
   410  	if !ok {
   411  		return 0, nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "parseInitialHandshakePacket: packet has no status flags")
   412  	}
   413  
   414  	// Upper 2 bytes of the capability flags.
   415  	capUpper, pos, ok := readUint16(data, pos)
   416  	if !ok {
   417  		return 0, nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "parseInitialHandshakePacket: packet has no capability flags (upper 2 bytes)")
   418  	}
   419  	capabilities += uint32(capUpper) << 16
   420  
   421  	// Length of auth-plugin-data, or 0.
   422  	// Only with CLIENT_PLUGIN_AUTH capability.
   423  	var authPluginDataLength byte
   424  	if capabilities&CapabilityClientPluginAuth != 0 {
   425  		authPluginDataLength, pos, ok = readByte(data, pos)
   426  		if !ok {
   427  			return 0, nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "parseInitialHandshakePacket: packet has no length of auth-plugin-data")
   428  		}
   429  	} else {
   430  		// One byte filler, 0. We don't really care about the value.
   431  		_, pos, ok = readByte(data, pos)
   432  		if !ok {
   433  			return 0, nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "parseInitialHandshakePacket: packet has no length of auth-plugin-data filler")
   434  		}
   435  	}
   436  
   437  	// 10 reserved 0 bytes.
   438  	pos += 10
   439  
   440  	if capabilities&CapabilityClientSecureConnection != 0 {
   441  		// The next part of the auth-plugin-data.
   442  		// The length is max(13, length of auth-plugin-data - 8).
   443  		l := int(authPluginDataLength) - 8
   444  		if l > 13 {
   445  			l = 13
   446  		}
   447  		var authPluginDataPart2 []byte
   448  		authPluginDataPart2, pos, ok = readBytes(data, pos, l)
   449  		if !ok {
   450  			return 0, nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "parseInitialHandshakePacket: packet has no auth-plugin-data-part-2")
   451  		}
   452  
   453  		// The last byte has to be 0, and is not part of the data.
   454  		if authPluginDataPart2[l-1] != 0 {
   455  			return 0, nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "parseInitialHandshakePacket: auth-plugin-data-part-2 is not 0 terminated")
   456  		}
   457  		authPluginData = append(authPluginData, authPluginDataPart2[0:l-1]...)
   458  	}
   459  
   460  	// Auth-plugin name.
   461  	if capabilities&CapabilityClientPluginAuth != 0 {
   462  		authPluginName, _, ok := readNullString(data, pos)
   463  		if !ok {
   464  			// Fallback for versions prior to 5.5.10 and
   465  			// 5.6.2 that don't have a null terminated string.
   466  			authPluginName = string(data[pos : len(data)-1])
   467  		}
   468  		c.authPluginName = AuthMethodDescription(authPluginName)
   469  	}
   470  
   471  	return capabilities, authPluginData, nil
   472  }
   473  
   474  // writeSSLRequest writes the SSLRequest packet. It's just a truncated
   475  // HandshakeResponse41.
   476  func (c *Conn) writeSSLRequest(capabilities uint32, characterSet uint8, params *ConnParams) error {
   477  	// Build our flags, with CapabilityClientSSL.
   478  	capabilityFlags := CapabilityFlagsSsl |
   479  		// If the server supported
   480  		// CapabilityClientDeprecateEOF, we also support it.
   481  		c.Capabilities&CapabilityClientDeprecateEOF |
   482  		// If the server supported
   483  		// CapabilityClientSessionTrack, we also support it.
   484  		c.Capabilities&CapabilityClientSessionTrack |
   485  		// Pass-through ClientFoundRows flag.
   486  		CapabilityClientFoundRows&uint32(params.Flags)
   487  
   488  	length :=
   489  		4 + // Client capability flags.
   490  			4 + // Max-packet size.
   491  			1 + // Character set.
   492  			23 // Reserved.
   493  
   494  	// Add the DB name if the server supports it.
   495  	if params.DbName != "" && (capabilities&CapabilityClientConnectWithDB != 0) {
   496  		capabilityFlags |= CapabilityClientConnectWithDB
   497  	}
   498  
   499  	data, pos := c.startEphemeralPacketWithHeader(length)
   500  
   501  	// Client capability flags.
   502  	pos = writeUint32(data, pos, capabilityFlags)
   503  
   504  	// Max-packet size, always 0. See doc.go.
   505  	pos = writeZeroes(data, pos, 4)
   506  
   507  	// Character set.
   508  	_ = writeByte(data, pos, characterSet)
   509  
   510  	// And send it as is.
   511  	if err := c.writeEphemeralPacket(); err != nil {
   512  		return NewSQLError(CRServerLost, SSUnknownSQLState, "cannot send SSLRequest: %v", err)
   513  	}
   514  	return nil
   515  }
   516  
   517  // CapabilityFlags are client capability flag sent to mysql on connect
   518  const CapabilityFlags uint32 = CapabilityClientLongPassword |
   519  	CapabilityClientLongFlag |
   520  	CapabilityClientProtocol41 |
   521  	CapabilityClientTransactions |
   522  	CapabilityClientSecureConnection |
   523  	CapabilityClientMultiStatements |
   524  	CapabilityClientMultiResults |
   525  	CapabilityClientPluginAuth |
   526  	CapabilityClientPluginAuthLenencClientData
   527  
   528  // CapabilityFlagsSsl signals that we can handle SSL as well
   529  const CapabilityFlagsSsl = CapabilityFlags |
   530  	CapabilityClientSSL
   531  
   532  // writeHandshakeResponse41 writes the handshake response.
   533  // Returns a SQLError.
   534  func (c *Conn) writeHandshakeResponse41(capabilities uint32, scrambledPassword []byte, characterSet uint8, params *ConnParams) error {
   535  	// Build our flags.
   536  	capabilityFlags := CapabilityFlags |
   537  		// If the server supported
   538  		// CapabilityClientDeprecateEOF, we also support it.
   539  		c.Capabilities&CapabilityClientDeprecateEOF |
   540  		// Pass-through ClientFoundRows flag.
   541  		CapabilityClientFoundRows&uint32(params.Flags) |
   542  		// If the server supported
   543  		// CapabilityClientSessionTrack, we also support it.
   544  		c.Capabilities&CapabilityClientSessionTrack
   545  
   546  	// FIXME(alainjobart) add multi statement.
   547  
   548  	length :=
   549  		4 + // Client capability flags.
   550  			4 + // Max-packet size.
   551  			1 + // Character set.
   552  			23 + // Reserved.
   553  			lenNullString(params.Uname) +
   554  			// length of scrambled password is handled below.
   555  			len(scrambledPassword) +
   556  			len(c.authPluginName) +
   557  			1 // terminating zero.
   558  
   559  	// Add the DB name if the server supports it.
   560  	if params.DbName != "" && (capabilities&CapabilityClientConnectWithDB != 0) {
   561  		capabilityFlags |= CapabilityClientConnectWithDB
   562  		length += lenNullString(params.DbName)
   563  	}
   564  
   565  	if capabilities&CapabilityClientPluginAuthLenencClientData != 0 {
   566  		length += lenEncIntSize(uint64(len(scrambledPassword)))
   567  	} else {
   568  		length++
   569  	}
   570  
   571  	data, pos := c.startEphemeralPacketWithHeader(length)
   572  
   573  	// Client capability flags.
   574  	pos = writeUint32(data, pos, capabilityFlags)
   575  
   576  	// Max-packet size, always 0. See doc.go.
   577  	pos = writeZeroes(data, pos, 4)
   578  
   579  	// Character set.
   580  	pos = writeByte(data, pos, characterSet)
   581  
   582  	// 23 reserved bytes, all 0.
   583  	pos = writeZeroes(data, pos, 23)
   584  
   585  	// Username
   586  	pos = writeNullString(data, pos, params.Uname)
   587  
   588  	// Scrambled password.  The length is encoded as variable length if
   589  	// CapabilityClientPluginAuthLenencClientData is set.
   590  	if capabilities&CapabilityClientPluginAuthLenencClientData != 0 {
   591  		pos = writeLenEncInt(data, pos, uint64(len(scrambledPassword)))
   592  	} else {
   593  		data[pos] = byte(len(scrambledPassword))
   594  		pos++
   595  	}
   596  	pos += copy(data[pos:], scrambledPassword)
   597  
   598  	// DbName, only if server supports it.
   599  	if params.DbName != "" && (capabilities&CapabilityClientConnectWithDB != 0) {
   600  		pos = writeNullString(data, pos, params.DbName)
   601  		c.schemaName = params.DbName
   602  	}
   603  
   604  	// Assume native client during response
   605  	pos = writeNullString(data, pos, string(c.authPluginName))
   606  
   607  	// Sanity-check the length.
   608  	if pos != len(data) {
   609  		return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "writeHandshakeResponse41: only packed %v bytes, out of %v allocated", pos, len(data))
   610  	}
   611  
   612  	if err := c.writeEphemeralPacket(); err != nil {
   613  		return NewSQLError(CRServerLost, SSUnknownSQLState, "cannot send HandshakeResponse41: %v", err)
   614  	}
   615  	return nil
   616  }
   617  
   618  // handleAuthResponse parses server's response after client sends the password for authentication
   619  // and handles next steps for AuthSwitchRequestPacket and AuthMoreDataPacket.
   620  func (c *Conn) handleAuthResponse(params *ConnParams) error {
   621  	response, err := c.readPacket()
   622  	if err != nil {
   623  		return NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err)
   624  	}
   625  
   626  	switch response[0] {
   627  	case OKPacket:
   628  		// OK packet, we are authenticated. Save the user, keep going.
   629  		c.User = params.Uname
   630  	case AuthSwitchRequestPacket:
   631  		// Server is asking to use a different auth method
   632  		if err = c.handleAuthSwitchPacket(params, response); err != nil {
   633  			return err
   634  		}
   635  	case AuthMoreDataPacket:
   636  		// Server is requesting more data - maybe un-scrambled password
   637  		if err := c.handleAuthMoreDataPacket(response[1], params); err != nil {
   638  			return err
   639  		}
   640  	case ErrPacket:
   641  		return ParseErrorPacket(response)
   642  	default:
   643  		return NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "initial server response cannot be parsed: %v", response)
   644  	}
   645  
   646  	return nil
   647  }
   648  
   649  // handleAuthSwitchPacket scrambles password for the plugin requested by the server and retries authentication
   650  func (c *Conn) handleAuthSwitchPacket(params *ConnParams, response []byte) error {
   651  	var err error
   652  	var salt []byte
   653  	c.authPluginName, salt, err = parseAuthSwitchRequest(response)
   654  	if err != nil {
   655  		return NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "cannot parse auth switch request: %v", err)
   656  	}
   657  	if salt != nil {
   658  		c.salt = salt
   659  	}
   660  	switch c.authPluginName {
   661  	case MysqlClearPassword:
   662  		if err := c.writeClearTextPassword(params); err != nil {
   663  			return err
   664  		}
   665  	case MysqlNativePassword:
   666  		scrambledPassword := ScrambleMysqlNativePassword(c.salt, []byte(params.Pass))
   667  		if err := c.writeScrambledPassword(scrambledPassword); err != nil {
   668  			return err
   669  		}
   670  	case CachingSha2Password:
   671  		scrambledPassword := ScrambleCachingSha2Password(c.salt, []byte(params.Pass))
   672  		if err := c.writeScrambledPassword(scrambledPassword); err != nil {
   673  			return err
   674  		}
   675  	default:
   676  		return NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "server asked for unsupported auth method: %v", c.authPluginName)
   677  	}
   678  
   679  	// The response could be an OKPacket, AuthMoreDataPacket or ErrPacket
   680  	return c.handleAuthResponse(params)
   681  }
   682  
   683  // handleAuthMoreDataPacket handles response of CachingSha2Password authentication and sends full password to the
   684  // server if requested
   685  func (c *Conn) handleAuthMoreDataPacket(data byte, params *ConnParams) error {
   686  	switch data {
   687  	case CachingSha2FastAuth:
   688  		// User credentials are verified using the cache ("Fast" path).
   689  		// Next packet should be an OKPacket
   690  		return c.handleAuthResponse(params)
   691  	case CachingSha2FullAuth:
   692  		// User credentials are not cached, we have to exchange full password.
   693  		if c.Capabilities&CapabilityClientSSL > 0 || params.UnixSocket != "" {
   694  			// If we are using an SSL connection or Unix socket, write clear text password
   695  			if err := c.writeClearTextPassword(params); err != nil {
   696  				return err
   697  			}
   698  		} else {
   699  			// If we are not using an SSL connection or Unix socket, we have to fetch a public key
   700  			// from the server to encrypt password
   701  			pub, err := c.requestPublicKey()
   702  			if err != nil {
   703  				return err
   704  			}
   705  			// Encrypt password with public key
   706  			enc, err := EncryptPasswordWithPublicKey(c.salt, []byte(params.Pass), pub)
   707  			if err != nil {
   708  				return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "error encrypting password with public key: %v", err)
   709  			}
   710  			// Write encrypted password
   711  			if err := c.writeScrambledPassword(enc); err != nil {
   712  				return err
   713  			}
   714  		}
   715  		// Next packet should either be an OKPacket or ErrPacket
   716  		return c.handleAuthResponse(params)
   717  	default:
   718  		return NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "cannot parse AuthMoreDataPacket: %v", data)
   719  	}
   720  }
   721  
   722  func parseAuthSwitchRequest(data []byte) (AuthMethodDescription, []byte, error) {
   723  	pos := 1
   724  	pluginName, pos, ok := readNullString(data, pos)
   725  	if !ok {
   726  		return "", nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "cannot get plugin name from AuthSwitchRequest: %v", data)
   727  	}
   728  
   729  	// If this was a request with a salt in it, max 20 bytes
   730  	salt := data[pos:]
   731  	if len(salt) > 20 {
   732  		salt = salt[:20]
   733  	}
   734  	return AuthMethodDescription(pluginName), salt, nil
   735  }
   736  
   737  // requestPublicKey requests a public key from the server
   738  func (c *Conn) requestPublicKey() (rsaKey *rsa.PublicKey, err error) {
   739  	// get public key from server
   740  	data, pos := c.startEphemeralPacketWithHeader(1)
   741  	data[pos] = 0x02
   742  	if err := c.writeEphemeralPacket(); err != nil {
   743  		return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "error sending public key request packet: %v", err)
   744  	}
   745  
   746  	response, err := c.readPacket()
   747  	if err != nil {
   748  		return nil, NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err)
   749  	}
   750  
   751  	// Server should respond with a AuthMoreDataPacket containing the public key
   752  	if response[0] != AuthMoreDataPacket {
   753  		return nil, ParseErrorPacket(response)
   754  	}
   755  
   756  	block, _ := pem.Decode(response[1:])
   757  	pub, err := x509.ParsePKIXPublicKey(block.Bytes)
   758  	if err != nil {
   759  		return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "failed to parse public key from server: %v", err)
   760  	}
   761  
   762  	return pub.(*rsa.PublicKey), nil
   763  }
   764  
   765  // writeClearTextPassword writes the clear text password.
   766  // Returns a SQLError.
   767  func (c *Conn) writeClearTextPassword(params *ConnParams) error {
   768  	length := len(params.Pass) + 1
   769  	data, pos := c.startEphemeralPacketWithHeader(length)
   770  	pos = writeNullString(data, pos, params.Pass)
   771  	// Sanity check.
   772  	if pos != len(data) {
   773  		return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "error building ClearTextPassword packet: got %v bytes expected %v", pos, len(data))
   774  	}
   775  	return c.writeEphemeralPacket()
   776  }
   777  
   778  // writeScrambledPassword writes the encrypted mysql_native_password format
   779  // Returns a SQLError.
   780  func (c *Conn) writeScrambledPassword(scrambledPassword []byte) error {
   781  	data, pos := c.startEphemeralPacketWithHeader(len(scrambledPassword))
   782  	pos += copy(data[pos:], scrambledPassword)
   783  	// Sanity check.
   784  	if pos != len(data) {
   785  		return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "error building %v packet: got %v bytes expected %v", c.authPluginName, pos, len(data))
   786  	}
   787  	return c.writeEphemeralPacket()
   788  }