vitess.io/vitess@v0.16.2/go/mysql/server.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package mysql 18 19 import ( 20 "context" 21 "crypto/tls" 22 "io" 23 "net" 24 "strings" 25 "sync/atomic" 26 "time" 27 28 "vitess.io/vitess/go/mysql/collations" 29 "vitess.io/vitess/go/vt/servenv" 30 31 "vitess.io/vitess/go/sqlescape" 32 33 proxyproto "github.com/pires/go-proxyproto" 34 35 "vitess.io/vitess/go/netutil" 36 "vitess.io/vitess/go/sqltypes" 37 "vitess.io/vitess/go/stats" 38 "vitess.io/vitess/go/sync2" 39 "vitess.io/vitess/go/tb" 40 "vitess.io/vitess/go/vt/log" 41 querypb "vitess.io/vitess/go/vt/proto/query" 42 "vitess.io/vitess/go/vt/proto/vtrpc" 43 "vitess.io/vitess/go/vt/vterrors" 44 ) 45 46 const ( 47 // DefaultServerVersion is the default server version we're sending to the client. 48 // Can be changed. 49 50 // timing metric keys 51 connectTimingKey = "Connect" 52 queryTimingKey = "Query" 53 versionTLS10 = "TLS10" 54 versionTLS11 = "TLS11" 55 versionTLS12 = "TLS12" 56 versionTLS13 = "TLS13" 57 versionTLSUnknown = "UnknownTLSVersion" 58 versionNoTLS = "None" 59 ) 60 61 var ( 62 // Metrics 63 timings = stats.NewTimings("MysqlServerTimings", "MySQL server timings", "operation") 64 connCount = stats.NewGauge("MysqlServerConnCount", "Active MySQL server connections") 65 connAccept = stats.NewCounter("MysqlServerConnAccepted", "Connections accepted by MySQL server") 66 connRefuse = stats.NewCounter("MysqlServerConnRefused", "Connections refused by MySQL server") 67 connSlow = stats.NewCounter("MysqlServerConnSlow", "Connections that took more than the configured mysql_slow_connect_warn_threshold to establish") 68 69 connCountByTLSVer = stats.NewGaugesWithSingleLabel("MysqlServerConnCountByTLSVer", "Active MySQL server connections by TLS version", "tls") 70 connCountPerUser = stats.NewGaugesWithSingleLabel("MysqlServerConnCountPerUser", "Active MySQL server connections per user", "count") 71 _ = stats.NewGaugeFunc("MysqlServerConnCountUnauthenticated", "Active MySQL server connections that haven't authenticated yet", func() int64 { 72 totalUsers := int64(0) 73 for _, v := range connCountPerUser.Counts() { 74 totalUsers += v 75 } 76 return connCount.Get() - totalUsers 77 }) 78 ) 79 80 // A Handler is an interface used by Listener to send queries. 81 // The implementation of this interface may store data in the ClientData 82 // field of the Connection for its own purposes. 83 // 84 // For a given Connection, all these methods are serialized. It means 85 // only one of these methods will be called concurrently for a given 86 // Connection. So access to the Connection ClientData does not need to 87 // be protected by a mutex. 88 // 89 // However, each connection is using one go routine, so multiple 90 // Connection objects can call these concurrently, for different Connections. 91 type Handler interface { 92 // NewConnection is called when a connection is created. 93 // It is not established yet. The handler can decide to 94 // set StatusFlags that will be returned by the handshake methods. 95 // In particular, ServerStatusAutocommit might be set. 96 NewConnection(c *Conn) 97 98 // ConnectionReady is called after the connection handshake, but 99 // before we begin to process commands. 100 ConnectionReady(c *Conn) 101 102 // ConnectionClosed is called when a connection is closed. 103 ConnectionClosed(c *Conn) 104 105 // ComQuery is called when a connection receives a query. 106 // Note the contents of the query slice may change after 107 // the first call to callback. So the Handler should not 108 // hang on to the byte slice. 109 ComQuery(c *Conn, query string, callback func(*sqltypes.Result) error) error 110 111 // ComPrepare is called when a connection receives a prepared 112 // statement query. 113 ComPrepare(c *Conn, query string, bindVars map[string]*querypb.BindVariable) ([]*querypb.Field, error) 114 115 // ComStmtExecute is called when a connection receives a statement 116 // execute query. 117 ComStmtExecute(c *Conn, prepare *PrepareData, callback func(*sqltypes.Result) error) error 118 119 // ComRegisterReplica is called when a connection receives a ComRegisterReplica request 120 ComRegisterReplica(c *Conn, replicaHost string, replicaPort uint16, replicaUser string, replicaPassword string) error 121 122 // ComBinlogDump is called when a connection receives a ComBinlogDump request 123 ComBinlogDump(c *Conn, logFile string, binlogPos uint32) error 124 125 // ComBinlogDumpGTID is called when a connection receives a ComBinlogDumpGTID request 126 ComBinlogDumpGTID(c *Conn, logFile string, logPos uint64, gtidSet GTIDSet) error 127 128 // WarningCount is called at the end of each query to obtain 129 // the value to be returned to the client in the EOF packet. 130 // Note that this will be called either in the context of the 131 // ComQuery callback if the result does not contain any fields, 132 // or after the last ComQuery call completes. 133 WarningCount(c *Conn) uint16 134 135 ComResetConnection(c *Conn) 136 } 137 138 // UnimplementedHandler implemnts all of the optional callbacks so as to satisy 139 // the Handler interface. Intended to be embedded into your custom Handler 140 // implementation without needing to define every callback and to help be forwards 141 // compatible when new functions are added. 142 type UnimplementedHandler struct{} 143 144 func (UnimplementedHandler) NewConnection(*Conn) {} 145 func (UnimplementedHandler) ConnectionReady(*Conn) {} 146 func (UnimplementedHandler) ConnectionClosed(*Conn) {} 147 func (UnimplementedHandler) ComResetConnection(*Conn) {} 148 149 // Listener is the MySQL server protocol listener. 150 type Listener struct { 151 // Construction parameters, set by NewListener. 152 153 // authServer is the AuthServer object to use for authentication. 154 authServer AuthServer 155 156 // handler is the data handler. 157 handler Handler 158 159 // This is the main listener socket. 160 listener net.Listener 161 162 // The following parameters are read by multiple connection go 163 // routines. They are not protected by a mutex, so they 164 // should be set after NewListener, and not changed while 165 // Accept is running. 166 167 // ServerVersion is the version we will advertise. 168 ServerVersion string 169 170 // TLSConfig is the server TLS config. If set, we will advertise 171 // that we support SSL. 172 // atomic value stores *tls.Config 173 TLSConfig atomic.Value 174 175 // AllowClearTextWithoutTLS needs to be set for the 176 // mysql_clear_password authentication method to be accepted 177 // by the server when TLS is not in use. 178 AllowClearTextWithoutTLS sync2.AtomicBool 179 180 // SlowConnectWarnThreshold if non-zero specifies an amount of time 181 // beyond which a warning is logged to identify the slow connection 182 SlowConnectWarnThreshold sync2.AtomicDuration 183 184 // The following parameters are changed by the Accept routine. 185 186 // Incrementing ID for connection id. 187 connectionID uint32 188 189 // Read timeout on a given connection 190 connReadTimeout time.Duration 191 // Write timeout on a given connection 192 connWriteTimeout time.Duration 193 // connReadBufferSize is size of buffer for reads from underlying connection. 194 // Reads are unbuffered if it's <=0. 195 connReadBufferSize int 196 197 // connBufferPooling configures if vtgate server pools connection buffers 198 connBufferPooling bool 199 200 // shutdown indicates that Shutdown method was called. 201 shutdown sync2.AtomicBool 202 203 // RequireSecureTransport configures the server to reject connections from insecure clients 204 RequireSecureTransport bool 205 206 // PreHandleFunc is called for each incoming connection, immediately after 207 // accepting a new connection. By default it's no-op. Useful for custom 208 // connection inspection or TLS termination. The returned connection is 209 // handled further by the MySQL handler. An non-nil error will stop 210 // processing the connection by the MySQL handler. 211 PreHandleFunc func(context.Context, net.Conn, uint32) (net.Conn, error) 212 } 213 214 // NewFromListener creates a new mysql listener from an existing net.Listener 215 func NewFromListener( 216 l net.Listener, 217 authServer AuthServer, 218 handler Handler, 219 connReadTimeout time.Duration, 220 connWriteTimeout time.Duration, 221 connBufferPooling bool, 222 ) (*Listener, error) { 223 cfg := ListenerConfig{ 224 Listener: l, 225 AuthServer: authServer, 226 Handler: handler, 227 ConnReadTimeout: connReadTimeout, 228 ConnWriteTimeout: connWriteTimeout, 229 ConnReadBufferSize: connBufferSize, 230 ConnBufferPooling: connBufferPooling, 231 } 232 return NewListenerWithConfig(cfg) 233 } 234 235 // NewListener creates a new Listener. 236 func NewListener( 237 protocol, address string, 238 authServer AuthServer, 239 handler Handler, 240 connReadTimeout time.Duration, 241 connWriteTimeout time.Duration, 242 proxyProtocol bool, 243 connBufferPooling bool, 244 ) (*Listener, error) { 245 listener, err := net.Listen(protocol, address) 246 if err != nil { 247 return nil, err 248 } 249 if proxyProtocol { 250 proxyListener := &proxyproto.Listener{Listener: listener} 251 return NewFromListener(proxyListener, authServer, handler, connReadTimeout, connWriteTimeout, connBufferPooling) 252 } 253 254 return NewFromListener(listener, authServer, handler, connReadTimeout, connWriteTimeout, connBufferPooling) 255 } 256 257 // ListenerConfig should be used with NewListenerWithConfig to specify listener parameters. 258 type ListenerConfig struct { 259 // Protocol-Address pair and Listener are mutually exclusive parameters 260 Protocol string 261 Address string 262 Listener net.Listener 263 AuthServer AuthServer 264 Handler Handler 265 ConnReadTimeout time.Duration 266 ConnWriteTimeout time.Duration 267 ConnReadBufferSize int 268 ConnBufferPooling bool 269 } 270 271 // NewListenerWithConfig creates new listener using provided config. There are 272 // no default values for config, so caller should ensure its correctness. 273 func NewListenerWithConfig(cfg ListenerConfig) (*Listener, error) { 274 var l net.Listener 275 if cfg.Listener != nil { 276 l = cfg.Listener 277 } else { 278 listener, err := net.Listen(cfg.Protocol, cfg.Address) 279 if err != nil { 280 return nil, err 281 } 282 l = listener 283 } 284 285 return &Listener{ 286 authServer: cfg.AuthServer, 287 handler: cfg.Handler, 288 listener: l, 289 ServerVersion: servenv.AppVersion.MySQLVersion(), 290 connectionID: 1, 291 connReadTimeout: cfg.ConnReadTimeout, 292 connWriteTimeout: cfg.ConnWriteTimeout, 293 connReadBufferSize: cfg.ConnReadBufferSize, 294 connBufferPooling: cfg.ConnBufferPooling, 295 }, nil 296 } 297 298 // Addr returns the listener address. 299 func (l *Listener) Addr() net.Addr { 300 return l.listener.Addr() 301 } 302 303 // Accept runs an accept loop until the listener is closed. 304 func (l *Listener) Accept() { 305 ctx := context.Background() 306 307 for { 308 conn, err := l.listener.Accept() 309 if err != nil { 310 // Close() was probably called. 311 connRefuse.Add(1) 312 return 313 } 314 315 acceptTime := time.Now() 316 317 connectionID := l.connectionID 318 l.connectionID++ 319 320 connCount.Add(1) 321 connAccept.Add(1) 322 323 go func() { 324 if l.PreHandleFunc != nil { 325 conn, err = l.PreHandleFunc(ctx, conn, connectionID) 326 if err != nil { 327 log.Errorf("mysql_server pre hook: %s", err) 328 return 329 } 330 } 331 332 l.handle(conn, connectionID, acceptTime) 333 }() 334 } 335 } 336 337 // handle is called in a go routine for each client connection. 338 // FIXME(alainjobart) handle per-connection logs in a way that makes sense. 339 func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Time) { 340 if l.connReadTimeout != 0 || l.connWriteTimeout != 0 { 341 conn = netutil.NewConnWithTimeouts(conn, l.connReadTimeout, l.connWriteTimeout) 342 } 343 c := newServerConn(conn, l) 344 c.ConnectionID = connectionID 345 346 // Catch panics, and close the connection in any case. 347 defer func() { 348 if x := recover(); x != nil { 349 log.Errorf("mysql_server caught panic:\n%v\n%s", x, tb.Stack(4)) 350 } 351 // We call endWriterBuffering here in case there's a premature return after 352 // startWriterBuffering is called 353 c.endWriterBuffering() 354 355 if l.connBufferPooling { 356 c.returnReader() 357 } 358 359 conn.Close() 360 }() 361 362 // Tell the handler about the connection coming and going. 363 l.handler.NewConnection(c) 364 defer l.handler.ConnectionClosed(c) 365 366 // Adjust the count of open connections 367 defer connCount.Add(-1) 368 369 // First build and send the server handshake packet. 370 serverAuthPluginData, err := c.writeHandshakeV10(l.ServerVersion, l.authServer, l.TLSConfig.Load() != nil) 371 if err != nil { 372 if err != io.EOF { 373 log.Errorf("Cannot send HandshakeV10 packet to %s: %v", c, err) 374 } 375 return 376 } 377 378 // Wait for the client response. This has to be a direct read, 379 // so we don't buffer the TLS negotiation packets. 380 response, err := c.readEphemeralPacketDirect() 381 if err != nil { 382 // Don't log EOF errors. They cause too much spam, same as main read loop. 383 if err != io.EOF { 384 log.Infof("Cannot read client handshake response from %s: %v, it may not be a valid MySQL client", c, err) 385 } 386 return 387 } 388 user, clientAuthMethod, clientAuthResponse, err := l.parseClientHandshakePacket(c, true, response) 389 if err != nil { 390 log.Errorf("Cannot parse client handshake response from %s: %v", c, err) 391 return 392 } 393 394 c.recycleReadPacket() 395 396 if c.TLSEnabled() { 397 // SSL was enabled. We need to re-read the auth packet. 398 response, err = c.readEphemeralPacket() 399 if err != nil { 400 log.Errorf("Cannot read post-SSL client handshake response from %s: %v", c, err) 401 return 402 } 403 404 // Returns copies of the data, so we can recycle the buffer. 405 user, clientAuthMethod, clientAuthResponse, err = l.parseClientHandshakePacket(c, false, response) 406 if err != nil { 407 log.Errorf("Cannot parse post-SSL client handshake response from %s: %v", c, err) 408 return 409 } 410 c.recycleReadPacket() 411 412 if con, ok := c.conn.(*tls.Conn); ok { 413 connState := con.ConnectionState() 414 tlsVerStr := tlsVersionToString(connState.Version) 415 if tlsVerStr != "" { 416 connCountByTLSVer.Add(tlsVerStr, 1) 417 defer connCountByTLSVer.Add(tlsVerStr, -1) 418 } 419 } 420 } else { 421 if l.RequireSecureTransport { 422 c.writeErrorPacketFromError(vterrors.Errorf(vtrpc.Code_UNAVAILABLE, "server does not allow insecure connections, client must use SSL/TLS")) 423 return 424 } 425 connCountByTLSVer.Add(versionNoTLS, 1) 426 defer connCountByTLSVer.Add(versionNoTLS, -1) 427 } 428 429 // See what auth method the AuthServer wants to use for that user. 430 negotiatedAuthMethod, err := negotiateAuthMethod(c, l.authServer, user, clientAuthMethod) 431 432 // We need to send down an additional packet if we either have no negotiated method 433 // at all or incomplete authentication data. 434 // 435 // The latter case happens for example for MySQL 8.0 clients until 8.0.25 who advertise 436 // support for caching_sha2_password by default but with no plugin data. 437 if err != nil || len(clientAuthResponse) == 0 { 438 // If we have no negotiated method yet, we pick the first one 439 // we know about ourselves as that's the last resort option we have here. 440 if err != nil { 441 // The client will disconnect if it doesn't understand 442 // the first auth method that we send, so we only have to send the 443 // first one that we allow for the user. 444 for _, m := range l.authServer.AuthMethods() { 445 if m.HandleUser(c, user) { 446 negotiatedAuthMethod = m 447 break 448 } 449 } 450 } 451 452 if negotiatedAuthMethod == nil { 453 c.writeErrorPacket(CRServerHandshakeErr, SSUnknownSQLState, "No authentication methods available for authentication.") 454 return 455 } 456 457 if !l.AllowClearTextWithoutTLS.Get() && !c.TLSEnabled() && !negotiatedAuthMethod.AllowClearTextWithoutTLS() { 458 c.writeErrorPacket(CRServerHandshakeErr, SSUnknownSQLState, "Cannot use clear text authentication over non-SSL connections.") 459 return 460 } 461 462 serverAuthPluginData, err = negotiatedAuthMethod.AuthPluginData() 463 if err != nil { 464 log.Errorf("Error generating auth switch packet for %s: %v", c, err) 465 return 466 } 467 468 if err := c.writeAuthSwitchRequest(string(negotiatedAuthMethod.Name()), serverAuthPluginData); err != nil { 469 log.Errorf("Error writing auth switch packet for %s: %v", c, err) 470 return 471 } 472 473 clientAuthResponse, err = c.readEphemeralPacket() 474 if err != nil { 475 log.Errorf("Error reading auth switch response for %s: %v", c, err) 476 return 477 } 478 c.recycleReadPacket() 479 } 480 481 userData, err := negotiatedAuthMethod.HandleAuthPluginData(c, user, serverAuthPluginData, clientAuthResponse, conn.RemoteAddr()) 482 if err != nil { 483 log.Warningf("Error authenticating user %s using: %s", user, negotiatedAuthMethod.Name()) 484 c.writeErrorPacketFromError(err) 485 return 486 } 487 488 c.User = user 489 c.UserData = userData 490 491 if c.User != "" { 492 connCountPerUser.Add(c.User, 1) 493 defer connCountPerUser.Add(c.User, -1) 494 } 495 496 // Set initial db name. 497 if c.schemaName != "" { 498 err = l.handler.ComQuery(c, "use "+sqlescape.EscapeID(c.schemaName), func(result *sqltypes.Result) error { 499 return nil 500 }) 501 if err != nil { 502 c.writeErrorPacketFromError(err) 503 return 504 } 505 } 506 507 // Negotiation worked, send OK packet. 508 if err := c.writeOKPacket(&PacketOK{statusFlags: c.StatusFlags}); err != nil { 509 log.Errorf("Cannot write OK packet to %s: %v", c, err) 510 return 511 } 512 513 // Record how long we took to establish the connection 514 timings.Record(connectTimingKey, acceptTime) 515 516 // Log a warning if it took too long to connect 517 connectTime := time.Since(acceptTime) 518 if threshold := l.SlowConnectWarnThreshold.Get(); threshold != 0 && connectTime > threshold { 519 connSlow.Add(1) 520 log.Warningf("Slow connection from %s: %v", c, connectTime) 521 } 522 523 // Tell our handler that we're finished handshake and are ready to 524 // process commands. 525 l.handler.ConnectionReady(c) 526 527 for { 528 kontinue := c.handleNextCommand(l.handler) 529 if !kontinue { 530 return 531 } 532 } 533 } 534 535 // Close stops the listener, which prevents accept of any new connections. Existing connections won't be closed. 536 func (l *Listener) Close() { 537 l.listener.Close() 538 } 539 540 // Shutdown closes listener and fails any Ping requests from existing connections. 541 // This can be used for graceful shutdown, to let clients know that they should reconnect to another server. 542 func (l *Listener) Shutdown() { 543 if l.shutdown.CompareAndSwap(false, true) { 544 l.Close() 545 } 546 } 547 548 func (l *Listener) isShutdown() bool { 549 return l.shutdown.Get() 550 } 551 552 // writeHandshakeV10 writes the Initial Handshake Packet, server side. 553 // It returns the salt data. 554 func (c *Conn) writeHandshakeV10(serverVersion string, authServer AuthServer, enableTLS bool) ([]byte, error) { 555 capabilities := CapabilityClientLongPassword | 556 CapabilityClientFoundRows | 557 CapabilityClientLongFlag | 558 CapabilityClientConnectWithDB | 559 CapabilityClientProtocol41 | 560 CapabilityClientTransactions | 561 CapabilityClientSecureConnection | 562 CapabilityClientMultiStatements | 563 CapabilityClientMultiResults | 564 CapabilityClientPluginAuth | 565 CapabilityClientPluginAuthLenencClientData | 566 CapabilityClientDeprecateEOF | 567 CapabilityClientConnAttr 568 if enableTLS { 569 capabilities |= CapabilityClientSSL 570 } 571 572 // Grab the default auth method. This can only be either 573 // mysql_native_password or caching_sha2_password. Both 574 // need the salt as well to be present too. 575 // 576 // Any other auth method will cause clients to throw a 577 // handshake error. 578 authMethod := authServer.DefaultAuthMethodDescription() 579 580 if authMethod != MysqlNativePassword && authMethod != CachingSha2Password { 581 authMethod = MysqlNativePassword 582 } 583 584 length := 585 1 + // protocol version 586 lenNullString(serverVersion) + 587 4 + // connection ID 588 8 + // first part of plugin auth data 589 1 + // filler byte 590 2 + // capability flags (lower 2 bytes) 591 1 + // character set 592 2 + // status flag 593 2 + // capability flags (upper 2 bytes) 594 1 + // length of auth plugin data 595 10 + // reserved (0) 596 13 + // auth-plugin-data 597 lenNullString(string(authMethod)) // auth-plugin-name 598 599 data, pos := c.startEphemeralPacketWithHeader(length) 600 601 // Protocol version. 602 pos = writeByte(data, pos, protocolVersion) 603 604 // Copy server version. 605 pos = writeNullString(data, pos, serverVersion) 606 607 // Add connectionID in. 608 pos = writeUint32(data, pos, c.ConnectionID) 609 610 // Generate the salt as the plugin data. Will be reused 611 // later on if no auth method switch happens and the real 612 // auth method is also mysql_native_password or caching_sha2_password. 613 pluginData, err := newSalt() 614 if err != nil { 615 return nil, err 616 } 617 // Plugin data is always defined as having a trailing NULL 618 pluginData = append(pluginData, 0) 619 620 pos += copy(data[pos:], pluginData[:8]) 621 622 // One filler byte, always 0. 623 pos = writeByte(data, pos, 0) 624 625 // Lower part of the capability flags. 626 pos = writeUint16(data, pos, uint16(capabilities)) 627 628 // Character set. 629 pos = writeByte(data, pos, collations.Local().DefaultConnectionCharset()) 630 631 // Status flag. 632 pos = writeUint16(data, pos, c.StatusFlags) 633 634 // Upper part of the capability flags. 635 pos = writeUint16(data, pos, uint16(capabilities>>16)) 636 637 // Length of auth plugin data. 638 // Always 21 (8 + 13). 639 pos = writeByte(data, pos, 21) 640 641 // Reserved 10 bytes: all 0 642 pos = writeZeroes(data, pos, 10) 643 644 // Second part of auth plugin data. 645 pos += copy(data[pos:], pluginData[8:]) 646 647 // Copy authPluginName. We always start with the first 648 // registered auth method name. 649 pos = writeNullString(data, pos, string(authMethod)) 650 651 // Sanity check. 652 if pos != len(data) { 653 return nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "error building Handshake packet: got %v bytes expected %v", pos, len(data)) 654 } 655 656 if err := c.writeEphemeralPacket(); err != nil { 657 if strings.HasSuffix(err.Error(), "write: connection reset by peer") { 658 return nil, io.EOF 659 } 660 if strings.HasSuffix(err.Error(), "write: broken pipe") { 661 return nil, io.EOF 662 } 663 return nil, err 664 } 665 666 return pluginData, nil 667 } 668 669 // parseClientHandshakePacket parses the handshake sent by the client. 670 // Returns the username, auth method, auth data, error. 671 // The original data is not pointed at, and can be freed. 672 func (l *Listener) parseClientHandshakePacket(c *Conn, firstTime bool, data []byte) (string, AuthMethodDescription, []byte, error) { 673 pos := 0 674 675 // Client flags, 4 bytes. 676 clientFlags, pos, ok := readUint32(data, pos) 677 if !ok { 678 return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read client flags") 679 } 680 if clientFlags&CapabilityClientProtocol41 == 0 { 681 return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: only support protocol 4.1") 682 } 683 684 // Remember a subset of the capabilities, so we can use them 685 // later in the protocol. If we re-received the handshake packet 686 // after SSL negotiation, do not overwrite capabilities. 687 if firstTime { 688 c.Capabilities = clientFlags & (CapabilityClientDeprecateEOF | CapabilityClientFoundRows) 689 } 690 691 // set connection capability for executing multi statements 692 if clientFlags&CapabilityClientMultiStatements > 0 { 693 c.Capabilities |= CapabilityClientMultiStatements 694 } 695 696 // Max packet size. Don't do anything with this now. 697 // See doc.go for more information. 698 _, pos, ok = readUint32(data, pos) 699 if !ok { 700 return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read maxPacketSize") 701 } 702 703 // Character set. Need to handle it. 704 characterSet, pos, ok := readByte(data, pos) 705 if !ok { 706 return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read characterSet") 707 } 708 c.CharacterSet = collations.ID(characterSet) 709 710 // 23x reserved zero bytes. 711 pos += 23 712 713 // Check for SSL. 714 if firstTime && l.TLSConfig.Load() != nil && clientFlags&CapabilityClientSSL > 0 { 715 // Need to switch to TLS, and then re-read the packet. 716 conn := tls.Server(c.conn, l.TLSConfig.Load().(*tls.Config)) 717 c.conn = conn 718 c.bufferedReader.Reset(conn) 719 c.Capabilities |= CapabilityClientSSL 720 return "", "", nil, nil 721 } 722 723 // username 724 username, pos, ok := readNullString(data, pos) 725 if !ok { 726 return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read username") 727 } 728 729 // auth-response can have three forms. 730 var authResponse []byte 731 if clientFlags&CapabilityClientPluginAuthLenencClientData != 0 { 732 var l uint64 733 l, pos, ok = readLenEncInt(data, pos) 734 if !ok { 735 return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response variable length") 736 } 737 authResponse, pos, ok = readBytesCopy(data, pos, int(l)) 738 if !ok { 739 return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response") 740 } 741 742 } else if clientFlags&CapabilityClientSecureConnection != 0 { 743 var l byte 744 l, pos, ok = readByte(data, pos) 745 if !ok { 746 return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response length") 747 } 748 749 authResponse, pos, ok = readBytesCopy(data, pos, int(l)) 750 if !ok { 751 return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response") 752 } 753 } else { 754 a := "" 755 a, pos, ok = readNullString(data, pos) 756 if !ok { 757 return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response") 758 } 759 authResponse = []byte(a) 760 } 761 762 // db name. 763 if clientFlags&CapabilityClientConnectWithDB != 0 { 764 dbname := "" 765 dbname, pos, ok = readNullString(data, pos) 766 if !ok { 767 return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read dbname") 768 } 769 c.schemaName = dbname 770 } 771 772 // authMethod (with default) 773 authMethod := MysqlNativePassword 774 if clientFlags&CapabilityClientPluginAuth != 0 { 775 var authMethodStr string 776 authMethodStr, pos, ok = readNullString(data, pos) 777 if !ok { 778 return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read authMethod") 779 } 780 // The JDBC driver sometimes sends an empty string as the auth method when it wants to use mysql_native_password 781 if authMethodStr != "" { 782 authMethod = AuthMethodDescription(authMethodStr) 783 } 784 } 785 786 // Decode connection attributes send by the client 787 if clientFlags&CapabilityClientConnAttr != 0 { 788 if _, _, err := parseConnAttrs(data, pos); err != nil { 789 log.Warningf("Decode connection attributes send by the client: %v", err) 790 } 791 } 792 793 return username, AuthMethodDescription(authMethod), authResponse, nil 794 } 795 796 func parseConnAttrs(data []byte, pos int) (map[string]string, int, error) { 797 var attrLen uint64 798 799 attrLen, pos, ok := readLenEncInt(data, pos) 800 if !ok { 801 return nil, 0, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read connection attributes variable length") 802 } 803 804 var attrLenRead uint64 805 806 attrs := make(map[string]string) 807 808 for attrLenRead < attrLen { 809 var keyLen byte 810 keyLen, pos, ok = readByte(data, pos) 811 if !ok { 812 return nil, 0, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read connection attribute key length") 813 } 814 attrLenRead += uint64(keyLen) + 1 815 816 var connAttrKey []byte 817 connAttrKey, pos, ok = readBytes(data, pos, int(keyLen)) 818 if !ok { 819 return nil, 0, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read connection attribute key") 820 } 821 822 var valLen byte 823 valLen, pos, ok = readByte(data, pos) 824 if !ok { 825 return nil, 0, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read connection attribute value length") 826 } 827 attrLenRead += uint64(valLen) + 1 828 829 var connAttrVal []byte 830 connAttrVal, pos, ok = readBytes(data, pos, int(valLen)) 831 if !ok { 832 return nil, 0, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read connection attribute value") 833 } 834 835 attrs[string(connAttrKey[:])] = string(connAttrVal[:]) 836 } 837 838 return attrs, pos, nil 839 840 } 841 842 // writeAuthSwitchRequest writes an auth switch request packet. 843 func (c *Conn) writeAuthSwitchRequest(pluginName string, pluginData []byte) error { 844 length := 1 + // AuthSwitchRequestPacket 845 len(pluginName) + 1 + // 0-terminated pluginName 846 len(pluginData) 847 848 data, pos := c.startEphemeralPacketWithHeader(length) 849 850 // Packet header. 851 pos = writeByte(data, pos, AuthSwitchRequestPacket) 852 853 // Copy server version. 854 pos = writeNullString(data, pos, pluginName) 855 856 // Copy auth data. 857 pos += copy(data[pos:], pluginData) 858 859 // Sanity check. 860 if pos != len(data) { 861 return vterrors.Errorf(vtrpc.Code_INTERNAL, "error building AuthSwitchRequestPacket packet: got %v bytes expected %v", pos, len(data)) 862 } 863 return c.writeEphemeralPacket() 864 } 865 866 // Whenever we move to a new version of go, we will need add any new supported TLS versions here 867 func tlsVersionToString(version uint16) string { 868 switch version { 869 case tls.VersionTLS10: 870 return versionTLS10 871 case tls.VersionTLS11: 872 return versionTLS11 873 case tls.VersionTLS12: 874 return versionTLS12 875 case tls.VersionTLS13: 876 return versionTLS13 877 default: 878 return versionTLSUnknown 879 } 880 }