github.com/hellobchain/third_party@v0.0.0-20230331131523-deb0478a2e52/go-sql-driver/mysql/packets.go (about) 1 // Go MySQL Driver - A MySQL-Driver for Go's database/sql package 2 // 3 // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. 4 // 5 // This Source Code Form is subject to the terms of the Mozilla Public 6 // License, v. 2.0. If a copy of the MPL was not distributed with this file, 7 // You can obtain one at http://mozilla.org/MPL/2.0/. 8 9 package mysql 10 11 import ( 12 "bytes" 13 "database/sql/driver" 14 "encoding/binary" 15 "errors" 16 "fmt" 17 "github.com/hellobchain/newcryptosm/tls" 18 "io" 19 "math" 20 "time" 21 ) 22 23 // Packets documentation: 24 // http://dev.mysql.com/doc/internals/en/client-server-protocol.html 25 26 // Read packet to buffer 'data' 27 func (mc *mysqlConn) readPacket() ([]byte, error) { 28 var prevData []byte 29 for { 30 // read packet header 31 data, err := mc.buf.readNext(4) 32 if err != nil { 33 if cerr := mc.canceled.Value(); cerr != nil { 34 return nil, cerr 35 } 36 errLog.Print(err) 37 mc.Close() 38 return nil, ErrInvalidConn 39 } 40 41 // packet length [24 bit] 42 pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16) 43 44 // check packet sync [8 bit] 45 if data[3] != mc.sequence { 46 if data[3] > mc.sequence { 47 return nil, ErrPktSyncMul 48 } 49 return nil, ErrPktSync 50 } 51 mc.sequence++ 52 53 // packets with length 0 terminate a previous packet which is a 54 // multiple of (2^24)-1 bytes long 55 if pktLen == 0 { 56 // there was no previous packet 57 if prevData == nil { 58 errLog.Print(ErrMalformPkt) 59 mc.Close() 60 return nil, ErrInvalidConn 61 } 62 63 return prevData, nil 64 } 65 66 // read packet body [pktLen bytes] 67 data, err = mc.buf.readNext(pktLen) 68 if err != nil { 69 if cerr := mc.canceled.Value(); cerr != nil { 70 return nil, cerr 71 } 72 errLog.Print(err) 73 mc.Close() 74 return nil, ErrInvalidConn 75 } 76 77 // return data if this was the last packet 78 if pktLen < maxPacketSize { 79 // zero allocations for non-split packets 80 if prevData == nil { 81 return data, nil 82 } 83 84 return append(prevData, data...), nil 85 } 86 87 prevData = append(prevData, data...) 88 } 89 } 90 91 // Write packet buffer 'data' 92 func (mc *mysqlConn) writePacket(data []byte) error { 93 pktLen := len(data) - 4 94 95 if pktLen > mc.maxAllowedPacket { 96 return ErrPktTooLarge 97 } 98 99 // Perform a stale connection check. We only perform this check for 100 // the first query on a connection that has been checked out of the 101 // connection pool: a fresh connection from the pool is more likely 102 // to be stale, and it has not performed any previous writes that 103 // could cause data corruption, so it's safe to return ErrBadConn 104 // if the check fails. 105 if mc.reset { 106 mc.reset = false 107 conn := mc.netConn 108 if mc.rawConn != nil { 109 conn = mc.rawConn 110 } 111 var err error 112 // If this connection has a ReadTimeout which we've been setting on 113 // reads, reset it to its default value before we attempt a non-blocking 114 // read, otherwise the scheduler will just time us out before we can read 115 if mc.cfg.ReadTimeout != 0 { 116 err = conn.SetReadDeadline(time.Time{}) 117 } 118 if err == nil { 119 err = connCheck(conn) 120 } 121 if err != nil { 122 errLog.Print("closing bad idle connection: ", err) 123 mc.Close() 124 return driver.ErrBadConn 125 } 126 } 127 128 for { 129 var size int 130 if pktLen >= maxPacketSize { 131 data[0] = 0xff 132 data[1] = 0xff 133 data[2] = 0xff 134 size = maxPacketSize 135 } else { 136 data[0] = byte(pktLen) 137 data[1] = byte(pktLen >> 8) 138 data[2] = byte(pktLen >> 16) 139 size = pktLen 140 } 141 data[3] = mc.sequence 142 143 // Write packet 144 if mc.writeTimeout > 0 { 145 if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil { 146 return err 147 } 148 } 149 150 n, err := mc.netConn.Write(data[:4+size]) 151 if err == nil && n == 4+size { 152 mc.sequence++ 153 if size != maxPacketSize { 154 return nil 155 } 156 pktLen -= size 157 data = data[size:] 158 continue 159 } 160 161 // Handle error 162 if err == nil { // n != len(data) 163 mc.cleanup() 164 errLog.Print(ErrMalformPkt) 165 } else { 166 if cerr := mc.canceled.Value(); cerr != nil { 167 return cerr 168 } 169 if n == 0 && pktLen == len(data)-4 { 170 // only for the first loop iteration when nothing was written yet 171 return errBadConnNoWrite 172 } 173 mc.cleanup() 174 errLog.Print(err) 175 } 176 return ErrInvalidConn 177 } 178 } 179 180 /****************************************************************************** 181 * Initialization Process * 182 ******************************************************************************/ 183 184 // Handshake Initialization Packet 185 // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake 186 func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err error) { 187 data, err = mc.readPacket() 188 if err != nil { 189 // for init we can rewrite this to ErrBadConn for sql.Driver to retry, since 190 // in connection initialization we don't risk retrying non-idempotent actions. 191 if err == ErrInvalidConn { 192 return nil, "", driver.ErrBadConn 193 } 194 return 195 } 196 197 if data[0] == iERR { 198 return nil, "", mc.handleErrorPacket(data) 199 } 200 201 // protocol version [1 byte] 202 if data[0] < minProtocolVersion { 203 return nil, "", fmt.Errorf( 204 "unsupported protocol version %d. Version %d or higher is required", 205 data[0], 206 minProtocolVersion, 207 ) 208 } 209 210 // server version [null terminated string] 211 // connection id [4 bytes] 212 pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4 213 214 // first part of the password cipher [8 bytes] 215 authData := data[pos : pos+8] 216 217 // (filler) always 0x00 [1 byte] 218 pos += 8 + 1 219 220 // capability flags (lower 2 bytes) [2 bytes] 221 mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) 222 if mc.flags&clientProtocol41 == 0 { 223 return nil, "", ErrOldProtocol 224 } 225 if mc.flags&clientSSL == 0 && mc.cfg.tls != nil { 226 if mc.cfg.TLSConfig == "preferred" { 227 mc.cfg.tls = nil 228 } else { 229 return nil, "", ErrNoTLS 230 } 231 } 232 pos += 2 233 234 if len(data) > pos { 235 // character set [1 byte] 236 // status flags [2 bytes] 237 // capability flags (upper 2 bytes) [2 bytes] 238 // length of auth-plugin-data [1 byte] 239 // reserved (all [00]) [10 bytes] 240 pos += 1 + 2 + 2 + 1 + 10 241 242 // second part of the password cipher [mininum 13 bytes], 243 // where len=MAX(13, length of auth-plugin-data - 8) 244 // 245 // The web documentation is ambiguous about the length. However, 246 // according to mysql-5.7/sql/auth/sql_authentication.cc line 538, 247 // the 13th byte is "\0 byte, terminating the second part of 248 // a scramble". So the second part of the password cipher is 249 // a NULL terminated string that's at least 13 bytes with the 250 // last byte being NULL. 251 // 252 // The official Python library uses the fixed length 12 253 // which seems to work but technically could have a hidden bug. 254 authData = append(authData, data[pos:pos+12]...) 255 pos += 13 256 257 // EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2) 258 // \NUL otherwise 259 if end := bytes.IndexByte(data[pos:], 0x00); end != -1 { 260 plugin = string(data[pos : pos+end]) 261 } else { 262 plugin = string(data[pos:]) 263 } 264 265 // make a memory safe copy of the cipher slice 266 var b [20]byte 267 copy(b[:], authData) 268 return b[:], plugin, nil 269 } 270 271 // make a memory safe copy of the cipher slice 272 var b [8]byte 273 copy(b[:], authData) 274 return b[:], plugin, nil 275 } 276 277 // Client Authentication Packet 278 // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse 279 func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string) error { 280 // Adjust client flags based on server support 281 clientFlags := clientProtocol41 | 282 clientSecureConn | 283 clientLongPassword | 284 clientTransactions | 285 clientLocalFiles | 286 clientPluginAuth | 287 clientMultiResults | 288 mc.flags&clientLongFlag 289 290 if mc.cfg.ClientFoundRows { 291 clientFlags |= clientFoundRows 292 } 293 294 // To enable TLS / SSL 295 if mc.cfg.tls != nil { 296 clientFlags |= clientSSL 297 } 298 299 if mc.cfg.MultiStatements { 300 clientFlags |= clientMultiStatements 301 } 302 303 // encode length of the auth plugin data 304 var authRespLEIBuf [9]byte 305 authRespLen := len(authResp) 306 authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(authRespLen)) 307 if len(authRespLEI) > 1 { 308 // if the length can not be written in 1 byte, it must be written as a 309 // length encoded integer 310 clientFlags |= clientPluginAuthLenEncClientData 311 } 312 313 pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1 314 315 // To specify a db name 316 if n := len(mc.cfg.DBName); n > 0 { 317 clientFlags |= clientConnectWithDB 318 pktLen += n + 1 319 } 320 321 // Calculate packet length and get buffer with that size 322 data, err := mc.buf.takeSmallBuffer(pktLen + 4) 323 if err != nil { 324 // cannot take the buffer. Something must be wrong with the connection 325 errLog.Print(err) 326 return errBadConnNoWrite 327 } 328 329 // ClientFlags [32 bit] 330 data[4] = byte(clientFlags) 331 data[5] = byte(clientFlags >> 8) 332 data[6] = byte(clientFlags >> 16) 333 data[7] = byte(clientFlags >> 24) 334 335 // MaxPacketSize [32 bit] (none) 336 data[8] = 0x00 337 data[9] = 0x00 338 data[10] = 0x00 339 data[11] = 0x00 340 341 // Charset [1 byte] 342 var found bool 343 data[12], found = collations[mc.cfg.Collation] 344 if !found { 345 // Note possibility for false negatives: 346 // could be triggered although the collation is valid if the 347 // collations map does not contain entries the server supports. 348 return errors.New("unknown collation") 349 } 350 351 // SSL Connection Request Packet 352 // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest 353 if mc.cfg.tls != nil { 354 // Send TLS / SSL request packet 355 if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil { 356 return err 357 } 358 359 // Switch to TLS 360 tlsConn := tls.Client(mc.netConn, mc.cfg.tls) 361 if err := tlsConn.Handshake(); err != nil { 362 return err 363 } 364 mc.rawConn = mc.netConn 365 mc.netConn = tlsConn 366 mc.buf.nc = tlsConn 367 } 368 369 // Filler [23 bytes] (all 0x00) 370 pos := 13 371 for ; pos < 13+23; pos++ { 372 data[pos] = 0 373 } 374 375 // User [null terminated string] 376 if len(mc.cfg.User) > 0 { 377 pos += copy(data[pos:], mc.cfg.User) 378 } 379 data[pos] = 0x00 380 pos++ 381 382 // Auth Data [length encoded integer] 383 pos += copy(data[pos:], authRespLEI) 384 pos += copy(data[pos:], authResp) 385 386 // Databasename [null terminated string] 387 if len(mc.cfg.DBName) > 0 { 388 pos += copy(data[pos:], mc.cfg.DBName) 389 data[pos] = 0x00 390 pos++ 391 } 392 393 pos += copy(data[pos:], plugin) 394 data[pos] = 0x00 395 pos++ 396 397 // Send Auth packet 398 return mc.writePacket(data[:pos]) 399 } 400 401 // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse 402 func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error { 403 pktLen := 4 + len(authData) 404 data, err := mc.buf.takeSmallBuffer(pktLen) 405 if err != nil { 406 // cannot take the buffer. Something must be wrong with the connection 407 errLog.Print(err) 408 return errBadConnNoWrite 409 } 410 411 // Add the auth data [EOF] 412 copy(data[4:], authData) 413 return mc.writePacket(data) 414 } 415 416 /****************************************************************************** 417 * Command Packets * 418 ******************************************************************************/ 419 420 func (mc *mysqlConn) writeCommandPacket(command byte) error { 421 // Reset Packet Sequence 422 mc.sequence = 0 423 424 data, err := mc.buf.takeSmallBuffer(4 + 1) 425 if err != nil { 426 // cannot take the buffer. Something must be wrong with the connection 427 errLog.Print(err) 428 return errBadConnNoWrite 429 } 430 431 // Add command byte 432 data[4] = command 433 434 // Send CMD packet 435 return mc.writePacket(data) 436 } 437 438 func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { 439 // Reset Packet Sequence 440 mc.sequence = 0 441 442 pktLen := 1 + len(arg) 443 data, err := mc.buf.takeBuffer(pktLen + 4) 444 if err != nil { 445 // cannot take the buffer. Something must be wrong with the connection 446 errLog.Print(err) 447 return errBadConnNoWrite 448 } 449 450 // Add command byte 451 data[4] = command 452 453 // Add arg 454 copy(data[5:], arg) 455 456 // Send CMD packet 457 return mc.writePacket(data) 458 } 459 460 func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { 461 // Reset Packet Sequence 462 mc.sequence = 0 463 464 data, err := mc.buf.takeSmallBuffer(4 + 1 + 4) 465 if err != nil { 466 // cannot take the buffer. Something must be wrong with the connection 467 errLog.Print(err) 468 return errBadConnNoWrite 469 } 470 471 // Add command byte 472 data[4] = command 473 474 // Add arg [32 bit] 475 data[5] = byte(arg) 476 data[6] = byte(arg >> 8) 477 data[7] = byte(arg >> 16) 478 data[8] = byte(arg >> 24) 479 480 // Send CMD packet 481 return mc.writePacket(data) 482 } 483 484 /****************************************************************************** 485 * Result Packets * 486 ******************************************************************************/ 487 488 func (mc *mysqlConn) readAuthResult() ([]byte, string, error) { 489 data, err := mc.readPacket() 490 if err != nil { 491 return nil, "", err 492 } 493 494 // packet indicator 495 switch data[0] { 496 497 case iOK: 498 return nil, "", mc.handleOkPacket(data) 499 500 case iAuthMoreData: 501 return data[1:], "", err 502 503 case iEOF: 504 if len(data) == 1 { 505 // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest 506 return nil, "mysql_old_password", nil 507 } 508 pluginEndIndex := bytes.IndexByte(data, 0x00) 509 if pluginEndIndex < 0 { 510 return nil, "", ErrMalformPkt 511 } 512 plugin := string(data[1:pluginEndIndex]) 513 authData := data[pluginEndIndex+1:] 514 return authData, plugin, nil 515 516 default: // Error otherwise 517 return nil, "", mc.handleErrorPacket(data) 518 } 519 } 520 521 // Returns error if Packet is not an 'Result OK'-Packet 522 func (mc *mysqlConn) readResultOK() error { 523 data, err := mc.readPacket() 524 if err != nil { 525 return err 526 } 527 528 if data[0] == iOK { 529 return mc.handleOkPacket(data) 530 } 531 return mc.handleErrorPacket(data) 532 } 533 534 // Result Set Header Packet 535 // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset 536 func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) { 537 data, err := mc.readPacket() 538 if err == nil { 539 switch data[0] { 540 541 case iOK: 542 return 0, mc.handleOkPacket(data) 543 544 case iERR: 545 return 0, mc.handleErrorPacket(data) 546 547 case iLocalInFile: 548 return 0, mc.handleInFileRequest(string(data[1:])) 549 } 550 551 // column count 552 num, _, n := readLengthEncodedInteger(data) 553 if n-len(data) == 0 { 554 return int(num), nil 555 } 556 557 return 0, ErrMalformPkt 558 } 559 return 0, err 560 } 561 562 // Error Packet 563 // http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-ERR_Packet 564 func (mc *mysqlConn) handleErrorPacket(data []byte) error { 565 if data[0] != iERR { 566 return ErrMalformPkt 567 } 568 569 // 0xff [1 byte] 570 571 // Error Number [16 bit uint] 572 errno := binary.LittleEndian.Uint16(data[1:3]) 573 574 // 1792: ER_CANT_EXECUTE_IN_READ_ONLY_TRANSACTION 575 // 1290: ER_OPTION_PREVENTS_STATEMENT (returned by Aurora during failover) 576 if (errno == 1792 || errno == 1290) && mc.cfg.RejectReadOnly { 577 // Oops; we are connected to a read-only connection, and won't be able 578 // to issue any write statements. Since RejectReadOnly is configured, 579 // we throw away this connection hoping this one would have write 580 // permission. This is specifically for a possible race condition 581 // during failover (e.g. on AWS Aurora). See README.md for more. 582 // 583 // We explicitly close the connection before returning 584 // driver.ErrBadConn to ensure that `database/sql` purges this 585 // connection and initiates a new one for next statement next time. 586 mc.Close() 587 return driver.ErrBadConn 588 } 589 590 pos := 3 591 592 // SQL State [optional: # + 5bytes string] 593 if data[3] == 0x23 { 594 //sqlstate := string(data[4 : 4+5]) 595 pos = 9 596 } 597 598 // Error Message [string] 599 return &MySQLError{ 600 Number: errno, 601 Message: string(data[pos:]), 602 } 603 } 604 605 func readStatus(b []byte) statusFlag { 606 return statusFlag(b[0]) | statusFlag(b[1])<<8 607 } 608 609 // Ok Packet 610 // http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet 611 func (mc *mysqlConn) handleOkPacket(data []byte) error { 612 var n, m int 613 614 // 0x00 [1 byte] 615 616 // Affected rows [Length Coded Binary] 617 mc.affectedRows, _, n = readLengthEncodedInteger(data[1:]) 618 619 // Insert id [Length Coded Binary] 620 mc.insertId, _, m = readLengthEncodedInteger(data[1+n:]) 621 622 // server_status [2 bytes] 623 mc.status = readStatus(data[1+n+m : 1+n+m+2]) 624 if mc.status&statusMoreResultsExists != 0 { 625 return nil 626 } 627 628 // warning count [2 bytes] 629 630 return nil 631 } 632 633 // Read Packets as Field Packets until EOF-Packet or an Error appears 634 // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41 635 func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { 636 columns := make([]mysqlField, count) 637 638 for i := 0; ; i++ { 639 data, err := mc.readPacket() 640 if err != nil { 641 return nil, err 642 } 643 644 // EOF Packet 645 if data[0] == iEOF && (len(data) == 5 || len(data) == 1) { 646 if i == count { 647 return columns, nil 648 } 649 return nil, fmt.Errorf("column count mismatch n:%d len:%d", count, len(columns)) 650 } 651 652 // Catalog 653 pos, err := skipLengthEncodedString(data) 654 if err != nil { 655 return nil, err 656 } 657 658 // Database [len coded string] 659 n, err := skipLengthEncodedString(data[pos:]) 660 if err != nil { 661 return nil, err 662 } 663 pos += n 664 665 // Table [len coded string] 666 if mc.cfg.ColumnsWithAlias { 667 tableName, _, n, err := readLengthEncodedString(data[pos:]) 668 if err != nil { 669 return nil, err 670 } 671 pos += n 672 columns[i].tableName = string(tableName) 673 } else { 674 n, err = skipLengthEncodedString(data[pos:]) 675 if err != nil { 676 return nil, err 677 } 678 pos += n 679 } 680 681 // Original table [len coded string] 682 n, err = skipLengthEncodedString(data[pos:]) 683 if err != nil { 684 return nil, err 685 } 686 pos += n 687 688 // Name [len coded string] 689 name, _, n, err := readLengthEncodedString(data[pos:]) 690 if err != nil { 691 return nil, err 692 } 693 columns[i].name = string(name) 694 pos += n 695 696 // Original name [len coded string] 697 n, err = skipLengthEncodedString(data[pos:]) 698 if err != nil { 699 return nil, err 700 } 701 pos += n 702 703 // Filler [uint8] 704 pos++ 705 706 // Charset [charset, collation uint8] 707 columns[i].charSet = data[pos] 708 pos += 2 709 710 // Length [uint32] 711 columns[i].length = binary.LittleEndian.Uint32(data[pos : pos+4]) 712 pos += 4 713 714 // Field type [uint8] 715 columns[i].fieldType = fieldType(data[pos]) 716 pos++ 717 718 // Flags [uint16] 719 columns[i].flags = fieldFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) 720 pos += 2 721 722 // Decimals [uint8] 723 columns[i].decimals = data[pos] 724 //pos++ 725 726 // Default value [len coded binary] 727 //if pos < len(data) { 728 // defaultVal, _, err = bytesToLengthCodedBinary(data[pos:]) 729 //} 730 } 731 } 732 733 // Read Packets as Field Packets until EOF-Packet or an Error appears 734 // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow 735 func (rows *textRows) readRow(dest []driver.Value) error { 736 mc := rows.mc 737 738 if rows.rs.done { 739 return io.EOF 740 } 741 742 data, err := mc.readPacket() 743 if err != nil { 744 return err 745 } 746 747 // EOF Packet 748 if data[0] == iEOF && len(data) == 5 { 749 // server_status [2 bytes] 750 rows.mc.status = readStatus(data[3:]) 751 rows.rs.done = true 752 if !rows.HasNextResultSet() { 753 rows.mc = nil 754 } 755 return io.EOF 756 } 757 if data[0] == iERR { 758 rows.mc = nil 759 return mc.handleErrorPacket(data) 760 } 761 762 // RowSet Packet 763 var n int 764 var isNull bool 765 pos := 0 766 767 for i := range dest { 768 // Read bytes and convert to string 769 dest[i], isNull, n, err = readLengthEncodedString(data[pos:]) 770 pos += n 771 if err == nil { 772 if !isNull { 773 if !mc.parseTime { 774 continue 775 } else { 776 switch rows.rs.columns[i].fieldType { 777 case fieldTypeTimestamp, fieldTypeDateTime, 778 fieldTypeDate, fieldTypeNewDate: 779 dest[i], err = parseDateTime( 780 string(dest[i].([]byte)), 781 mc.cfg.Loc, 782 ) 783 if err == nil { 784 continue 785 } 786 default: 787 continue 788 } 789 } 790 791 } else { 792 dest[i] = nil 793 continue 794 } 795 } 796 return err // err != nil 797 } 798 799 return nil 800 } 801 802 // Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read 803 func (mc *mysqlConn) readUntilEOF() error { 804 for { 805 data, err := mc.readPacket() 806 if err != nil { 807 return err 808 } 809 810 switch data[0] { 811 case iERR: 812 return mc.handleErrorPacket(data) 813 case iEOF: 814 if len(data) == 5 { 815 mc.status = readStatus(data[3:]) 816 } 817 return nil 818 } 819 } 820 } 821 822 /****************************************************************************** 823 * Prepared Statements * 824 ******************************************************************************/ 825 826 // Prepare Result Packets 827 // http://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html 828 func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) { 829 data, err := stmt.mc.readPacket() 830 if err == nil { 831 // packet indicator [1 byte] 832 if data[0] != iOK { 833 return 0, stmt.mc.handleErrorPacket(data) 834 } 835 836 // statement id [4 bytes] 837 stmt.id = binary.LittleEndian.Uint32(data[1:5]) 838 839 // Column count [16 bit uint] 840 columnCount := binary.LittleEndian.Uint16(data[5:7]) 841 842 // Param count [16 bit uint] 843 stmt.paramCount = int(binary.LittleEndian.Uint16(data[7:9])) 844 845 // Reserved [8 bit] 846 847 // Warning count [16 bit uint] 848 849 return columnCount, nil 850 } 851 return 0, err 852 } 853 854 // http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html 855 func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { 856 maxLen := stmt.mc.maxAllowedPacket - 1 857 pktLen := maxLen 858 859 // After the header (bytes 0-3) follows before the data: 860 // 1 byte command 861 // 4 bytes stmtID 862 // 2 bytes paramID 863 const dataOffset = 1 + 4 + 2 864 865 // Cannot use the write buffer since 866 // a) the buffer is too small 867 // b) it is in use 868 data := make([]byte, 4+1+4+2+len(arg)) 869 870 copy(data[4+dataOffset:], arg) 871 872 for argLen := len(arg); argLen > 0; argLen -= pktLen - dataOffset { 873 if dataOffset+argLen < maxLen { 874 pktLen = dataOffset + argLen 875 } 876 877 stmt.mc.sequence = 0 878 // Add command byte [1 byte] 879 data[4] = comStmtSendLongData 880 881 // Add stmtID [32 bit] 882 data[5] = byte(stmt.id) 883 data[6] = byte(stmt.id >> 8) 884 data[7] = byte(stmt.id >> 16) 885 data[8] = byte(stmt.id >> 24) 886 887 // Add paramID [16 bit] 888 data[9] = byte(paramID) 889 data[10] = byte(paramID >> 8) 890 891 // Send CMD packet 892 err := stmt.mc.writePacket(data[:4+pktLen]) 893 if err == nil { 894 data = data[pktLen-dataOffset:] 895 continue 896 } 897 return err 898 899 } 900 901 // Reset Packet Sequence 902 stmt.mc.sequence = 0 903 return nil 904 } 905 906 // Execute Prepared Statement 907 // http://dev.mysql.com/doc/internals/en/com-stmt-execute.html 908 func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { 909 if len(args) != stmt.paramCount { 910 return fmt.Errorf( 911 "argument count mismatch (got: %d; has: %d)", 912 len(args), 913 stmt.paramCount, 914 ) 915 } 916 917 const minPktLen = 4 + 1 + 4 + 1 + 4 918 mc := stmt.mc 919 920 // Determine threshold dynamically to avoid packet size shortage. 921 longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1) 922 if longDataSize < 64 { 923 longDataSize = 64 924 } 925 926 // Reset packet-sequence 927 mc.sequence = 0 928 929 var data []byte 930 var err error 931 932 if len(args) == 0 { 933 data, err = mc.buf.takeBuffer(minPktLen) 934 } else { 935 data, err = mc.buf.takeCompleteBuffer() 936 // In this case the len(data) == cap(data) which is used to optimise the flow below. 937 } 938 if err != nil { 939 // cannot take the buffer. Something must be wrong with the connection 940 errLog.Print(err) 941 return errBadConnNoWrite 942 } 943 944 // command [1 byte] 945 data[4] = comStmtExecute 946 947 // statement_id [4 bytes] 948 data[5] = byte(stmt.id) 949 data[6] = byte(stmt.id >> 8) 950 data[7] = byte(stmt.id >> 16) 951 data[8] = byte(stmt.id >> 24) 952 953 // flags (0: CURSOR_TYPE_NO_CURSOR) [1 byte] 954 data[9] = 0x00 955 956 // iteration_count (uint32(1)) [4 bytes] 957 data[10] = 0x01 958 data[11] = 0x00 959 data[12] = 0x00 960 data[13] = 0x00 961 962 if len(args) > 0 { 963 pos := minPktLen 964 965 var nullMask []byte 966 if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= cap(data) { 967 // buffer has to be extended but we don't know by how much so 968 // we depend on append after all data with known sizes fit. 969 // We stop at that because we deal with a lot of columns here 970 // which makes the required allocation size hard to guess. 971 tmp := make([]byte, pos+maskLen+typesLen) 972 copy(tmp[:pos], data[:pos]) 973 data = tmp 974 nullMask = data[pos : pos+maskLen] 975 // No need to clean nullMask as make ensures that. 976 pos += maskLen 977 } else { 978 nullMask = data[pos : pos+maskLen] 979 for i := range nullMask { 980 nullMask[i] = 0 981 } 982 pos += maskLen 983 } 984 985 // newParameterBoundFlag 1 [1 byte] 986 data[pos] = 0x01 987 pos++ 988 989 // type of each parameter [len(args)*2 bytes] 990 paramTypes := data[pos:] 991 pos += len(args) * 2 992 993 // value of each parameter [n bytes] 994 paramValues := data[pos:pos] 995 valuesCap := cap(paramValues) 996 997 for i, arg := range args { 998 // build NULL-bitmap 999 if arg == nil { 1000 nullMask[i/8] |= 1 << (uint(i) & 7) 1001 paramTypes[i+i] = byte(fieldTypeNULL) 1002 paramTypes[i+i+1] = 0x00 1003 continue 1004 } 1005 1006 // cache types and values 1007 switch v := arg.(type) { 1008 case int64: 1009 paramTypes[i+i] = byte(fieldTypeLongLong) 1010 paramTypes[i+i+1] = 0x00 1011 1012 if cap(paramValues)-len(paramValues)-8 >= 0 { 1013 paramValues = paramValues[:len(paramValues)+8] 1014 binary.LittleEndian.PutUint64( 1015 paramValues[len(paramValues)-8:], 1016 uint64(v), 1017 ) 1018 } else { 1019 paramValues = append(paramValues, 1020 uint64ToBytes(uint64(v))..., 1021 ) 1022 } 1023 1024 case uint64: 1025 paramTypes[i+i] = byte(fieldTypeLongLong) 1026 paramTypes[i+i+1] = 0x80 // type is unsigned 1027 1028 if cap(paramValues)-len(paramValues)-8 >= 0 { 1029 paramValues = paramValues[:len(paramValues)+8] 1030 binary.LittleEndian.PutUint64( 1031 paramValues[len(paramValues)-8:], 1032 uint64(v), 1033 ) 1034 } else { 1035 paramValues = append(paramValues, 1036 uint64ToBytes(uint64(v))..., 1037 ) 1038 } 1039 1040 case float64: 1041 paramTypes[i+i] = byte(fieldTypeDouble) 1042 paramTypes[i+i+1] = 0x00 1043 1044 if cap(paramValues)-len(paramValues)-8 >= 0 { 1045 paramValues = paramValues[:len(paramValues)+8] 1046 binary.LittleEndian.PutUint64( 1047 paramValues[len(paramValues)-8:], 1048 math.Float64bits(v), 1049 ) 1050 } else { 1051 paramValues = append(paramValues, 1052 uint64ToBytes(math.Float64bits(v))..., 1053 ) 1054 } 1055 1056 case bool: 1057 paramTypes[i+i] = byte(fieldTypeTiny) 1058 paramTypes[i+i+1] = 0x00 1059 1060 if v { 1061 paramValues = append(paramValues, 0x01) 1062 } else { 1063 paramValues = append(paramValues, 0x00) 1064 } 1065 1066 case []byte: 1067 // Common case (non-nil value) first 1068 if v != nil { 1069 paramTypes[i+i] = byte(fieldTypeString) 1070 paramTypes[i+i+1] = 0x00 1071 1072 if len(v) < longDataSize { 1073 paramValues = appendLengthEncodedInteger(paramValues, 1074 uint64(len(v)), 1075 ) 1076 paramValues = append(paramValues, v...) 1077 } else { 1078 if err := stmt.writeCommandLongData(i, v); err != nil { 1079 return err 1080 } 1081 } 1082 continue 1083 } 1084 1085 // Handle []byte(nil) as a NULL value 1086 nullMask[i/8] |= 1 << (uint(i) & 7) 1087 paramTypes[i+i] = byte(fieldTypeNULL) 1088 paramTypes[i+i+1] = 0x00 1089 1090 case string: 1091 paramTypes[i+i] = byte(fieldTypeString) 1092 paramTypes[i+i+1] = 0x00 1093 1094 if len(v) < longDataSize { 1095 paramValues = appendLengthEncodedInteger(paramValues, 1096 uint64(len(v)), 1097 ) 1098 paramValues = append(paramValues, v...) 1099 } else { 1100 if err := stmt.writeCommandLongData(i, []byte(v)); err != nil { 1101 return err 1102 } 1103 } 1104 1105 case time.Time: 1106 paramTypes[i+i] = byte(fieldTypeString) 1107 paramTypes[i+i+1] = 0x00 1108 1109 var a [64]byte 1110 var b = a[:0] 1111 1112 if v.IsZero() { 1113 b = append(b, "0000-00-00"...) 1114 } else { 1115 b = v.In(mc.cfg.Loc).AppendFormat(b, timeFormat) 1116 } 1117 1118 paramValues = appendLengthEncodedInteger(paramValues, 1119 uint64(len(b)), 1120 ) 1121 paramValues = append(paramValues, b...) 1122 1123 default: 1124 return fmt.Errorf("cannot convert type: %T", arg) 1125 } 1126 } 1127 1128 // Check if param values exceeded the available buffer 1129 // In that case we must build the data packet with the new values buffer 1130 if valuesCap != cap(paramValues) { 1131 data = append(data[:pos], paramValues...) 1132 if err = mc.buf.store(data); err != nil { 1133 errLog.Print(err) 1134 return errBadConnNoWrite 1135 } 1136 } 1137 1138 pos += len(paramValues) 1139 data = data[:pos] 1140 } 1141 1142 return mc.writePacket(data) 1143 } 1144 1145 func (mc *mysqlConn) discardResults() error { 1146 for mc.status&statusMoreResultsExists != 0 { 1147 resLen, err := mc.readResultSetHeaderPacket() 1148 if err != nil { 1149 return err 1150 } 1151 if resLen > 0 { 1152 // columns 1153 if err := mc.readUntilEOF(); err != nil { 1154 return err 1155 } 1156 // rows 1157 if err := mc.readUntilEOF(); err != nil { 1158 return err 1159 } 1160 } 1161 } 1162 return nil 1163 } 1164 1165 // http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html 1166 func (rows *binaryRows) readRow(dest []driver.Value) error { 1167 data, err := rows.mc.readPacket() 1168 if err != nil { 1169 return err 1170 } 1171 1172 // packet indicator [1 byte] 1173 if data[0] != iOK { 1174 // EOF Packet 1175 if data[0] == iEOF && len(data) == 5 { 1176 rows.mc.status = readStatus(data[3:]) 1177 rows.rs.done = true 1178 if !rows.HasNextResultSet() { 1179 rows.mc = nil 1180 } 1181 return io.EOF 1182 } 1183 mc := rows.mc 1184 rows.mc = nil 1185 1186 // Error otherwise 1187 return mc.handleErrorPacket(data) 1188 } 1189 1190 // NULL-bitmap, [(column-count + 7 + 2) / 8 bytes] 1191 pos := 1 + (len(dest)+7+2)>>3 1192 nullMask := data[1:pos] 1193 1194 for i := range dest { 1195 // Field is NULL 1196 // (byte >> bit-pos) % 2 == 1 1197 if ((nullMask[(i+2)>>3] >> uint((i+2)&7)) & 1) == 1 { 1198 dest[i] = nil 1199 continue 1200 } 1201 1202 // Convert to byte-coded string 1203 switch rows.rs.columns[i].fieldType { 1204 case fieldTypeNULL: 1205 dest[i] = nil 1206 continue 1207 1208 // Numeric Types 1209 case fieldTypeTiny: 1210 if rows.rs.columns[i].flags&flagUnsigned != 0 { 1211 dest[i] = int64(data[pos]) 1212 } else { 1213 dest[i] = int64(int8(data[pos])) 1214 } 1215 pos++ 1216 continue 1217 1218 case fieldTypeShort, fieldTypeYear: 1219 if rows.rs.columns[i].flags&flagUnsigned != 0 { 1220 dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2])) 1221 } else { 1222 dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2]))) 1223 } 1224 pos += 2 1225 continue 1226 1227 case fieldTypeInt24, fieldTypeLong: 1228 if rows.rs.columns[i].flags&flagUnsigned != 0 { 1229 dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4])) 1230 } else { 1231 dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4]))) 1232 } 1233 pos += 4 1234 continue 1235 1236 case fieldTypeLongLong: 1237 if rows.rs.columns[i].flags&flagUnsigned != 0 { 1238 val := binary.LittleEndian.Uint64(data[pos : pos+8]) 1239 if val > math.MaxInt64 { 1240 dest[i] = uint64ToString(val) 1241 } else { 1242 dest[i] = int64(val) 1243 } 1244 } else { 1245 dest[i] = int64(binary.LittleEndian.Uint64(data[pos : pos+8])) 1246 } 1247 pos += 8 1248 continue 1249 1250 case fieldTypeFloat: 1251 dest[i] = math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4])) 1252 pos += 4 1253 continue 1254 1255 case fieldTypeDouble: 1256 dest[i] = math.Float64frombits(binary.LittleEndian.Uint64(data[pos : pos+8])) 1257 pos += 8 1258 continue 1259 1260 // Length coded Binary Strings 1261 case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar, 1262 fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB, 1263 fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB, 1264 fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON: 1265 var isNull bool 1266 var n int 1267 dest[i], isNull, n, err = readLengthEncodedString(data[pos:]) 1268 pos += n 1269 if err == nil { 1270 if !isNull { 1271 continue 1272 } else { 1273 dest[i] = nil 1274 continue 1275 } 1276 } 1277 return err 1278 1279 case 1280 fieldTypeDate, fieldTypeNewDate, // Date YYYY-MM-DD 1281 fieldTypeTime, // Time [-][H]HH:MM:SS[.fractal] 1282 fieldTypeTimestamp, fieldTypeDateTime: // Timestamp YYYY-MM-DD HH:MM:SS[.fractal] 1283 1284 num, isNull, n := readLengthEncodedInteger(data[pos:]) 1285 pos += n 1286 1287 switch { 1288 case isNull: 1289 dest[i] = nil 1290 continue 1291 case rows.rs.columns[i].fieldType == fieldTypeTime: 1292 // database/sql does not support an equivalent to TIME, return a string 1293 var dstlen uint8 1294 switch decimals := rows.rs.columns[i].decimals; decimals { 1295 case 0x00, 0x1f: 1296 dstlen = 8 1297 case 1, 2, 3, 4, 5, 6: 1298 dstlen = 8 + 1 + decimals 1299 default: 1300 return fmt.Errorf( 1301 "protocol error, illegal decimals value %d", 1302 rows.rs.columns[i].decimals, 1303 ) 1304 } 1305 dest[i], err = formatBinaryTime(data[pos:pos+int(num)], dstlen) 1306 case rows.mc.parseTime: 1307 dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc) 1308 default: 1309 var dstlen uint8 1310 if rows.rs.columns[i].fieldType == fieldTypeDate { 1311 dstlen = 10 1312 } else { 1313 switch decimals := rows.rs.columns[i].decimals; decimals { 1314 case 0x00, 0x1f: 1315 dstlen = 19 1316 case 1, 2, 3, 4, 5, 6: 1317 dstlen = 19 + 1 + decimals 1318 default: 1319 return fmt.Errorf( 1320 "protocol error, illegal decimals value %d", 1321 rows.rs.columns[i].decimals, 1322 ) 1323 } 1324 } 1325 dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen) 1326 } 1327 1328 if err == nil { 1329 pos += int(num) 1330 continue 1331 } else { 1332 return err 1333 } 1334 1335 // Please report if this happens! 1336 default: 1337 return fmt.Errorf("unknown field type %d", rows.rs.columns[i].fieldType) 1338 } 1339 } 1340 1341 return nil 1342 }