github.com/cellofellow/gopkg@v0.0.0-20140722061823-eec0544a62ad/database/mysql/packets.go (about) 1 // Go MySQL Driver - A MySQL-Driver for Go's database/sql package 2 // 3 // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. 4 // 5 // This Source Code Form is subject to the terms of the Mozilla Public 6 // License, v. 2.0. If a copy of the MPL was not distributed with this file, 7 // You can obtain one at http://mozilla.org/MPL/2.0/. 8 9 package mysql 10 11 import ( 12 "bytes" 13 "crypto/tls" 14 "database/sql/driver" 15 "encoding/binary" 16 "fmt" 17 "io" 18 "math" 19 "time" 20 ) 21 22 // Packets documentation: 23 // http://dev.mysql.com/doc/internals/en/client-server-protocol.html 24 25 // Read packet to buffer 'data' 26 func (mc *mysqlConn) readPacket() ([]byte, error) { 27 var payload []byte 28 for { 29 // Read packet header 30 data, err := mc.buf.readNext(4) 31 if err != nil { 32 errLog.Print(err) 33 mc.Close() 34 return nil, driver.ErrBadConn 35 } 36 37 // Packet Length [24 bit] 38 pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16) 39 40 if pktLen < 1 { 41 errLog.Print(ErrMalformPkt) 42 mc.Close() 43 return nil, driver.ErrBadConn 44 } 45 46 // Check Packet Sync [8 bit] 47 if data[3] != mc.sequence { 48 if data[3] > mc.sequence { 49 return nil, ErrPktSyncMul 50 } else { 51 return nil, ErrPktSync 52 } 53 } 54 mc.sequence++ 55 56 // Read packet body [pktLen bytes] 57 data, err = mc.buf.readNext(pktLen) 58 if err != nil { 59 errLog.Print(err) 60 mc.Close() 61 return nil, driver.ErrBadConn 62 } 63 64 isLastPacket := (pktLen < maxPacketSize) 65 66 // Zero allocations for non-splitting packets 67 if isLastPacket && payload == nil { 68 return data, nil 69 } 70 71 payload = append(payload, data...) 72 73 if isLastPacket { 74 return payload, nil 75 } 76 } 77 } 78 79 // Write packet buffer 'data' 80 func (mc *mysqlConn) writePacket(data []byte) error { 81 pktLen := len(data) - 4 82 83 if pktLen > mc.maxPacketAllowed { 84 return ErrPktTooLarge 85 } 86 87 for { 88 var size int 89 if pktLen >= maxPacketSize { 90 data[0] = 0xff 91 data[1] = 0xff 92 data[2] = 0xff 93 size = maxPacketSize 94 } else { 95 data[0] = byte(pktLen) 96 data[1] = byte(pktLen >> 8) 97 data[2] = byte(pktLen >> 16) 98 size = pktLen 99 } 100 data[3] = mc.sequence 101 102 // Write packet 103 n, err := mc.netConn.Write(data[:4+size]) 104 if err == nil && n == 4+size { 105 mc.sequence++ 106 if size != maxPacketSize { 107 return nil 108 } 109 pktLen -= size 110 data = data[size:] 111 continue 112 } 113 114 // Handle error 115 if err == nil { // n != len(data) 116 errLog.Print(ErrMalformPkt) 117 } else { 118 errLog.Print(err) 119 } 120 return driver.ErrBadConn 121 } 122 } 123 124 /****************************************************************************** 125 * Initialisation Process * 126 ******************************************************************************/ 127 128 // Handshake Initialization Packet 129 // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake 130 func (mc *mysqlConn) readInitPacket() ([]byte, error) { 131 data, err := mc.readPacket() 132 if err != nil { 133 return nil, err 134 } 135 136 if data[0] == iERR { 137 return nil, mc.handleErrorPacket(data) 138 } 139 140 // protocol version [1 byte] 141 if data[0] < minProtocolVersion { 142 return nil, fmt.Errorf( 143 "Unsupported MySQL Protocol Version %d. Protocol Version %d or higher is required", 144 data[0], 145 minProtocolVersion, 146 ) 147 } 148 149 // server version [null terminated string] 150 // connection id [4 bytes] 151 pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4 152 153 // first part of the password cipher [8 bytes] 154 cipher := data[pos : pos+8] 155 156 // (filler) always 0x00 [1 byte] 157 pos += 8 + 1 158 159 // capability flags (lower 2 bytes) [2 bytes] 160 mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) 161 if mc.flags&clientProtocol41 == 0 { 162 return nil, ErrOldProtocol 163 } 164 if mc.flags&clientSSL == 0 && mc.cfg.tls != nil { 165 return nil, ErrNoTLS 166 } 167 pos += 2 168 169 if len(data) > pos { 170 // character set [1 byte] 171 // status flags [2 bytes] 172 // capability flags (upper 2 bytes) [2 bytes] 173 // length of auth-plugin-data [1 byte] 174 // reserved (all [00]) [10 bytes] 175 pos += 1 + 2 + 2 + 1 + 10 176 177 // second part of the password cipher [mininum 13 bytes], 178 // where len=MAX(13, length of auth-plugin-data - 8) 179 // 180 // The web documentation is ambiguous about the length. However, 181 // according to mysql-5.7/sql/auth/sql_authentication.cc line 538, 182 // the 13th byte is "\0 byte, terminating the second part of 183 // a scramble". So the second part of the password cipher is 184 // a NULL terminated string that's at least 13 bytes with the 185 // last byte being NULL. 186 // 187 // The official Python library uses the fixed length 12 188 // which seems to work but technically could have a hidden bug. 189 cipher = append(cipher, data[pos:pos+12]...) 190 191 // TODO: Verify string termination 192 // EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2) 193 // \NUL otherwise 194 // 195 //if data[len(data)-1] == 0 { 196 // return 197 //} 198 //return ErrMalformPkt 199 return cipher, nil 200 } 201 202 // make a memory safe copy of the cipher slice 203 var b [8]byte 204 copy(b[:], cipher) 205 return b[:], nil 206 } 207 208 // Client Authentication Packet 209 // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse 210 func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { 211 // Adjust client flags based on server support 212 clientFlags := clientProtocol41 | 213 clientSecureConn | 214 clientLongPassword | 215 clientTransactions | 216 clientLocalFiles | 217 mc.flags&clientLongFlag 218 219 if mc.cfg.clientFoundRows { 220 clientFlags |= clientFoundRows 221 } 222 223 // To enable TLS / SSL 224 if mc.cfg.tls != nil { 225 clientFlags |= clientSSL 226 } 227 228 // User Password 229 scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.passwd)) 230 231 pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.user) + 1 + 1 + len(scrambleBuff) 232 233 // To specify a db name 234 if n := len(mc.cfg.dbname); n > 0 { 235 clientFlags |= clientConnectWithDB 236 pktLen += n + 1 237 } 238 239 // Calculate packet length and get buffer with that size 240 data := mc.buf.takeSmallBuffer(pktLen + 4) 241 if data == nil { 242 // can not take the buffer. Something must be wrong with the connection 243 errLog.Print(ErrBusyBuffer) 244 return driver.ErrBadConn 245 } 246 247 // ClientFlags [32 bit] 248 data[4] = byte(clientFlags) 249 data[5] = byte(clientFlags >> 8) 250 data[6] = byte(clientFlags >> 16) 251 data[7] = byte(clientFlags >> 24) 252 253 // MaxPacketSize [32 bit] (none) 254 data[8] = 0x00 255 data[9] = 0x00 256 data[10] = 0x00 257 data[11] = 0x00 258 259 // Charset [1 byte] 260 data[12] = mc.cfg.collation 261 262 // SSL Connection Request Packet 263 // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest 264 if mc.cfg.tls != nil { 265 // Send TLS / SSL request packet 266 if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil { 267 return err 268 } 269 270 // Switch to TLS 271 tlsConn := tls.Client(mc.netConn, mc.cfg.tls) 272 if err := tlsConn.Handshake(); err != nil { 273 return err 274 } 275 mc.netConn = tlsConn 276 mc.buf.rd = tlsConn 277 } 278 279 // Filler [23 bytes] (all 0x00) 280 pos := 13 + 23 281 282 // User [null terminated string] 283 if len(mc.cfg.user) > 0 { 284 pos += copy(data[pos:], mc.cfg.user) 285 } 286 data[pos] = 0x00 287 pos++ 288 289 // ScrambleBuffer [length encoded integer] 290 data[pos] = byte(len(scrambleBuff)) 291 pos += 1 + copy(data[pos+1:], scrambleBuff) 292 293 // Databasename [null terminated string] 294 if len(mc.cfg.dbname) > 0 { 295 pos += copy(data[pos:], mc.cfg.dbname) 296 data[pos] = 0x00 297 } 298 299 // Send Auth packet 300 return mc.writePacket(data) 301 } 302 303 // Client old authentication packet 304 // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse 305 func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error { 306 // User password 307 scrambleBuff := scrambleOldPassword(cipher, []byte(mc.cfg.passwd)) 308 309 // Calculate the packet lenght and add a tailing 0 310 pktLen := len(scrambleBuff) + 1 311 data := mc.buf.takeSmallBuffer(4 + pktLen) 312 if data == nil { 313 // can not take the buffer. Something must be wrong with the connection 314 errLog.Print(ErrBusyBuffer) 315 return driver.ErrBadConn 316 } 317 318 // Add the scrambled password [null terminated string] 319 copy(data[4:], scrambleBuff) 320 data[4+pktLen-1] = 0x00 321 322 return mc.writePacket(data) 323 } 324 325 /****************************************************************************** 326 * Command Packets * 327 ******************************************************************************/ 328 329 func (mc *mysqlConn) writeCommandPacket(command byte) error { 330 // Reset Packet Sequence 331 mc.sequence = 0 332 333 data := mc.buf.takeSmallBuffer(4 + 1) 334 if data == nil { 335 // can not take the buffer. Something must be wrong with the connection 336 errLog.Print(ErrBusyBuffer) 337 return driver.ErrBadConn 338 } 339 340 // Add command byte 341 data[4] = command 342 343 // Send CMD packet 344 return mc.writePacket(data) 345 } 346 347 func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { 348 // Reset Packet Sequence 349 mc.sequence = 0 350 351 pktLen := 1 + len(arg) 352 data := mc.buf.takeBuffer(pktLen + 4) 353 if data == nil { 354 // can not take the buffer. Something must be wrong with the connection 355 errLog.Print(ErrBusyBuffer) 356 return driver.ErrBadConn 357 } 358 359 // Add command byte 360 data[4] = command 361 362 // Add arg 363 copy(data[5:], arg) 364 365 // Send CMD packet 366 return mc.writePacket(data) 367 } 368 369 func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { 370 // Reset Packet Sequence 371 mc.sequence = 0 372 373 data := mc.buf.takeSmallBuffer(4 + 1 + 4) 374 if data == nil { 375 // can not take the buffer. Something must be wrong with the connection 376 errLog.Print(ErrBusyBuffer) 377 return driver.ErrBadConn 378 } 379 380 // Add command byte 381 data[4] = command 382 383 // Add arg [32 bit] 384 data[5] = byte(arg) 385 data[6] = byte(arg >> 8) 386 data[7] = byte(arg >> 16) 387 data[8] = byte(arg >> 24) 388 389 // Send CMD packet 390 return mc.writePacket(data) 391 } 392 393 /****************************************************************************** 394 * Result Packets * 395 ******************************************************************************/ 396 397 // Returns error if Packet is not an 'Result OK'-Packet 398 func (mc *mysqlConn) readResultOK() error { 399 data, err := mc.readPacket() 400 if err == nil { 401 // packet indicator 402 switch data[0] { 403 404 case iOK: 405 return mc.handleOkPacket(data) 406 407 case iEOF: 408 // someone is using old_passwords 409 return ErrOldPassword 410 411 default: // Error otherwise 412 return mc.handleErrorPacket(data) 413 } 414 } 415 return err 416 } 417 418 // Result Set Header Packet 419 // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset 420 func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) { 421 data, err := mc.readPacket() 422 if err == nil { 423 switch data[0] { 424 425 case iOK: 426 return 0, mc.handleOkPacket(data) 427 428 case iERR: 429 return 0, mc.handleErrorPacket(data) 430 431 case iLocalInFile: 432 return 0, mc.handleInFileRequest(string(data[1:])) 433 } 434 435 // column count 436 num, _, n := readLengthEncodedInteger(data) 437 if n-len(data) == 0 { 438 return int(num), nil 439 } 440 441 return 0, ErrMalformPkt 442 } 443 return 0, err 444 } 445 446 // Error Packet 447 // http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-ERR_Packet 448 func (mc *mysqlConn) handleErrorPacket(data []byte) error { 449 if data[0] != iERR { 450 return ErrMalformPkt 451 } 452 453 // 0xff [1 byte] 454 455 // Error Number [16 bit uint] 456 errno := binary.LittleEndian.Uint16(data[1:3]) 457 458 pos := 3 459 460 // SQL State [optional: # + 5bytes string] 461 if data[3] == 0x23 { 462 //sqlstate := string(data[4 : 4+5]) 463 pos = 9 464 } 465 466 // Error Message [string] 467 return &MySQLError{ 468 Number: errno, 469 Message: string(data[pos:]), 470 } 471 } 472 473 // Ok Packet 474 // http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet 475 func (mc *mysqlConn) handleOkPacket(data []byte) error { 476 var n, m int 477 478 // 0x00 [1 byte] 479 480 // Affected rows [Length Coded Binary] 481 mc.affectedRows, _, n = readLengthEncodedInteger(data[1:]) 482 483 // Insert id [Length Coded Binary] 484 mc.insertId, _, m = readLengthEncodedInteger(data[1+n:]) 485 486 // server_status [2 bytes] 487 488 // warning count [2 bytes] 489 if !mc.strict { 490 return nil 491 } else { 492 pos := 1 + n + m + 2 493 if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 { 494 return mc.getWarnings() 495 } 496 return nil 497 } 498 } 499 500 // Read Packets as Field Packets until EOF-Packet or an Error appears 501 // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41 502 func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { 503 columns := make([]mysqlField, count) 504 505 for i := 0; ; i++ { 506 data, err := mc.readPacket() 507 if err != nil { 508 return nil, err 509 } 510 511 // EOF Packet 512 if data[0] == iEOF && (len(data) == 5 || len(data) == 1) { 513 if i == count { 514 return columns, nil 515 } 516 return nil, fmt.Errorf("ColumnsCount mismatch n:%d len:%d", count, len(columns)) 517 } 518 519 // Catalog 520 pos, err := skipLengthEncodedString(data) 521 if err != nil { 522 return nil, err 523 } 524 525 // Database [len coded string] 526 n, err := skipLengthEncodedString(data[pos:]) 527 if err != nil { 528 return nil, err 529 } 530 pos += n 531 532 // Table [len coded string] 533 n, err = skipLengthEncodedString(data[pos:]) 534 if err != nil { 535 return nil, err 536 } 537 pos += n 538 539 // Original table [len coded string] 540 n, err = skipLengthEncodedString(data[pos:]) 541 if err != nil { 542 return nil, err 543 } 544 pos += n 545 546 // Name [len coded string] 547 name, _, n, err := readLengthEncodedString(data[pos:]) 548 if err != nil { 549 return nil, err 550 } 551 columns[i].name = string(name) 552 pos += n 553 554 // Original name [len coded string] 555 n, err = skipLengthEncodedString(data[pos:]) 556 if err != nil { 557 return nil, err 558 } 559 560 // Filler [1 byte] 561 // Charset [16 bit uint] 562 // Length [32 bit uint] 563 pos += n + 1 + 2 + 4 564 565 // Field type [byte] 566 columns[i].fieldType = data[pos] 567 pos++ 568 569 // Flags [16 bit uint] 570 columns[i].flags = fieldFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) 571 //pos += 2 572 573 // Decimals [8 bit uint] 574 //pos++ 575 576 // Default value [len coded binary] 577 //if pos < len(data) { 578 // defaultVal, _, err = bytesToLengthCodedBinary(data[pos:]) 579 //} 580 } 581 } 582 583 // Read Packets as Field Packets until EOF-Packet or an Error appears 584 // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow 585 func (rows *textRows) readRow(dest []driver.Value) error { 586 mc := rows.mc 587 588 data, err := mc.readPacket() 589 if err != nil { 590 return err 591 } 592 593 // EOF Packet 594 if data[0] == iEOF && len(data) == 5 { 595 return io.EOF 596 } 597 598 // RowSet Packet 599 var n int 600 var isNull bool 601 pos := 0 602 603 for i := range dest { 604 // Read bytes and convert to string 605 dest[i], isNull, n, err = readLengthEncodedString(data[pos:]) 606 pos += n 607 if err == nil { 608 if !isNull { 609 if !mc.parseTime { 610 continue 611 } else { 612 switch rows.columns[i].fieldType { 613 case fieldTypeTimestamp, fieldTypeDateTime, 614 fieldTypeDate, fieldTypeNewDate: 615 dest[i], err = parseDateTime( 616 string(dest[i].([]byte)), 617 mc.cfg.loc, 618 ) 619 if err == nil { 620 continue 621 } 622 default: 623 continue 624 } 625 } 626 627 } else { 628 dest[i] = nil 629 continue 630 } 631 } 632 return err // err != nil 633 } 634 635 return nil 636 } 637 638 // Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read 639 func (mc *mysqlConn) readUntilEOF() error { 640 for { 641 data, err := mc.readPacket() 642 643 // No Err and no EOF Packet 644 if err == nil && data[0] != iEOF { 645 continue 646 } 647 return err // Err or EOF 648 } 649 } 650 651 /****************************************************************************** 652 * Prepared Statements * 653 ******************************************************************************/ 654 655 // Prepare Result Packets 656 // http://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html 657 func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) { 658 data, err := stmt.mc.readPacket() 659 if err == nil { 660 // packet indicator [1 byte] 661 if data[0] != iOK { 662 return 0, stmt.mc.handleErrorPacket(data) 663 } 664 665 // statement id [4 bytes] 666 stmt.id = binary.LittleEndian.Uint32(data[1:5]) 667 668 // Column count [16 bit uint] 669 columnCount := binary.LittleEndian.Uint16(data[5:7]) 670 671 // Param count [16 bit uint] 672 stmt.paramCount = int(binary.LittleEndian.Uint16(data[7:9])) 673 674 // Reserved [8 bit] 675 676 // Warning count [16 bit uint] 677 if !stmt.mc.strict { 678 return columnCount, nil 679 } else { 680 // Check for warnings count > 0, only available in MySQL > 4.1 681 if len(data) >= 12 && binary.LittleEndian.Uint16(data[10:12]) > 0 { 682 return columnCount, stmt.mc.getWarnings() 683 } 684 return columnCount, nil 685 } 686 } 687 return 0, err 688 } 689 690 // http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html 691 func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { 692 maxLen := stmt.mc.maxPacketAllowed - 1 693 pktLen := maxLen 694 695 // After the header (bytes 0-3) follows before the data: 696 // 1 byte command 697 // 4 bytes stmtID 698 // 2 bytes paramID 699 const dataOffset = 1 + 4 + 2 700 701 // Can not use the write buffer since 702 // a) the buffer is too small 703 // b) it is in use 704 data := make([]byte, 4+1+4+2+len(arg)) 705 706 copy(data[4+dataOffset:], arg) 707 708 for argLen := len(arg); argLen > 0; argLen -= pktLen - dataOffset { 709 if dataOffset+argLen < maxLen { 710 pktLen = dataOffset + argLen 711 } 712 713 stmt.mc.sequence = 0 714 // Add command byte [1 byte] 715 data[4] = comStmtSendLongData 716 717 // Add stmtID [32 bit] 718 data[5] = byte(stmt.id) 719 data[6] = byte(stmt.id >> 8) 720 data[7] = byte(stmt.id >> 16) 721 data[8] = byte(stmt.id >> 24) 722 723 // Add paramID [16 bit] 724 data[9] = byte(paramID) 725 data[10] = byte(paramID >> 8) 726 727 // Send CMD packet 728 err := stmt.mc.writePacket(data[:4+pktLen]) 729 if err == nil { 730 data = data[pktLen-dataOffset:] 731 continue 732 } 733 return err 734 735 } 736 737 // Reset Packet Sequence 738 stmt.mc.sequence = 0 739 return nil 740 } 741 742 // Execute Prepared Statement 743 // http://dev.mysql.com/doc/internals/en/com-stmt-execute.html 744 func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { 745 if len(args) != stmt.paramCount { 746 return fmt.Errorf( 747 "Arguments count mismatch (Got: %d Has: %d)", 748 len(args), 749 stmt.paramCount, 750 ) 751 } 752 753 const minPktLen = 4 + 1 + 4 + 1 + 4 754 mc := stmt.mc 755 756 // Reset packet-sequence 757 mc.sequence = 0 758 759 var data []byte 760 761 if len(args) == 0 { 762 data = mc.buf.takeBuffer(minPktLen) 763 } else { 764 data = mc.buf.takeCompleteBuffer() 765 } 766 if data == nil { 767 // can not take the buffer. Something must be wrong with the connection 768 errLog.Print(ErrBusyBuffer) 769 return driver.ErrBadConn 770 } 771 772 // command [1 byte] 773 data[4] = comStmtExecute 774 775 // statement_id [4 bytes] 776 data[5] = byte(stmt.id) 777 data[6] = byte(stmt.id >> 8) 778 data[7] = byte(stmt.id >> 16) 779 data[8] = byte(stmt.id >> 24) 780 781 // flags (0: CURSOR_TYPE_NO_CURSOR) [1 byte] 782 data[9] = 0x00 783 784 // iteration_count (uint32(1)) [4 bytes] 785 data[10] = 0x01 786 data[11] = 0x00 787 data[12] = 0x00 788 data[13] = 0x00 789 790 if len(args) > 0 { 791 pos := minPktLen 792 793 var nullMask []byte 794 if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= len(data) { 795 // buffer has to be extended but we don't know by how much so 796 // we depend on append after all data with known sizes fit. 797 // We stop at that because we deal with a lot of columns here 798 // which makes the required allocation size hard to guess. 799 tmp := make([]byte, pos+maskLen+typesLen) 800 copy(tmp[:pos], data[:pos]) 801 data = tmp 802 nullMask = data[pos : pos+maskLen] 803 pos += maskLen 804 } else { 805 nullMask = data[pos : pos+maskLen] 806 for i := 0; i < maskLen; i++ { 807 nullMask[i] = 0 808 } 809 pos += maskLen 810 } 811 812 // newParameterBoundFlag 1 [1 byte] 813 data[pos] = 0x01 814 pos++ 815 816 // type of each parameter [len(args)*2 bytes] 817 paramTypes := data[pos:] 818 pos += len(args) * 2 819 820 // value of each parameter [n bytes] 821 paramValues := data[pos:pos] 822 valuesCap := cap(paramValues) 823 824 for i, arg := range args { 825 // build NULL-bitmap 826 if arg == nil { 827 nullMask[i/8] |= 1 << (uint(i) & 7) 828 paramTypes[i+i] = fieldTypeNULL 829 paramTypes[i+i+1] = 0x00 830 continue 831 } 832 833 // cache types and values 834 switch v := arg.(type) { 835 case int64: 836 paramTypes[i+i] = fieldTypeLongLong 837 paramTypes[i+i+1] = 0x00 838 839 if cap(paramValues)-len(paramValues)-8 >= 0 { 840 paramValues = paramValues[:len(paramValues)+8] 841 binary.LittleEndian.PutUint64( 842 paramValues[len(paramValues)-8:], 843 uint64(v), 844 ) 845 } else { 846 paramValues = append(paramValues, 847 uint64ToBytes(uint64(v))..., 848 ) 849 } 850 851 case float64: 852 paramTypes[i+i] = fieldTypeDouble 853 paramTypes[i+i+1] = 0x00 854 855 if cap(paramValues)-len(paramValues)-8 >= 0 { 856 paramValues = paramValues[:len(paramValues)+8] 857 binary.LittleEndian.PutUint64( 858 paramValues[len(paramValues)-8:], 859 math.Float64bits(v), 860 ) 861 } else { 862 paramValues = append(paramValues, 863 uint64ToBytes(math.Float64bits(v))..., 864 ) 865 } 866 867 case bool: 868 paramTypes[i+i] = fieldTypeTiny 869 paramTypes[i+i+1] = 0x00 870 871 if v { 872 paramValues = append(paramValues, 0x01) 873 } else { 874 paramValues = append(paramValues, 0x00) 875 } 876 877 case []byte: 878 // Common case (non-nil value) first 879 if v != nil { 880 paramTypes[i+i] = fieldTypeString 881 paramTypes[i+i+1] = 0x00 882 883 if len(v) < mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 { 884 paramValues = appendLengthEncodedInteger(paramValues, 885 uint64(len(v)), 886 ) 887 paramValues = append(paramValues, v...) 888 } else { 889 if err := stmt.writeCommandLongData(i, v); err != nil { 890 return err 891 } 892 } 893 continue 894 } 895 896 // Handle []byte(nil) as a NULL value 897 nullMask[i/8] |= 1 << (uint(i) & 7) 898 paramTypes[i+i] = fieldTypeNULL 899 paramTypes[i+i+1] = 0x00 900 901 case string: 902 paramTypes[i+i] = fieldTypeString 903 paramTypes[i+i+1] = 0x00 904 905 if len(v) < mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 { 906 paramValues = appendLengthEncodedInteger(paramValues, 907 uint64(len(v)), 908 ) 909 paramValues = append(paramValues, v...) 910 } else { 911 if err := stmt.writeCommandLongData(i, []byte(v)); err != nil { 912 return err 913 } 914 } 915 916 case time.Time: 917 paramTypes[i+i] = fieldTypeString 918 paramTypes[i+i+1] = 0x00 919 920 var val []byte 921 if v.IsZero() { 922 val = []byte("0000-00-00") 923 } else { 924 val = []byte(v.In(mc.cfg.loc).Format(timeFormat)) 925 } 926 927 paramValues = appendLengthEncodedInteger(paramValues, 928 uint64(len(val)), 929 ) 930 paramValues = append(paramValues, val...) 931 932 default: 933 return fmt.Errorf("Can't convert type: %T", arg) 934 } 935 } 936 937 // Check if param values exceeded the available buffer 938 // In that case we must build the data packet with the new values buffer 939 if valuesCap != cap(paramValues) { 940 data = append(data[:pos], paramValues...) 941 mc.buf.buf = data 942 } 943 944 pos += len(paramValues) 945 data = data[:pos] 946 } 947 948 return mc.writePacket(data) 949 } 950 951 // http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html 952 func (rows *binaryRows) readRow(dest []driver.Value) error { 953 data, err := rows.mc.readPacket() 954 if err != nil { 955 return err 956 } 957 958 // packet indicator [1 byte] 959 if data[0] != iOK { 960 // EOF Packet 961 if data[0] == iEOF && len(data) == 5 { 962 return io.EOF 963 } 964 965 // Error otherwise 966 return rows.mc.handleErrorPacket(data) 967 } 968 969 // NULL-bitmap, [(column-count + 7 + 2) / 8 bytes] 970 pos := 1 + (len(dest)+7+2)>>3 971 nullMask := data[1:pos] 972 973 for i := range dest { 974 // Field is NULL 975 // (byte >> bit-pos) % 2 == 1 976 if ((nullMask[(i+2)>>3] >> uint((i+2)&7)) & 1) == 1 { 977 dest[i] = nil 978 continue 979 } 980 981 // Convert to byte-coded string 982 switch rows.columns[i].fieldType { 983 case fieldTypeNULL: 984 dest[i] = nil 985 continue 986 987 // Numeric Types 988 case fieldTypeTiny: 989 if rows.columns[i].flags&flagUnsigned != 0 { 990 dest[i] = int64(data[pos]) 991 } else { 992 dest[i] = int64(int8(data[pos])) 993 } 994 pos++ 995 continue 996 997 case fieldTypeShort, fieldTypeYear: 998 if rows.columns[i].flags&flagUnsigned != 0 { 999 dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2])) 1000 } else { 1001 dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2]))) 1002 } 1003 pos += 2 1004 continue 1005 1006 case fieldTypeInt24, fieldTypeLong: 1007 if rows.columns[i].flags&flagUnsigned != 0 { 1008 dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4])) 1009 } else { 1010 dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4]))) 1011 } 1012 pos += 4 1013 continue 1014 1015 case fieldTypeLongLong: 1016 if rows.columns[i].flags&flagUnsigned != 0 { 1017 val := binary.LittleEndian.Uint64(data[pos : pos+8]) 1018 if val > math.MaxInt64 { 1019 dest[i] = uint64ToString(val) 1020 } else { 1021 dest[i] = int64(val) 1022 } 1023 } else { 1024 dest[i] = int64(binary.LittleEndian.Uint64(data[pos : pos+8])) 1025 } 1026 pos += 8 1027 continue 1028 1029 case fieldTypeFloat: 1030 dest[i] = float64(math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4]))) 1031 pos += 4 1032 continue 1033 1034 case fieldTypeDouble: 1035 dest[i] = math.Float64frombits(binary.LittleEndian.Uint64(data[pos : pos+8])) 1036 pos += 8 1037 continue 1038 1039 // Length coded Binary Strings 1040 case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar, 1041 fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB, 1042 fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB, 1043 fieldTypeVarString, fieldTypeString, fieldTypeGeometry: 1044 var isNull bool 1045 var n int 1046 dest[i], isNull, n, err = readLengthEncodedString(data[pos:]) 1047 pos += n 1048 if err == nil { 1049 if !isNull { 1050 continue 1051 } else { 1052 dest[i] = nil 1053 continue 1054 } 1055 } 1056 return err 1057 1058 // Date YYYY-MM-DD 1059 case fieldTypeDate, fieldTypeNewDate: 1060 num, isNull, n := readLengthEncodedInteger(data[pos:]) 1061 pos += n 1062 1063 if isNull { 1064 dest[i] = nil 1065 continue 1066 } 1067 1068 if rows.mc.parseTime { 1069 dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.loc) 1070 } else { 1071 dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], false) 1072 } 1073 1074 if err == nil { 1075 pos += int(num) 1076 continue 1077 } else { 1078 return err 1079 } 1080 1081 // Time [-][H]HH:MM:SS[.fractal] 1082 case fieldTypeTime: 1083 num, isNull, n := readLengthEncodedInteger(data[pos:]) 1084 pos += n 1085 1086 if num == 0 { 1087 if isNull { 1088 dest[i] = nil 1089 continue 1090 } else { 1091 dest[i] = []byte("00:00:00") 1092 continue 1093 } 1094 } 1095 1096 var sign string 1097 if data[pos] == 1 { 1098 sign = "-" 1099 } 1100 1101 switch num { 1102 case 8: 1103 dest[i] = []byte(fmt.Sprintf( 1104 sign+"%02d:%02d:%02d", 1105 uint16(data[pos+1])*24+uint16(data[pos+5]), 1106 data[pos+6], 1107 data[pos+7], 1108 )) 1109 pos += 8 1110 continue 1111 case 12: 1112 dest[i] = []byte(fmt.Sprintf( 1113 sign+"%02d:%02d:%02d.%06d", 1114 uint16(data[pos+1])*24+uint16(data[pos+5]), 1115 data[pos+6], 1116 data[pos+7], 1117 binary.LittleEndian.Uint32(data[pos+8:pos+12]), 1118 )) 1119 pos += 12 1120 continue 1121 default: 1122 return fmt.Errorf("Invalid TIME-packet length %d", num) 1123 } 1124 1125 // Timestamp YYYY-MM-DD HH:MM:SS[.fractal] 1126 case fieldTypeTimestamp, fieldTypeDateTime: 1127 num, isNull, n := readLengthEncodedInteger(data[pos:]) 1128 1129 pos += n 1130 1131 if isNull { 1132 dest[i] = nil 1133 continue 1134 } 1135 1136 if rows.mc.parseTime { 1137 dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.loc) 1138 } else { 1139 dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], true) 1140 } 1141 1142 if err == nil { 1143 pos += int(num) 1144 continue 1145 } else { 1146 return err 1147 } 1148 1149 // Please report if this happens! 1150 default: 1151 return fmt.Errorf("Unknown FieldType %d", rows.columns[i].fieldType) 1152 } 1153 } 1154 1155 return nil 1156 }