vitess.io/vitess@v0.16.2/go/mysql/server.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package mysql
    18  
    19  import (
    20  	"context"
    21  	"crypto/tls"
    22  	"io"
    23  	"net"
    24  	"strings"
    25  	"sync/atomic"
    26  	"time"
    27  
    28  	"vitess.io/vitess/go/mysql/collations"
    29  	"vitess.io/vitess/go/vt/servenv"
    30  
    31  	"vitess.io/vitess/go/sqlescape"
    32  
    33  	proxyproto "github.com/pires/go-proxyproto"
    34  
    35  	"vitess.io/vitess/go/netutil"
    36  	"vitess.io/vitess/go/sqltypes"
    37  	"vitess.io/vitess/go/stats"
    38  	"vitess.io/vitess/go/sync2"
    39  	"vitess.io/vitess/go/tb"
    40  	"vitess.io/vitess/go/vt/log"
    41  	querypb "vitess.io/vitess/go/vt/proto/query"
    42  	"vitess.io/vitess/go/vt/proto/vtrpc"
    43  	"vitess.io/vitess/go/vt/vterrors"
    44  )
    45  
    46  const (
    47  	// DefaultServerVersion is the default server version we're sending to the client.
    48  	// Can be changed.
    49  
    50  	// timing metric keys
    51  	connectTimingKey  = "Connect"
    52  	queryTimingKey    = "Query"
    53  	versionTLS10      = "TLS10"
    54  	versionTLS11      = "TLS11"
    55  	versionTLS12      = "TLS12"
    56  	versionTLS13      = "TLS13"
    57  	versionTLSUnknown = "UnknownTLSVersion"
    58  	versionNoTLS      = "None"
    59  )
    60  
    61  var (
    62  	// Metrics
    63  	timings    = stats.NewTimings("MysqlServerTimings", "MySQL server timings", "operation")
    64  	connCount  = stats.NewGauge("MysqlServerConnCount", "Active MySQL server connections")
    65  	connAccept = stats.NewCounter("MysqlServerConnAccepted", "Connections accepted by MySQL server")
    66  	connRefuse = stats.NewCounter("MysqlServerConnRefused", "Connections refused by MySQL server")
    67  	connSlow   = stats.NewCounter("MysqlServerConnSlow", "Connections that took more than the configured mysql_slow_connect_warn_threshold to establish")
    68  
    69  	connCountByTLSVer = stats.NewGaugesWithSingleLabel("MysqlServerConnCountByTLSVer", "Active MySQL server connections by TLS version", "tls")
    70  	connCountPerUser  = stats.NewGaugesWithSingleLabel("MysqlServerConnCountPerUser", "Active MySQL server connections per user", "count")
    71  	_                 = stats.NewGaugeFunc("MysqlServerConnCountUnauthenticated", "Active MySQL server connections that haven't authenticated yet", func() int64 {
    72  		totalUsers := int64(0)
    73  		for _, v := range connCountPerUser.Counts() {
    74  			totalUsers += v
    75  		}
    76  		return connCount.Get() - totalUsers
    77  	})
    78  )
    79  
    80  // A Handler is an interface used by Listener to send queries.
    81  // The implementation of this interface may store data in the ClientData
    82  // field of the Connection for its own purposes.
    83  //
    84  // For a given Connection, all these methods are serialized. It means
    85  // only one of these methods will be called concurrently for a given
    86  // Connection. So access to the Connection ClientData does not need to
    87  // be protected by a mutex.
    88  //
    89  // However, each connection is using one go routine, so multiple
    90  // Connection objects can call these concurrently, for different Connections.
    91  type Handler interface {
    92  	// NewConnection is called when a connection is created.
    93  	// It is not established yet. The handler can decide to
    94  	// set StatusFlags that will be returned by the handshake methods.
    95  	// In particular, ServerStatusAutocommit might be set.
    96  	NewConnection(c *Conn)
    97  
    98  	// ConnectionReady is called after the connection handshake, but
    99  	// before we begin to process commands.
   100  	ConnectionReady(c *Conn)
   101  
   102  	// ConnectionClosed is called when a connection is closed.
   103  	ConnectionClosed(c *Conn)
   104  
   105  	// ComQuery is called when a connection receives a query.
   106  	// Note the contents of the query slice may change after
   107  	// the first call to callback. So the Handler should not
   108  	// hang on to the byte slice.
   109  	ComQuery(c *Conn, query string, callback func(*sqltypes.Result) error) error
   110  
   111  	// ComPrepare is called when a connection receives a prepared
   112  	// statement query.
   113  	ComPrepare(c *Conn, query string, bindVars map[string]*querypb.BindVariable) ([]*querypb.Field, error)
   114  
   115  	// ComStmtExecute is called when a connection receives a statement
   116  	// execute query.
   117  	ComStmtExecute(c *Conn, prepare *PrepareData, callback func(*sqltypes.Result) error) error
   118  
   119  	// ComRegisterReplica is called when a connection receives a ComRegisterReplica request
   120  	ComRegisterReplica(c *Conn, replicaHost string, replicaPort uint16, replicaUser string, replicaPassword string) error
   121  
   122  	// ComBinlogDump is called when a connection receives a ComBinlogDump request
   123  	ComBinlogDump(c *Conn, logFile string, binlogPos uint32) error
   124  
   125  	// ComBinlogDumpGTID is called when a connection receives a ComBinlogDumpGTID request
   126  	ComBinlogDumpGTID(c *Conn, logFile string, logPos uint64, gtidSet GTIDSet) error
   127  
   128  	// WarningCount is called at the end of each query to obtain
   129  	// the value to be returned to the client in the EOF packet.
   130  	// Note that this will be called either in the context of the
   131  	// ComQuery callback if the result does not contain any fields,
   132  	// or after the last ComQuery call completes.
   133  	WarningCount(c *Conn) uint16
   134  
   135  	ComResetConnection(c *Conn)
   136  }
   137  
   138  // UnimplementedHandler implemnts all of the optional callbacks so as to satisy
   139  // the Handler interface. Intended to be embedded into your custom Handler
   140  // implementation without needing to define every callback and to help be forwards
   141  // compatible when new functions are added.
   142  type UnimplementedHandler struct{}
   143  
   144  func (UnimplementedHandler) NewConnection(*Conn)      {}
   145  func (UnimplementedHandler) ConnectionReady(*Conn)    {}
   146  func (UnimplementedHandler) ConnectionClosed(*Conn)   {}
   147  func (UnimplementedHandler) ComResetConnection(*Conn) {}
   148  
   149  // Listener is the MySQL server protocol listener.
   150  type Listener struct {
   151  	// Construction parameters, set by NewListener.
   152  
   153  	// authServer is the AuthServer object to use for authentication.
   154  	authServer AuthServer
   155  
   156  	// handler is the data handler.
   157  	handler Handler
   158  
   159  	// This is the main listener socket.
   160  	listener net.Listener
   161  
   162  	// The following parameters are read by multiple connection go
   163  	// routines.  They are not protected by a mutex, so they
   164  	// should be set after NewListener, and not changed while
   165  	// Accept is running.
   166  
   167  	// ServerVersion is the version we will advertise.
   168  	ServerVersion string
   169  
   170  	// TLSConfig is the server TLS config. If set, we will advertise
   171  	// that we support SSL.
   172  	// atomic value stores *tls.Config
   173  	TLSConfig atomic.Value
   174  
   175  	// AllowClearTextWithoutTLS needs to be set for the
   176  	// mysql_clear_password authentication method to be accepted
   177  	// by the server when TLS is not in use.
   178  	AllowClearTextWithoutTLS sync2.AtomicBool
   179  
   180  	// SlowConnectWarnThreshold if non-zero specifies an amount of time
   181  	// beyond which a warning is logged to identify the slow connection
   182  	SlowConnectWarnThreshold sync2.AtomicDuration
   183  
   184  	// The following parameters are changed by the Accept routine.
   185  
   186  	// Incrementing ID for connection id.
   187  	connectionID uint32
   188  
   189  	// Read timeout on a given connection
   190  	connReadTimeout time.Duration
   191  	// Write timeout on a given connection
   192  	connWriteTimeout time.Duration
   193  	// connReadBufferSize is size of buffer for reads from underlying connection.
   194  	// Reads are unbuffered if it's <=0.
   195  	connReadBufferSize int
   196  
   197  	// connBufferPooling configures if vtgate server pools connection buffers
   198  	connBufferPooling bool
   199  
   200  	// shutdown indicates that Shutdown method was called.
   201  	shutdown sync2.AtomicBool
   202  
   203  	// RequireSecureTransport configures the server to reject connections from insecure clients
   204  	RequireSecureTransport bool
   205  
   206  	// PreHandleFunc is called for each incoming connection, immediately after
   207  	// accepting a new connection. By default it's no-op. Useful for custom
   208  	// connection inspection or TLS termination. The returned connection is
   209  	// handled further by the MySQL handler. An non-nil error will stop
   210  	// processing the connection by the MySQL handler.
   211  	PreHandleFunc func(context.Context, net.Conn, uint32) (net.Conn, error)
   212  }
   213  
   214  // NewFromListener creates a new mysql listener from an existing net.Listener
   215  func NewFromListener(
   216  	l net.Listener,
   217  	authServer AuthServer,
   218  	handler Handler,
   219  	connReadTimeout time.Duration,
   220  	connWriteTimeout time.Duration,
   221  	connBufferPooling bool,
   222  ) (*Listener, error) {
   223  	cfg := ListenerConfig{
   224  		Listener:           l,
   225  		AuthServer:         authServer,
   226  		Handler:            handler,
   227  		ConnReadTimeout:    connReadTimeout,
   228  		ConnWriteTimeout:   connWriteTimeout,
   229  		ConnReadBufferSize: connBufferSize,
   230  		ConnBufferPooling:  connBufferPooling,
   231  	}
   232  	return NewListenerWithConfig(cfg)
   233  }
   234  
   235  // NewListener creates a new Listener.
   236  func NewListener(
   237  	protocol, address string,
   238  	authServer AuthServer,
   239  	handler Handler,
   240  	connReadTimeout time.Duration,
   241  	connWriteTimeout time.Duration,
   242  	proxyProtocol bool,
   243  	connBufferPooling bool,
   244  ) (*Listener, error) {
   245  	listener, err := net.Listen(protocol, address)
   246  	if err != nil {
   247  		return nil, err
   248  	}
   249  	if proxyProtocol {
   250  		proxyListener := &proxyproto.Listener{Listener: listener}
   251  		return NewFromListener(proxyListener, authServer, handler, connReadTimeout, connWriteTimeout, connBufferPooling)
   252  	}
   253  
   254  	return NewFromListener(listener, authServer, handler, connReadTimeout, connWriteTimeout, connBufferPooling)
   255  }
   256  
   257  // ListenerConfig should be used with NewListenerWithConfig to specify listener parameters.
   258  type ListenerConfig struct {
   259  	// Protocol-Address pair and Listener are mutually exclusive parameters
   260  	Protocol           string
   261  	Address            string
   262  	Listener           net.Listener
   263  	AuthServer         AuthServer
   264  	Handler            Handler
   265  	ConnReadTimeout    time.Duration
   266  	ConnWriteTimeout   time.Duration
   267  	ConnReadBufferSize int
   268  	ConnBufferPooling  bool
   269  }
   270  
   271  // NewListenerWithConfig creates new listener using provided config. There are
   272  // no default values for config, so caller should ensure its correctness.
   273  func NewListenerWithConfig(cfg ListenerConfig) (*Listener, error) {
   274  	var l net.Listener
   275  	if cfg.Listener != nil {
   276  		l = cfg.Listener
   277  	} else {
   278  		listener, err := net.Listen(cfg.Protocol, cfg.Address)
   279  		if err != nil {
   280  			return nil, err
   281  		}
   282  		l = listener
   283  	}
   284  
   285  	return &Listener{
   286  		authServer:         cfg.AuthServer,
   287  		handler:            cfg.Handler,
   288  		listener:           l,
   289  		ServerVersion:      servenv.AppVersion.MySQLVersion(),
   290  		connectionID:       1,
   291  		connReadTimeout:    cfg.ConnReadTimeout,
   292  		connWriteTimeout:   cfg.ConnWriteTimeout,
   293  		connReadBufferSize: cfg.ConnReadBufferSize,
   294  		connBufferPooling:  cfg.ConnBufferPooling,
   295  	}, nil
   296  }
   297  
   298  // Addr returns the listener address.
   299  func (l *Listener) Addr() net.Addr {
   300  	return l.listener.Addr()
   301  }
   302  
   303  // Accept runs an accept loop until the listener is closed.
   304  func (l *Listener) Accept() {
   305  	ctx := context.Background()
   306  
   307  	for {
   308  		conn, err := l.listener.Accept()
   309  		if err != nil {
   310  			// Close() was probably called.
   311  			connRefuse.Add(1)
   312  			return
   313  		}
   314  
   315  		acceptTime := time.Now()
   316  
   317  		connectionID := l.connectionID
   318  		l.connectionID++
   319  
   320  		connCount.Add(1)
   321  		connAccept.Add(1)
   322  
   323  		go func() {
   324  			if l.PreHandleFunc != nil {
   325  				conn, err = l.PreHandleFunc(ctx, conn, connectionID)
   326  				if err != nil {
   327  					log.Errorf("mysql_server pre hook: %s", err)
   328  					return
   329  				}
   330  			}
   331  
   332  			l.handle(conn, connectionID, acceptTime)
   333  		}()
   334  	}
   335  }
   336  
   337  // handle is called in a go routine for each client connection.
   338  // FIXME(alainjobart) handle per-connection logs in a way that makes sense.
   339  func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Time) {
   340  	if l.connReadTimeout != 0 || l.connWriteTimeout != 0 {
   341  		conn = netutil.NewConnWithTimeouts(conn, l.connReadTimeout, l.connWriteTimeout)
   342  	}
   343  	c := newServerConn(conn, l)
   344  	c.ConnectionID = connectionID
   345  
   346  	// Catch panics, and close the connection in any case.
   347  	defer func() {
   348  		if x := recover(); x != nil {
   349  			log.Errorf("mysql_server caught panic:\n%v\n%s", x, tb.Stack(4))
   350  		}
   351  		// We call endWriterBuffering here in case there's a premature return after
   352  		// startWriterBuffering is called
   353  		c.endWriterBuffering()
   354  
   355  		if l.connBufferPooling {
   356  			c.returnReader()
   357  		}
   358  
   359  		conn.Close()
   360  	}()
   361  
   362  	// Tell the handler about the connection coming and going.
   363  	l.handler.NewConnection(c)
   364  	defer l.handler.ConnectionClosed(c)
   365  
   366  	// Adjust the count of open connections
   367  	defer connCount.Add(-1)
   368  
   369  	// First build and send the server handshake packet.
   370  	serverAuthPluginData, err := c.writeHandshakeV10(l.ServerVersion, l.authServer, l.TLSConfig.Load() != nil)
   371  	if err != nil {
   372  		if err != io.EOF {
   373  			log.Errorf("Cannot send HandshakeV10 packet to %s: %v", c, err)
   374  		}
   375  		return
   376  	}
   377  
   378  	// Wait for the client response. This has to be a direct read,
   379  	// so we don't buffer the TLS negotiation packets.
   380  	response, err := c.readEphemeralPacketDirect()
   381  	if err != nil {
   382  		// Don't log EOF errors. They cause too much spam, same as main read loop.
   383  		if err != io.EOF {
   384  			log.Infof("Cannot read client handshake response from %s: %v, it may not be a valid MySQL client", c, err)
   385  		}
   386  		return
   387  	}
   388  	user, clientAuthMethod, clientAuthResponse, err := l.parseClientHandshakePacket(c, true, response)
   389  	if err != nil {
   390  		log.Errorf("Cannot parse client handshake response from %s: %v", c, err)
   391  		return
   392  	}
   393  
   394  	c.recycleReadPacket()
   395  
   396  	if c.TLSEnabled() {
   397  		// SSL was enabled. We need to re-read the auth packet.
   398  		response, err = c.readEphemeralPacket()
   399  		if err != nil {
   400  			log.Errorf("Cannot read post-SSL client handshake response from %s: %v", c, err)
   401  			return
   402  		}
   403  
   404  		// Returns copies of the data, so we can recycle the buffer.
   405  		user, clientAuthMethod, clientAuthResponse, err = l.parseClientHandshakePacket(c, false, response)
   406  		if err != nil {
   407  			log.Errorf("Cannot parse post-SSL client handshake response from %s: %v", c, err)
   408  			return
   409  		}
   410  		c.recycleReadPacket()
   411  
   412  		if con, ok := c.conn.(*tls.Conn); ok {
   413  			connState := con.ConnectionState()
   414  			tlsVerStr := tlsVersionToString(connState.Version)
   415  			if tlsVerStr != "" {
   416  				connCountByTLSVer.Add(tlsVerStr, 1)
   417  				defer connCountByTLSVer.Add(tlsVerStr, -1)
   418  			}
   419  		}
   420  	} else {
   421  		if l.RequireSecureTransport {
   422  			c.writeErrorPacketFromError(vterrors.Errorf(vtrpc.Code_UNAVAILABLE, "server does not allow insecure connections, client must use SSL/TLS"))
   423  			return
   424  		}
   425  		connCountByTLSVer.Add(versionNoTLS, 1)
   426  		defer connCountByTLSVer.Add(versionNoTLS, -1)
   427  	}
   428  
   429  	// See what auth method the AuthServer wants to use for that user.
   430  	negotiatedAuthMethod, err := negotiateAuthMethod(c, l.authServer, user, clientAuthMethod)
   431  
   432  	// We need to send down an additional packet if we either have no negotiated method
   433  	// at all or incomplete authentication data.
   434  	//
   435  	// The latter case happens for example for MySQL 8.0 clients until 8.0.25 who advertise
   436  	// support for caching_sha2_password by default but with no plugin data.
   437  	if err != nil || len(clientAuthResponse) == 0 {
   438  		// If we have no negotiated method yet, we pick the first one
   439  		// we know about ourselves as that's the last resort option we have here.
   440  		if err != nil {
   441  			// The client will disconnect if it doesn't understand
   442  			// the first auth method that we send, so we only have to send the
   443  			// first one that we allow for the user.
   444  			for _, m := range l.authServer.AuthMethods() {
   445  				if m.HandleUser(c, user) {
   446  					negotiatedAuthMethod = m
   447  					break
   448  				}
   449  			}
   450  		}
   451  
   452  		if negotiatedAuthMethod == nil {
   453  			c.writeErrorPacket(CRServerHandshakeErr, SSUnknownSQLState, "No authentication methods available for authentication.")
   454  			return
   455  		}
   456  
   457  		if !l.AllowClearTextWithoutTLS.Get() && !c.TLSEnabled() && !negotiatedAuthMethod.AllowClearTextWithoutTLS() {
   458  			c.writeErrorPacket(CRServerHandshakeErr, SSUnknownSQLState, "Cannot use clear text authentication over non-SSL connections.")
   459  			return
   460  		}
   461  
   462  		serverAuthPluginData, err = negotiatedAuthMethod.AuthPluginData()
   463  		if err != nil {
   464  			log.Errorf("Error generating auth switch packet for %s: %v", c, err)
   465  			return
   466  		}
   467  
   468  		if err := c.writeAuthSwitchRequest(string(negotiatedAuthMethod.Name()), serverAuthPluginData); err != nil {
   469  			log.Errorf("Error writing auth switch packet for %s: %v", c, err)
   470  			return
   471  		}
   472  
   473  		clientAuthResponse, err = c.readEphemeralPacket()
   474  		if err != nil {
   475  			log.Errorf("Error reading auth switch response for %s: %v", c, err)
   476  			return
   477  		}
   478  		c.recycleReadPacket()
   479  	}
   480  
   481  	userData, err := negotiatedAuthMethod.HandleAuthPluginData(c, user, serverAuthPluginData, clientAuthResponse, conn.RemoteAddr())
   482  	if err != nil {
   483  		log.Warningf("Error authenticating user %s using: %s", user, negotiatedAuthMethod.Name())
   484  		c.writeErrorPacketFromError(err)
   485  		return
   486  	}
   487  
   488  	c.User = user
   489  	c.UserData = userData
   490  
   491  	if c.User != "" {
   492  		connCountPerUser.Add(c.User, 1)
   493  		defer connCountPerUser.Add(c.User, -1)
   494  	}
   495  
   496  	// Set initial db name.
   497  	if c.schemaName != "" {
   498  		err = l.handler.ComQuery(c, "use "+sqlescape.EscapeID(c.schemaName), func(result *sqltypes.Result) error {
   499  			return nil
   500  		})
   501  		if err != nil {
   502  			c.writeErrorPacketFromError(err)
   503  			return
   504  		}
   505  	}
   506  
   507  	// Negotiation worked, send OK packet.
   508  	if err := c.writeOKPacket(&PacketOK{statusFlags: c.StatusFlags}); err != nil {
   509  		log.Errorf("Cannot write OK packet to %s: %v", c, err)
   510  		return
   511  	}
   512  
   513  	// Record how long we took to establish the connection
   514  	timings.Record(connectTimingKey, acceptTime)
   515  
   516  	// Log a warning if it took too long to connect
   517  	connectTime := time.Since(acceptTime)
   518  	if threshold := l.SlowConnectWarnThreshold.Get(); threshold != 0 && connectTime > threshold {
   519  		connSlow.Add(1)
   520  		log.Warningf("Slow connection from %s: %v", c, connectTime)
   521  	}
   522  
   523  	// Tell our handler that we're finished handshake and are ready to
   524  	// process commands.
   525  	l.handler.ConnectionReady(c)
   526  
   527  	for {
   528  		kontinue := c.handleNextCommand(l.handler)
   529  		if !kontinue {
   530  			return
   531  		}
   532  	}
   533  }
   534  
   535  // Close stops the listener, which prevents accept of any new connections. Existing connections won't be closed.
   536  func (l *Listener) Close() {
   537  	l.listener.Close()
   538  }
   539  
   540  // Shutdown closes listener and fails any Ping requests from existing connections.
   541  // This can be used for graceful shutdown, to let clients know that they should reconnect to another server.
   542  func (l *Listener) Shutdown() {
   543  	if l.shutdown.CompareAndSwap(false, true) {
   544  		l.Close()
   545  	}
   546  }
   547  
   548  func (l *Listener) isShutdown() bool {
   549  	return l.shutdown.Get()
   550  }
   551  
   552  // writeHandshakeV10 writes the Initial Handshake Packet, server side.
   553  // It returns the salt data.
   554  func (c *Conn) writeHandshakeV10(serverVersion string, authServer AuthServer, enableTLS bool) ([]byte, error) {
   555  	capabilities := CapabilityClientLongPassword |
   556  		CapabilityClientFoundRows |
   557  		CapabilityClientLongFlag |
   558  		CapabilityClientConnectWithDB |
   559  		CapabilityClientProtocol41 |
   560  		CapabilityClientTransactions |
   561  		CapabilityClientSecureConnection |
   562  		CapabilityClientMultiStatements |
   563  		CapabilityClientMultiResults |
   564  		CapabilityClientPluginAuth |
   565  		CapabilityClientPluginAuthLenencClientData |
   566  		CapabilityClientDeprecateEOF |
   567  		CapabilityClientConnAttr
   568  	if enableTLS {
   569  		capabilities |= CapabilityClientSSL
   570  	}
   571  
   572  	// Grab the default auth method. This can only be either
   573  	// mysql_native_password or caching_sha2_password. Both
   574  	// need the salt as well to be present too.
   575  	//
   576  	// Any other auth method will cause clients to throw a
   577  	// handshake error.
   578  	authMethod := authServer.DefaultAuthMethodDescription()
   579  
   580  	if authMethod != MysqlNativePassword && authMethod != CachingSha2Password {
   581  		authMethod = MysqlNativePassword
   582  	}
   583  
   584  	length :=
   585  		1 + // protocol version
   586  			lenNullString(serverVersion) +
   587  			4 + // connection ID
   588  			8 + // first part of plugin auth data
   589  			1 + // filler byte
   590  			2 + // capability flags (lower 2 bytes)
   591  			1 + // character set
   592  			2 + // status flag
   593  			2 + // capability flags (upper 2 bytes)
   594  			1 + // length of auth plugin data
   595  			10 + // reserved (0)
   596  			13 + // auth-plugin-data
   597  			lenNullString(string(authMethod)) // auth-plugin-name
   598  
   599  	data, pos := c.startEphemeralPacketWithHeader(length)
   600  
   601  	// Protocol version.
   602  	pos = writeByte(data, pos, protocolVersion)
   603  
   604  	// Copy server version.
   605  	pos = writeNullString(data, pos, serverVersion)
   606  
   607  	// Add connectionID in.
   608  	pos = writeUint32(data, pos, c.ConnectionID)
   609  
   610  	// Generate the salt as the plugin data. Will be reused
   611  	// later on if no auth method switch happens and the real
   612  	// auth method is also mysql_native_password or caching_sha2_password.
   613  	pluginData, err := newSalt()
   614  	if err != nil {
   615  		return nil, err
   616  	}
   617  	// Plugin data is always defined as having a trailing NULL
   618  	pluginData = append(pluginData, 0)
   619  
   620  	pos += copy(data[pos:], pluginData[:8])
   621  
   622  	// One filler byte, always 0.
   623  	pos = writeByte(data, pos, 0)
   624  
   625  	// Lower part of the capability flags.
   626  	pos = writeUint16(data, pos, uint16(capabilities))
   627  
   628  	// Character set.
   629  	pos = writeByte(data, pos, collations.Local().DefaultConnectionCharset())
   630  
   631  	// Status flag.
   632  	pos = writeUint16(data, pos, c.StatusFlags)
   633  
   634  	// Upper part of the capability flags.
   635  	pos = writeUint16(data, pos, uint16(capabilities>>16))
   636  
   637  	// Length of auth plugin data.
   638  	// Always 21 (8 + 13).
   639  	pos = writeByte(data, pos, 21)
   640  
   641  	// Reserved 10 bytes: all 0
   642  	pos = writeZeroes(data, pos, 10)
   643  
   644  	// Second part of auth plugin data.
   645  	pos += copy(data[pos:], pluginData[8:])
   646  
   647  	// Copy authPluginName. We always start with the first
   648  	// registered auth method name.
   649  	pos = writeNullString(data, pos, string(authMethod))
   650  
   651  	// Sanity check.
   652  	if pos != len(data) {
   653  		return nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "error building Handshake packet: got %v bytes expected %v", pos, len(data))
   654  	}
   655  
   656  	if err := c.writeEphemeralPacket(); err != nil {
   657  		if strings.HasSuffix(err.Error(), "write: connection reset by peer") {
   658  			return nil, io.EOF
   659  		}
   660  		if strings.HasSuffix(err.Error(), "write: broken pipe") {
   661  			return nil, io.EOF
   662  		}
   663  		return nil, err
   664  	}
   665  
   666  	return pluginData, nil
   667  }
   668  
   669  // parseClientHandshakePacket parses the handshake sent by the client.
   670  // Returns the username, auth method, auth data, error.
   671  // The original data is not pointed at, and can be freed.
   672  func (l *Listener) parseClientHandshakePacket(c *Conn, firstTime bool, data []byte) (string, AuthMethodDescription, []byte, error) {
   673  	pos := 0
   674  
   675  	// Client flags, 4 bytes.
   676  	clientFlags, pos, ok := readUint32(data, pos)
   677  	if !ok {
   678  		return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read client flags")
   679  	}
   680  	if clientFlags&CapabilityClientProtocol41 == 0 {
   681  		return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: only support protocol 4.1")
   682  	}
   683  
   684  	// Remember a subset of the capabilities, so we can use them
   685  	// later in the protocol. If we re-received the handshake packet
   686  	// after SSL negotiation, do not overwrite capabilities.
   687  	if firstTime {
   688  		c.Capabilities = clientFlags & (CapabilityClientDeprecateEOF | CapabilityClientFoundRows)
   689  	}
   690  
   691  	// set connection capability for executing multi statements
   692  	if clientFlags&CapabilityClientMultiStatements > 0 {
   693  		c.Capabilities |= CapabilityClientMultiStatements
   694  	}
   695  
   696  	// Max packet size. Don't do anything with this now.
   697  	// See doc.go for more information.
   698  	_, pos, ok = readUint32(data, pos)
   699  	if !ok {
   700  		return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read maxPacketSize")
   701  	}
   702  
   703  	// Character set. Need to handle it.
   704  	characterSet, pos, ok := readByte(data, pos)
   705  	if !ok {
   706  		return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read characterSet")
   707  	}
   708  	c.CharacterSet = collations.ID(characterSet)
   709  
   710  	// 23x reserved zero bytes.
   711  	pos += 23
   712  
   713  	// Check for SSL.
   714  	if firstTime && l.TLSConfig.Load() != nil && clientFlags&CapabilityClientSSL > 0 {
   715  		// Need to switch to TLS, and then re-read the packet.
   716  		conn := tls.Server(c.conn, l.TLSConfig.Load().(*tls.Config))
   717  		c.conn = conn
   718  		c.bufferedReader.Reset(conn)
   719  		c.Capabilities |= CapabilityClientSSL
   720  		return "", "", nil, nil
   721  	}
   722  
   723  	// username
   724  	username, pos, ok := readNullString(data, pos)
   725  	if !ok {
   726  		return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read username")
   727  	}
   728  
   729  	// auth-response can have three forms.
   730  	var authResponse []byte
   731  	if clientFlags&CapabilityClientPluginAuthLenencClientData != 0 {
   732  		var l uint64
   733  		l, pos, ok = readLenEncInt(data, pos)
   734  		if !ok {
   735  			return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response variable length")
   736  		}
   737  		authResponse, pos, ok = readBytesCopy(data, pos, int(l))
   738  		if !ok {
   739  			return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response")
   740  		}
   741  
   742  	} else if clientFlags&CapabilityClientSecureConnection != 0 {
   743  		var l byte
   744  		l, pos, ok = readByte(data, pos)
   745  		if !ok {
   746  			return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response length")
   747  		}
   748  
   749  		authResponse, pos, ok = readBytesCopy(data, pos, int(l))
   750  		if !ok {
   751  			return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response")
   752  		}
   753  	} else {
   754  		a := ""
   755  		a, pos, ok = readNullString(data, pos)
   756  		if !ok {
   757  			return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response")
   758  		}
   759  		authResponse = []byte(a)
   760  	}
   761  
   762  	// db name.
   763  	if clientFlags&CapabilityClientConnectWithDB != 0 {
   764  		dbname := ""
   765  		dbname, pos, ok = readNullString(data, pos)
   766  		if !ok {
   767  			return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read dbname")
   768  		}
   769  		c.schemaName = dbname
   770  	}
   771  
   772  	// authMethod (with default)
   773  	authMethod := MysqlNativePassword
   774  	if clientFlags&CapabilityClientPluginAuth != 0 {
   775  		var authMethodStr string
   776  		authMethodStr, pos, ok = readNullString(data, pos)
   777  		if !ok {
   778  			return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read authMethod")
   779  		}
   780  		// The JDBC driver sometimes sends an empty string as the auth method when it wants to use mysql_native_password
   781  		if authMethodStr != "" {
   782  			authMethod = AuthMethodDescription(authMethodStr)
   783  		}
   784  	}
   785  
   786  	// Decode connection attributes send by the client
   787  	if clientFlags&CapabilityClientConnAttr != 0 {
   788  		if _, _, err := parseConnAttrs(data, pos); err != nil {
   789  			log.Warningf("Decode connection attributes send by the client: %v", err)
   790  		}
   791  	}
   792  
   793  	return username, AuthMethodDescription(authMethod), authResponse, nil
   794  }
   795  
   796  func parseConnAttrs(data []byte, pos int) (map[string]string, int, error) {
   797  	var attrLen uint64
   798  
   799  	attrLen, pos, ok := readLenEncInt(data, pos)
   800  	if !ok {
   801  		return nil, 0, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read connection attributes variable length")
   802  	}
   803  
   804  	var attrLenRead uint64
   805  
   806  	attrs := make(map[string]string)
   807  
   808  	for attrLenRead < attrLen {
   809  		var keyLen byte
   810  		keyLen, pos, ok = readByte(data, pos)
   811  		if !ok {
   812  			return nil, 0, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read connection attribute key length")
   813  		}
   814  		attrLenRead += uint64(keyLen) + 1
   815  
   816  		var connAttrKey []byte
   817  		connAttrKey, pos, ok = readBytes(data, pos, int(keyLen))
   818  		if !ok {
   819  			return nil, 0, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read connection attribute key")
   820  		}
   821  
   822  		var valLen byte
   823  		valLen, pos, ok = readByte(data, pos)
   824  		if !ok {
   825  			return nil, 0, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read connection attribute value length")
   826  		}
   827  		attrLenRead += uint64(valLen) + 1
   828  
   829  		var connAttrVal []byte
   830  		connAttrVal, pos, ok = readBytes(data, pos, int(valLen))
   831  		if !ok {
   832  			return nil, 0, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read connection attribute value")
   833  		}
   834  
   835  		attrs[string(connAttrKey[:])] = string(connAttrVal[:])
   836  	}
   837  
   838  	return attrs, pos, nil
   839  
   840  }
   841  
   842  // writeAuthSwitchRequest writes an auth switch request packet.
   843  func (c *Conn) writeAuthSwitchRequest(pluginName string, pluginData []byte) error {
   844  	length := 1 + // AuthSwitchRequestPacket
   845  		len(pluginName) + 1 + // 0-terminated pluginName
   846  		len(pluginData)
   847  
   848  	data, pos := c.startEphemeralPacketWithHeader(length)
   849  
   850  	// Packet header.
   851  	pos = writeByte(data, pos, AuthSwitchRequestPacket)
   852  
   853  	// Copy server version.
   854  	pos = writeNullString(data, pos, pluginName)
   855  
   856  	// Copy auth data.
   857  	pos += copy(data[pos:], pluginData)
   858  
   859  	// Sanity check.
   860  	if pos != len(data) {
   861  		return vterrors.Errorf(vtrpc.Code_INTERNAL, "error building AuthSwitchRequestPacket packet: got %v bytes expected %v", pos, len(data))
   862  	}
   863  	return c.writeEphemeralPacket()
   864  }
   865  
   866  // Whenever we move to a new version of go, we will need add any new supported TLS versions here
   867  func tlsVersionToString(version uint16) string {
   868  	switch version {
   869  	case tls.VersionTLS10:
   870  		return versionTLS10
   871  	case tls.VersionTLS11:
   872  		return versionTLS11
   873  	case tls.VersionTLS12:
   874  		return versionTLS12
   875  	case tls.VersionTLS13:
   876  		return versionTLS13
   877  	default:
   878  		return versionTLSUnknown
   879  	}
   880  }