github.com/insionng/yougam@v0.0.0-20170714101924-2bc18d833463/libraries/go-sql-driver/mysql/packets.go (about) 1 // Go MySQL Driver - A MySQL-Driver for Go's database/sql package 2 // 3 // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. 4 // 5 // This Source Code Form is subject to the terms of the Mozilla Public 6 // License, v. 2.0. If a copy of the MPL was not distributed with this file, 7 // You can obtain one at http://mozilla.org/MPL/2.0/. 8 9 package mysql 10 11 import ( 12 "bytes" 13 "crypto/tls" 14 "database/sql/driver" 15 "encoding/binary" 16 "errors" 17 "fmt" 18 "io" 19 "math" 20 "time" 21 ) 22 23 // Packets documentation: 24 // http://dev.mysql.com/doc/internals/en/client-server-protocol.html 25 26 // Read packet to buffer 'data' 27 func (mc *mysqlConn) readPacket() ([]byte, error) { 28 var payload []byte 29 for { 30 // Read packet header 31 data, err := mc.buf.readNext(4) 32 if err != nil { 33 errLog.Print(err) 34 mc.Close() 35 return nil, driver.ErrBadConn 36 } 37 38 // Packet Length [24 bit] 39 pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16) 40 41 if pktLen < 1 { 42 errLog.Print(ErrMalformPkt) 43 mc.Close() 44 return nil, driver.ErrBadConn 45 } 46 47 // Check Packet Sync [8 bit] 48 if data[3] != mc.sequence { 49 if data[3] > mc.sequence { 50 return nil, ErrPktSyncMul 51 } 52 return nil, ErrPktSync 53 } 54 mc.sequence++ 55 56 // Read packet body [pktLen bytes] 57 data, err = mc.buf.readNext(pktLen) 58 if err != nil { 59 errLog.Print(err) 60 mc.Close() 61 return nil, driver.ErrBadConn 62 } 63 64 isLastPacket := (pktLen < maxPacketSize) 65 66 // Zero allocations for non-splitting packets 67 if isLastPacket && payload == nil { 68 return data, nil 69 } 70 71 payload = append(payload, data...) 72 73 if isLastPacket { 74 return payload, nil 75 } 76 } 77 } 78 79 // Write packet buffer 'data' 80 func (mc *mysqlConn) writePacket(data []byte) error { 81 pktLen := len(data) - 4 82 83 if pktLen > mc.maxPacketAllowed { 84 return ErrPktTooLarge 85 } 86 87 for { 88 var size int 89 if pktLen >= maxPacketSize { 90 data[0] = 0xff 91 data[1] = 0xff 92 data[2] = 0xff 93 size = maxPacketSize 94 } else { 95 data[0] = byte(pktLen) 96 data[1] = byte(pktLen >> 8) 97 data[2] = byte(pktLen >> 16) 98 size = pktLen 99 } 100 data[3] = mc.sequence 101 102 // Write packet 103 if mc.writeTimeout > 0 { 104 if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil { 105 return err 106 } 107 } 108 109 n, err := mc.netConn.Write(data[:4+size]) 110 if err == nil && n == 4+size { 111 mc.sequence++ 112 if size != maxPacketSize { 113 return nil 114 } 115 pktLen -= size 116 data = data[size:] 117 continue 118 } 119 120 // Handle error 121 if err == nil { // n != len(data) 122 errLog.Print(ErrMalformPkt) 123 } else { 124 errLog.Print(err) 125 } 126 return driver.ErrBadConn 127 } 128 } 129 130 /****************************************************************************** 131 * Initialisation Process * 132 ******************************************************************************/ 133 134 // Handshake Initialization Packet 135 // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake 136 func (mc *mysqlConn) readInitPacket() ([]byte, error) { 137 data, err := mc.readPacket() 138 if err != nil { 139 return nil, err 140 } 141 142 if data[0] == iERR { 143 return nil, mc.handleErrorPacket(data) 144 } 145 146 // protocol version [1 byte] 147 if data[0] < minProtocolVersion { 148 return nil, fmt.Errorf( 149 "unsupported protocol version %d. Version %d or higher is required", 150 data[0], 151 minProtocolVersion, 152 ) 153 } 154 155 // server version [null terminated string] 156 // connection id [4 bytes] 157 pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4 158 159 // first part of the password cipher [8 bytes] 160 cipher := data[pos : pos+8] 161 162 // (filler) always 0x00 [1 byte] 163 pos += 8 + 1 164 165 // capability flags (lower 2 bytes) [2 bytes] 166 mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) 167 if mc.flags&clientProtocol41 == 0 { 168 return nil, ErrOldProtocol 169 } 170 if mc.flags&clientSSL == 0 && mc.cfg.tls != nil { 171 return nil, ErrNoTLS 172 } 173 pos += 2 174 175 if len(data) > pos { 176 // character set [1 byte] 177 // status flags [2 bytes] 178 // capability flags (upper 2 bytes) [2 bytes] 179 // length of auth-plugin-data [1 byte] 180 // reserved (all [00]) [10 bytes] 181 pos += 1 + 2 + 2 + 1 + 10 182 183 // second part of the password cipher [mininum 13 bytes], 184 // where len=MAX(13, length of auth-plugin-data - 8) 185 // 186 // The web documentation is ambiguous about the length. However, 187 // according to mysql-5.7/sql/auth/sql_authentication.cc line 538, 188 // the 13th byte is "\0 byte, terminating the second part of 189 // a scramble". So the second part of the password cipher is 190 // a NULL terminated string that's at least 13 bytes with the 191 // last byte being NULL. 192 // 193 // The official Python library uses the fixed length 12 194 // which seems to work but technically could have a hidden bug. 195 cipher = append(cipher, data[pos:pos+12]...) 196 197 // TODO: Verify string termination 198 // EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2) 199 // \NUL otherwise 200 // 201 //if data[len(data)-1] == 0 { 202 // return 203 //} 204 //return ErrMalformPkt 205 206 // make a memory safe copy of the cipher slice 207 var b [20]byte 208 copy(b[:], cipher) 209 return b[:], nil 210 } 211 212 // make a memory safe copy of the cipher slice 213 var b [8]byte 214 copy(b[:], cipher) 215 return b[:], nil 216 } 217 218 // Client Authentication Packet 219 // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse 220 func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { 221 // Adjust client flags based on server support 222 clientFlags := clientProtocol41 | 223 clientSecureConn | 224 clientLongPassword | 225 clientTransactions | 226 clientLocalFiles | 227 clientPluginAuth | 228 clientMultiResults | 229 mc.flags&clientLongFlag 230 231 if mc.cfg.ClientFoundRows { 232 clientFlags |= clientFoundRows 233 } 234 235 // To enable TLS / SSL 236 if mc.cfg.tls != nil { 237 clientFlags |= clientSSL 238 } 239 240 if mc.cfg.MultiStatements { 241 clientFlags |= clientMultiStatements 242 } 243 244 // User Password 245 scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd)) 246 247 pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + 1 + len(scrambleBuff) + 21 + 1 248 249 // To specify a db name 250 if n := len(mc.cfg.DBName); n > 0 { 251 clientFlags |= clientConnectWithDB 252 pktLen += n + 1 253 } 254 255 // Calculate packet length and get buffer with that size 256 data := mc.buf.takeSmallBuffer(pktLen + 4) 257 if data == nil { 258 // can not take the buffer. Something must be wrong with the connection 259 errLog.Print(ErrBusyBuffer) 260 return driver.ErrBadConn 261 } 262 263 // ClientFlags [32 bit] 264 data[4] = byte(clientFlags) 265 data[5] = byte(clientFlags >> 8) 266 data[6] = byte(clientFlags >> 16) 267 data[7] = byte(clientFlags >> 24) 268 269 // MaxPacketSize [32 bit] (none) 270 data[8] = 0x00 271 data[9] = 0x00 272 data[10] = 0x00 273 data[11] = 0x00 274 275 // Charset [1 byte] 276 var found bool 277 data[12], found = collations[mc.cfg.Collation] 278 if !found { 279 // Note possibility for false negatives: 280 // could be triggered although the collation is valid if the 281 // collations map does not contain entries the server supports. 282 return errors.New("unknown collation") 283 } 284 285 // SSL Connection Request Packet 286 // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest 287 if mc.cfg.tls != nil { 288 // Send TLS / SSL request packet 289 if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil { 290 return err 291 } 292 293 // Switch to TLS 294 tlsConn := tls.Client(mc.netConn, mc.cfg.tls) 295 if err := tlsConn.Handshake(); err != nil { 296 return err 297 } 298 mc.netConn = tlsConn 299 mc.buf.nc = tlsConn 300 } 301 302 // Filler [23 bytes] (all 0x00) 303 pos := 13 304 for ; pos < 13+23; pos++ { 305 data[pos] = 0 306 } 307 308 // User [null terminated string] 309 if len(mc.cfg.User) > 0 { 310 pos += copy(data[pos:], mc.cfg.User) 311 } 312 data[pos] = 0x00 313 pos++ 314 315 // ScrambleBuffer [length encoded integer] 316 data[pos] = byte(len(scrambleBuff)) 317 pos += 1 + copy(data[pos+1:], scrambleBuff) 318 319 // Databasename [null terminated string] 320 if len(mc.cfg.DBName) > 0 { 321 pos += copy(data[pos:], mc.cfg.DBName) 322 data[pos] = 0x00 323 pos++ 324 } 325 326 // Assume native client during response 327 pos += copy(data[pos:], "mysql_native_password") 328 data[pos] = 0x00 329 330 // Send Auth packet 331 return mc.writePacket(data) 332 } 333 334 // Client old authentication packet 335 // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse 336 func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error { 337 // User password 338 scrambleBuff := scrambleOldPassword(cipher, []byte(mc.cfg.Passwd)) 339 340 // Calculate the packet length and add a tailing 0 341 pktLen := len(scrambleBuff) + 1 342 data := mc.buf.takeSmallBuffer(4 + pktLen) 343 if data == nil { 344 // can not take the buffer. Something must be wrong with the connection 345 errLog.Print(ErrBusyBuffer) 346 return driver.ErrBadConn 347 } 348 349 // Add the scrambled password [null terminated string] 350 copy(data[4:], scrambleBuff) 351 data[4+pktLen-1] = 0x00 352 353 return mc.writePacket(data) 354 } 355 356 // Client clear text authentication packet 357 // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse 358 func (mc *mysqlConn) writeClearAuthPacket() error { 359 // Calculate the packet length and add a tailing 0 360 pktLen := len(mc.cfg.Passwd) + 1 361 data := mc.buf.takeSmallBuffer(4 + pktLen) 362 if data == nil { 363 // can not take the buffer. Something must be wrong with the connection 364 errLog.Print(ErrBusyBuffer) 365 return driver.ErrBadConn 366 } 367 368 // Add the clear password [null terminated string] 369 copy(data[4:], mc.cfg.Passwd) 370 data[4+pktLen-1] = 0x00 371 372 return mc.writePacket(data) 373 } 374 375 /****************************************************************************** 376 * Command Packets * 377 ******************************************************************************/ 378 379 func (mc *mysqlConn) writeCommandPacket(command byte) error { 380 // Reset Packet Sequence 381 mc.sequence = 0 382 383 data := mc.buf.takeSmallBuffer(4 + 1) 384 if data == nil { 385 // can not take the buffer. Something must be wrong with the connection 386 errLog.Print(ErrBusyBuffer) 387 return driver.ErrBadConn 388 } 389 390 // Add command byte 391 data[4] = command 392 393 // Send CMD packet 394 return mc.writePacket(data) 395 } 396 397 func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { 398 // Reset Packet Sequence 399 mc.sequence = 0 400 401 pktLen := 1 + len(arg) 402 data := mc.buf.takeBuffer(pktLen + 4) 403 if data == nil { 404 // can not take the buffer. Something must be wrong with the connection 405 errLog.Print(ErrBusyBuffer) 406 return driver.ErrBadConn 407 } 408 409 // Add command byte 410 data[4] = command 411 412 // Add arg 413 copy(data[5:], arg) 414 415 // Send CMD packet 416 return mc.writePacket(data) 417 } 418 419 func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { 420 // Reset Packet Sequence 421 mc.sequence = 0 422 423 data := mc.buf.takeSmallBuffer(4 + 1 + 4) 424 if data == nil { 425 // can not take the buffer. Something must be wrong with the connection 426 errLog.Print(ErrBusyBuffer) 427 return driver.ErrBadConn 428 } 429 430 // Add command byte 431 data[4] = command 432 433 // Add arg [32 bit] 434 data[5] = byte(arg) 435 data[6] = byte(arg >> 8) 436 data[7] = byte(arg >> 16) 437 data[8] = byte(arg >> 24) 438 439 // Send CMD packet 440 return mc.writePacket(data) 441 } 442 443 /****************************************************************************** 444 * Result Packets * 445 ******************************************************************************/ 446 447 // Returns error if Packet is not an 'Result OK'-Packet 448 func (mc *mysqlConn) readResultOK() error { 449 data, err := mc.readPacket() 450 if err == nil { 451 // packet indicator 452 switch data[0] { 453 454 case iOK: 455 return mc.handleOkPacket(data) 456 457 case iEOF: 458 if len(data) > 1 { 459 plugin := string(data[1:bytes.IndexByte(data, 0x00)]) 460 if plugin == "mysql_old_password" { 461 // using old_passwords 462 return ErrOldPassword 463 } else if plugin == "mysql_clear_password" { 464 // using clear text password 465 return ErrCleartextPassword 466 } else { 467 return ErrUnknownPlugin 468 } 469 } else { 470 return ErrOldPassword 471 } 472 473 default: // Error otherwise 474 return mc.handleErrorPacket(data) 475 } 476 } 477 return err 478 } 479 480 // Result Set Header Packet 481 // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset 482 func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) { 483 data, err := mc.readPacket() 484 if err == nil { 485 switch data[0] { 486 487 case iOK: 488 return 0, mc.handleOkPacket(data) 489 490 case iERR: 491 return 0, mc.handleErrorPacket(data) 492 493 case iLocalInFile: 494 return 0, mc.handleInFileRequest(string(data[1:])) 495 } 496 497 // column count 498 num, _, n := readLengthEncodedInteger(data) 499 if n-len(data) == 0 { 500 return int(num), nil 501 } 502 503 return 0, ErrMalformPkt 504 } 505 return 0, err 506 } 507 508 // Error Packet 509 // http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-ERR_Packet 510 func (mc *mysqlConn) handleErrorPacket(data []byte) error { 511 if data[0] != iERR { 512 return ErrMalformPkt 513 } 514 515 // 0xff [1 byte] 516 517 // Error Number [16 bit uint] 518 errno := binary.LittleEndian.Uint16(data[1:3]) 519 520 pos := 3 521 522 // SQL State [optional: # + 5bytes string] 523 if data[3] == 0x23 { 524 //sqlstate := string(data[4 : 4+5]) 525 pos = 9 526 } 527 528 // Error Message [string] 529 return &MySQLError{ 530 Number: errno, 531 Message: string(data[pos:]), 532 } 533 } 534 535 func readStatus(b []byte) statusFlag { 536 return statusFlag(b[0]) | statusFlag(b[1])<<8 537 } 538 539 // Ok Packet 540 // http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet 541 func (mc *mysqlConn) handleOkPacket(data []byte) error { 542 var n, m int 543 544 // 0x00 [1 byte] 545 546 // Affected rows [Length Coded Binary] 547 mc.affectedRows, _, n = readLengthEncodedInteger(data[1:]) 548 549 // Insert id [Length Coded Binary] 550 mc.insertId, _, m = readLengthEncodedInteger(data[1+n:]) 551 552 // server_status [2 bytes] 553 mc.status = readStatus(data[1+n+m : 1+n+m+2]) 554 if err := mc.discardResults(); err != nil { 555 return err 556 } 557 558 // warning count [2 bytes] 559 if !mc.strict { 560 return nil 561 } 562 563 pos := 1 + n + m + 2 564 if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 { 565 return mc.getWarnings() 566 } 567 return nil 568 } 569 570 // Read Packets as Field Packets until EOF-Packet or an Error appears 571 // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41 572 func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { 573 columns := make([]mysqlField, count) 574 575 for i := 0; ; i++ { 576 data, err := mc.readPacket() 577 if err != nil { 578 return nil, err 579 } 580 581 // EOF Packet 582 if data[0] == iEOF && (len(data) == 5 || len(data) == 1) { 583 if i == count { 584 return columns, nil 585 } 586 return nil, fmt.Errorf("column count mismatch n:%d len:%d", count, len(columns)) 587 } 588 589 // Catalog 590 pos, err := skipLengthEncodedString(data) 591 if err != nil { 592 return nil, err 593 } 594 595 // Database [len coded string] 596 n, err := skipLengthEncodedString(data[pos:]) 597 if err != nil { 598 return nil, err 599 } 600 pos += n 601 602 // Table [len coded string] 603 if mc.cfg.ColumnsWithAlias { 604 tableName, _, n, err := readLengthEncodedString(data[pos:]) 605 if err != nil { 606 return nil, err 607 } 608 pos += n 609 columns[i].tableName = string(tableName) 610 } else { 611 n, err = skipLengthEncodedString(data[pos:]) 612 if err != nil { 613 return nil, err 614 } 615 pos += n 616 } 617 618 // Original table [len coded string] 619 n, err = skipLengthEncodedString(data[pos:]) 620 if err != nil { 621 return nil, err 622 } 623 pos += n 624 625 // Name [len coded string] 626 name, _, n, err := readLengthEncodedString(data[pos:]) 627 if err != nil { 628 return nil, err 629 } 630 columns[i].name = string(name) 631 pos += n 632 633 // Original name [len coded string] 634 n, err = skipLengthEncodedString(data[pos:]) 635 if err != nil { 636 return nil, err 637 } 638 639 // Filler [uint8] 640 // Charset [charset, collation uint8] 641 // Length [uint32] 642 pos += n + 1 + 2 + 4 643 644 // Field type [uint8] 645 columns[i].fieldType = data[pos] 646 pos++ 647 648 // Flags [uint16] 649 columns[i].flags = fieldFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) 650 pos += 2 651 652 // Decimals [uint8] 653 columns[i].decimals = data[pos] 654 //pos++ 655 656 // Default value [len coded binary] 657 //if pos < len(data) { 658 // defaultVal, _, err = bytesToLengthCodedBinary(data[pos:]) 659 //} 660 } 661 } 662 663 // Read Packets as Field Packets until EOF-Packet or an Error appears 664 // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow 665 func (rows *textRows) readRow(dest []driver.Value) error { 666 mc := rows.mc 667 668 data, err := mc.readPacket() 669 if err != nil { 670 return err 671 } 672 673 // EOF Packet 674 if data[0] == iEOF && len(data) == 5 { 675 // server_status [2 bytes] 676 rows.mc.status = readStatus(data[3:]) 677 if err := rows.mc.discardResults(); err != nil { 678 return err 679 } 680 rows.mc = nil 681 return io.EOF 682 } 683 if data[0] == iERR { 684 rows.mc = nil 685 return mc.handleErrorPacket(data) 686 } 687 688 // RowSet Packet 689 var n int 690 var isNull bool 691 pos := 0 692 693 for i := range dest { 694 // Read bytes and convert to string 695 dest[i], isNull, n, err = readLengthEncodedString(data[pos:]) 696 pos += n 697 if err == nil { 698 if !isNull { 699 if !mc.parseTime { 700 continue 701 } else { 702 switch rows.columns[i].fieldType { 703 case fieldTypeTimestamp, fieldTypeDateTime, 704 fieldTypeDate, fieldTypeNewDate: 705 dest[i], err = parseDateTime( 706 string(dest[i].([]byte)), 707 mc.cfg.Loc, 708 ) 709 if err == nil { 710 continue 711 } 712 default: 713 continue 714 } 715 } 716 717 } else { 718 dest[i] = nil 719 continue 720 } 721 } 722 return err // err != nil 723 } 724 725 return nil 726 } 727 728 // Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read 729 func (mc *mysqlConn) readUntilEOF() error { 730 for { 731 data, err := mc.readPacket() 732 733 // No Err and no EOF Packet 734 if err == nil && data[0] != iEOF { 735 continue 736 } 737 if err == nil && data[0] == iEOF && len(data) == 5 { 738 mc.status = readStatus(data[3:]) 739 } 740 741 return err // Err or EOF 742 } 743 } 744 745 /****************************************************************************** 746 * Prepared Statements * 747 ******************************************************************************/ 748 749 // Prepare Result Packets 750 // http://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html 751 func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) { 752 data, err := stmt.mc.readPacket() 753 if err == nil { 754 // packet indicator [1 byte] 755 if data[0] != iOK { 756 return 0, stmt.mc.handleErrorPacket(data) 757 } 758 759 // statement id [4 bytes] 760 stmt.id = binary.LittleEndian.Uint32(data[1:5]) 761 762 // Column count [16 bit uint] 763 columnCount := binary.LittleEndian.Uint16(data[5:7]) 764 765 // Param count [16 bit uint] 766 stmt.paramCount = int(binary.LittleEndian.Uint16(data[7:9])) 767 768 // Reserved [8 bit] 769 770 // Warning count [16 bit uint] 771 if !stmt.mc.strict { 772 return columnCount, nil 773 } 774 775 // Check for warnings count > 0, only available in MySQL > 4.1 776 if len(data) >= 12 && binary.LittleEndian.Uint16(data[10:12]) > 0 { 777 return columnCount, stmt.mc.getWarnings() 778 } 779 return columnCount, nil 780 } 781 return 0, err 782 } 783 784 // http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html 785 func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { 786 maxLen := stmt.mc.maxPacketAllowed - 1 787 pktLen := maxLen 788 789 // After the header (bytes 0-3) follows before the data: 790 // 1 byte command 791 // 4 bytes stmtID 792 // 2 bytes paramID 793 const dataOffset = 1 + 4 + 2 794 795 // Can not use the write buffer since 796 // a) the buffer is too small 797 // b) it is in use 798 data := make([]byte, 4+1+4+2+len(arg)) 799 800 copy(data[4+dataOffset:], arg) 801 802 for argLen := len(arg); argLen > 0; argLen -= pktLen - dataOffset { 803 if dataOffset+argLen < maxLen { 804 pktLen = dataOffset + argLen 805 } 806 807 stmt.mc.sequence = 0 808 // Add command byte [1 byte] 809 data[4] = comStmtSendLongData 810 811 // Add stmtID [32 bit] 812 data[5] = byte(stmt.id) 813 data[6] = byte(stmt.id >> 8) 814 data[7] = byte(stmt.id >> 16) 815 data[8] = byte(stmt.id >> 24) 816 817 // Add paramID [16 bit] 818 data[9] = byte(paramID) 819 data[10] = byte(paramID >> 8) 820 821 // Send CMD packet 822 err := stmt.mc.writePacket(data[:4+pktLen]) 823 if err == nil { 824 data = data[pktLen-dataOffset:] 825 continue 826 } 827 return err 828 829 } 830 831 // Reset Packet Sequence 832 stmt.mc.sequence = 0 833 return nil 834 } 835 836 // Execute Prepared Statement 837 // http://dev.mysql.com/doc/internals/en/com-stmt-execute.html 838 func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { 839 if len(args) != stmt.paramCount { 840 return fmt.Errorf( 841 "argument count mismatch (got: %d; has: %d)", 842 len(args), 843 stmt.paramCount, 844 ) 845 } 846 847 const minPktLen = 4 + 1 + 4 + 1 + 4 848 mc := stmt.mc 849 850 // Reset packet-sequence 851 mc.sequence = 0 852 853 var data []byte 854 855 if len(args) == 0 { 856 data = mc.buf.takeBuffer(minPktLen) 857 } else { 858 data = mc.buf.takeCompleteBuffer() 859 } 860 if data == nil { 861 // can not take the buffer. Something must be wrong with the connection 862 errLog.Print(ErrBusyBuffer) 863 return driver.ErrBadConn 864 } 865 866 // command [1 byte] 867 data[4] = comStmtExecute 868 869 // statement_id [4 bytes] 870 data[5] = byte(stmt.id) 871 data[6] = byte(stmt.id >> 8) 872 data[7] = byte(stmt.id >> 16) 873 data[8] = byte(stmt.id >> 24) 874 875 // flags (0: CURSOR_TYPE_NO_CURSOR) [1 byte] 876 data[9] = 0x00 877 878 // iteration_count (uint32(1)) [4 bytes] 879 data[10] = 0x01 880 data[11] = 0x00 881 data[12] = 0x00 882 data[13] = 0x00 883 884 if len(args) > 0 { 885 pos := minPktLen 886 887 var nullMask []byte 888 if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= len(data) { 889 // buffer has to be extended but we don't know by how much so 890 // we depend on append after all data with known sizes fit. 891 // We stop at that because we deal with a lot of columns here 892 // which makes the required allocation size hard to guess. 893 tmp := make([]byte, pos+maskLen+typesLen) 894 copy(tmp[:pos], data[:pos]) 895 data = tmp 896 nullMask = data[pos : pos+maskLen] 897 pos += maskLen 898 } else { 899 nullMask = data[pos : pos+maskLen] 900 for i := 0; i < maskLen; i++ { 901 nullMask[i] = 0 902 } 903 pos += maskLen 904 } 905 906 // newParameterBoundFlag 1 [1 byte] 907 data[pos] = 0x01 908 pos++ 909 910 // type of each parameter [len(args)*2 bytes] 911 paramTypes := data[pos:] 912 pos += len(args) * 2 913 914 // value of each parameter [n bytes] 915 paramValues := data[pos:pos] 916 valuesCap := cap(paramValues) 917 918 for i, arg := range args { 919 // build NULL-bitmap 920 if arg == nil { 921 nullMask[i/8] |= 1 << (uint(i) & 7) 922 paramTypes[i+i] = fieldTypeNULL 923 paramTypes[i+i+1] = 0x00 924 continue 925 } 926 927 // cache types and values 928 switch v := arg.(type) { 929 case int64: 930 paramTypes[i+i] = fieldTypeLongLong 931 paramTypes[i+i+1] = 0x00 932 933 if cap(paramValues)-len(paramValues)-8 >= 0 { 934 paramValues = paramValues[:len(paramValues)+8] 935 binary.LittleEndian.PutUint64( 936 paramValues[len(paramValues)-8:], 937 uint64(v), 938 ) 939 } else { 940 paramValues = append(paramValues, 941 uint64ToBytes(uint64(v))..., 942 ) 943 } 944 945 case float64: 946 paramTypes[i+i] = fieldTypeDouble 947 paramTypes[i+i+1] = 0x00 948 949 if cap(paramValues)-len(paramValues)-8 >= 0 { 950 paramValues = paramValues[:len(paramValues)+8] 951 binary.LittleEndian.PutUint64( 952 paramValues[len(paramValues)-8:], 953 math.Float64bits(v), 954 ) 955 } else { 956 paramValues = append(paramValues, 957 uint64ToBytes(math.Float64bits(v))..., 958 ) 959 } 960 961 case bool: 962 paramTypes[i+i] = fieldTypeTiny 963 paramTypes[i+i+1] = 0x00 964 965 if v { 966 paramValues = append(paramValues, 0x01) 967 } else { 968 paramValues = append(paramValues, 0x00) 969 } 970 971 case []byte: 972 // Common case (non-nil value) first 973 if v != nil { 974 paramTypes[i+i] = fieldTypeString 975 paramTypes[i+i+1] = 0x00 976 977 if len(v) < mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 { 978 paramValues = appendLengthEncodedInteger(paramValues, 979 uint64(len(v)), 980 ) 981 paramValues = append(paramValues, v...) 982 } else { 983 if err := stmt.writeCommandLongData(i, v); err != nil { 984 return err 985 } 986 } 987 continue 988 } 989 990 // Handle []byte(nil) as a NULL value 991 nullMask[i/8] |= 1 << (uint(i) & 7) 992 paramTypes[i+i] = fieldTypeNULL 993 paramTypes[i+i+1] = 0x00 994 995 case string: 996 paramTypes[i+i] = fieldTypeString 997 paramTypes[i+i+1] = 0x00 998 999 if len(v) < mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 { 1000 paramValues = appendLengthEncodedInteger(paramValues, 1001 uint64(len(v)), 1002 ) 1003 paramValues = append(paramValues, v...) 1004 } else { 1005 if err := stmt.writeCommandLongData(i, []byte(v)); err != nil { 1006 return err 1007 } 1008 } 1009 1010 case time.Time: 1011 paramTypes[i+i] = fieldTypeString 1012 paramTypes[i+i+1] = 0x00 1013 1014 var val []byte 1015 if v.IsZero() { 1016 val = []byte("0000-00-00") 1017 } else { 1018 val = []byte(v.In(mc.cfg.Loc).Format(timeFormat)) 1019 } 1020 1021 paramValues = appendLengthEncodedInteger(paramValues, 1022 uint64(len(val)), 1023 ) 1024 paramValues = append(paramValues, val...) 1025 1026 default: 1027 return fmt.Errorf("can not convert type: %T", arg) 1028 } 1029 } 1030 1031 // Check if param values exceeded the available buffer 1032 // In that case we must build the data packet with the new values buffer 1033 if valuesCap != cap(paramValues) { 1034 data = append(data[:pos], paramValues...) 1035 mc.buf.buf = data 1036 } 1037 1038 pos += len(paramValues) 1039 data = data[:pos] 1040 } 1041 1042 return mc.writePacket(data) 1043 } 1044 1045 func (mc *mysqlConn) discardResults() error { 1046 for mc.status&statusMoreResultsExists != 0 { 1047 resLen, err := mc.readResultSetHeaderPacket() 1048 if err != nil { 1049 return err 1050 } 1051 if resLen > 0 { 1052 // columns 1053 if err := mc.readUntilEOF(); err != nil { 1054 return err 1055 } 1056 // rows 1057 if err := mc.readUntilEOF(); err != nil { 1058 return err 1059 } 1060 } else { 1061 mc.status &^= statusMoreResultsExists 1062 } 1063 } 1064 return nil 1065 } 1066 1067 // http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html 1068 func (rows *binaryRows) readRow(dest []driver.Value) error { 1069 data, err := rows.mc.readPacket() 1070 if err != nil { 1071 return err 1072 } 1073 1074 // packet indicator [1 byte] 1075 if data[0] != iOK { 1076 // EOF Packet 1077 if data[0] == iEOF && len(data) == 5 { 1078 rows.mc.status = readStatus(data[3:]) 1079 if err := rows.mc.discardResults(); err != nil { 1080 return err 1081 } 1082 rows.mc = nil 1083 return io.EOF 1084 } 1085 rows.mc = nil 1086 1087 // Error otherwise 1088 return rows.mc.handleErrorPacket(data) 1089 } 1090 1091 // NULL-bitmap, [(column-count + 7 + 2) / 8 bytes] 1092 pos := 1 + (len(dest)+7+2)>>3 1093 nullMask := data[1:pos] 1094 1095 for i := range dest { 1096 // Field is NULL 1097 // (byte >> bit-pos) % 2 == 1 1098 if ((nullMask[(i+2)>>3] >> uint((i+2)&7)) & 1) == 1 { 1099 dest[i] = nil 1100 continue 1101 } 1102 1103 // Convert to byte-coded string 1104 switch rows.columns[i].fieldType { 1105 case fieldTypeNULL: 1106 dest[i] = nil 1107 continue 1108 1109 // Numeric Types 1110 case fieldTypeTiny: 1111 if rows.columns[i].flags&flagUnsigned != 0 { 1112 dest[i] = int64(data[pos]) 1113 } else { 1114 dest[i] = int64(int8(data[pos])) 1115 } 1116 pos++ 1117 continue 1118 1119 case fieldTypeShort, fieldTypeYear: 1120 if rows.columns[i].flags&flagUnsigned != 0 { 1121 dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2])) 1122 } else { 1123 dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2]))) 1124 } 1125 pos += 2 1126 continue 1127 1128 case fieldTypeInt24, fieldTypeLong: 1129 if rows.columns[i].flags&flagUnsigned != 0 { 1130 dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4])) 1131 } else { 1132 dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4]))) 1133 } 1134 pos += 4 1135 continue 1136 1137 case fieldTypeLongLong: 1138 if rows.columns[i].flags&flagUnsigned != 0 { 1139 val := binary.LittleEndian.Uint64(data[pos : pos+8]) 1140 if val > math.MaxInt64 { 1141 dest[i] = uint64ToString(val) 1142 } else { 1143 dest[i] = int64(val) 1144 } 1145 } else { 1146 dest[i] = int64(binary.LittleEndian.Uint64(data[pos : pos+8])) 1147 } 1148 pos += 8 1149 continue 1150 1151 case fieldTypeFloat: 1152 dest[i] = float64(math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4]))) 1153 pos += 4 1154 continue 1155 1156 case fieldTypeDouble: 1157 dest[i] = math.Float64frombits(binary.LittleEndian.Uint64(data[pos : pos+8])) 1158 pos += 8 1159 continue 1160 1161 // Length coded Binary Strings 1162 case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar, 1163 fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB, 1164 fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB, 1165 fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON: 1166 var isNull bool 1167 var n int 1168 dest[i], isNull, n, err = readLengthEncodedString(data[pos:]) 1169 pos += n 1170 if err == nil { 1171 if !isNull { 1172 continue 1173 } else { 1174 dest[i] = nil 1175 continue 1176 } 1177 } 1178 return err 1179 1180 case 1181 fieldTypeDate, fieldTypeNewDate, // Date YYYY-MM-DD 1182 fieldTypeTime, // Time [-][H]HH:MM:SS[.fractal] 1183 fieldTypeTimestamp, fieldTypeDateTime: // Timestamp YYYY-MM-DD HH:MM:SS[.fractal] 1184 1185 num, isNull, n := readLengthEncodedInteger(data[pos:]) 1186 pos += n 1187 1188 switch { 1189 case isNull: 1190 dest[i] = nil 1191 continue 1192 case rows.columns[i].fieldType == fieldTypeTime: 1193 // database/sql does not support an equivalent to TIME, return a string 1194 var dstlen uint8 1195 switch decimals := rows.columns[i].decimals; decimals { 1196 case 0x00, 0x1f: 1197 dstlen = 8 1198 case 1, 2, 3, 4, 5, 6: 1199 dstlen = 8 + 1 + decimals 1200 default: 1201 return fmt.Errorf( 1202 "protocol error, illegal decimals value %d", 1203 rows.columns[i].decimals, 1204 ) 1205 } 1206 dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, true) 1207 case rows.mc.parseTime: 1208 dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc) 1209 default: 1210 var dstlen uint8 1211 if rows.columns[i].fieldType == fieldTypeDate { 1212 dstlen = 10 1213 } else { 1214 switch decimals := rows.columns[i].decimals; decimals { 1215 case 0x00, 0x1f: 1216 dstlen = 19 1217 case 1, 2, 3, 4, 5, 6: 1218 dstlen = 19 + 1 + decimals 1219 default: 1220 return fmt.Errorf( 1221 "protocol error, illegal decimals value %d", 1222 rows.columns[i].decimals, 1223 ) 1224 } 1225 } 1226 dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, false) 1227 } 1228 1229 if err == nil { 1230 pos += int(num) 1231 continue 1232 } else { 1233 return err 1234 } 1235 1236 // Please report if this happens! 1237 default: 1238 return fmt.Errorf("unknown field type %d", rows.columns[i].fieldType) 1239 } 1240 } 1241 1242 return nil 1243 }