vitess.io/vitess@v0.16.2/go/mysql/conn.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package mysql 18 19 import ( 20 "bufio" 21 "crypto/tls" 22 "crypto/x509" 23 "errors" 24 "fmt" 25 "io" 26 "net" 27 "strings" 28 "sync" 29 "time" 30 31 "vitess.io/vitess/go/mysql/collations" 32 33 "vitess.io/vitess/go/sqlescape" 34 35 "vitess.io/vitess/go/bucketpool" 36 "vitess.io/vitess/go/sqltypes" 37 "vitess.io/vitess/go/sync2" 38 "vitess.io/vitess/go/vt/log" 39 querypb "vitess.io/vitess/go/vt/proto/query" 40 vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" 41 "vitess.io/vitess/go/vt/sqlparser" 42 "vitess.io/vitess/go/vt/vterrors" 43 ) 44 45 const ( 46 // connBufferSize is how much we buffer for reading and 47 // writing. It is also how much we allocate for ephemeral buffers. 48 connBufferSize = 16 * 1024 49 50 // packetHeaderSize is the 4 bytes of header per MySQL packet 51 // sent over 52 packetHeaderSize = 4 53 ) 54 55 // Constants for how ephemeral buffers were used for reading / writing. 56 const ( 57 // ephemeralUnused means the ephemeral buffer is not in use at this 58 // moment. This is the default value, and is checked so we don't 59 // read or write a packet while one is already used. 60 ephemeralUnused = iota 61 62 // ephemeralWrite means we currently in process of writing from currentEphemeralBuffer 63 ephemeralWrite 64 65 // ephemeralRead means we currently in process of reading into currentEphemeralBuffer 66 ephemeralRead 67 ) 68 69 // A Getter has a Get() 70 type Getter interface { 71 Get() *querypb.VTGateCallerID 72 } 73 74 // Conn is a connection between a client and a server, using the MySQL 75 // binary protocol. It is built on top of an existing net.Conn, that 76 // has already been established. 77 // 78 // Use Connect on the client side to create a connection. 79 // Use NewListener to create a server side and listen for connections. 80 type Conn struct { 81 // fields contains the fields definitions for an on-going 82 // streaming query. It is set by ExecuteStreamFetch, and 83 // cleared by the last FetchNext(). It is nil if no streaming 84 // query is in progress. If the streaming query returned no 85 // fields, this is set to an empty array (but not nil). 86 fields []*querypb.Field 87 88 // salt is sent by the server during initial handshake to be used for authentication 89 salt []byte 90 91 // authPluginName is the name of server's authentication plugin. 92 // It is set during the initial handshake. 93 authPluginName AuthMethodDescription 94 95 // schemaName is the default database name to use. It is set 96 // during handshake, and by ComInitDb packets. Both client and 97 // servers maintain it. This member is private because it's 98 // non-authoritative: the client can change the schema name 99 // through the 'USE' statement, which will bypass this variable. 100 schemaName string 101 102 // ClientData is a place where an application can store any 103 // connection-related data. Mostly used on the server side, to 104 // avoid maps indexed by ConnectionID for instance. 105 ClientData any 106 107 // conn is the underlying network connection. 108 // Calling Close() on the Conn will close this connection. 109 // If there are any ongoing reads or writes, they may get interrupted. 110 conn net.Conn 111 112 // flavor contains the auto-detected flavor for this client 113 // connection. It is unused for server-side connections. 114 flavor flavor 115 116 // ServerVersion is set during Connect with the server 117 // version. It is not changed afterwards. It is unused for 118 // server-side connections. 119 ServerVersion string 120 121 // User is the name used by the client to connect. 122 // It is set during the initial handshake. 123 User string // For server-side connections, listener points to the server object. 124 125 // UserData is custom data returned by the AuthServer module. 126 // It is set during the initial handshake. 127 UserData Getter 128 129 bufferedReader *bufio.Reader 130 flushTimer *time.Timer 131 header [packetHeaderSize]byte 132 133 // Keep track of how and of the buffer we allocated for an 134 // ephemeral packet on the read and write sides. 135 // These fields are used by: 136 // - startEphemeralPacketWithHeader / writeEphemeralPacket methods for writes. 137 // - readEphemeralPacket / recycleReadPacket methods for reads. 138 currentEphemeralPolicy int 139 // currentEphemeralBuffer for tracking allocated temporary buffer for writes and reads respectively. 140 // It can be allocated from bufPool or heap and should be recycled in the same manner. 141 currentEphemeralBuffer *[]byte 142 143 listener *Listener 144 145 // Buffered writing has a timer which flushes on inactivity. 146 bufferedWriter *bufio.Writer 147 148 // PrepareData is the map to use a prepared statement. 149 PrepareData map[uint32]*PrepareData 150 151 // protects the bufferedWriter and bufferedReader 152 bufMu sync.Mutex 153 154 // Capabilities is the current set of features this connection 155 // is using. It is the features that are both supported by 156 // the client and the server, and currently in use. 157 // It is set during the initial handshake. 158 // 159 // It is only used for CapabilityClientDeprecateEOF 160 // and CapabilityClientFoundRows. 161 Capabilities uint32 162 163 // closed is set to true when Close() is called on the connection. 164 closed sync2.AtomicBool 165 166 // ConnectionID is set: 167 // - at Connect() time for clients, with the value returned by 168 // the server. 169 // - at accept time for the server. 170 ConnectionID uint32 171 172 // StatementID is the prepared statement ID. 173 StatementID uint32 174 175 // StatusFlags are the status flags we will base our returned flags on. 176 // This is a bit field, with values documented in constants.go. 177 // An interesting value here would be ServerStatusAutocommit. 178 // It is only used by the server. These flags can be changed 179 // by Handler methods. 180 StatusFlags uint16 181 182 // CharacterSet is the charset for this connection, as negotiated 183 // in our handshake with the server. Note that although the MySQL protocol lists this 184 // as a "character set", the returned byte value is actually a Collation ID, 185 // and hence it's casted as such here. 186 // If the user has specified a custom Collation in the ConnParams for this 187 // connection, once the CharacterSet has been negotiated, we will override 188 // it via SQL and update this field accordingly. 189 CharacterSet collations.ID 190 191 // Packet encoding variables. 192 sequence uint8 193 194 // ExpectSemiSyncIndicator is applicable when the connection is used for replication (ComBinlogDump). 195 // When 'true', events are assumed to be padded with 2-byte semi-sync information 196 // See https://dev.mysql.com/doc/internals/en/semi-sync-binlog-event.html 197 ExpectSemiSyncIndicator bool 198 199 // enableQueryInfo controls whether we parse the INFO field in QUERY_OK packets 200 // See: ConnParams.EnableQueryInfo 201 enableQueryInfo bool 202 } 203 204 // splitStatementFunciton is the function that is used to split the statement in case of a multi-statement query. 205 var splitStatementFunction = sqlparser.SplitStatementToPieces 206 207 // PrepareData is a buffer used for store prepare statement meta data 208 type PrepareData struct { 209 ParamsType []int32 210 ColumnNames []string 211 PrepareStmt string 212 BindVars map[string]*querypb.BindVariable 213 StatementID uint32 214 ParamsCount uint16 215 } 216 217 // execResult is an enum signifying the result of executing a query 218 type execResult byte 219 220 const ( 221 execSuccess execResult = iota 222 execErr 223 connErr 224 ) 225 226 // bufPool is used to allocate and free buffers in an efficient way. 227 var bufPool = bucketpool.New(connBufferSize, MaxPacketSize) 228 229 // writersPool is used for pooling bufio.Writer objects. 230 var writersPool = sync.Pool{New: func() any { return bufio.NewWriterSize(nil, connBufferSize) }} 231 232 var readersPool = sync.Pool{New: func() any { return bufio.NewReaderSize(nil, connBufferSize) }} 233 234 // newConn is an internal method to create a Conn. Used by client and server 235 // side for common creation code. 236 func newConn(conn net.Conn) *Conn { 237 return &Conn{ 238 conn: conn, 239 closed: sync2.NewAtomicBool(false), 240 bufferedReader: bufio.NewReaderSize(conn, connBufferSize), 241 } 242 } 243 244 // newServerConn should be used to create server connections. 245 // 246 // It stashes a reference to the listener to be able to determine if 247 // the server is shutting down, and has the ability to control buffer 248 // size for reads. 249 func newServerConn(conn net.Conn, listener *Listener) *Conn { 250 c := &Conn{ 251 conn: conn, 252 listener: listener, 253 closed: sync2.NewAtomicBool(false), 254 PrepareData: make(map[uint32]*PrepareData), 255 } 256 257 if listener.connReadBufferSize > 0 { 258 var buf *bufio.Reader 259 if listener.connBufferPooling { 260 buf = readersPool.Get().(*bufio.Reader) 261 buf.Reset(conn) 262 } else { 263 buf = bufio.NewReaderSize(conn, listener.connReadBufferSize) 264 } 265 266 c.bufferedReader = buf 267 } 268 269 return c 270 } 271 272 // startWriterBuffering starts using buffered writes. This should 273 // be terminated by a call to endWriteBuffering. 274 func (c *Conn) startWriterBuffering() { 275 c.bufMu.Lock() 276 defer c.bufMu.Unlock() 277 278 c.bufferedWriter = writersPool.Get().(*bufio.Writer) 279 c.bufferedWriter.Reset(c.conn) 280 } 281 282 // endWriterBuffering must be called to terminate startWriteBuffering. 283 func (c *Conn) endWriterBuffering() error { 284 c.bufMu.Lock() 285 defer c.bufMu.Unlock() 286 287 if c.bufferedWriter == nil { 288 return nil 289 } 290 291 defer func() { 292 c.bufferedWriter.Reset(nil) 293 writersPool.Put(c.bufferedWriter) 294 c.bufferedWriter = nil 295 }() 296 297 c.stopFlushTimer() 298 return c.bufferedWriter.Flush() 299 } 300 301 func (c *Conn) returnReader() { 302 if c.bufferedReader == nil { 303 return 304 } 305 c.bufferedReader.Reset(nil) 306 readersPool.Put(c.bufferedReader) 307 } 308 309 // getWriter returns the current writer. It may be either 310 // the original connection or a wrapper. The returned unget 311 // function must be invoked after the writing is finished. 312 // In buffered mode, the unget starts a timer to flush any 313 // buffered data. 314 func (c *Conn) getWriter() (w io.Writer, unget func()) { 315 c.bufMu.Lock() 316 if c.bufferedWriter != nil { 317 return c.bufferedWriter, func() { 318 c.startFlushTimer() 319 c.bufMu.Unlock() 320 } 321 } 322 c.bufMu.Unlock() 323 return c.conn, func() {} 324 } 325 326 // startFlushTimer must be called while holding lock on bufMu. 327 func (c *Conn) startFlushTimer() { 328 c.stopFlushTimer() 329 c.flushTimer = time.AfterFunc(mysqlServerFlushDelay, func() { 330 c.bufMu.Lock() 331 defer c.bufMu.Unlock() 332 333 if c.bufferedWriter == nil { 334 return 335 } 336 c.stopFlushTimer() 337 c.bufferedWriter.Flush() 338 }) 339 } 340 341 // stopFlushTimer must be called while holding lock on bufMu. 342 func (c *Conn) stopFlushTimer() { 343 if c.flushTimer != nil { 344 c.flushTimer.Stop() 345 c.flushTimer = nil 346 } 347 } 348 349 // getReader returns reader for connection. It can be *bufio.Reader or net.Conn 350 // depending on which buffer size was passed to newServerConn. 351 func (c *Conn) getReader() io.Reader { 352 if c.bufferedReader != nil { 353 return c.bufferedReader 354 } 355 return c.conn 356 } 357 358 func (c *Conn) readHeaderFrom(r io.Reader) (int, error) { 359 // Note io.ReadFull will return two different types of errors: 360 // 1. if the socket is already closed, and the go runtime knows it, 361 // then ReadFull will return an error (different than EOF), 362 // something like 'read: connection reset by peer'. 363 // 2. if the socket is not closed while we start the read, 364 // but gets closed after the read is started, we'll get io.EOF. 365 if _, err := io.ReadFull(r, c.header[:]); err != nil { 366 // The special casing of propagating io.EOF up 367 // is used by the server side only, to suppress an error 368 // message if a client just disconnects. 369 if err == io.EOF { 370 return 0, err 371 } 372 if strings.HasSuffix(err.Error(), "read: connection reset by peer") { 373 return 0, io.EOF 374 } 375 return 0, vterrors.Wrapf(err, "io.ReadFull(header size) failed") 376 } 377 378 sequence := uint8(c.header[3]) 379 if sequence != c.sequence { 380 return 0, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid sequence, expected %v got %v", c.sequence, sequence) 381 } 382 383 c.sequence++ 384 385 return int(uint32(c.header[0]) | uint32(c.header[1])<<8 | uint32(c.header[2])<<16), nil 386 } 387 388 // readEphemeralPacket attempts to read a packet into buffer from sync.Pool. Do 389 // not use this method if the contents of the packet needs to be kept 390 // after the next readEphemeralPacket. 391 // 392 // Note if the connection is closed already, an error will be 393 // returned, and it may not be io.EOF. If the connection closes while 394 // we are stuck waiting for data, an error will also be returned, and 395 // it most likely will be io.EOF. 396 func (c *Conn) readEphemeralPacket() ([]byte, error) { 397 if c.currentEphemeralPolicy != ephemeralUnused { 398 panic(vterrors.Errorf(vtrpcpb.Code_INTERNAL, "readEphemeralPacket: unexpected currentEphemeralPolicy: %v", c.currentEphemeralPolicy)) 399 } 400 401 r := c.getReader() 402 403 length, err := c.readHeaderFrom(r) 404 if err != nil { 405 return nil, err 406 } 407 408 c.currentEphemeralPolicy = ephemeralRead 409 if length == 0 { 410 // This can be caused by the packet after a packet of 411 // exactly size MaxPacketSize. 412 return nil, nil 413 } 414 415 // Use the bufPool. 416 if length < MaxPacketSize { 417 c.currentEphemeralBuffer = bufPool.Get(length) 418 if _, err := io.ReadFull(r, *c.currentEphemeralBuffer); err != nil { 419 return nil, vterrors.Wrapf(err, "io.ReadFull(packet body of length %v) failed", length) 420 } 421 return *c.currentEphemeralBuffer, nil 422 } 423 424 // Much slower path, revert to allocating everything from scratch. 425 // We're going to concatenate a lot of data anyway, can't really 426 // optimize this code path easily. 427 data := make([]byte, length) 428 if _, err := io.ReadFull(r, data); err != nil { 429 return nil, vterrors.Wrapf(err, "io.ReadFull(packet body of length %v) failed", length) 430 } 431 for { 432 next, err := c.readOnePacket() 433 if err != nil { 434 return nil, err 435 } 436 437 if len(next) == 0 { 438 // Again, the packet after a packet of exactly size MaxPacketSize. 439 break 440 } 441 442 data = append(data, next...) 443 if len(next) < MaxPacketSize { 444 break 445 } 446 } 447 448 return data, nil 449 } 450 451 // readEphemeralPacketDirect attempts to read a packet from the socket directly. 452 // It needs to be used for the first handshake packet the server receives, 453 // so we do't buffer the SSL negotiation packet. As a shortcut, only 454 // packets smaller than MaxPacketSize can be read here. 455 // This function usually shouldn't be used - use readEphemeralPacket. 456 func (c *Conn) readEphemeralPacketDirect() ([]byte, error) { 457 if c.currentEphemeralPolicy != ephemeralUnused { 458 panic(vterrors.Errorf(vtrpcpb.Code_INTERNAL, "readEphemeralPacketDirect: unexpected currentEphemeralPolicy: %v", c.currentEphemeralPolicy)) 459 } 460 461 var r io.Reader = c.conn 462 463 length, err := c.readHeaderFrom(r) 464 if err != nil { 465 return nil, err 466 } 467 468 c.currentEphemeralPolicy = ephemeralRead 469 if length == 0 { 470 // This can be caused by the packet after a packet of 471 // exactly size MaxPacketSize. 472 return nil, nil 473 } 474 475 if length < MaxPacketSize { 476 c.currentEphemeralBuffer = bufPool.Get(length) 477 if _, err := io.ReadFull(r, *c.currentEphemeralBuffer); err != nil { 478 return nil, vterrors.Wrapf(err, "io.ReadFull(packet body of length %v) failed", length) 479 } 480 return *c.currentEphemeralBuffer, nil 481 } 482 483 return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "readEphemeralPacketDirect doesn't support more than one packet") 484 } 485 486 // recycleReadPacket recycles the read packet. It needs to be called 487 // after readEphemeralPacket was called. 488 func (c *Conn) recycleReadPacket() { 489 if c.currentEphemeralPolicy != ephemeralRead { 490 // Programming error. 491 panic(vterrors.Errorf(vtrpcpb.Code_INTERNAL, "trying to call recycleReadPacket while currentEphemeralPolicy is %d", c.currentEphemeralPolicy)) 492 } 493 if c.currentEphemeralBuffer != nil { 494 // We are using the pool, put the buffer back in. 495 bufPool.Put(c.currentEphemeralBuffer) 496 c.currentEphemeralBuffer = nil 497 } 498 c.currentEphemeralPolicy = ephemeralUnused 499 } 500 501 // readOnePacket reads a single packet into a newly allocated buffer. 502 func (c *Conn) readOnePacket() ([]byte, error) { 503 r := c.getReader() 504 length, err := c.readHeaderFrom(r) 505 if err != nil { 506 return nil, err 507 } 508 if length == 0 { 509 // This can be caused by the packet after a packet of 510 // exactly size MaxPacketSize. 511 return nil, nil 512 } 513 514 data := make([]byte, length) 515 if _, err := io.ReadFull(r, data); err != nil { 516 return nil, vterrors.Wrapf(err, "io.ReadFull(packet body of length %v) failed", length) 517 } 518 return data, nil 519 } 520 521 // readPacket reads a packet from the underlying connection. 522 // It re-assembles packets that span more than one message. 523 // This method returns a generic error, not a SQLError. 524 func (c *Conn) readPacket() ([]byte, error) { 525 // Optimize for a single packet case. 526 data, err := c.readOnePacket() 527 if err != nil { 528 return nil, err 529 } 530 531 // This is a single packet. 532 if len(data) < MaxPacketSize { 533 return data, nil 534 } 535 536 // There is more than one packet, read them all. 537 for { 538 next, err := c.readOnePacket() 539 if err != nil { 540 return nil, err 541 } 542 543 if len(next) == 0 { 544 // Again, the packet after a packet of exactly size MaxPacketSize. 545 break 546 } 547 548 data = append(data, next...) 549 if len(next) < MaxPacketSize { 550 break 551 } 552 } 553 554 return data, nil 555 } 556 557 // ReadPacket reads a packet from the underlying connection. 558 // it is the public API version, that returns a SQLError. 559 // The memory for the packet is always allocated, and it is owned by the caller 560 // after this function returns. 561 func (c *Conn) ReadPacket() ([]byte, error) { 562 result, err := c.readPacket() 563 if err != nil { 564 return nil, NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err) 565 } 566 return result, err 567 } 568 569 // writePacket writes a packet, possibly cutting it into multiple 570 // chunks. Note this is not very efficient, as the client probably 571 // has to build the []byte and that makes a memory copy. 572 // Try to use startEphemeralPacketWithHeader/writeEphemeralPacket instead. 573 // 574 // This method returns a generic error, not a SQLError. 575 func (c *Conn) writePacket(data []byte) error { 576 index := 0 577 dataLength := len(data) - packetHeaderSize 578 579 w, unget := c.getWriter() 580 defer unget() 581 582 var header [packetHeaderSize]byte 583 for { 584 // toBeSent is capped to MaxPacketSize. 585 toBeSent := dataLength 586 if toBeSent > MaxPacketSize { 587 toBeSent = MaxPacketSize 588 } 589 590 // save the first 4 bytes of the payload, we will overwrite them with the 591 // header below 592 copy(header[0:packetHeaderSize], data[index:index+packetHeaderSize]) 593 594 // Compute and write the header. 595 data[index] = byte(toBeSent) 596 data[index+1] = byte(toBeSent >> 8) 597 data[index+2] = byte(toBeSent >> 16) 598 data[index+3] = c.sequence 599 600 // Write the body. 601 if n, err := w.Write(data[index : index+toBeSent+packetHeaderSize]); err != nil { 602 return vterrors.Wrapf(err, "Write(packet) failed") 603 } else if n != (toBeSent + packetHeaderSize) { 604 return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "Write(packet) returned a short write: %v < %v", n, (toBeSent + packetHeaderSize)) 605 } 606 607 // restore the first 4 bytes once the network send is done 608 copy(data[index:index+packetHeaderSize], header[0:packetHeaderSize]) 609 610 // Update our state. 611 c.sequence++ 612 dataLength -= toBeSent 613 if dataLength == 0 { 614 if toBeSent == MaxPacketSize { 615 // The packet we just sent had exactly 616 // MaxPacketSize size, we need to 617 // sent a zero-size packet too. 618 header[0] = 0 619 header[1] = 0 620 header[2] = 0 621 header[3] = c.sequence 622 if n, err := w.Write(header[:]); err != nil { 623 return vterrors.Wrapf(err, "Write(empty header) failed") 624 } else if n != packetHeaderSize { 625 return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "Write(empty header) returned a short write: %v < 4", n) 626 } 627 c.sequence++ 628 } 629 return nil 630 } 631 index += toBeSent 632 } 633 } 634 635 func (c *Conn) startEphemeralPacketWithHeader(length int) ([]byte, int) { 636 if c.currentEphemeralPolicy != ephemeralUnused { 637 panic("startEphemeralPacketWithHeader cannot be used while a packet is already started.") 638 } 639 640 c.currentEphemeralPolicy = ephemeralWrite 641 // get buffer from pool or it'll be allocated if length is too big 642 c.currentEphemeralBuffer = bufPool.Get(length + packetHeaderSize) 643 return *c.currentEphemeralBuffer, packetHeaderSize 644 } 645 646 // writeEphemeralPacket writes the packet that was allocated by 647 // startEphemeralPacketWithHeader. 648 func (c *Conn) writeEphemeralPacket() error { 649 defer c.recycleWritePacket() 650 651 switch c.currentEphemeralPolicy { 652 case ephemeralWrite: 653 if err := c.writePacket(*c.currentEphemeralBuffer); err != nil { 654 return vterrors.Wrapf(err, "conn %v", c.ID()) 655 } 656 case ephemeralUnused, ephemeralRead: 657 // Programming error. 658 panic(vterrors.Errorf(vtrpcpb.Code_INTERNAL, "conn %v: trying to call writeEphemeralPacket while currentEphemeralPolicy is %v", c.ID(), c.currentEphemeralPolicy)) 659 } 660 661 return nil 662 } 663 664 // recycleWritePacket recycles the write packet. It needs to be called 665 // after writeEphemeralPacket was called. 666 func (c *Conn) recycleWritePacket() { 667 if c.currentEphemeralPolicy != ephemeralWrite { 668 // Programming error. 669 panic(vterrors.Errorf(vtrpcpb.Code_INTERNAL, "trying to call recycleWritePacket while currentEphemeralPolicy is %d", c.currentEphemeralPolicy)) 670 } 671 // Release our reference so the buffer can be gced 672 bufPool.Put(c.currentEphemeralBuffer) 673 c.currentEphemeralBuffer = nil 674 c.currentEphemeralPolicy = ephemeralUnused 675 } 676 677 // writeComQuit writes a Quit message for the server, to indicate we 678 // want to close the connection. 679 // Client -> Server. 680 // Returns SQLError(CRServerGone) if it can't. 681 func (c *Conn) writeComQuit() error { 682 // This is a new command, need to reset the sequence. 683 c.sequence = 0 684 685 data, pos := c.startEphemeralPacketWithHeader(1) 686 data[pos] = ComQuit 687 if err := c.writeEphemeralPacket(); err != nil { 688 return NewSQLError(CRServerGone, SSUnknownSQLState, err.Error()) 689 } 690 return nil 691 } 692 693 // RemoteAddr returns the underlying socket RemoteAddr(). 694 func (c *Conn) RemoteAddr() net.Addr { 695 return c.conn.RemoteAddr() 696 } 697 698 // ID returns the MySQL connection ID for this connection. 699 func (c *Conn) ID() int64 { 700 return int64(c.ConnectionID) 701 } 702 703 // Ident returns a useful identification string for error logging 704 func (c *Conn) String() string { 705 return fmt.Sprintf("client %v (%s)", c.ConnectionID, c.RemoteAddr().String()) 706 } 707 708 // Close closes the connection. It can be called from a different go 709 // routine to interrupt the current connection. 710 func (c *Conn) Close() { 711 if c.closed.CompareAndSwap(false, true) { 712 c.conn.Close() 713 } 714 } 715 716 // IsClosed returns true if this connection was ever closed by the 717 // Close() method. Note if the other side closes the connection, but 718 // Close() wasn't called, this will return false. 719 func (c *Conn) IsClosed() bool { 720 return c.closed.Get() 721 } 722 723 // 724 // Packet writing methods, for generic packets. 725 // 726 727 // writeOKPacket writes an OK packet. 728 // Server -> Client. 729 // This method returns a generic error, not a SQLError. 730 func (c *Conn) writeOKPacket(packetOk *PacketOK) error { 731 return c.writeOKPacketWithHeader(packetOk, OKPacket) 732 } 733 734 // writeOKPacketWithEOFHeader writes an OK packet with an EOF header. 735 // This is used at the end of a result set if 736 // CapabilityClientDeprecateEOF is set. 737 // Server -> Client. 738 // This method returns a generic error, not a SQLError. 739 func (c *Conn) writeOKPacketWithEOFHeader(packetOk *PacketOK) error { 740 return c.writeOKPacketWithHeader(packetOk, EOFPacket) 741 } 742 743 // writeOKPacketWithEOFHeader writes an OK packet with an EOF header. 744 // This is used at the end of a result set if 745 // CapabilityClientDeprecateEOF is set. 746 // Server -> Client. 747 // This method returns a generic error, not a SQLError. 748 func (c *Conn) writeOKPacketWithHeader(packetOk *PacketOK, headerType byte) error { 749 length := 1 + // OKPacket 750 lenEncIntSize(packetOk.affectedRows) + 751 lenEncIntSize(packetOk.lastInsertID) 752 // assuming CapabilityClientProtocol41 753 length += 4 // status_flags + warnings 754 755 var gtidData []byte 756 if c.Capabilities&CapabilityClientSessionTrack == CapabilityClientSessionTrack { 757 length += lenEncStringSize(packetOk.info) // info 758 if packetOk.statusFlags&ServerSessionStateChanged == ServerSessionStateChanged { 759 gtidData = getLenEncString([]byte(packetOk.sessionStateData)) 760 gtidData = append([]byte{0x00}, gtidData...) 761 gtidData = getLenEncString(gtidData) 762 gtidData = append([]byte{0x03}, gtidData...) 763 gtidData = append(getLenEncInt(uint64(len(gtidData))), gtidData...) 764 length += len(gtidData) 765 } 766 } else { 767 length += len(packetOk.info) // info 768 } 769 770 bytes, pos := c.startEphemeralPacketWithHeader(length) 771 data := &coder{data: bytes, pos: pos} 772 data.writeByte(headerType) //header - OK or EOF 773 data.writeLenEncInt(packetOk.affectedRows) 774 data.writeLenEncInt(packetOk.lastInsertID) 775 data.writeUint16(packetOk.statusFlags) 776 data.writeUint16(packetOk.warnings) 777 if c.Capabilities&CapabilityClientSessionTrack == CapabilityClientSessionTrack { 778 data.writeLenEncString(packetOk.info) 779 if packetOk.statusFlags&ServerSessionStateChanged == ServerSessionStateChanged { 780 data.writeEOFString(string(gtidData)) 781 } 782 } else { 783 data.writeEOFString(packetOk.info) 784 } 785 return c.writeEphemeralPacket() 786 } 787 788 func getLenEncString(value []byte) []byte { 789 data := getLenEncInt(uint64(len(value))) 790 return append(data, value...) 791 } 792 793 func getLenEncInt(i uint64) []byte { 794 var data []byte 795 switch { 796 case i < 251: 797 data = append(data, byte(i)) 798 case i < 1<<16: 799 data = append(data, 0xfc) 800 data = append(data, byte(i)) 801 data = append(data, byte(i>>8)) 802 case i < 1<<24: 803 data = append(data, 0xfd) 804 data = append(data, byte(i)) 805 data = append(data, byte(i>>8)) 806 data = append(data, byte(i>>16)) 807 default: 808 data = append(data, 0xfe) 809 data = append(data, byte(i)) 810 data = append(data, byte(i>>8)) 811 data = append(data, byte(i>>16)) 812 data = append(data, byte(i>>24)) 813 data = append(data, byte(i>>32)) 814 data = append(data, byte(i>>40)) 815 data = append(data, byte(i>>48)) 816 data = append(data, byte(i>>56)) 817 } 818 return data 819 } 820 821 func (c *Conn) WriteErrorAndLog(format string, args ...interface{}) bool { 822 return c.writeErrorAndLog(ERUnknownComError, SSNetError, format, args...) 823 } 824 825 func (c *Conn) writeErrorAndLog(errorCode uint16, sqlState string, format string, args ...any) bool { 826 if err := c.writeErrorPacket(errorCode, sqlState, format, args...); err != nil { 827 log.Errorf("Error writing error to %s: %v", c, err) 828 return false 829 } 830 return true 831 } 832 833 func (c *Conn) writeErrorPacketFromErrorAndLog(err error) bool { 834 werr := c.writeErrorPacketFromError(err) 835 if werr != nil { 836 log.Errorf("Error writing error to %s: %v", c, werr) 837 return false 838 } 839 return true 840 } 841 842 // writeErrorPacket writes an error packet. 843 // Server -> Client. 844 // This method returns a generic error, not a SQLError. 845 func (c *Conn) writeErrorPacket(errorCode uint16, sqlState string, format string, args ...any) error { 846 errorMessage := fmt.Sprintf(format, args...) 847 length := 1 + 2 + 1 + 5 + len(errorMessage) 848 data, pos := c.startEphemeralPacketWithHeader(length) 849 pos = writeByte(data, pos, ErrPacket) 850 pos = writeUint16(data, pos, errorCode) 851 pos = writeByte(data, pos, '#') 852 if sqlState == "" { 853 sqlState = SSUnknownSQLState 854 } 855 if len(sqlState) != 5 { 856 panic("sqlState has to be 5 characters long") 857 } 858 pos = writeEOFString(data, pos, sqlState) 859 _ = writeEOFString(data, pos, errorMessage) 860 861 return c.writeEphemeralPacket() 862 } 863 864 // writeErrorPacketFromError writes an error packet, from a regular error. 865 // See writeErrorPacket for other info. 866 func (c *Conn) writeErrorPacketFromError(err error) error { 867 if se, ok := err.(*SQLError); ok { 868 return c.writeErrorPacket(uint16(se.Num), se.State, "%v", se.Message) 869 } 870 871 return c.writeErrorPacket(ERUnknownError, SSUnknownSQLState, "unknown error: %v", err) 872 } 873 874 // writeEOFPacket writes an EOF packet, through the buffer, and 875 // doesn't flush (as it is used as part of a query result). 876 func (c *Conn) writeEOFPacket(flags uint16, warnings uint16) error { 877 length := 5 878 data, pos := c.startEphemeralPacketWithHeader(length) 879 pos = writeByte(data, pos, EOFPacket) 880 pos = writeUint16(data, pos, warnings) 881 _ = writeUint16(data, pos, flags) 882 883 return c.writeEphemeralPacket() 884 } 885 886 // handleNextCommand is called in the server loop to process 887 // incoming packets. 888 func (c *Conn) handleNextCommand(handler Handler) bool { 889 c.sequence = 0 890 data, err := c.readEphemeralPacket() 891 if err != nil { 892 // Don't log EOF errors. They cause too much spam. 893 if err != io.EOF && !strings.Contains(err.Error(), "use of closed network connection") { 894 log.Errorf("Error reading packet from %s: %v", c, err) 895 } 896 return false 897 } 898 if len(data) == 0 { 899 return false 900 } 901 902 switch data[0] { 903 case ComQuit: 904 c.recycleReadPacket() 905 return false 906 case ComInitDB: 907 db := c.parseComInitDB(data) 908 c.recycleReadPacket() 909 res := c.execQuery("use "+sqlescape.EscapeID(db), handler, false) 910 return res != connErr 911 case ComQuery: 912 return c.handleComQuery(handler, data) 913 case ComPing: 914 return c.handleComPing() 915 case ComSetOption: 916 return c.handleComSetOption(data) 917 case ComPrepare: 918 return c.handleComPrepare(handler, data) 919 case ComStmtExecute: 920 return c.handleComStmtExecute(handler, data) 921 case ComStmtSendLongData: 922 return c.handleComStmtSendLongData(data) 923 case ComStmtClose: 924 stmtID, ok := c.parseComStmtClose(data) 925 c.recycleReadPacket() 926 if ok { 927 delete(c.PrepareData, stmtID) 928 } 929 case ComStmtReset: 930 return c.handleComStmtReset(data) 931 case ComResetConnection: 932 c.handleComResetConnection(handler) 933 return true 934 case ComFieldList: 935 c.recycleReadPacket() 936 if !c.writeErrorAndLog(ERUnknownComError, SSNetError, "command handling not implemented yet: %v", data[0]) { 937 return false 938 } 939 case ComBinlogDump: 940 return c.handleComBinlogDump(handler, data) 941 case ComBinlogDumpGTID: 942 return c.handleComBinlogDumpGTID(handler, data) 943 case ComRegisterReplica: 944 return c.handleComRegisterReplica(handler, data) 945 default: 946 log.Errorf("Got unhandled packet (default) from %s, returning error: %v", c, data) 947 c.recycleReadPacket() 948 if !c.writeErrorAndLog(ERUnknownComError, SSNetError, "command handling not implemented yet: %v", data[0]) { 949 return false 950 } 951 } 952 953 return true 954 } 955 956 func (c *Conn) handleComRegisterReplica(handler Handler, data []byte) (kontinue bool) { 957 c.recycleReadPacket() 958 959 replicaHost, replicaPort, replicaUser, replicaPassword, err := c.parseComRegisterReplica(data) 960 if err != nil { 961 log.Errorf("conn %v: parseComRegisterReplica failed: %v", c.ID(), err) 962 return false 963 } 964 if err := handler.ComRegisterReplica(c, replicaHost, replicaPort, replicaUser, replicaPassword); err != nil { 965 c.writeErrorPacketFromError(err) 966 return false 967 } 968 if err := c.writeOKPacket(&PacketOK{}); err != nil { 969 c.writeErrorPacketFromError(err) 970 } 971 return true 972 } 973 974 func (c *Conn) handleComBinlogDump(handler Handler, data []byte) (kontinue bool) { 975 c.recycleReadPacket() 976 kontinue = true 977 978 c.startWriterBuffering() 979 defer func() { 980 if err := c.endWriterBuffering(); err != nil { 981 log.Errorf("conn %v: flush() failed: %v", c.ID(), err) 982 kontinue = false 983 } 984 }() 985 986 logfile, binlogPos, err := c.parseComBinlogDump(data) 987 if err != nil { 988 log.Errorf("conn %v: parseComBinlogDumpGTID failed: %v", c.ID(), err) 989 return false 990 } 991 if err := handler.ComBinlogDump(c, logfile, binlogPos); err != nil { 992 log.Error(err.Error()) 993 return false 994 } 995 return kontinue 996 } 997 998 func (c *Conn) handleComBinlogDumpGTID(handler Handler, data []byte) (kontinue bool) { 999 c.recycleReadPacket() 1000 kontinue = true 1001 1002 c.startWriterBuffering() 1003 defer func() { 1004 if err := c.endWriterBuffering(); err != nil { 1005 log.Errorf("conn %v: flush() failed: %v", c.ID(), err) 1006 kontinue = false 1007 } 1008 }() 1009 1010 logFile, logPos, position, err := c.parseComBinlogDumpGTID(data) 1011 if err != nil { 1012 log.Errorf("conn %v: parseComBinlogDumpGTID failed: %v", c.ID(), err) 1013 return false 1014 } 1015 if err := handler.ComBinlogDumpGTID(c, logFile, logPos, position.GTIDSet); err != nil { 1016 log.Error(err.Error()) 1017 return false 1018 } 1019 return kontinue 1020 } 1021 1022 func (c *Conn) handleComResetConnection(handler Handler) { 1023 // Clean up and reset the connection 1024 c.recycleReadPacket() 1025 handler.ComResetConnection(c) 1026 // Reset prepared statements 1027 c.PrepareData = make(map[uint32]*PrepareData) 1028 err := c.writeOKPacket(&PacketOK{}) 1029 if err != nil { 1030 c.writeErrorPacketFromError(err) 1031 } 1032 } 1033 1034 func (c *Conn) handleComStmtReset(data []byte) bool { 1035 stmtID, ok := c.parseComStmtReset(data) 1036 c.recycleReadPacket() 1037 if !ok { 1038 log.Error("Got unhandled packet from client %v, returning error: %v", c.ConnectionID, data) 1039 if !c.writeErrorAndLog(ERUnknownComError, SSNetError, "error handling packet: %v", data) { 1040 return false 1041 } 1042 } 1043 1044 prepare, ok := c.PrepareData[stmtID] 1045 if !ok { 1046 log.Error("Commands were executed in an improper order from client %v, packet: %v", c.ConnectionID, data) 1047 if !c.writeErrorAndLog(CRCommandsOutOfSync, SSNetError, "commands were executed in an improper order: %v", data) { 1048 return false 1049 } 1050 } 1051 1052 if prepare.BindVars != nil { 1053 for k := range prepare.BindVars { 1054 prepare.BindVars[k] = nil 1055 } 1056 } 1057 1058 if err := c.writeOKPacket(&PacketOK{statusFlags: c.StatusFlags}); err != nil { 1059 log.Error("Error writing ComStmtReset OK packet to client %v: %v", c.ConnectionID, err) 1060 return false 1061 } 1062 return true 1063 } 1064 1065 func (c *Conn) handleComStmtSendLongData(data []byte) bool { 1066 stmtID, paramID, chunk, ok := c.parseComStmtSendLongData(data) 1067 c.recycleReadPacket() 1068 if !ok { 1069 err := fmt.Errorf("error parsing statement send long data from client %v, returning error: %v", c.ConnectionID, data) 1070 return c.writeErrorPacketFromErrorAndLog(err) 1071 } 1072 1073 prepare, ok := c.PrepareData[stmtID] 1074 if !ok { 1075 err := fmt.Errorf("got wrong statement id from client %v, statement ID(%v) is not found from record", c.ConnectionID, stmtID) 1076 return c.writeErrorPacketFromErrorAndLog(err) 1077 } 1078 1079 if prepare.BindVars == nil || 1080 prepare.ParamsCount == uint16(0) || 1081 paramID >= prepare.ParamsCount { 1082 err := fmt.Errorf("invalid parameter Number from client %v, statement: %v", c.ConnectionID, prepare.PrepareStmt) 1083 return c.writeErrorPacketFromErrorAndLog(err) 1084 } 1085 1086 key := fmt.Sprintf("v%d", paramID+1) 1087 if val, ok := prepare.BindVars[key]; ok { 1088 val.Value = append(val.Value, chunk...) 1089 } else { 1090 prepare.BindVars[key] = sqltypes.BytesBindVariable(chunk) 1091 } 1092 return true 1093 } 1094 1095 func (c *Conn) handleComStmtExecute(handler Handler, data []byte) (kontinue bool) { 1096 c.startWriterBuffering() 1097 defer func() { 1098 if err := c.endWriterBuffering(); err != nil { 1099 log.Errorf("conn %v: flush() failed: %v", c.ID(), err) 1100 kontinue = false 1101 } 1102 }() 1103 queryStart := time.Now() 1104 stmtID, _, err := c.parseComStmtExecute(c.PrepareData, data) 1105 c.recycleReadPacket() 1106 1107 if stmtID != uint32(0) { 1108 defer func() { 1109 // Allocate a new bindvar map every time since VTGate.Execute() mutates it. 1110 prepare := c.PrepareData[stmtID] 1111 prepare.BindVars = make(map[string]*querypb.BindVariable, prepare.ParamsCount) 1112 }() 1113 } 1114 1115 if err != nil { 1116 return c.writeErrorPacketFromErrorAndLog(err) 1117 } 1118 1119 fieldSent := false 1120 // sendFinished is set if the response should just be an OK packet. 1121 sendFinished := false 1122 prepare := c.PrepareData[stmtID] 1123 err = handler.ComStmtExecute(c, prepare, func(qr *sqltypes.Result) error { 1124 if sendFinished { 1125 // Failsafe: Unreachable if server is well-behaved. 1126 return io.EOF 1127 } 1128 1129 if !fieldSent { 1130 fieldSent = true 1131 1132 if len(qr.Fields) == 0 { 1133 sendFinished = true 1134 // We should not send any more packets after this. 1135 ok := PacketOK{ 1136 affectedRows: qr.RowsAffected, 1137 lastInsertID: qr.InsertID, 1138 statusFlags: c.StatusFlags, 1139 warnings: 0, 1140 info: "", 1141 sessionStateData: qr.SessionStateChanges, 1142 } 1143 return c.writeOKPacket(&ok) 1144 } 1145 if err := c.writeFields(qr); err != nil { 1146 return err 1147 } 1148 } 1149 1150 return c.writeBinaryRows(qr) 1151 }) 1152 1153 // If no field was sent, we expect an error. 1154 if !fieldSent { 1155 // This is just a failsafe. Should never happen. 1156 if err == nil || err == io.EOF { 1157 err = NewSQLErrorFromError(errors.New("unexpected: query ended without no results and no error")) 1158 } 1159 if !c.writeErrorPacketFromErrorAndLog(err) { 1160 return false 1161 } 1162 } else { 1163 if err != nil { 1164 // We can't send an error in the middle of a stream. 1165 // All we can do is abort the send, which will cause a 2013. 1166 log.Errorf("Error in the middle of a stream to %s: %v", c, err) 1167 return false 1168 } 1169 1170 // Send the end packet only sendFinished is false (results were streamed). 1171 // In this case the affectedRows and lastInsertID are always 0 since it 1172 // was a read operation. 1173 if !sendFinished { 1174 if err := c.writeEndResult(false, 0, 0, handler.WarningCount(c)); err != nil { 1175 log.Errorf("Error writing result to %s: %v", c, err) 1176 return false 1177 } 1178 } 1179 } 1180 1181 timings.Record(queryTimingKey, queryStart) 1182 return true 1183 } 1184 1185 func (c *Conn) handleComPrepare(handler Handler, data []byte) (kontinue bool) { 1186 c.startWriterBuffering() 1187 defer func() { 1188 if err := c.endWriterBuffering(); err != nil { 1189 log.Errorf("conn %v: flush() failed: %v", c.ID(), err) 1190 kontinue = false 1191 } 1192 }() 1193 1194 query := c.parseComPrepare(data) 1195 c.recycleReadPacket() 1196 1197 var queries []string 1198 if c.Capabilities&CapabilityClientMultiStatements != 0 { 1199 var err error 1200 queries, err = splitStatementFunction(query) 1201 if err != nil { 1202 log.Errorf("Conn %v: Error splitting query: %v", c, err) 1203 return c.writeErrorPacketFromErrorAndLog(err) 1204 } 1205 if len(queries) != 1 { 1206 log.Errorf("Conn %v: can not prepare multiple statements", c) 1207 return c.writeErrorPacketFromErrorAndLog(err) 1208 } 1209 } else { 1210 queries = []string{query} 1211 } 1212 1213 // Popoulate PrepareData 1214 c.StatementID++ 1215 prepare := &PrepareData{ 1216 StatementID: c.StatementID, 1217 PrepareStmt: queries[0], 1218 } 1219 1220 statement, err := sqlparser.ParseStrictDDL(query) 1221 if err != nil { 1222 log.Errorf("Conn %v: Error parsing prepared statement: %v", c, err) 1223 if !c.writeErrorPacketFromErrorAndLog(err) { 1224 return false 1225 } 1226 } 1227 1228 paramsCount := uint16(0) 1229 _ = sqlparser.Walk(func(node sqlparser.SQLNode) (bool, error) { 1230 switch node := node.(type) { 1231 case sqlparser.Argument: 1232 if strings.HasPrefix(string(node), "v") { 1233 paramsCount++ 1234 } 1235 } 1236 return true, nil 1237 }, statement) 1238 1239 if paramsCount > 0 { 1240 prepare.ParamsCount = paramsCount 1241 prepare.ParamsType = make([]int32, paramsCount) 1242 prepare.BindVars = make(map[string]*querypb.BindVariable, paramsCount) 1243 } 1244 1245 bindVars := make(map[string]*querypb.BindVariable, paramsCount) 1246 for i := uint16(0); i < paramsCount; i++ { 1247 parameterID := fmt.Sprintf("v%d", i+1) 1248 bindVars[parameterID] = &querypb.BindVariable{} 1249 } 1250 1251 c.PrepareData[c.StatementID] = prepare 1252 1253 fld, err := handler.ComPrepare(c, queries[0], bindVars) 1254 1255 if err != nil { 1256 return c.writeErrorPacketFromErrorAndLog(err) 1257 } 1258 1259 if err := c.writePrepare(fld, c.PrepareData[c.StatementID]); err != nil { 1260 log.Error("Error writing prepare data to client %v: %v", c.ConnectionID, err) 1261 return false 1262 } 1263 return true 1264 } 1265 1266 func (c *Conn) handleComSetOption(data []byte) bool { 1267 operation, ok := c.parseComSetOption(data) 1268 c.recycleReadPacket() 1269 if ok { 1270 switch operation { 1271 case 0: 1272 c.Capabilities |= CapabilityClientMultiStatements 1273 case 1: 1274 c.Capabilities &^= CapabilityClientMultiStatements 1275 default: 1276 log.Errorf("Got unhandled packet (ComSetOption default) from client %v, returning error: %v", c.ConnectionID, data) 1277 if !c.writeErrorAndLog(ERUnknownComError, SSNetError, "error handling packet: %v", data) { 1278 return false 1279 } 1280 } 1281 if err := c.writeEndResult(false, 0, 0, 0); err != nil { 1282 log.Errorf("Error writeEndResult error %v ", err) 1283 return false 1284 } 1285 } else { 1286 log.Errorf("Got unhandled packet (ComSetOption else) from client %v, returning error: %v", c.ConnectionID, data) 1287 if !c.writeErrorAndLog(ERUnknownComError, SSNetError, "error handling packet: %v", data) { 1288 return false 1289 } 1290 } 1291 return true 1292 } 1293 1294 func (c *Conn) handleComPing() bool { 1295 c.recycleReadPacket() 1296 // Return error if listener was shut down and OK otherwise 1297 if c.listener.isShutdown() { 1298 if !c.writeErrorAndLog(ERServerShutdown, SSNetError, "Server shutdown in progress") { 1299 return false 1300 } 1301 } else { 1302 if err := c.writeOKPacket(&PacketOK{statusFlags: c.StatusFlags}); err != nil { 1303 log.Errorf("Error writing ComPing result to %s: %v", c, err) 1304 return false 1305 } 1306 } 1307 return true 1308 } 1309 1310 var errEmptyStatement = NewSQLError(EREmptyQuery, SSClientError, "Query was empty") 1311 1312 func (c *Conn) handleComQuery(handler Handler, data []byte) (kontinue bool) { 1313 c.startWriterBuffering() 1314 defer func() { 1315 if err := c.endWriterBuffering(); err != nil { 1316 log.Errorf("conn %v: flush() failed: %v", c.ID(), err) 1317 kontinue = false 1318 } 1319 }() 1320 1321 queryStart := time.Now() 1322 query := c.parseComQuery(data) 1323 c.recycleReadPacket() 1324 1325 var queries []string 1326 var err error 1327 if c.Capabilities&CapabilityClientMultiStatements != 0 { 1328 queries, err = splitStatementFunction(query) 1329 if err != nil { 1330 log.Errorf("Conn %v: Error splitting query: %v", c, err) 1331 return c.writeErrorPacketFromErrorAndLog(err) 1332 } 1333 } else { 1334 queries = []string{query} 1335 } 1336 1337 if len(queries) == 0 { 1338 return c.writeErrorPacketFromErrorAndLog(errEmptyStatement) 1339 } 1340 1341 for index, sql := range queries { 1342 more := false 1343 if index != len(queries)-1 { 1344 more = true 1345 } 1346 res := c.execQuery(sql, handler, more) 1347 if res != execSuccess { 1348 return res != connErr 1349 } 1350 } 1351 1352 timings.Record(queryTimingKey, queryStart) 1353 return true 1354 } 1355 1356 func (c *Conn) execQuery(query string, handler Handler, more bool) execResult { 1357 callbackCalled := false 1358 // sendFinished is set if the response should just be an OK packet. 1359 sendFinished := false 1360 1361 err := handler.ComQuery(c, query, func(qr *sqltypes.Result) error { 1362 flag := c.StatusFlags 1363 if more { 1364 flag |= ServerMoreResultsExists 1365 } 1366 if sendFinished { 1367 // Failsafe: Unreachable if server is well-behaved. 1368 return io.EOF 1369 } 1370 1371 if !callbackCalled { 1372 callbackCalled = true 1373 1374 if len(qr.Fields) == 0 { 1375 sendFinished = true 1376 1377 // A successful callback with no fields means that this was a 1378 // DML or other write-only operation. 1379 // 1380 // We should not send any more packets after this, but make sure 1381 // to extract the affected rows and last insert id from the result 1382 // struct here since clients expect it. 1383 ok := PacketOK{ 1384 affectedRows: qr.RowsAffected, 1385 lastInsertID: qr.InsertID, 1386 statusFlags: flag, 1387 warnings: handler.WarningCount(c), 1388 info: "", 1389 sessionStateData: qr.SessionStateChanges, 1390 } 1391 return c.writeOKPacket(&ok) 1392 } 1393 if err := c.writeFields(qr); err != nil { 1394 return err 1395 } 1396 } 1397 1398 return c.writeRows(qr) 1399 }) 1400 1401 // If callback was not called, we expect an error. 1402 if !callbackCalled { 1403 // This is just a failsafe. Should never happen. 1404 if err == nil || err == io.EOF { 1405 err = NewSQLErrorFromError(errors.New("unexpected: query ended without no results and no error")) 1406 } 1407 if !c.writeErrorPacketFromErrorAndLog(err) { 1408 return connErr 1409 } 1410 return execErr 1411 } 1412 if err != nil { 1413 // We can't send an error in the middle of a stream. 1414 // All we can do is abort the send, which will cause a 2013. 1415 log.Errorf("Error in the middle of a stream to %s: %v", c, err) 1416 return connErr 1417 } 1418 1419 // Send the end packet only sendFinished is false (results were streamed). 1420 // In this case the affectedRows and lastInsertID are always 0 since it 1421 // was a read operation. 1422 if !sendFinished { 1423 if err := c.writeEndResult(more, 0, 0, handler.WarningCount(c)); err != nil { 1424 log.Errorf("Error writing result to %s: %v", c, err) 1425 return connErr 1426 } 1427 } 1428 1429 return execSuccess 1430 } 1431 1432 // 1433 // Packet parsing methods, for generic packets. 1434 // 1435 1436 // isEOFPacket determines whether a data packet is an EOF. In case the client capabilities 1437 // do not have DEPRECATE_EOF set, DO NOT blindly compare the first byte of a packet to EOFPacket 1438 // as you might do for other packet types, as 0xfe is overloaded as a first byte. 1439 1440 // In case that DEPRECATE_EOF is set, we have really an OK packet which is always maximum a single 1441 // packet and not multiple, but otherwise 0xfe definitely indicates it is an EOF. 1442 // 1443 // Per https://dev.mysql.com/doc/internals/en/packet-EOF_Packet.html, a packet starting with 0xfe 1444 // but having length >= 9 (on top of 4 byte header) without DEPRECATE_EOF set is not a true EOF but 1445 // a LengthEncodedInteger (typically preceding a LengthEncodedString). Thus, all EOF checks without 1446 // DEPRECATE_EOF must validate the payload size before exiting. 1447 // 1448 // More docs here: 1449 // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_response_packets.html 1450 func (c *Conn) isEOFPacket(data []byte) bool { 1451 if data[0] != EOFPacket { 1452 return false 1453 } 1454 if c.Capabilities&CapabilityClientDeprecateEOF == 0 { 1455 return len(data) < 9 1456 } 1457 return len(data) < MaxPacketSize 1458 } 1459 1460 // parseEOFPacket returns the warning count and a boolean to indicate if there 1461 // are more results to receive. 1462 // 1463 // Note: This is only valid on actual EOF packets and not on OK packets with the EOF 1464 // type code set, i.e. should not be used if ClientDeprecateEOF is set. 1465 func parseEOFPacket(data []byte) (warnings uint16, statusFlags uint16, err error) { 1466 // The warning count is in position 2 & 3 1467 warnings, _, _ = readUint16(data, 1) 1468 1469 // The status flag is in position 4 & 5 1470 statusFlags, _, ok := readUint16(data, 3) 1471 if !ok { 1472 return 0, 0, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid EOF packet statusFlags: %v", data) 1473 } 1474 return warnings, statusFlags, nil 1475 } 1476 1477 // PacketOK contains the ok packet details 1478 type PacketOK struct { 1479 affectedRows uint64 1480 lastInsertID uint64 1481 statusFlags uint16 1482 warnings uint16 1483 info string 1484 1485 // at the moment, we only store GTID information in this field 1486 sessionStateData string 1487 } 1488 1489 func (c *Conn) parseOKPacket(in []byte) (*PacketOK, error) { 1490 data := &coder{ 1491 data: in, 1492 pos: 1, // We already read the type. 1493 } 1494 packetOK := &PacketOK{} 1495 1496 fail := func(format string, args ...any) (*PacketOK, error) { 1497 return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, format, args...) 1498 } 1499 1500 // Affected rows. 1501 affectedRows, ok := data.readLenEncInt() 1502 if !ok { 1503 return fail("invalid OK packet affectedRows: %v", data) 1504 } 1505 packetOK.affectedRows = affectedRows 1506 1507 // Last Insert ID. 1508 lastInsertID, ok := data.readLenEncInt() 1509 if !ok { 1510 return fail("invalid OK packet lastInsertID: %v", data) 1511 } 1512 packetOK.lastInsertID = lastInsertID 1513 1514 // Status flags. 1515 statusFlags, ok := data.readUint16() 1516 if !ok { 1517 return fail("invalid OK packet statusFlags: %v", data) 1518 } 1519 packetOK.statusFlags = statusFlags 1520 1521 // assuming CapabilityClientProtocol41 1522 // Warnings. 1523 warnings, ok := data.readUint16() 1524 if !ok { 1525 return fail("invalid OK packet warnings: %v", data) 1526 } 1527 packetOK.warnings = warnings 1528 1529 // info 1530 info, _ := data.readLenEncInfo() 1531 if c.enableQueryInfo { 1532 packetOK.info = info 1533 } 1534 1535 if c.Capabilities&uint32(CapabilityClientSessionTrack) == CapabilityClientSessionTrack { 1536 // session tracking 1537 if statusFlags&ServerSessionStateChanged == ServerSessionStateChanged { 1538 length, ok := data.readLenEncInt() 1539 if !ok || length == 0 { 1540 // In case we have no more data or a zero length string, there's no additional information so 1541 // we can return the packet. 1542 return packetOK, nil 1543 } 1544 1545 // Alright, now we need to read each sub packet from the session state change. 1546 for { 1547 sscType, ok := data.readByte() 1548 if !ok { 1549 // We're done, there's no more session state parts in the packet. 1550 break 1551 } 1552 sessionLen, ok := data.readLenEncInt() 1553 if !ok { 1554 return fail("invalid OK packet session state change length for type %v", sscType) 1555 } 1556 1557 if sscType != SessionTrackGtids { 1558 // Still need to increase the pointer here to indicate we're consuming 1559 // but otherwise ignoring the rest of this packet 1560 data.pos = data.pos + int(sessionLen) 1561 continue 1562 } 1563 1564 // read (and ignore for now) the GTIDS encoding specification code: 1 byte 1565 _, ok = data.readByte() 1566 if !ok { 1567 return fail("invalid OK packet gtids type: %v", data) 1568 } 1569 1570 gtids, ok := data.readLenEncString() 1571 if !ok { 1572 return fail("invalid OK packet gtids: %v", data) 1573 } 1574 packetOK.sessionStateData = gtids 1575 } 1576 } 1577 } 1578 1579 return packetOK, nil 1580 } 1581 1582 // isErrorPacket determines whether or not the packet is an error packet. Mostly here for 1583 // consistency with isEOFPacket 1584 func isErrorPacket(data []byte) bool { 1585 return data[0] == ErrPacket 1586 } 1587 1588 // ParseErrorPacket parses the error packet and returns a SQLError. 1589 func ParseErrorPacket(data []byte) error { 1590 // We already read the type. 1591 pos := 1 1592 1593 // Error code is 2 bytes. 1594 code, pos, ok := readUint16(data, pos) 1595 if !ok { 1596 return NewSQLError(CRUnknownError, SSUnknownSQLState, "invalid error packet code: %v", data) 1597 } 1598 1599 // '#' marker of the SQL state is 1 byte. Ignored. 1600 pos++ 1601 1602 // SQL state is 5 bytes 1603 sqlState, pos, ok := readBytes(data, pos, 5) 1604 if !ok { 1605 return NewSQLError(CRUnknownError, SSUnknownSQLState, "invalid error packet sqlState: %v", data) 1606 } 1607 1608 // Human readable error message is the rest. 1609 msg := string(data[pos:]) 1610 1611 return NewSQLError(int(code), string(sqlState), "%v", msg) 1612 } 1613 1614 // GetTLSClientCerts gets TLS certificates. 1615 func (c *Conn) GetTLSClientCerts() []*x509.Certificate { 1616 if tlsConn, ok := c.conn.(*tls.Conn); ok { 1617 return tlsConn.ConnectionState().PeerCertificates 1618 } 1619 return nil 1620 } 1621 1622 // TLSEnabled returns true if this connection is using TLS. 1623 func (c *Conn) TLSEnabled() bool { 1624 return c.Capabilities&CapabilityClientSSL > 0 1625 } 1626 1627 // IsUnixSocket returns true if this connection is over a Unix socket. 1628 func (c *Conn) IsUnixSocket() bool { 1629 _, ok := c.listener.listener.(*net.UnixListener) 1630 return ok 1631 } 1632 1633 // GetRawConn returns the raw net.Conn for nefarious purposes. 1634 func (c *Conn) GetRawConn() net.Conn { 1635 return c.conn 1636 }