vitess.io/vitess@v0.16.2/go/mysql/client.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package mysql 18 19 import ( 20 "context" 21 "crypto/rsa" 22 "crypto/tls" 23 "crypto/x509" 24 "encoding/pem" 25 "fmt" 26 "net" 27 "time" 28 29 "vitess.io/vitess/go/mysql/collations" 30 vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" 31 "vitess.io/vitess/go/vt/vterrors" 32 "vitess.io/vitess/go/vt/vttls" 33 ) 34 35 // connectResult is used by Connect. 36 type connectResult struct { 37 c *Conn 38 err error 39 } 40 41 // Connect creates a connection to a server. 42 // It then handles the initial handshake. 43 // 44 // If context is canceled before the end of the process, this function 45 // will return nil, ctx.Err(). 46 // 47 // FIXME(alainjobart) once we have more of a server side, add test cases 48 // to cover all failure scenarios. 49 func Connect(ctx context.Context, params *ConnParams) (*Conn, error) { 50 if params.ConnectTimeoutMs != 0 { 51 var cancel context.CancelFunc 52 ctx, cancel = context.WithTimeout(ctx, time.Duration(params.ConnectTimeoutMs)*time.Millisecond) 53 defer cancel() 54 } 55 netProto := "tcp" 56 addr := "" 57 if params.UnixSocket != "" { 58 netProto = "unix" 59 addr = params.UnixSocket 60 } else { 61 addr = net.JoinHostPort(params.Host, fmt.Sprintf("%v", params.Port)) 62 } 63 64 // Start a background connection routine. It first 65 // establishes a network connection, returns it on the channel, 66 // then starts the negotiation, and returns the result on the channel. 67 // It can send on the channel, before closing it: 68 // - a connectResult with an error and nothing else (when dial fails). 69 // - a connectResult with a *Conn and no error, then another one 70 // with possibly an error. 71 status := make(chan connectResult) 72 go func() { 73 defer close(status) 74 var err error 75 var conn net.Conn 76 77 // Cap the Dial time with the context deadline, plus a 78 // few seconds. We want to reclaim resources quickly 79 // and not let this go routine stuck in Dial forever. 80 // 81 // We add a few seconds so we detect the context is 82 // Done() before timing out the Dial. That way we'll 83 // return the right error to the client (ctx.Err(), vs 84 // DialTimeout() error). 85 if deadline, ok := ctx.Deadline(); ok { 86 timeout := time.Until(deadline) + 5*time.Second 87 conn, err = net.DialTimeout(netProto, addr, timeout) 88 } else { 89 conn, err = net.Dial(netProto, addr) 90 } 91 if err != nil { 92 // If we get an error, the connection to a Unix socket 93 // should return a 2002, but for a TCP socket it 94 // should return a 2003. 95 if netProto == "tcp" { 96 status <- connectResult{ 97 err: NewSQLError(CRConnHostError, SSUnknownSQLState, "net.Dial(%v) failed: %v", addr, err), 98 } 99 } else { 100 status <- connectResult{ 101 err: NewSQLError(CRConnectionError, SSUnknownSQLState, "net.Dial(%v) to local server failed: %v", addr, err), 102 } 103 } 104 return 105 } 106 107 // Send the connection back, so the other side can close it. 108 c := newConn(conn) 109 status <- connectResult{ 110 c: c, 111 } 112 113 // During the handshake, and if the context is 114 // canceled, the connection will be closed. That will 115 // make any read or write just return with an error 116 // right away. 117 status <- connectResult{ 118 err: c.clientHandshake(params), 119 } 120 }() 121 122 // Wait on the context and the status, for the connection to happen. 123 var c *Conn 124 select { 125 case <-ctx.Done(): 126 // The background routine may send us a few things, 127 // wait for them and terminate them properly in the 128 // background. 129 go func() { 130 dialCR := <-status // This one can take a while. 131 if dialCR.err != nil { 132 // Dial failed, nothing else to do. 133 return 134 } 135 // Dial worked, close the connection, wait for the end. 136 // We wait as not to leave a channel with an unread value. 137 dialCR.c.Close() 138 <-status 139 }() 140 return nil, ctx.Err() 141 case cr := <-status: 142 if cr.err != nil { 143 // Dial failed, no connection was ever established. 144 return nil, cr.err 145 } 146 147 // Dial worked, we have a connection. Keep going. 148 c = cr.c 149 } 150 151 // Wait for the end of the handshake. 152 select { 153 case <-ctx.Done(): 154 // We are interrupted. Close the connection, wait for 155 // the handshake to finish in the background. 156 c.Close() 157 go func() { 158 // Since we closed the connection, this one should be fast. 159 // We wait as not to leave a channel with an unread value. 160 <-status 161 }() 162 return nil, ctx.Err() 163 case cr := <-status: 164 if cr.err != nil { 165 c.Close() 166 return nil, cr.err 167 } 168 } 169 170 return c, nil 171 } 172 173 // Ping implements mysql ping command. 174 func (c *Conn) Ping() error { 175 // This is a new command, need to reset the sequence. 176 c.sequence = 0 177 data, pos := c.startEphemeralPacketWithHeader(1) 178 data[pos] = ComPing 179 180 if err := c.writeEphemeralPacket(); err != nil { 181 return NewSQLError(CRServerGone, SSUnknownSQLState, "%v", err) 182 } 183 data, err := c.readEphemeralPacket() 184 if err != nil { 185 return NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err) 186 } 187 defer c.recycleReadPacket() 188 switch data[0] { 189 case OKPacket: 190 return nil 191 case ErrPacket: 192 return ParseErrorPacket(data) 193 } 194 return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "unexpected packet type: %d", data[0]) 195 } 196 197 // clientHandshake handles the client side of the handshake. 198 // Note the connection can be closed while this is running. 199 // Returns a SQLError. 200 func (c *Conn) clientHandshake(params *ConnParams) error { 201 // if EnableQueryInfo is set, make sure that all queries starting with the handshake 202 // will actually process the INFO fields in QUERY_OK packets 203 if params.EnableQueryInfo { 204 c.enableQueryInfo = true 205 } 206 207 // Wait for the server initial handshake packet, and parse it. 208 data, err := c.readPacket() 209 if err != nil { 210 return NewSQLError(CRServerLost, "", "initial packet read failed: %v", err) 211 } 212 capabilities, salt, err := c.parseInitialHandshakePacket(data) 213 if err != nil { 214 return err 215 } 216 c.fillFlavor(params) 217 c.salt = salt 218 219 // Sanity check. 220 if capabilities&CapabilityClientProtocol41 == 0 { 221 return NewSQLError(CRVersionError, SSUnknownSQLState, "cannot connect to servers earlier than 4.1") 222 } 223 224 // Remember a subset of the capabilities, so we can use them 225 // later in the protocol. 226 c.Capabilities = 0 227 if !params.DisableClientDeprecateEOF { 228 c.Capabilities = capabilities & (CapabilityClientDeprecateEOF) 229 } 230 231 charset, err := collations.Local().ParseConnectionCharset(params.Charset) 232 if err != nil { 233 return err 234 } 235 236 // Handle switch to SSL if necessary. 237 if params.SslEnabled() { 238 // If client asked for SSL, but server doesn't support it, 239 // stop right here. 240 if params.SslRequired() && capabilities&CapabilityClientSSL == 0 { 241 return NewSQLError(CRSSLConnectionError, SSUnknownSQLState, "server doesn't support SSL but client asked for it") 242 } 243 244 // The ServerName to verify depends on what the hostname is. 245 // We use the params's ServerName if specified. Otherwise: 246 // - If using a socket, we use "localhost". 247 // - If it is an IP address, we need to prefix it with 'IP:'. 248 // - If not, we can just use it as is. 249 serverName := "localhost" 250 if params.ServerName != "" { 251 serverName = params.ServerName 252 } else if params.Host != "" { 253 if net.ParseIP(params.Host) != nil { 254 serverName = "IP:" + params.Host 255 } else { 256 serverName = params.Host 257 } 258 } 259 260 tlsVersion, err := vttls.TLSVersionToNumber(params.TLSMinVersion) 261 if err != nil { 262 return NewSQLError(CRSSLConnectionError, SSUnknownSQLState, "error parsing minimal TLS version: %v", err) 263 } 264 265 // Build the TLS config. 266 clientConfig, err := vttls.ClientConfig(params.EffectiveSslMode(), params.SslCert, params.SslKey, params.SslCa, params.SslCrl, serverName, tlsVersion) 267 if err != nil { 268 return NewSQLError(CRSSLConnectionError, SSUnknownSQLState, "error loading client cert and ca: %v", err) 269 } 270 271 // Send the SSLRequest packet. 272 if err := c.writeSSLRequest(capabilities, charset, params); err != nil { 273 return err 274 } 275 276 // Switch to SSL. 277 conn := tls.Client(c.conn, clientConfig) 278 c.conn = conn 279 c.bufferedReader.Reset(conn) 280 c.Capabilities |= CapabilityClientSSL 281 } 282 283 // Password encryption. 284 var scrambledPassword []byte 285 if c.authPluginName == CachingSha2Password { 286 scrambledPassword = ScrambleCachingSha2Password(salt, []byte(params.Pass)) 287 } else { 288 scrambledPassword = ScrambleMysqlNativePassword(salt, []byte(params.Pass)) 289 } 290 291 // Client Session Tracking Capability. 292 if capabilities&CapabilityClientSessionTrack == CapabilityClientSessionTrack { 293 // If the server also supports it, we will have enabled 294 // it so we also add it to our capabilities. 295 c.Capabilities |= CapabilityClientSessionTrack 296 } else if params.Flags&CapabilityClientSessionTrack == CapabilityClientSessionTrack { 297 // If client asked for ClientSessionTrack, but server doesn't support it, 298 // stop right here. 299 return NewSQLError(CRSSLConnectionError, SSUnknownSQLState, "server doesn't support ClientSessionTrack but client asked for it") 300 } 301 302 // Build and send our handshake response 41. 303 // Note this one will never have SSL flag on. 304 if err := c.writeHandshakeResponse41(capabilities, scrambledPassword, charset, params); err != nil { 305 return err 306 } 307 308 // Read the server response. 309 if err := c.handleAuthResponse(params); err != nil { 310 return err 311 } 312 313 // If the server didn't support DbName in its handshake, set 314 // it now. This is what the 'mysql' client does. 315 if capabilities&CapabilityClientConnectWithDB == 0 && params.DbName != "" { 316 // Write the packet. 317 if err := c.writeComInitDB(params.DbName); err != nil { 318 return err 319 } 320 321 // Wait for response, should be OK. 322 response, err := c.readPacket() 323 if err != nil { 324 return NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err) 325 } 326 switch response[0] { 327 case OKPacket: 328 // OK packet, we are authenticated. 329 return nil 330 case ErrPacket: 331 return ParseErrorPacket(response) 332 default: 333 // FIXME(alainjobart) handle extra auth cases and so on. 334 return NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "initial server response is asking for more information, not implemented yet: %v", response) 335 } 336 } 337 338 return nil 339 } 340 341 // parseInitialHandshakePacket parses the initial handshake from the server. 342 // It returns a SQLError with the right code. 343 func (c *Conn) parseInitialHandshakePacket(data []byte) (uint32, []byte, error) { 344 pos := 0 345 346 // Protocol version. 347 pver, pos, ok := readByte(data, pos) 348 if !ok { 349 return 0, nil, NewSQLError(CRVersionError, SSUnknownSQLState, "parseInitialHandshakePacket: packet has no protocol version") 350 } 351 352 // Server is allowed to immediately send ERR packet 353 if pver == ErrPacket { 354 errorCode, pos, _ := readUint16(data, pos) 355 // Normally there would be a 1-byte sql_state_marker field and a 5-byte 356 // sql_state field here, but docs say these will not be present in this case. 357 errorMsg, _, _ := readEOFString(data, pos) 358 return 0, nil, NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "immediate error from server errorCode=%v errorMsg=%v", errorCode, errorMsg) 359 } 360 361 if pver != protocolVersion { 362 return 0, nil, NewSQLError(CRVersionError, SSUnknownSQLState, "bad protocol version: %v", pver) 363 } 364 365 // Read the server version. 366 c.ServerVersion, pos, ok = readNullString(data, pos) 367 if !ok { 368 return 0, nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "parseInitialHandshakePacket: packet has no server version") 369 } 370 371 // Read the connection id. 372 c.ConnectionID, pos, ok = readUint32(data, pos) 373 if !ok { 374 return 0, nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "parseInitialHandshakePacket: packet has no connection id") 375 } 376 377 // Read the first part of the auth-plugin-data 378 authPluginData, pos, ok := readBytes(data, pos, 8) 379 if !ok { 380 return 0, nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "parseInitialHandshakePacket: packet has no auth-plugin-data-part-1") 381 } 382 383 // One byte filler, 0. We don't really care about the value. 384 _, pos, ok = readByte(data, pos) 385 if !ok { 386 return 0, nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "parseInitialHandshakePacket: packet has no filler") 387 } 388 389 // Lower 2 bytes of the capability flags. 390 capLower, pos, ok := readUint16(data, pos) 391 if !ok { 392 return 0, nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "parseInitialHandshakePacket: packet has no capability flags (lower 2 bytes)") 393 } 394 var capabilities = uint32(capLower) 395 396 // The packet can end here. 397 if pos == len(data) { 398 return capabilities, authPluginData, nil 399 } 400 401 // Character set. 402 characterSet, pos, ok := readByte(data, pos) 403 if !ok { 404 return 0, nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "parseInitialHandshakePacket: packet has no character set") 405 } 406 c.CharacterSet = collations.ID(characterSet) 407 408 // Status flags. Ignored. 409 _, pos, ok = readUint16(data, pos) 410 if !ok { 411 return 0, nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "parseInitialHandshakePacket: packet has no status flags") 412 } 413 414 // Upper 2 bytes of the capability flags. 415 capUpper, pos, ok := readUint16(data, pos) 416 if !ok { 417 return 0, nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "parseInitialHandshakePacket: packet has no capability flags (upper 2 bytes)") 418 } 419 capabilities += uint32(capUpper) << 16 420 421 // Length of auth-plugin-data, or 0. 422 // Only with CLIENT_PLUGIN_AUTH capability. 423 var authPluginDataLength byte 424 if capabilities&CapabilityClientPluginAuth != 0 { 425 authPluginDataLength, pos, ok = readByte(data, pos) 426 if !ok { 427 return 0, nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "parseInitialHandshakePacket: packet has no length of auth-plugin-data") 428 } 429 } else { 430 // One byte filler, 0. We don't really care about the value. 431 _, pos, ok = readByte(data, pos) 432 if !ok { 433 return 0, nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "parseInitialHandshakePacket: packet has no length of auth-plugin-data filler") 434 } 435 } 436 437 // 10 reserved 0 bytes. 438 pos += 10 439 440 if capabilities&CapabilityClientSecureConnection != 0 { 441 // The next part of the auth-plugin-data. 442 // The length is max(13, length of auth-plugin-data - 8). 443 l := int(authPluginDataLength) - 8 444 if l > 13 { 445 l = 13 446 } 447 var authPluginDataPart2 []byte 448 authPluginDataPart2, pos, ok = readBytes(data, pos, l) 449 if !ok { 450 return 0, nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "parseInitialHandshakePacket: packet has no auth-plugin-data-part-2") 451 } 452 453 // The last byte has to be 0, and is not part of the data. 454 if authPluginDataPart2[l-1] != 0 { 455 return 0, nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "parseInitialHandshakePacket: auth-plugin-data-part-2 is not 0 terminated") 456 } 457 authPluginData = append(authPluginData, authPluginDataPart2[0:l-1]...) 458 } 459 460 // Auth-plugin name. 461 if capabilities&CapabilityClientPluginAuth != 0 { 462 authPluginName, _, ok := readNullString(data, pos) 463 if !ok { 464 // Fallback for versions prior to 5.5.10 and 465 // 5.6.2 that don't have a null terminated string. 466 authPluginName = string(data[pos : len(data)-1]) 467 } 468 c.authPluginName = AuthMethodDescription(authPluginName) 469 } 470 471 return capabilities, authPluginData, nil 472 } 473 474 // writeSSLRequest writes the SSLRequest packet. It's just a truncated 475 // HandshakeResponse41. 476 func (c *Conn) writeSSLRequest(capabilities uint32, characterSet uint8, params *ConnParams) error { 477 // Build our flags, with CapabilityClientSSL. 478 capabilityFlags := CapabilityFlagsSsl | 479 // If the server supported 480 // CapabilityClientDeprecateEOF, we also support it. 481 c.Capabilities&CapabilityClientDeprecateEOF | 482 // If the server supported 483 // CapabilityClientSessionTrack, we also support it. 484 c.Capabilities&CapabilityClientSessionTrack | 485 // Pass-through ClientFoundRows flag. 486 CapabilityClientFoundRows&uint32(params.Flags) 487 488 length := 489 4 + // Client capability flags. 490 4 + // Max-packet size. 491 1 + // Character set. 492 23 // Reserved. 493 494 // Add the DB name if the server supports it. 495 if params.DbName != "" && (capabilities&CapabilityClientConnectWithDB != 0) { 496 capabilityFlags |= CapabilityClientConnectWithDB 497 } 498 499 data, pos := c.startEphemeralPacketWithHeader(length) 500 501 // Client capability flags. 502 pos = writeUint32(data, pos, capabilityFlags) 503 504 // Max-packet size, always 0. See doc.go. 505 pos = writeZeroes(data, pos, 4) 506 507 // Character set. 508 _ = writeByte(data, pos, characterSet) 509 510 // And send it as is. 511 if err := c.writeEphemeralPacket(); err != nil { 512 return NewSQLError(CRServerLost, SSUnknownSQLState, "cannot send SSLRequest: %v", err) 513 } 514 return nil 515 } 516 517 // CapabilityFlags are client capability flag sent to mysql on connect 518 const CapabilityFlags uint32 = CapabilityClientLongPassword | 519 CapabilityClientLongFlag | 520 CapabilityClientProtocol41 | 521 CapabilityClientTransactions | 522 CapabilityClientSecureConnection | 523 CapabilityClientMultiStatements | 524 CapabilityClientMultiResults | 525 CapabilityClientPluginAuth | 526 CapabilityClientPluginAuthLenencClientData 527 528 // CapabilityFlagsSsl signals that we can handle SSL as well 529 const CapabilityFlagsSsl = CapabilityFlags | 530 CapabilityClientSSL 531 532 // writeHandshakeResponse41 writes the handshake response. 533 // Returns a SQLError. 534 func (c *Conn) writeHandshakeResponse41(capabilities uint32, scrambledPassword []byte, characterSet uint8, params *ConnParams) error { 535 // Build our flags. 536 capabilityFlags := CapabilityFlags | 537 // If the server supported 538 // CapabilityClientDeprecateEOF, we also support it. 539 c.Capabilities&CapabilityClientDeprecateEOF | 540 // Pass-through ClientFoundRows flag. 541 CapabilityClientFoundRows&uint32(params.Flags) | 542 // If the server supported 543 // CapabilityClientSessionTrack, we also support it. 544 c.Capabilities&CapabilityClientSessionTrack 545 546 // FIXME(alainjobart) add multi statement. 547 548 length := 549 4 + // Client capability flags. 550 4 + // Max-packet size. 551 1 + // Character set. 552 23 + // Reserved. 553 lenNullString(params.Uname) + 554 // length of scrambled password is handled below. 555 len(scrambledPassword) + 556 len(c.authPluginName) + 557 1 // terminating zero. 558 559 // Add the DB name if the server supports it. 560 if params.DbName != "" && (capabilities&CapabilityClientConnectWithDB != 0) { 561 capabilityFlags |= CapabilityClientConnectWithDB 562 length += lenNullString(params.DbName) 563 } 564 565 if capabilities&CapabilityClientPluginAuthLenencClientData != 0 { 566 length += lenEncIntSize(uint64(len(scrambledPassword))) 567 } else { 568 length++ 569 } 570 571 data, pos := c.startEphemeralPacketWithHeader(length) 572 573 // Client capability flags. 574 pos = writeUint32(data, pos, capabilityFlags) 575 576 // Max-packet size, always 0. See doc.go. 577 pos = writeZeroes(data, pos, 4) 578 579 // Character set. 580 pos = writeByte(data, pos, characterSet) 581 582 // 23 reserved bytes, all 0. 583 pos = writeZeroes(data, pos, 23) 584 585 // Username 586 pos = writeNullString(data, pos, params.Uname) 587 588 // Scrambled password. The length is encoded as variable length if 589 // CapabilityClientPluginAuthLenencClientData is set. 590 if capabilities&CapabilityClientPluginAuthLenencClientData != 0 { 591 pos = writeLenEncInt(data, pos, uint64(len(scrambledPassword))) 592 } else { 593 data[pos] = byte(len(scrambledPassword)) 594 pos++ 595 } 596 pos += copy(data[pos:], scrambledPassword) 597 598 // DbName, only if server supports it. 599 if params.DbName != "" && (capabilities&CapabilityClientConnectWithDB != 0) { 600 pos = writeNullString(data, pos, params.DbName) 601 c.schemaName = params.DbName 602 } 603 604 // Assume native client during response 605 pos = writeNullString(data, pos, string(c.authPluginName)) 606 607 // Sanity-check the length. 608 if pos != len(data) { 609 return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "writeHandshakeResponse41: only packed %v bytes, out of %v allocated", pos, len(data)) 610 } 611 612 if err := c.writeEphemeralPacket(); err != nil { 613 return NewSQLError(CRServerLost, SSUnknownSQLState, "cannot send HandshakeResponse41: %v", err) 614 } 615 return nil 616 } 617 618 // handleAuthResponse parses server's response after client sends the password for authentication 619 // and handles next steps for AuthSwitchRequestPacket and AuthMoreDataPacket. 620 func (c *Conn) handleAuthResponse(params *ConnParams) error { 621 response, err := c.readPacket() 622 if err != nil { 623 return NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err) 624 } 625 626 switch response[0] { 627 case OKPacket: 628 // OK packet, we are authenticated. Save the user, keep going. 629 c.User = params.Uname 630 case AuthSwitchRequestPacket: 631 // Server is asking to use a different auth method 632 if err = c.handleAuthSwitchPacket(params, response); err != nil { 633 return err 634 } 635 case AuthMoreDataPacket: 636 // Server is requesting more data - maybe un-scrambled password 637 if err := c.handleAuthMoreDataPacket(response[1], params); err != nil { 638 return err 639 } 640 case ErrPacket: 641 return ParseErrorPacket(response) 642 default: 643 return NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "initial server response cannot be parsed: %v", response) 644 } 645 646 return nil 647 } 648 649 // handleAuthSwitchPacket scrambles password for the plugin requested by the server and retries authentication 650 func (c *Conn) handleAuthSwitchPacket(params *ConnParams, response []byte) error { 651 var err error 652 var salt []byte 653 c.authPluginName, salt, err = parseAuthSwitchRequest(response) 654 if err != nil { 655 return NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "cannot parse auth switch request: %v", err) 656 } 657 if salt != nil { 658 c.salt = salt 659 } 660 switch c.authPluginName { 661 case MysqlClearPassword: 662 if err := c.writeClearTextPassword(params); err != nil { 663 return err 664 } 665 case MysqlNativePassword: 666 scrambledPassword := ScrambleMysqlNativePassword(c.salt, []byte(params.Pass)) 667 if err := c.writeScrambledPassword(scrambledPassword); err != nil { 668 return err 669 } 670 case CachingSha2Password: 671 scrambledPassword := ScrambleCachingSha2Password(c.salt, []byte(params.Pass)) 672 if err := c.writeScrambledPassword(scrambledPassword); err != nil { 673 return err 674 } 675 default: 676 return NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "server asked for unsupported auth method: %v", c.authPluginName) 677 } 678 679 // The response could be an OKPacket, AuthMoreDataPacket or ErrPacket 680 return c.handleAuthResponse(params) 681 } 682 683 // handleAuthMoreDataPacket handles response of CachingSha2Password authentication and sends full password to the 684 // server if requested 685 func (c *Conn) handleAuthMoreDataPacket(data byte, params *ConnParams) error { 686 switch data { 687 case CachingSha2FastAuth: 688 // User credentials are verified using the cache ("Fast" path). 689 // Next packet should be an OKPacket 690 return c.handleAuthResponse(params) 691 case CachingSha2FullAuth: 692 // User credentials are not cached, we have to exchange full password. 693 if c.Capabilities&CapabilityClientSSL > 0 || params.UnixSocket != "" { 694 // If we are using an SSL connection or Unix socket, write clear text password 695 if err := c.writeClearTextPassword(params); err != nil { 696 return err 697 } 698 } else { 699 // If we are not using an SSL connection or Unix socket, we have to fetch a public key 700 // from the server to encrypt password 701 pub, err := c.requestPublicKey() 702 if err != nil { 703 return err 704 } 705 // Encrypt password with public key 706 enc, err := EncryptPasswordWithPublicKey(c.salt, []byte(params.Pass), pub) 707 if err != nil { 708 return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "error encrypting password with public key: %v", err) 709 } 710 // Write encrypted password 711 if err := c.writeScrambledPassword(enc); err != nil { 712 return err 713 } 714 } 715 // Next packet should either be an OKPacket or ErrPacket 716 return c.handleAuthResponse(params) 717 default: 718 return NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "cannot parse AuthMoreDataPacket: %v", data) 719 } 720 } 721 722 func parseAuthSwitchRequest(data []byte) (AuthMethodDescription, []byte, error) { 723 pos := 1 724 pluginName, pos, ok := readNullString(data, pos) 725 if !ok { 726 return "", nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "cannot get plugin name from AuthSwitchRequest: %v", data) 727 } 728 729 // If this was a request with a salt in it, max 20 bytes 730 salt := data[pos:] 731 if len(salt) > 20 { 732 salt = salt[:20] 733 } 734 return AuthMethodDescription(pluginName), salt, nil 735 } 736 737 // requestPublicKey requests a public key from the server 738 func (c *Conn) requestPublicKey() (rsaKey *rsa.PublicKey, err error) { 739 // get public key from server 740 data, pos := c.startEphemeralPacketWithHeader(1) 741 data[pos] = 0x02 742 if err := c.writeEphemeralPacket(); err != nil { 743 return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "error sending public key request packet: %v", err) 744 } 745 746 response, err := c.readPacket() 747 if err != nil { 748 return nil, NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err) 749 } 750 751 // Server should respond with a AuthMoreDataPacket containing the public key 752 if response[0] != AuthMoreDataPacket { 753 return nil, ParseErrorPacket(response) 754 } 755 756 block, _ := pem.Decode(response[1:]) 757 pub, err := x509.ParsePKIXPublicKey(block.Bytes) 758 if err != nil { 759 return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "failed to parse public key from server: %v", err) 760 } 761 762 return pub.(*rsa.PublicKey), nil 763 } 764 765 // writeClearTextPassword writes the clear text password. 766 // Returns a SQLError. 767 func (c *Conn) writeClearTextPassword(params *ConnParams) error { 768 length := len(params.Pass) + 1 769 data, pos := c.startEphemeralPacketWithHeader(length) 770 pos = writeNullString(data, pos, params.Pass) 771 // Sanity check. 772 if pos != len(data) { 773 return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "error building ClearTextPassword packet: got %v bytes expected %v", pos, len(data)) 774 } 775 return c.writeEphemeralPacket() 776 } 777 778 // writeScrambledPassword writes the encrypted mysql_native_password format 779 // Returns a SQLError. 780 func (c *Conn) writeScrambledPassword(scrambledPassword []byte) error { 781 data, pos := c.startEphemeralPacketWithHeader(len(scrambledPassword)) 782 pos += copy(data[pos:], scrambledPassword) 783 // Sanity check. 784 if pos != len(data) { 785 return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "error building %v packet: got %v bytes expected %v", c.authPluginName, pos, len(data)) 786 } 787 return c.writeEphemeralPacket() 788 }