github.com/matrixorigin/matrixone@v0.7.0/pkg/frontend/mysql_protocol.go (about) 1 // Copyright 2021 Matrix Origin 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package frontend 16 17 import ( 18 "bytes" 19 "context" 20 "crypto/sha1" 21 "encoding/binary" 22 "fmt" 23 "math" 24 "math/rand" 25 "strconv" 26 "strings" 27 "sync" 28 "sync/atomic" 29 "time" 30 "unicode" 31 32 "github.com/fagongzi/goetty/v2" 33 "github.com/matrixorigin/matrixone/pkg/common/moerr" 34 "github.com/matrixorigin/matrixone/pkg/config" 35 "github.com/matrixorigin/matrixone/pkg/container/types" 36 "github.com/matrixorigin/matrixone/pkg/defines" 37 "github.com/matrixorigin/matrixone/pkg/logutil" 38 planPb "github.com/matrixorigin/matrixone/pkg/pb/plan" 39 plan2 "github.com/matrixorigin/matrixone/pkg/sql/plan" 40 ) 41 42 // DefaultCapability means default capabilities of the server 43 var DefaultCapability = CLIENT_LONG_PASSWORD | 44 CLIENT_FOUND_ROWS | 45 CLIENT_LONG_FLAG | 46 CLIENT_CONNECT_WITH_DB | 47 CLIENT_LOCAL_FILES | 48 CLIENT_PROTOCOL_41 | 49 CLIENT_INTERACTIVE | 50 CLIENT_TRANSACTIONS | 51 CLIENT_SECURE_CONNECTION | 52 CLIENT_MULTI_STATEMENTS | 53 CLIENT_MULTI_RESULTS | 54 CLIENT_PLUGIN_AUTH | 55 CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA | 56 CLIENT_DEPRECATE_EOF 57 58 // DefaultClientConnStatus default server status 59 var DefaultClientConnStatus = SERVER_STATUS_AUTOCOMMIT 60 61 var serverVersion atomic.Value 62 63 func init() { 64 serverVersion.Store("0.5.0") 65 } 66 67 func InitServerVersion(v string) { 68 if len(v) > 0 { 69 switch v[0] { 70 case 'v': // format 'v1.1.1' 71 v = v[1:] 72 serverVersion.Store(v) 73 default: 74 vv := []byte(v) 75 for i := 0; i < len(vv); i++ { 76 if !unicode.IsDigit(rune(vv[i])) && vv[i] != '.' { 77 vv = append(vv[:i], vv[i+1:]...) 78 i-- 79 } 80 } 81 serverVersion.Store(string(vv)) 82 } 83 } else { 84 serverVersion.Store("0.5.0") 85 } 86 } 87 88 const ( 89 clientProtocolVersion uint8 = 10 90 91 /** 92 An answer talks about the charset utf8mb4. 93 https://stackoverflow.com/questions/766809/whats-the-difference-between-utf8-general-ci-and-utf8-unicode-ci 94 It recommends the charset utf8mb4_0900_ai_ci. 95 Maybe we can support utf8mb4_0900_ai_ci in the future. 96 97 A concise research in the Mysql 8.0.23. 98 99 the charset in sever level 100 ====================================== 101 102 mysql> show variables like 'character_set_server'; 103 +----------------------+---------+ 104 | Variable_name | Value | 105 +----------------------+---------+ 106 | character_set_server | utf8mb4 | 107 +----------------------+---------+ 108 109 mysql> show variables like 'collation_server'; 110 +------------------+--------------------+ 111 | Variable_name | Value | 112 +------------------+--------------------+ 113 | collation_server | utf8mb4_0900_ai_ci | 114 +------------------+--------------------+ 115 116 the charset in database level 117 ===================================== 118 mysql> show variables like 'character_set_database'; 119 +------------------------+---------+ 120 | Variable_name | Value | 121 +------------------------+---------+ 122 | character_set_database | utf8mb4 | 123 +------------------------+---------+ 124 125 mysql> show variables like 'collation_database'; 126 +--------------------+--------------------+ 127 | Variable_name | Value | 128 +--------------------+--------------------+ 129 | collation_database | utf8mb4_0900_ai_ci | 130 +--------------------+--------------------+ 131 132 */ 133 // DefaultCollationID is utf8mb4_bin(46) 134 utf8mb4BinCollationID uint8 = 46 135 136 Utf8mb4CollationID uint8 = 45 137 138 AuthNativePassword string = "mysql_native_password" 139 140 //the length of the mysql protocol header 141 HeaderLengthOfTheProtocol int = 4 142 HeaderOffset int = 0 143 144 // MaxPayloadSize If the payload is larger than or equal to 2^24−1 bytes the length is set to 2^24−1 (ff ff ff) 145 //and additional packets are sent with the rest of the payload until the payload of a packet 146 //is less than 2^24−1 bytes. 147 MaxPayloadSize uint32 = (1 << 24) - 1 148 149 // DefaultMySQLState is the default state of the mySQL 150 DefaultMySQLState string = "HY000" 151 ) 152 153 type MysqlProtocol interface { 154 Protocol 155 //the server send group row of the result set as an independent packet thread safe 156 SendResultSetTextBatchRow(mrs *MysqlResultSet, cnt uint64) error 157 158 SendResultSetTextBatchRowSpeedup(mrs *MysqlResultSet, cnt uint64) error 159 160 //SendColumnDefinitionPacket the server send the column definition to the client 161 SendColumnDefinitionPacket(ctx context.Context, column Column, cmd int) error 162 163 //SendColumnCountPacket makes the column count packet 164 SendColumnCountPacket(count uint64) error 165 166 SendResponse(ctx context.Context, resp *Response) error 167 168 SendEOFPacketIf(warnings uint16, status uint16) error 169 170 //send OK packet to the client 171 sendOKPacket(affectedRows uint64, lastInsertId uint64, status uint16, warnings uint16, message string) error 172 173 //the OK or EOF packet thread safe 174 sendEOFOrOkPacket(warnings uint16, status uint16) error 175 176 sendLocalInfileRequest(filename string) error 177 178 ResetStatistics() 179 180 GetStats() string 181 182 ParseExecuteData(ctx context.Context, stmt *PrepareStmt, data []byte, pos int) (names []string, vars []any, err error) 183 } 184 185 var _ MysqlProtocol = &MysqlProtocolImpl{} 186 187 func (ses *Session) GetMysqlProtocol() MysqlProtocol { 188 ses.mu.Lock() 189 defer ses.mu.Unlock() 190 return ses.protocol.(MysqlProtocol) 191 } 192 193 type debugStats struct { 194 writeCount uint64 195 writeBytes uint64 196 } 197 198 func (ds *debugStats) ResetStats() { 199 ds.writeCount = 0 200 ds.writeBytes = 0 201 } 202 203 func (ds *debugStats) String() string { 204 if ds.writeCount <= 0 { 205 ds.writeCount = 1 206 } 207 return fmt.Sprintf( 208 "writeCount %v \n"+ 209 "writeBytes %v %v MB\n", 210 ds.writeCount, 211 ds.writeBytes, ds.writeBytes/(1024*1024.0), 212 ) 213 } 214 215 /* 216 rowHandler maintains the states in encoding the result row 217 */ 218 type rowHandler struct { 219 //the begin position of writing. 220 //the range [beginWriteIndex,beginWriteIndex+3] 221 //for the length and sequenceId of the mysql protocol packet 222 beginWriteIndex int 223 //the bytes in the outbuffer 224 bytesInOutBuffer int 225 //when the number of bytes in the outbuffer exceeds the it, 226 //the outbuffer will be flushed. 227 untilBytesInOutbufToFlush int 228 //the count of the flush 229 flushCount int 230 enableLog bool 231 } 232 233 /* 234 isInPacket means it is compositing a packet now 235 */ 236 func (rh *rowHandler) isInPacket() bool { 237 return rh.beginWriteIndex >= 0 238 } 239 240 /* 241 resetPacket reset the beginWriteIndex 242 */ 243 func (rh *rowHandler) resetPacket() { 244 rh.beginWriteIndex = -1 245 } 246 247 /* 248 resetFlushOutBuffer clears the bytesInOutBuffer 249 */ 250 func (rh *rowHandler) resetFlushOutBuffer() { 251 rh.bytesInOutBuffer = 0 252 } 253 254 /* 255 resetFlushCount reset flushCount 256 */ 257 func (rh *rowHandler) resetFlushCount() { 258 rh.flushCount = 0 259 } 260 261 type MysqlProtocolImpl struct { 262 ProtocolImpl 263 264 //joint capability shared by the server and the client 265 capability uint32 266 267 //collation id 268 collationID int 269 270 //collation name 271 collationName string 272 273 //character set 274 charset string 275 276 //max packet size of the client 277 maxClientPacketSize uint32 278 279 //the user of the client 280 username string 281 282 //the default database for the client 283 database string 284 285 //for debug 286 debugStats 287 288 //for converting the data into string 289 strconvBuffer []byte 290 291 //for encoding the length into bytes 292 lenEncBuffer []byte 293 294 //for encoding the null bytes in binary row 295 binaryNullBuffer []byte 296 297 rowHandler 298 299 SV *config.FrontendParameters 300 301 m sync.Mutex 302 303 ses *Session 304 305 //skip checking the password of the user 306 skipCheckUser bool 307 } 308 309 func (mp *MysqlProtocolImpl) GetSession() *Session { 310 mp.m.Lock() 311 defer mp.m.Unlock() 312 return mp.ses 313 } 314 315 func (mp *MysqlProtocolImpl) SetSkipCheckUser(b bool) { 316 mp.m.Lock() 317 defer mp.m.Unlock() 318 mp.skipCheckUser = b 319 } 320 321 func (mp *MysqlProtocolImpl) GetCapability() uint32 { 322 mp.m.Lock() 323 defer mp.m.Unlock() 324 return mp.capability 325 } 326 327 func (mp *MysqlProtocolImpl) AddSequenceId(a uint8) { 328 mp.sequenceId.Add(uint32(a)) 329 } 330 331 func (mp *MysqlProtocolImpl) GetSkipCheckUser() bool { 332 mp.m.Lock() 333 defer mp.m.Unlock() 334 return mp.skipCheckUser 335 } 336 337 func (mp *MysqlProtocolImpl) GetDatabaseName() string { 338 mp.m.Lock() 339 defer mp.m.Unlock() 340 return mp.database 341 } 342 343 func (mp *MysqlProtocolImpl) SetDatabaseName(s string) { 344 mp.m.Lock() 345 defer mp.m.Unlock() 346 mp.database = s 347 } 348 349 func (mp *MysqlProtocolImpl) GetUserName() string { 350 mp.m.Lock() 351 defer mp.m.Unlock() 352 return mp.username 353 } 354 355 func (mp *MysqlProtocolImpl) SetUserName(s string) { 356 mp.m.Lock() 357 defer mp.m.Unlock() 358 mp.username = s 359 } 360 361 func (mp *MysqlProtocolImpl) GetStats() string { 362 return fmt.Sprintf("flushCount %d %s", 363 mp.flushCount, 364 mp.String()) 365 } 366 367 func (mp *MysqlProtocolImpl) ResetStatistics() { 368 mp.ResetStats() 369 mp.resetFlushCount() 370 } 371 372 func (mp *MysqlProtocolImpl) Quit() { 373 mp.m.Lock() 374 defer mp.m.Unlock() 375 mp.ProtocolImpl.Quit() 376 if mp.strconvBuffer != nil { 377 mp.strconvBuffer = nil 378 } 379 if mp.lenEncBuffer != nil { 380 mp.lenEncBuffer = nil 381 } 382 if mp.binaryNullBuffer != nil { 383 mp.binaryNullBuffer = nil 384 } 385 } 386 387 func (mp *MysqlProtocolImpl) SetSession(ses *Session) { 388 mp.m.Lock() 389 defer mp.m.Unlock() 390 mp.ses = ses 391 } 392 393 // handshake response 41 394 type response41 struct { 395 capabilities uint32 396 maxPacketSize uint32 397 collationID uint8 398 username string 399 authResponse []byte 400 database string 401 clientPluginName string 402 isAskForTlsHeader bool 403 } 404 405 // handshake response 320 406 type response320 struct { 407 capabilities uint32 408 maxPacketSize uint32 409 username string 410 authResponse []byte 411 database string 412 isAskForTlsHeader bool 413 } 414 415 func (mp *MysqlProtocolImpl) SendPrepareResponse(ctx context.Context, stmt *PrepareStmt) error { 416 dcPrepare, ok := stmt.PreparePlan.GetDcl().Control.(*planPb.DataControl_Prepare) 417 if !ok { 418 return moerr.NewInternalError(ctx, "can not get Prepare plan in prepareStmt") 419 } 420 stmtID, err := GetPrepareStmtID(ctx, stmt.Name) 421 if err != nil { 422 return moerr.NewInternalError(ctx, "can not get Prepare stmtID") 423 } 424 paramTypes := dcPrepare.Prepare.ParamTypes 425 numParams := len(paramTypes) 426 columns := plan2.GetResultColumnsFromPlan(dcPrepare.Prepare.Plan) 427 numColumns := len(columns) 428 429 var data []byte 430 // status ok 431 data = append(data, 0) 432 // stmt id 433 data = mp.io.AppendUint32(data, uint32(stmtID)) 434 // number columns 435 data = mp.io.AppendUint16(data, uint16(numColumns)) 436 // number params 437 data = mp.io.AppendUint16(data, uint16(numParams)) 438 // filter [00] 439 data = append(data, 0) 440 // warning count 441 data = append(data, 0, 0) // TODO support warning count 442 if err := mp.writePackets(data); err != nil { 443 return err 444 } 445 446 cmd := int(COM_STMT_PREPARE) 447 for i := 0; i < numParams; i++ { 448 column := new(MysqlColumn) 449 column.SetName("?") 450 451 err = convertEngineTypeToMysqlType(ctx, types.T(paramTypes[i]), column) 452 if err != nil { 453 return err 454 } 455 456 err = mp.SendColumnDefinitionPacket(ctx, column, cmd) 457 if err != nil { 458 return err 459 } 460 } 461 if numParams > 0 { 462 if err := mp.SendEOFPacketIf(0, 0); err != nil { 463 return err 464 } 465 } 466 467 for i := 0; i < numColumns; i++ { 468 column := new(MysqlColumn) 469 column.SetName(columns[i].Name) 470 471 err = convertEngineTypeToMysqlType(ctx, types.T(columns[i].Typ.Id), column) 472 if err != nil { 473 return err 474 } 475 476 err = mp.SendColumnDefinitionPacket(ctx, column, cmd) 477 if err != nil { 478 return err 479 } 480 } 481 if numColumns > 0 { 482 if err := mp.SendEOFPacketIf(0, 0); err != nil { 483 return err 484 } 485 } 486 487 return nil 488 } 489 490 func (mp *MysqlProtocolImpl) ParseExecuteData(requestCtx context.Context, stmt *PrepareStmt, data []byte, pos int) (names []string, vars []any, err error) { 491 dcPrepare, ok := stmt.PreparePlan.GetDcl().Control.(*planPb.DataControl_Prepare) 492 if !ok { 493 err = moerr.NewInternalError(requestCtx, "can not get Prepare plan in prepareStmt") 494 return 495 } 496 numParams := len(dcPrepare.Prepare.ParamTypes) 497 498 var flag uint8 499 flag, pos, ok = mp.io.ReadUint8(data, pos) 500 if !ok { 501 err = moerr.NewInternalError(requestCtx, "malform packet") 502 return 503 } 504 if flag != 0 { 505 // TODO only support CURSOR_TYPE_NO_CURSOR flag now 506 err = moerr.NewInvalidInput(requestCtx, "unsupported Prepare flag '%v'", flag) 507 return 508 } 509 510 // skip iteration-count, always 1 511 pos += 4 512 513 if numParams > 0 { 514 var nullBitmaps []byte 515 nullBitmapLen := (numParams + 7) >> 3 516 nullBitmaps, pos, ok = mp.readCountOfBytes(data, pos, nullBitmapLen) 517 if !ok { 518 err = moerr.NewInvalidInput(requestCtx, "mysql protocol error, malformed packet") 519 return 520 } 521 522 // new param bound flag 523 if data[pos] == 1 { 524 pos++ 525 526 // Just the first StmtExecute packet contain parameters type, 527 // we need save it for further use. 528 stmt.ParamTypes, pos, ok = mp.readCountOfBytes(data, pos, numParams<<1) 529 if !ok { 530 err = moerr.NewInvalidInput(requestCtx, "mysql protocol error, malformed packet") 531 return 532 } 533 } else { 534 pos++ 535 } 536 537 // get paramters and set value to session variables 538 names = make([]string, numParams) 539 vars = make([]any, numParams) 540 for i := 0; i < numParams; i++ { 541 varName := getPrepareStmtSessionVarName(i) 542 names[i] = varName 543 544 // TODO :if params had received via COM_STMT_SEND_LONG_DATA, use them directly. 545 // ref https://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html 546 547 if nullBitmaps[i>>3]&(1<<(uint(i)%8)) > 0 { 548 vars[i] = nil 549 continue 550 } 551 552 if (i<<1)+1 >= len(stmt.ParamTypes) { 553 err = moerr.NewInvalidInput(requestCtx, "mysql protocol error, malformed packet") 554 return 555 } 556 tp := stmt.ParamTypes[i<<1] 557 isUnsigned := (stmt.ParamTypes[(i<<1)+1] & 0x80) > 0 558 559 switch defines.MysqlType(tp) { 560 case defines.MYSQL_TYPE_NULL: 561 vars[i] = nil 562 563 case defines.MYSQL_TYPE_TINY: 564 val, newPos, ok := mp.io.ReadUint8(data, pos) 565 if !ok { 566 err = moerr.NewInvalidInput(requestCtx, "mysql protocol error, malformed packet") 567 return 568 } 569 570 pos = newPos 571 if isUnsigned { 572 vars[i] = val 573 } else { 574 vars[i] = int8(val) 575 } 576 577 case defines.MYSQL_TYPE_SHORT, defines.MYSQL_TYPE_YEAR: 578 val, newPos, ok := mp.io.ReadUint16(data, pos) 579 if !ok { 580 err = moerr.NewInvalidInput(requestCtx, "mysql protocol error, malformed packet") 581 return 582 } 583 584 pos = newPos 585 if isUnsigned { 586 vars[i] = val 587 } else { 588 vars[i] = int16(val) 589 } 590 591 case defines.MYSQL_TYPE_INT24, defines.MYSQL_TYPE_LONG: 592 val, newPos, ok := mp.io.ReadUint32(data, pos) 593 if !ok { 594 err = moerr.NewInvalidInput(requestCtx, "mysql protocol error, malformed packet") 595 return 596 } 597 598 pos = newPos 599 if isUnsigned { 600 vars[i] = val 601 } else { 602 vars[i] = int32(val) 603 } 604 605 case defines.MYSQL_TYPE_LONGLONG: 606 val, newPos, ok := mp.io.ReadUint64(data, pos) 607 if !ok { 608 err = moerr.NewInvalidInput(requestCtx, "mysql protocol error, malformed packet") 609 return 610 } 611 612 pos = newPos 613 if isUnsigned { 614 vars[i] = val 615 } else { 616 vars[i] = int64(val) 617 } 618 619 case defines.MYSQL_TYPE_FLOAT: 620 val, newPos, ok := mp.io.ReadUint32(data, pos) 621 if !ok { 622 err = moerr.NewInvalidInput(requestCtx, "mysql protocol error, malformed packet") 623 return 624 } 625 pos = newPos 626 vars[i] = math.Float32frombits(val) 627 628 case defines.MYSQL_TYPE_DOUBLE: 629 val, newPos, ok := mp.io.ReadUint64(data, pos) 630 if !ok { 631 err = moerr.NewInvalidInput(requestCtx, "mysql protocol error, malformed packet") 632 return 633 } 634 pos = newPos 635 vars[i] = math.Float64frombits(val) 636 637 case defines.MYSQL_TYPE_VARCHAR, defines.MYSQL_TYPE_VAR_STRING, defines.MYSQL_TYPE_STRING, defines.MYSQL_TYPE_DECIMAL, 638 defines.MYSQL_TYPE_ENUM, defines.MYSQL_TYPE_SET, defines.MYSQL_TYPE_GEOMETRY, defines.MYSQL_TYPE_BIT: 639 val, newPos, ok := mp.readStringLenEnc(data, pos) 640 if !ok { 641 err = moerr.NewInvalidInput(requestCtx, "mysql protocol error, malformed packet") 642 return 643 } 644 pos = newPos 645 vars[i] = val 646 647 case defines.MYSQL_TYPE_BLOB, defines.MYSQL_TYPE_TINY_BLOB, defines.MYSQL_TYPE_MEDIUM_BLOB, defines.MYSQL_TYPE_LONG_BLOB, defines.MYSQL_TYPE_TEXT: 648 val, newPos, ok := mp.readStringLenEnc(data, pos) 649 if !ok { 650 err = moerr.NewInvalidInput(requestCtx, "mysql protocol error, malformed packet") 651 return 652 } 653 pos = newPos 654 vars[i] = []byte(val) 655 656 case defines.MYSQL_TYPE_TIME: 657 // See https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_binary_resultset.html 658 // for more details. 659 length, newPos, ok := mp.io.ReadUint8(data, pos) 660 if !ok { 661 err = moerr.NewInvalidInput(requestCtx, "mysql protocol error, malformed packet") 662 return 663 } 664 pos = newPos 665 switch length { 666 case 0: 667 vars[i] = "0d 00:00:00" 668 case 8, 12: 669 pos, vars[i] = mp.readTime(data, pos, length) 670 default: 671 err = moerr.NewInvalidInput(requestCtx, "mysql protocol error, malformed packet") 672 return 673 } 674 case defines.MYSQL_TYPE_DATE, defines.MYSQL_TYPE_DATETIME, defines.MYSQL_TYPE_TIMESTAMP: 675 // See https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_binary_resultset.html 676 // for more details. 677 length, newPos, ok := mp.io.ReadUint8(data, pos) 678 if !ok { 679 err = moerr.NewInvalidInput(requestCtx, "mysql protocol error, malformed packet") 680 return 681 } 682 pos = newPos 683 switch length { 684 case 0: 685 vars[i] = "0000-00-00 00:00:00" 686 case 4: 687 pos, vars[i] = mp.readDate(data, pos) 688 case 7: 689 pos, vars[i] = mp.readDateTime(data, pos) 690 case 11: 691 pos, vars[i] = mp.readTimestamp(data, pos) 692 default: 693 err = moerr.NewInvalidInput(requestCtx, "mysql protocol error, malformed packet") 694 return 695 } 696 697 case defines.MYSQL_TYPE_NEWDECIMAL: 698 // use string for decimal. Not tested 699 val, newPos, ok := mp.readStringLenEnc(data, pos) 700 if !ok { 701 err = moerr.NewInvalidInput(requestCtx, "mysql protocol error, malformed packet") 702 return 703 } 704 pos = newPos 705 vars[i] = val 706 707 default: 708 err = moerr.NewInternalError(requestCtx, "unsupport parameter type") 709 return 710 } 711 } 712 } 713 714 return 715 } 716 717 func (mp *MysqlProtocolImpl) readDate(data []byte, pos int) (int, string) { 718 year, pos, _ := mp.io.ReadUint16(data, pos) 719 month := data[pos] 720 pos++ 721 day := data[pos] 722 pos++ 723 return pos, fmt.Sprintf("%04d-%02d-%02d", year, month, day) 724 } 725 726 func (mp *MysqlProtocolImpl) readTime(data []byte, pos int, len uint8) (int, string) { 727 var retStr string 728 negate := data[pos] 729 pos++ 730 if negate == 1 { 731 retStr += "-" 732 } 733 day, pos, _ := mp.io.ReadUint32(data, pos) 734 if day > 0 { 735 retStr += fmt.Sprintf("%dd ", day) 736 } 737 hour := data[pos] 738 pos++ 739 minute := data[pos] 740 pos++ 741 second := data[pos] 742 pos++ 743 744 if len == 12 { 745 ms, _, _ := mp.io.ReadUint32(data, pos) 746 retStr += fmt.Sprintf("%02d:%02d:%02d.%06d", hour, minute, second, ms) 747 } else { 748 retStr += fmt.Sprintf("%02d:%02d:%02d", hour, minute, second) 749 } 750 751 return pos, retStr 752 } 753 754 func (mp *MysqlProtocolImpl) readDateTime(data []byte, pos int) (int, string) { 755 pos, date := mp.readDate(data, pos) 756 hour := data[pos] 757 pos++ 758 minute := data[pos] 759 pos++ 760 second := data[pos] 761 pos++ 762 return pos, fmt.Sprintf("%s %02d:%02d:%02d", date, hour, minute, second) 763 } 764 765 func (mp *MysqlProtocolImpl) readTimestamp(data []byte, pos int) (int, string) { 766 pos, dateTime := mp.readDateTime(data, pos) 767 microSecond, pos, _ := mp.io.ReadUint32(data, pos) 768 return pos, fmt.Sprintf("%s.%06d", dateTime, microSecond) 769 } 770 771 // read an int with length encoded from the buffer at the position 772 // return the int ; position + the count of bytes for length encoded (1 or 3 or 4 or 9) 773 func (mp *MysqlProtocolImpl) readIntLenEnc(data []byte, pos int) (uint64, int, bool) { 774 if pos >= len(data) { 775 return 0, 0, false 776 } 777 switch data[pos] { 778 case 0xfb: 779 //zero, one byte 780 return 0, pos + 1, true 781 case 0xfc: 782 // int in two bytes 783 if pos+2 >= len(data) { 784 return 0, 0, false 785 } 786 value := uint64(data[pos+1]) | 787 uint64(data[pos+2])<<8 788 return value, pos + 3, true 789 case 0xfd: 790 // int in three bytes 791 if pos+3 >= len(data) { 792 return 0, 0, false 793 } 794 value := uint64(data[pos+1]) | 795 uint64(data[pos+2])<<8 | 796 uint64(data[pos+3])<<16 797 return value, pos + 4, true 798 case 0xfe: 799 // int in eight bytes 800 if pos+8 >= len(data) { 801 return 0, 0, false 802 } 803 value := uint64(data[pos+1]) | 804 uint64(data[pos+2])<<8 | 805 uint64(data[pos+3])<<16 | 806 uint64(data[pos+4])<<24 | 807 uint64(data[pos+5])<<32 | 808 uint64(data[pos+6])<<40 | 809 uint64(data[pos+7])<<48 | 810 uint64(data[pos+8])<<56 811 return value, pos + 9, true 812 } 813 // 0-250 814 return uint64(data[pos]), pos + 1, true 815 } 816 817 // write an int with length encoded into the buffer at the position 818 // return position + the count of bytes for length encoded (1 or 3 or 4 or 9) 819 func (mp *MysqlProtocolImpl) writeIntLenEnc(data []byte, pos int, value uint64) int { 820 switch { 821 case value < 251: 822 data[pos] = byte(value) 823 return pos + 1 824 case value < (1 << 16): 825 data[pos] = 0xfc 826 data[pos+1] = byte(value) 827 data[pos+2] = byte(value >> 8) 828 return pos + 3 829 case value < (1 << 24): 830 data[pos] = 0xfd 831 data[pos+1] = byte(value) 832 data[pos+2] = byte(value >> 8) 833 data[pos+3] = byte(value >> 16) 834 return pos + 4 835 default: 836 data[pos] = 0xfe 837 data[pos+1] = byte(value) 838 data[pos+2] = byte(value >> 8) 839 data[pos+3] = byte(value >> 16) 840 data[pos+4] = byte(value >> 24) 841 data[pos+5] = byte(value >> 32) 842 data[pos+6] = byte(value >> 40) 843 data[pos+7] = byte(value >> 48) 844 data[pos+8] = byte(value >> 56) 845 return pos + 9 846 } 847 } 848 849 // append an int with length encoded to the buffer 850 // return the buffer 851 func (mp *MysqlProtocolImpl) appendIntLenEnc(data []byte, value uint64) []byte { 852 mp.lenEncBuffer = mp.lenEncBuffer[:9] 853 pos := mp.writeIntLenEnc(mp.lenEncBuffer, 0, value) 854 return mp.append(data, mp.lenEncBuffer[:pos]...) 855 } 856 857 // read the count of bytes from the buffer at the position 858 // return bytes slice ; position + count ; true - succeeded or false - failed 859 func (mp *MysqlProtocolImpl) readCountOfBytes(data []byte, pos int, count int) ([]byte, int, bool) { 860 if pos+count-1 >= len(data) { 861 return nil, 0, false 862 } 863 return data[pos : pos+count], pos + count, true 864 } 865 866 // write the count of bytes into the buffer at the position 867 // return position + the number of bytes 868 func (mp *MysqlProtocolImpl) writeCountOfBytes(data []byte, pos int, value []byte) int { 869 pos += copy(data[pos:], value) 870 return pos 871 } 872 873 // append the count of bytes to the buffer 874 // return the buffer 875 func (mp *MysqlProtocolImpl) appendCountOfBytes(data []byte, value []byte) []byte { 876 return mp.append(data, value...) 877 } 878 879 // read a string with fixed length from the buffer at the position 880 // return string ; position + length ; true - succeeded or false - failed 881 func (mp *MysqlProtocolImpl) readStringFix(data []byte, pos int, length int) (string, int, bool) { 882 var sdata []byte 883 var ok bool 884 sdata, pos, ok = mp.readCountOfBytes(data, pos, length) 885 if !ok { 886 return "", 0, false 887 } 888 return string(sdata), pos, true 889 } 890 891 // write a string with fixed length into the buffer at the position 892 // return pos + string.length 893 func (mp *MysqlProtocolImpl) writeStringFix(data []byte, pos int, value string, length int) int { 894 pos += copy(data[pos:], value[0:length]) 895 return pos 896 } 897 898 // append a string with fixed length to the buffer 899 // return the buffer 900 func (mp *MysqlProtocolImpl) appendStringFix(data []byte, value string, length int) []byte { 901 return mp.append(data, []byte(value[:length])...) 902 } 903 904 // read a string appended with zero from the buffer at the position 905 // return string ; position + length of the string + 1; true - succeeded or false - failed 906 func (mp *MysqlProtocolImpl) readStringNUL(data []byte, pos int) (string, int, bool) { 907 zeroPos := bytes.IndexByte(data[pos:], 0) 908 if zeroPos == -1 { 909 return "", 0, false 910 } 911 return string(data[pos : pos+zeroPos]), pos + zeroPos + 1, true 912 } 913 914 // write a string into the buffer at the position, then appended with 0 915 // return pos + string.length + 1 916 func (mp *MysqlProtocolImpl) writeStringNUL(data []byte, pos int, value string) int { 917 pos = mp.writeStringFix(data, pos, value, len(value)) 918 data[pos] = 0 919 return pos + 1 920 } 921 922 // read a string with length encoded from the buffer at the position 923 // return string ; position + the count of bytes for length encoded (1 or 3 or 4 or 9) + length of the string; true - succeeded or false - failed 924 func (mp *MysqlProtocolImpl) readStringLenEnc(data []byte, pos int) (string, int, bool) { 925 var value uint64 926 var ok bool 927 value, pos, ok = mp.readIntLenEnc(data, pos) 928 if !ok { 929 return "", 0, false 930 } 931 sLength := int(value) 932 if pos+sLength-1 >= len(data) { 933 return "", 0, false 934 } 935 return string(data[pos : pos+sLength]), pos + sLength, true 936 } 937 938 // write a string with length encoded into the buffer at the position 939 // return position + the count of bytes for length encoded (1 or 3 or 4 or 9) + length of the string; 940 func (mp *MysqlProtocolImpl) writeStringLenEnc(data []byte, pos int, value string) int { 941 pos = mp.writeIntLenEnc(data, pos, uint64(len(value))) 942 return mp.writeStringFix(data, pos, value, len(value)) 943 } 944 945 // append a string with length encoded to the buffer 946 // return the buffer 947 func (mp *MysqlProtocolImpl) appendStringLenEnc(data []byte, value string) []byte { 948 data = mp.appendIntLenEnc(data, uint64(len(value))) 949 return mp.appendStringFix(data, value, len(value)) 950 } 951 952 // append bytes with length encoded to the buffer 953 // return the buffer 954 func (mp *MysqlProtocolImpl) appendCountOfBytesLenEnc(data []byte, value []byte) []byte { 955 data = mp.appendIntLenEnc(data, uint64(len(value))) 956 return mp.appendCountOfBytes(data, value) 957 } 958 959 // append an int64 value converted to string with length encoded to the buffer 960 // return the buffer 961 func (mp *MysqlProtocolImpl) appendStringLenEncOfInt64(data []byte, value int64) []byte { 962 mp.strconvBuffer = mp.strconvBuffer[:0] 963 mp.strconvBuffer = strconv.AppendInt(mp.strconvBuffer, value, 10) 964 return mp.appendCountOfBytesLenEnc(data, mp.strconvBuffer) 965 } 966 967 // append an uint64 value converted to string with length encoded to the buffer 968 // return the buffer 969 func (mp *MysqlProtocolImpl) appendStringLenEncOfUint64(data []byte, value uint64) []byte { 970 mp.strconvBuffer = mp.strconvBuffer[:0] 971 mp.strconvBuffer = strconv.AppendUint(mp.strconvBuffer, value, 10) 972 return mp.appendCountOfBytesLenEnc(data, mp.strconvBuffer) 973 } 974 975 // append an float32 value converted to string with length encoded to the buffer 976 // return the buffer 977 func (mp *MysqlProtocolImpl) appendStringLenEncOfFloat64(data []byte, value float64, bitSize int) []byte { 978 mp.strconvBuffer = mp.strconvBuffer[:0] 979 if !math.IsInf(value, 0) { 980 mp.strconvBuffer = strconv.AppendFloat(mp.strconvBuffer, value, 'f', -1, bitSize) 981 } else { 982 if math.IsInf(value, 1) { 983 mp.strconvBuffer = append(mp.strconvBuffer, []byte("+Infinity")...) 984 } else { 985 mp.strconvBuffer = append(mp.strconvBuffer, []byte("-Infinity")...) 986 } 987 } 988 return mp.appendCountOfBytesLenEnc(data, mp.strconvBuffer) 989 } 990 991 func (mp *MysqlProtocolImpl) appendUint8(data []byte, e uint8) []byte { 992 return mp.append(data, e) 993 } 994 995 func (mp *MysqlProtocolImpl) appendUint16(data []byte, e uint16) []byte { 996 buf := mp.lenEncBuffer[:2] 997 pos := mp.io.WriteUint16(buf, 0, e) 998 return mp.append(data, buf[:pos]...) 999 } 1000 1001 func (mp *MysqlProtocolImpl) appendUint32(data []byte, e uint32) []byte { 1002 buf := mp.lenEncBuffer[:4] 1003 pos := mp.io.WriteUint32(buf, 0, e) 1004 return mp.append(data, buf[:pos]...) 1005 } 1006 1007 func (mp *MysqlProtocolImpl) appendUint64(data []byte, e uint64) []byte { 1008 buf := mp.lenEncBuffer[:8] 1009 pos := mp.io.WriteUint64(buf, 0, e) 1010 return mp.append(data, buf[:pos]...) 1011 } 1012 1013 // write the count of zeros into the buffer at the position 1014 // return pos + count 1015 func (mp *MysqlProtocolImpl) writeZeros(data []byte, pos int, count int) int { 1016 for i := 0; i < count; i++ { 1017 data[pos+i] = 0 1018 } 1019 return pos + count 1020 } 1021 1022 // the server calculates the hash value of the password with the algorithm 1023 // and judges it with the authentication data from the client. 1024 // Algorithm: SHA1( password ) XOR SHA1( slat + SHA1( SHA1( password ) ) ) 1025 func (mp *MysqlProtocolImpl) checkPassword(password, salt, auth []byte) bool { 1026 //if len(password) == 0 { 1027 // return false 1028 //} 1029 //hash1 = SHA1(password) 1030 sha := sha1.New() 1031 _, err := sha.Write(password) 1032 if err != nil { 1033 logutil.Errorf("SHA1(password) failed.") 1034 return false 1035 } 1036 hash1 := sha.Sum(nil) 1037 1038 //hash2 = SHA1(SHA1(password)) 1039 sha.Reset() 1040 _, err = sha.Write(hash1) 1041 if err != nil { 1042 logutil.Errorf("SHA1(SHA1(password)) failed.") 1043 return false 1044 } 1045 hash2 := sha.Sum(nil) 1046 1047 //hash3 = SHA1(salt + SHA1(SHA1(password))) 1048 sha.Reset() 1049 _, err = sha.Write(salt) 1050 if err != nil { 1051 logutil.Errorf("write salt failed.") 1052 return false 1053 } 1054 _, err = sha.Write(hash2) 1055 if err != nil { 1056 logutil.Errorf("write SHA1(SHA1(password)) failed.") 1057 return false 1058 } 1059 hash3 := sha.Sum(nil) 1060 1061 //SHA1(password) XOR SHA1(salt + SHA1(SHA1(password))) 1062 for i := range hash1 { 1063 hash1[i] ^= hash3[i] 1064 } 1065 1066 logDebugf(mp.getProfile(profileTypeConcise), "server calculated %v", hash1) 1067 logDebugf(mp.getProfile(profileTypeConcise), "client calculated %v", auth) 1068 1069 return bytes.Equal(hash1, auth) 1070 } 1071 1072 // the server authenticate that the client can connect and use the database 1073 func (mp *MysqlProtocolImpl) authenticateUser(ctx context.Context, authResponse []byte) error { 1074 var psw []byte 1075 var err error 1076 var tenant *TenantInfo 1077 1078 ses := mp.GetSession() 1079 if !mp.GetSkipCheckUser() { 1080 logDebugf(mp.getProfile(profileTypeConcise), "authenticate user 1") 1081 psw, err = ses.AuthenticateUser(mp.GetUserName()) 1082 if err != nil { 1083 return err 1084 } 1085 logDebugf(mp.getProfile(profileTypeConcise), "authenticate user 2") 1086 1087 //TO Check password 1088 if mp.checkPassword(psw, mp.GetSalt(), authResponse) { 1089 logInfof(mp.getProfile(profileTypeConcise), "check password succeeded") 1090 } else { 1091 return moerr.NewInternalError(ctx, "check password failed") 1092 } 1093 } else { 1094 logDebugf(mp.getProfile(profileTypeConcise), "skip authenticate user") 1095 //Get tenant info 1096 tenant, err = GetTenantInfo(ctx, mp.GetUserName()) 1097 if err != nil { 1098 return err 1099 } 1100 1101 if ses != nil { 1102 ses.SetTenantInfo(tenant) 1103 1104 //TO Check password 1105 if len(psw) == 0 || mp.checkPassword(psw, mp.GetSalt(), authResponse) { 1106 logInfof(mp.getProfile(profileTypeConcise), "check password succeeded") 1107 } else { 1108 return moerr.NewInternalError(ctx, "check password failed") 1109 } 1110 } 1111 } 1112 1113 return nil 1114 } 1115 1116 func (mp *MysqlProtocolImpl) HandleHandshake(ctx context.Context, payload []byte) (bool, error) { 1117 var err, err2 error 1118 if len(payload) < 2 { 1119 return false, moerr.NewInternalError(ctx, "received a broken response packet") 1120 } 1121 1122 var authResponse []byte 1123 if capabilities, _, ok := mp.io.ReadUint16(payload, 0); !ok { 1124 return false, moerr.NewInternalError(ctx, "read capabilities from response packet failed") 1125 } else if uint32(capabilities)&CLIENT_PROTOCOL_41 != 0 { 1126 var resp41 response41 1127 var ok2 bool 1128 logDebugf(mp.getProfile(profileTypeConcise), "analyse handshake response") 1129 if ok2, resp41, err = mp.analyseHandshakeResponse41(ctx, payload); !ok2 { 1130 return false, err 1131 } 1132 1133 // client ask server to upgradeTls 1134 if resp41.isAskForTlsHeader { 1135 return true, nil 1136 } 1137 1138 authResponse = resp41.authResponse 1139 mp.capability = mp.capability & resp41.capabilities 1140 1141 if nameAndCharset, ok3 := collationID2CharsetAndName[int(resp41.collationID)]; !ok3 { 1142 return false, moerr.NewInternalError(ctx, "get collationName and charset failed") 1143 } else { 1144 mp.collationID = int(resp41.collationID) 1145 mp.collationName = nameAndCharset.collationName 1146 mp.charset = nameAndCharset.charset 1147 } 1148 1149 mp.maxClientPacketSize = resp41.maxPacketSize 1150 mp.username = resp41.username 1151 mp.database = resp41.database 1152 } else { 1153 var resp320 response320 1154 var ok2 bool 1155 if ok2, resp320, err = mp.analyseHandshakeResponse320(ctx, payload); !ok2 { 1156 return false, err 1157 } 1158 1159 // client ask server to upgradeTls 1160 if resp320.isAskForTlsHeader { 1161 return true, nil 1162 } 1163 1164 authResponse = resp320.authResponse 1165 mp.capability = mp.capability & resp320.capabilities 1166 mp.collationID = int(Utf8mb4CollationID) 1167 mp.collationName = "utf8mb4_general_ci" 1168 mp.charset = "utf8mb4" 1169 1170 mp.maxClientPacketSize = resp320.maxPacketSize 1171 mp.username = resp320.username 1172 mp.database = resp320.database 1173 } 1174 1175 logDebugf(mp.getProfile(profileTypeConcise), "authenticate user") 1176 if err = mp.authenticateUser(ctx, authResponse); err != nil { 1177 logutil.Errorf("authenticate user failed.error:%v", err) 1178 fail := moerr.MysqlErrorMsgRefer[moerr.ER_ACCESS_DENIED_ERROR] 1179 tipsFormat := "Access denied for user %s. %s" 1180 msg := fmt.Sprintf(tipsFormat, mp.username, err.Error()) 1181 err2 = mp.sendErrPacket(fail.ErrorCode, fail.SqlStates[0], msg) 1182 if err2 != nil { 1183 logutil.Errorf("send err packet failed.error:%v", err2) 1184 return false, err2 1185 } 1186 return false, err 1187 } 1188 1189 logDebugf(mp.getProfile(profileTypeConcise), "handle handshake end") 1190 err = mp.sendOKPacket(0, 0, 0, 0, "") 1191 if err != nil { 1192 return false, err 1193 } 1194 return false, nil 1195 } 1196 1197 // the server makes a handshake v10 packet 1198 // return handshake packet 1199 func (mp *MysqlProtocolImpl) makeHandshakeV10Payload() []byte { 1200 var data = make([]byte, HeaderOffset+256) 1201 var pos = HeaderOffset 1202 //int<1> protocol version 1203 pos = mp.io.WriteUint8(data, pos, clientProtocolVersion) 1204 1205 pos = mp.writeStringNUL(data, pos, mp.SV.ServerVersionPrefix+serverVersion.Load().(string)) 1206 1207 //int<4> connection id 1208 pos = mp.io.WriteUint32(data, pos, mp.ConnectionID()) 1209 1210 //string[8] auth-plugin-data-part-1 1211 pos = mp.writeCountOfBytes(data, pos, mp.GetSalt()[0:8]) 1212 1213 //int<1> filler 0 1214 pos = mp.io.WriteUint8(data, pos, 0) 1215 1216 //int<2> capabilities flags (lower 2 bytes) 1217 pos = mp.io.WriteUint16(data, pos, uint16(mp.capability&0xFFFF)) 1218 1219 //int<1> character set 1220 pos = mp.io.WriteUint8(data, pos, utf8mb4BinCollationID) 1221 1222 //int<2> status flags 1223 pos = mp.io.WriteUint16(data, pos, DefaultClientConnStatus) 1224 1225 //int<2> capabilities flags (upper 2 bytes) 1226 pos = mp.io.WriteUint16(data, pos, uint16((DefaultCapability>>16)&0xFFFF)) 1227 1228 if (DefaultCapability & CLIENT_PLUGIN_AUTH) != 0 { 1229 //int<1> length of auth-plugin-data 1230 //set 21 always 1231 pos = mp.io.WriteUint8(data, pos, uint8(len(mp.GetSalt())+1)) 1232 } else { 1233 //int<1> [00] 1234 //set 0 always 1235 pos = mp.io.WriteUint8(data, pos, 0) 1236 } 1237 1238 //string[10] reserved (all [00]) 1239 pos = mp.writeZeros(data, pos, 10) 1240 1241 if (DefaultCapability & CLIENT_SECURE_CONNECTION) != 0 { 1242 //string[$len] auth-plugin-data-part-2 ($len=MAX(13, length of auth-plugin-data - 8)) 1243 pos = mp.writeCountOfBytes(data, pos, mp.GetSalt()[8:]) 1244 pos = mp.io.WriteUint8(data, pos, 0) 1245 } 1246 1247 if (DefaultCapability & CLIENT_PLUGIN_AUTH) != 0 { 1248 //string[NUL] auth-plugin name 1249 pos = mp.writeStringNUL(data, pos, AuthNativePassword) 1250 } 1251 1252 return data[:pos] 1253 } 1254 1255 // the server analyses handshake response41 info from the client 1256 // return true - analysed successfully / false - failed ; response41 ; error 1257 func (mp *MysqlProtocolImpl) analyseHandshakeResponse41(ctx context.Context, data []byte) (bool, response41, error) { 1258 var pos = 0 1259 var ok bool 1260 var info response41 1261 1262 //int<4> capabilities flags of the client, CLIENT_PROTOCOL_41 always set 1263 info.capabilities, pos, ok = mp.io.ReadUint32(data, pos) 1264 if !ok { 1265 return false, info, moerr.NewInternalError(ctx, "get capabilities failed") 1266 } 1267 1268 if (info.capabilities & CLIENT_PROTOCOL_41) == 0 { 1269 return false, info, moerr.NewInternalError(ctx, "capabilities does not have protocol 41") 1270 } 1271 1272 //int<4> max-packet size 1273 //max size of a command packet that the client wants to send to the server 1274 info.maxPacketSize, pos, ok = mp.io.ReadUint32(data, pos) 1275 if !ok { 1276 return false, info, moerr.NewInternalError(ctx, "get max packet size failed") 1277 } 1278 1279 //int<1> character set 1280 //connection's default character set 1281 info.collationID, pos, ok = mp.io.ReadUint8(data, pos) 1282 if !ok { 1283 return false, info, moerr.NewInternalError(ctx, "get character set failed") 1284 } 1285 1286 if pos+22 >= len(data) { 1287 return false, info, moerr.NewInternalError(ctx, "skip reserved failed") 1288 } 1289 //string[23] reserved (all [0]) 1290 //just skip it 1291 pos += 23 1292 1293 // if client reply for upgradeTls, then data will contains header only. 1294 if pos == len(data) && (info.capabilities&CLIENT_SSL) != 0 { 1295 info.isAskForTlsHeader = true 1296 return true, info, nil 1297 } 1298 1299 //string[NUL] username 1300 info.username, pos, ok = mp.readStringNUL(data, pos) 1301 if !ok { 1302 return false, info, moerr.NewInternalError(ctx, "get username failed") 1303 } 1304 1305 /* 1306 if capabilities & CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA { 1307 lenenc-int length of auth-response 1308 string[n] auth-response 1309 } else if capabilities & CLIENT_SECURE_CONNECTION { 1310 int<1> length of auth-response 1311 string[n] auth-response 1312 } else { 1313 string[NUL] auth-response 1314 } 1315 */ 1316 if (info.capabilities & CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA) != 0 { 1317 var l uint64 1318 l, pos, ok = mp.readIntLenEnc(data, pos) 1319 if !ok { 1320 return false, info, moerr.NewInternalError(ctx, "get length of auth-response failed") 1321 } 1322 info.authResponse, pos, ok = mp.readCountOfBytes(data, pos, int(l)) 1323 if !ok { 1324 return false, info, moerr.NewInternalError(ctx, "get auth-response failed") 1325 } 1326 } else if (info.capabilities & CLIENT_SECURE_CONNECTION) != 0 { 1327 var l uint8 1328 l, pos, ok = mp.io.ReadUint8(data, pos) 1329 if !ok { 1330 return false, info, moerr.NewInternalError(ctx, "get length of auth-response failed") 1331 } 1332 info.authResponse, pos, ok = mp.readCountOfBytes(data, pos, int(l)) 1333 if !ok { 1334 return false, info, moerr.NewInternalError(ctx, "get auth-response failed") 1335 } 1336 } else { 1337 var auth string 1338 auth, pos, ok = mp.readStringNUL(data, pos) 1339 if !ok { 1340 return false, info, moerr.NewInternalError(ctx, "get auth-response failed") 1341 } 1342 info.authResponse = []byte(auth) 1343 } 1344 1345 if (info.capabilities & CLIENT_CONNECT_WITH_DB) != 0 { 1346 info.database, pos, ok = mp.readStringNUL(data, pos) 1347 if !ok { 1348 return false, info, moerr.NewInternalError(ctx, "get database failed") 1349 } 1350 } 1351 1352 if (info.capabilities & CLIENT_PLUGIN_AUTH) != 0 { 1353 info.clientPluginName, _, ok = mp.readStringNUL(data, pos) 1354 if !ok { 1355 return false, info, moerr.NewInternalError(ctx, "get auth plugin name failed") 1356 } 1357 1358 //to switch authenticate method 1359 if info.clientPluginName != AuthNativePassword { 1360 var err error 1361 if info.authResponse, err = mp.negotiateAuthenticationMethod(ctx); err != nil { 1362 return false, info, moerr.NewInternalError(ctx, "negotiate authentication method failed. error:%v", err) 1363 } 1364 info.clientPluginName = AuthNativePassword 1365 } 1366 } 1367 1368 //drop client connection attributes 1369 return true, info, nil 1370 } 1371 1372 /* 1373 //the server does something after receiving a handshake response41 from the client 1374 //like check user and password 1375 //and other things 1376 func (mp *MysqlProtocolImpl) handleClientResponse41(resp41 response41) error { 1377 //to do something else 1378 //logutil.Infof("capabilities 0x%x\n", resp41.capabilities) 1379 //logutil.Infof("maxPacketSize %d\n", resp41.maxPacketSize) 1380 //logutil.Infof("collationID %d\n", resp41.collationID) 1381 //logutil.Infof("username %s\n", resp41.username) 1382 //logutil.Infof("authResponse: \n") 1383 //update the capabilities with client's capabilities 1384 mp.capability = DefaultCapability & resp41.capabilities 1385 1386 //character set 1387 if nameAndCharset, ok := collationID2CharsetAndName[int(resp41.collationID)]; !ok { 1388 return moerr.NewInternalError(requestCtx, "get collationName and charset failed") 1389 } else { 1390 mp.collationID = int(resp41.collationID) 1391 mp.collationName = nameAndCharset.collationName 1392 mp.charset = nameAndCharset.charset 1393 } 1394 1395 mp.maxClientPacketSize = resp41.maxPacketSize 1396 mp.username = resp41.username 1397 mp.database = resp41.database 1398 1399 //logutil.Infof("collationID %d collatonName %s charset %s \n", mp.collationID, mp.collationName, mp.charset) 1400 //logutil.Infof("database %s \n", resp41.database) 1401 //logutil.Infof("clientPluginName %s \n", resp41.clientPluginName) 1402 return nil 1403 } 1404 */ 1405 1406 // the server analyses handshake response320 info from the old client 1407 // return true - analysed successfully / false - failed ; response320 ; error 1408 func (mp *MysqlProtocolImpl) analyseHandshakeResponse320(ctx context.Context, data []byte) (bool, response320, error) { 1409 var pos = 0 1410 var ok bool 1411 var info response320 1412 var capa uint16 1413 1414 //int<2> capabilities flags, CLIENT_PROTOCOL_41 never set 1415 capa, pos, ok = mp.io.ReadUint16(data, pos) 1416 if !ok { 1417 return false, info, moerr.NewInternalError(ctx, "get capabilities failed") 1418 } 1419 info.capabilities = uint32(capa) 1420 1421 if pos+2 >= len(data) { 1422 return false, info, moerr.NewInternalError(ctx, "get max-packet-size failed") 1423 } 1424 1425 //int<3> max-packet size 1426 //max size of a command packet that the client wants to send to the server 1427 info.maxPacketSize = uint32(data[pos]) | uint32(data[pos+1])<<8 | uint32(data[pos+2])<<16 1428 pos += 3 1429 1430 // if client reply for upgradeTls, then data will contains header only. 1431 if pos == len(data) && (info.capabilities&CLIENT_SSL) != 0 { 1432 info.isAskForTlsHeader = true 1433 return true, info, nil 1434 } 1435 1436 //string[NUL] username 1437 info.username, pos, ok = mp.readStringNUL(data, pos) 1438 if !ok { 1439 return false, info, moerr.NewInternalError(ctx, "get username failed") 1440 } 1441 1442 if (info.capabilities & CLIENT_CONNECT_WITH_DB) != 0 { 1443 var auth string 1444 auth, pos, ok = mp.readStringNUL(data, pos) 1445 if !ok { 1446 return false, info, moerr.NewInternalError(ctx, "get auth-response failed") 1447 } 1448 info.authResponse = []byte(auth) 1449 1450 info.database, _, ok = mp.readStringNUL(data, pos) 1451 if !ok { 1452 return false, info, moerr.NewInternalError(ctx, "get database failed") 1453 } 1454 } else { 1455 info.authResponse, _, ok = mp.readCountOfBytes(data, pos, len(data)-pos) 1456 if !ok { 1457 return false, info, moerr.NewInternalError(ctx, "get auth-response failed") 1458 } 1459 } 1460 1461 return true, info, nil 1462 } 1463 1464 /* 1465 //the server does something after receiving a handshake response320 from the client 1466 //like check user and password 1467 //and other things 1468 func (mp *MysqlProtocolImpl) handleClientResponse320(resp320 response320) error { 1469 //to do something else 1470 //logutil.Infof("capabilities 0x%x\n", resp320.capabilities) 1471 //logutil.Infof("maxPacketSize %d\n", resp320.maxPacketSize) 1472 //logutil.Infof("username %s\n", resp320.username) 1473 //logutil.Infof("authResponse: \n") 1474 1475 //update the capabilities with client's capabilities 1476 mp.capability = DefaultCapability & resp320.capabilities 1477 1478 //if the client does not notice its default charset, the server gives a default charset. 1479 //Run the sql in mysql 8.0.23 to get the charset 1480 //the sql: select * from information_schema.collations where collation_name = 'utf8mb4_general_ci'; 1481 mp.collationID = int(Utf8mb4CollationID) 1482 mp.collationName = "utf8mb4_general_ci" 1483 mp.charset = "utf8mb4" 1484 1485 mp.maxClientPacketSize = resp320.maxPacketSize 1486 mp.username = resp320.username 1487 mp.database = resp320.database 1488 1489 //logutil.Infof("collationID %d collatonName %s charset %s \n", mp.collationID, mp.collationName, mp.charset) 1490 //logutil.Infof("database %s \n", resp320.database) 1491 return nil 1492 } 1493 */ 1494 1495 // the server makes a AuthSwitchRequest that asks the client to authenticate the data with new method 1496 func (mp *MysqlProtocolImpl) makeAuthSwitchRequestPayload(authMethodName string) []byte { 1497 data := make([]byte, HeaderOffset+1+len(authMethodName)+1+len(mp.GetSalt())+1) 1498 pos := HeaderOffset 1499 pos = mp.io.WriteUint8(data, pos, defines.EOFHeader) 1500 pos = mp.writeStringNUL(data, pos, authMethodName) 1501 pos = mp.writeCountOfBytes(data, pos, mp.GetSalt()) 1502 pos = mp.io.WriteUint8(data, pos, 0) 1503 return data[:pos] 1504 } 1505 1506 // the server can send AuthSwitchRequest to ask client to use designated authentication method, 1507 // if both server and client support CLIENT_PLUGIN_AUTH capability. 1508 // return data authenticated with new method 1509 func (mp *MysqlProtocolImpl) negotiateAuthenticationMethod(ctx context.Context) ([]byte, error) { 1510 var err error 1511 aswPkt := mp.makeAuthSwitchRequestPayload(AuthNativePassword) 1512 err = mp.writePackets(aswPkt) 1513 if err != nil { 1514 return nil, err 1515 } 1516 1517 read, err := mp.tcpConn.Read(goetty.ReadOptions{}) 1518 if err != nil { 1519 return nil, err 1520 } 1521 1522 if read == nil { 1523 return nil, moerr.NewInternalError(ctx, "read nil from tcp conn") 1524 } 1525 1526 pack, ok := read.(*Packet) 1527 if !ok { 1528 return nil, moerr.NewInternalError(ctx, "it is not the Packet") 1529 } 1530 1531 if pack == nil { 1532 return nil, moerr.NewInternalError(ctx, "packet is null") 1533 } 1534 1535 data := pack.Payload 1536 mp.AddSequenceId(1) 1537 return data, nil 1538 } 1539 1540 // make a OK packet 1541 func (mp *MysqlProtocolImpl) makeOKPayload(affectedRows, lastInsertId uint64, statusFlags, warnings uint16, message string) []byte { 1542 data := make([]byte, HeaderOffset+128+len(message)+10) 1543 var pos = HeaderOffset 1544 pos = mp.io.WriteUint8(data, pos, defines.OKHeader) 1545 pos = mp.writeIntLenEnc(data, pos, affectedRows) 1546 pos = mp.writeIntLenEnc(data, pos, lastInsertId) 1547 if (mp.capability & CLIENT_PROTOCOL_41) != 0 { 1548 pos = mp.io.WriteUint16(data, pos, statusFlags) 1549 pos = mp.io.WriteUint16(data, pos, warnings) 1550 } else if (mp.capability & CLIENT_TRANSACTIONS) != 0 { 1551 pos = mp.io.WriteUint16(data, pos, statusFlags) 1552 } 1553 1554 if mp.capability&CLIENT_SESSION_TRACK != 0 { 1555 //TODO:implement it 1556 } else { 1557 //string<lenenc> instead of string<EOF> in the manual of mysql 1558 pos = mp.writeStringLenEnc(data, pos, message) 1559 return data[:pos] 1560 } 1561 return data[:pos] 1562 } 1563 1564 func (mp *MysqlProtocolImpl) makeOKPayloadWithEof(affectedRows, lastInsertId uint64, statusFlags, warnings uint16, message string) []byte { 1565 data := make([]byte, HeaderOffset+128+len(message)+10) 1566 var pos = HeaderOffset 1567 pos = mp.io.WriteUint8(data, pos, defines.EOFHeader) 1568 pos = mp.writeIntLenEnc(data, pos, affectedRows) 1569 pos = mp.writeIntLenEnc(data, pos, lastInsertId) 1570 if (mp.capability & CLIENT_PROTOCOL_41) != 0 { 1571 pos = mp.io.WriteUint16(data, pos, statusFlags) 1572 pos = mp.io.WriteUint16(data, pos, warnings) 1573 } else if (mp.capability & CLIENT_TRANSACTIONS) != 0 { 1574 pos = mp.io.WriteUint16(data, pos, statusFlags) 1575 } 1576 1577 if mp.capability&CLIENT_SESSION_TRACK != 0 { 1578 //TODO:implement it 1579 } else { 1580 //string<lenenc> instead of string<EOF> in the manual of mysql 1581 pos = mp.writeStringLenEnc(data, pos, message) 1582 return data[:pos] 1583 } 1584 return data[:pos] 1585 } 1586 1587 func (mp *MysqlProtocolImpl) makeLocalInfileRequestPayload(filename string) []byte { 1588 data := make([]byte, HeaderOffset+1+len(filename)+1) 1589 pos := HeaderOffset 1590 pos = mp.io.WriteUint8(data, pos, defines.LocalInFileHeader) 1591 pos = mp.writeStringFix(data, pos, filename, len(filename)) 1592 return data[:pos] 1593 } 1594 1595 func (mp *MysqlProtocolImpl) sendLocalInfileRequest(filename string) error { 1596 req := mp.makeLocalInfileRequestPayload(filename) 1597 return mp.writePackets(req) 1598 } 1599 1600 func (mp *MysqlProtocolImpl) sendOKPacketWithEof(affectedRows, lastInsertId uint64, status, warnings uint16, message string) error { 1601 okPkt := mp.makeOKPayloadWithEof(affectedRows, lastInsertId, status, warnings, message) 1602 return mp.writePackets(okPkt) 1603 } 1604 1605 // send OK packet to the client 1606 func (mp *MysqlProtocolImpl) sendOKPacket(affectedRows, lastInsertId uint64, status, warnings uint16, message string) error { 1607 okPkt := mp.makeOKPayload(affectedRows, lastInsertId, status, warnings, message) 1608 return mp.writePackets(okPkt) 1609 } 1610 1611 // make Err packet 1612 func (mp *MysqlProtocolImpl) makeErrPayload(errorCode uint16, sqlState, errorMessage string) []byte { 1613 data := make([]byte, HeaderOffset+9+len(errorMessage)) 1614 pos := HeaderOffset 1615 pos = mp.io.WriteUint8(data, pos, defines.ErrHeader) 1616 pos = mp.io.WriteUint16(data, pos, errorCode) 1617 if mp.capability&CLIENT_PROTOCOL_41 != 0 { 1618 pos = mp.io.WriteUint8(data, pos, '#') 1619 if len(sqlState) < 5 { 1620 stuff := " " 1621 sqlState += stuff[:5-len(sqlState)] 1622 } 1623 pos = mp.writeStringFix(data, pos, sqlState, 5) 1624 } 1625 pos = mp.writeStringFix(data, pos, errorMessage, len(errorMessage)) 1626 return data[:pos] 1627 } 1628 1629 /* 1630 the server sends the Error packet 1631 1632 information from https://dev.mysql.com/doc/mysql-errors/8.0/en/server-error-reference.html 1633 mysql version 8.0.23 1634 usually it is in the directory /usr/local/include/mysql/mysqld_error.h 1635 1636 Error information includes several elements: an error code, SQLSTATE value, and message string. 1637 1638 Error code: This value is numeric. It is MySQL-specific and is not portable to other database systems. 1639 SQLSTATE value: This value is a five-character string (for example, '42S02'). SQLSTATE values are taken from ANSI SQL and ODBC and are more standardized than the numeric error codes. 1640 Message string: This string provides a textual description of the error. 1641 */ 1642 func (mp *MysqlProtocolImpl) sendErrPacket(errorCode uint16, sqlState, errorMessage string) error { 1643 if mp.ses != nil { 1644 mp.ses.GetErrInfo().push(errorCode, errorMessage) 1645 } 1646 errPkt := mp.makeErrPayload(errorCode, sqlState, errorMessage) 1647 return mp.writePackets(errPkt) 1648 } 1649 1650 func (mp *MysqlProtocolImpl) makeEOFPayload(warnings, status uint16) []byte { 1651 data := make([]byte, HeaderOffset+10) 1652 pos := HeaderOffset 1653 pos = mp.io.WriteUint8(data, pos, defines.EOFHeader) 1654 if mp.capability&CLIENT_PROTOCOL_41 != 0 { 1655 pos = mp.io.WriteUint16(data, pos, warnings) 1656 pos = mp.io.WriteUint16(data, pos, status) 1657 } 1658 return data[:pos] 1659 } 1660 1661 func (mp *MysqlProtocolImpl) sendEOFPacket(warnings, status uint16) error { 1662 data := mp.makeEOFPayload(warnings, status) 1663 return mp.writePackets(data) 1664 } 1665 1666 func (mp *MysqlProtocolImpl) SendEOFPacketIf(warnings, status uint16) error { 1667 //If the CLIENT_DEPRECATE_EOF client capabilities flag is not set, EOF_Packet 1668 if mp.capability&CLIENT_DEPRECATE_EOF == 0 { 1669 return mp.sendEOFPacket(warnings, status) 1670 } 1671 return nil 1672 } 1673 1674 // the OK or EOF packet 1675 // thread safe 1676 func (mp *MysqlProtocolImpl) sendEOFOrOkPacket(warnings, status uint16) error { 1677 //If the CLIENT_DEPRECATE_EOF client capabilities flag is set, OK_Packet; else EOF_Packet. 1678 if mp.capability&CLIENT_DEPRECATE_EOF != 0 { 1679 return mp.sendOKPacketWithEof(0, 0, status, 0, "") 1680 } else { 1681 return mp.sendEOFPacket(warnings, status) 1682 } 1683 } 1684 1685 func setColLength(column *MysqlColumn, width int32) { 1686 column.length = column.columnType.GetLength(width) 1687 } 1688 1689 func setColFlag(column *MysqlColumn) { 1690 if column.auto_incr { 1691 column.flag |= uint16(defines.AUTO_INCREMENT_FLAG) 1692 } 1693 } 1694 1695 func setCharacter(column *MysqlColumn) { 1696 switch column.columnType { 1697 // blob type should use 0x3f to show the binary data 1698 case defines.MYSQL_TYPE_VARCHAR, defines.MYSQL_TYPE_STRING, defines.MYSQL_TYPE_TEXT: 1699 column.SetCharset(0x21) 1700 default: 1701 column.SetCharset(0x3f) 1702 } 1703 } 1704 1705 // make the column information with the format of column definition41 1706 func (mp *MysqlProtocolImpl) makeColumnDefinition41Payload(column *MysqlColumn, cmd int) []byte { 1707 space := HeaderOffset + 8*9 + //lenenc bytes of 8 fields 1708 21 + //fixed-length fields 1709 3 + // catalog "def" 1710 len(column.Schema()) + 1711 len(column.Table()) + 1712 len(column.OrgTable()) + 1713 len(column.Name()) + 1714 len(column.OrgName()) + 1715 len(column.DefaultValue()) + 1716 100 // for safe 1717 1718 data := make([]byte, space) 1719 pos := HeaderOffset 1720 1721 //lenenc_str catalog(always "def") 1722 pos = mp.writeStringLenEnc(data, pos, "def") 1723 1724 //lenenc_str schema 1725 pos = mp.writeStringLenEnc(data, pos, column.Schema()) 1726 1727 //lenenc_str table 1728 pos = mp.writeStringLenEnc(data, pos, column.Table()) 1729 1730 //lenenc_str org_table 1731 pos = mp.writeStringLenEnc(data, pos, column.OrgTable()) 1732 1733 //lenenc_str name 1734 pos = mp.writeStringLenEnc(data, pos, column.Name()) 1735 1736 //lenenc_str org_name 1737 pos = mp.writeStringLenEnc(data, pos, column.OrgName()) 1738 1739 //lenenc_int length of fixed-length fields [0c] 1740 pos = mp.io.WriteUint8(data, pos, 0x0c) 1741 1742 //int<2> character set 1743 pos = mp.io.WriteUint16(data, pos, column.Charset()) 1744 1745 //int<4> column length 1746 pos = mp.io.WriteUint32(data, pos, column.Length()) 1747 1748 //int<1> type 1749 pos = mp.io.WriteUint8(data, pos, uint8(column.ColumnType())) 1750 1751 //int<2> flags 1752 pos = mp.io.WriteUint16(data, pos, column.Flag()) 1753 1754 //int<1> decimals 1755 pos = mp.io.WriteUint8(data, pos, column.Decimal()) 1756 1757 //int<2> filler [00] [00] 1758 pos = mp.io.WriteUint16(data, pos, 0) 1759 1760 if CommandType(cmd) == COM_FIELD_LIST { 1761 pos = mp.writeIntLenEnc(data, pos, uint64(len(column.DefaultValue()))) 1762 pos = mp.writeCountOfBytes(data, pos, column.DefaultValue()) 1763 } 1764 1765 return data[:pos] 1766 } 1767 1768 // SendColumnDefinitionPacket the server send the column definition to the client 1769 func (mp *MysqlProtocolImpl) SendColumnDefinitionPacket(ctx context.Context, column Column, cmd int) error { 1770 mysqlColumn, ok := column.(*MysqlColumn) 1771 if !ok { 1772 return moerr.NewInternalError(ctx, "sendColumn need MysqlColumn") 1773 } 1774 1775 var data []byte 1776 if mp.capability&CLIENT_PROTOCOL_41 != 0 { 1777 data = mp.makeColumnDefinition41Payload(mysqlColumn, cmd) 1778 } 1779 1780 return mp.writePackets(data) 1781 } 1782 1783 // SendColumnCountPacket makes the column count packet 1784 func (mp *MysqlProtocolImpl) SendColumnCountPacket(count uint64) error { 1785 data := make([]byte, HeaderOffset+20) 1786 pos := HeaderOffset 1787 pos = mp.writeIntLenEnc(data, pos, count) 1788 1789 return mp.writePackets(data[:pos]) 1790 } 1791 1792 func (mp *MysqlProtocolImpl) sendColumns(ctx context.Context, mrs *MysqlResultSet, cmd int, warnings, status uint16) error { 1793 //column_count * Protocol::ColumnDefinition packets 1794 for i := uint64(0); i < mrs.GetColumnCount(); i++ { 1795 var col Column 1796 col, err := mrs.GetColumn(ctx, i) 1797 if err != nil { 1798 return err 1799 } 1800 1801 err = mp.SendColumnDefinitionPacket(ctx, col, cmd) 1802 if err != nil { 1803 return err 1804 } 1805 } 1806 1807 //If the CLIENT_DEPRECATE_EOF client capabilities flag is not set, EOF_Packet 1808 if mp.capability&CLIENT_DEPRECATE_EOF == 0 { 1809 err := mp.sendEOFPacket(warnings, status) 1810 if err != nil { 1811 return err 1812 } 1813 } 1814 return nil 1815 } 1816 1817 // the server convert every row of the result set into the format that mysql protocol needs 1818 func (mp *MysqlProtocolImpl) makeResultSetBinaryRow(data []byte, mrs *MysqlResultSet, rowIdx uint64) ([]byte, error) { 1819 data = mp.append(data, defines.OKHeader) // append OkHeader 1820 1821 ctx := mp.ses.GetRequestContext() 1822 1823 // get null buffer 1824 buffer := mp.binaryNullBuffer[:0] 1825 columnsLength := mrs.GetColumnCount() 1826 numBytes4Null := (columnsLength + 7 + 2) / 8 1827 for i := uint64(0); i < numBytes4Null; i++ { 1828 buffer = append(buffer, 0) 1829 } 1830 for i := uint64(0); i < columnsLength; i++ { 1831 if isNil, err := mrs.ColumnIsNull(ctx, rowIdx, i); err != nil { 1832 return nil, err 1833 } else if isNil { 1834 bytePos := (i + 2) / 8 1835 bitPos := byte((i + 2) % 8) 1836 idx := int(bytePos) 1837 buffer[idx] |= 1 << bitPos 1838 continue 1839 } 1840 } 1841 data = mp.append(data, buffer...) 1842 1843 for i := uint64(0); i < columnsLength; i++ { 1844 if isNil, err := mrs.ColumnIsNull(ctx, rowIdx, i); err != nil { 1845 return nil, err 1846 } else if isNil { 1847 continue 1848 } 1849 1850 column, err := mrs.GetColumn(ctx, uint64(i)) 1851 if err != nil { 1852 return nil, err 1853 } 1854 mysqlColumn, ok := column.(*MysqlColumn) 1855 if !ok { 1856 return nil, moerr.NewInternalError(mp.ses.requestCtx, "sendColumn need MysqlColumn") 1857 } 1858 1859 switch mysqlColumn.ColumnType() { 1860 case defines.MYSQL_TYPE_TINY: 1861 if value, err := mrs.GetInt64(ctx, rowIdx, i); err != nil { 1862 return nil, err 1863 } else { 1864 data = mp.appendUint8(data, uint8(value)) 1865 } 1866 case defines.MYSQL_TYPE_SHORT, defines.MYSQL_TYPE_YEAR: 1867 if value, err := mrs.GetInt64(ctx, rowIdx, i); err != nil { 1868 return nil, err 1869 } else { 1870 data = mp.appendUint16(data, uint16(value)) 1871 } 1872 case defines.MYSQL_TYPE_INT24, defines.MYSQL_TYPE_LONG: 1873 if value, err := mrs.GetInt64(ctx, rowIdx, i); err != nil { 1874 return nil, err 1875 } else { 1876 buffer = mp.appendUint32(buffer, uint32(value)) 1877 } 1878 case defines.MYSQL_TYPE_LONGLONG: 1879 if value, err := mrs.GetUint64(ctx, rowIdx, i); err != nil { 1880 return nil, err 1881 } else { 1882 buffer = mp.appendUint64(buffer, value) 1883 } 1884 case defines.MYSQL_TYPE_FLOAT: 1885 if value, err := mrs.GetFloat64(ctx, rowIdx, i); err != nil { 1886 return nil, err 1887 } else { 1888 buffer = mp.appendUint32(buffer, math.Float32bits(float32(value))) 1889 } 1890 case defines.MYSQL_TYPE_DOUBLE: 1891 if value, err := mrs.GetFloat64(ctx, rowIdx, i); err != nil { 1892 return nil, err 1893 } else { 1894 buffer = mp.appendUint64(buffer, math.Float64bits(value)) 1895 } 1896 case defines.MYSQL_TYPE_VARCHAR, defines.MYSQL_TYPE_VAR_STRING, defines.MYSQL_TYPE_STRING, defines.MYSQL_TYPE_BLOB, defines.MYSQL_TYPE_TEXT, defines.MYSQL_TYPE_JSON: 1897 if value, err := mrs.GetString(ctx, rowIdx, i); err != nil { 1898 return nil, err 1899 } else { 1900 data = mp.appendStringLenEnc(data, value) 1901 } 1902 // TODO: some type, we use string now. someday need fix it 1903 case defines.MYSQL_TYPE_DECIMAL: 1904 if value, err := mrs.GetString(ctx, rowIdx, i); err != nil { 1905 return nil, err 1906 } else { 1907 data = mp.appendStringLenEnc(data, value) 1908 } 1909 case defines.MYSQL_TYPE_UUID: 1910 if value, err := mrs.GetString(ctx, rowIdx, i); err != nil { 1911 return nil, err 1912 } else { 1913 data = mp.appendStringLenEnc(data, value) 1914 } 1915 case defines.MYSQL_TYPE_DATE: 1916 if value, err := mrs.GetValue(ctx, rowIdx, i); err != nil { 1917 return nil, err 1918 } else { 1919 data = mp.appendDate(data, value.(types.Date)) 1920 } 1921 case defines.MYSQL_TYPE_TIME: 1922 if value, err := mrs.GetString(ctx, rowIdx, i); err != nil { 1923 return nil, err 1924 } else { 1925 var t types.Time 1926 var err error 1927 idx := strings.Index(value, ".") 1928 if idx == -1 { 1929 t, err = types.ParseTime(value, 0) 1930 } else { 1931 t, err = types.ParseTime(value, int32(len(value)-idx-1)) 1932 } 1933 if err != nil { 1934 data = mp.appendStringLenEnc(data, value) 1935 } else { 1936 data = mp.appendTime(data, t) 1937 } 1938 } 1939 case defines.MYSQL_TYPE_DATETIME, defines.MYSQL_TYPE_TIMESTAMP: 1940 if value, err := mrs.GetString(ctx, rowIdx, i); err != nil { 1941 return nil, err 1942 } else { 1943 var dt types.Datetime 1944 var err error 1945 idx := strings.Index(value, ".") 1946 if idx == -1 { 1947 dt, err = types.ParseDatetime(value, 0) 1948 } else { 1949 dt, err = types.ParseDatetime(value, int32(len(value)-idx-1)) 1950 } 1951 if err != nil { 1952 data = mp.appendStringLenEnc(data, value) 1953 } else { 1954 data = mp.appendDatetime(data, dt) 1955 } 1956 } 1957 // case defines.MYSQL_TYPE_TIMESTAMP: 1958 // if value, err := mrs.GetString(rowIdx, i); err != nil { 1959 // return nil, err 1960 // } else { 1961 // data = mp.appendStringLenEnc(data, value) 1962 // } 1963 default: 1964 return nil, moerr.NewInternalError(ctx, "type is not supported in binary text result row") 1965 } 1966 } 1967 1968 return data, nil 1969 } 1970 1971 // the server convert every row of the result set into the format that mysql protocol needs 1972 func (mp *MysqlProtocolImpl) makeResultSetTextRow(data []byte, mrs *MysqlResultSet, r uint64) ([]byte, error) { 1973 ctx := mp.ses.GetRequestContext() 1974 for i := uint64(0); i < mrs.GetColumnCount(); i++ { 1975 column, err := mrs.GetColumn(ctx, i) 1976 if err != nil { 1977 return nil, err 1978 } 1979 mysqlColumn, ok := column.(*MysqlColumn) 1980 if !ok { 1981 return nil, moerr.NewInternalError(mp.ses.requestCtx, "sendColumn need MysqlColumn") 1982 } 1983 1984 if isNil, err1 := mrs.ColumnIsNull(ctx, r, i); err1 != nil { 1985 return nil, err1 1986 } else if isNil { 1987 //NULL is sent as 0xfb 1988 data = mp.appendUint8(data, 0xFB) 1989 continue 1990 } 1991 1992 switch mysqlColumn.ColumnType() { 1993 case defines.MYSQL_TYPE_JSON: 1994 if value, err2 := mrs.GetString(ctx, r, i); err2 != nil { 1995 return nil, err2 1996 } else { 1997 data = mp.appendStringLenEnc(data, value) 1998 } 1999 case defines.MYSQL_TYPE_BOOL: 2000 if value, err2 := mrs.GetString(ctx, r, i); err2 != nil { 2001 return nil, err2 2002 } else { 2003 data = mp.appendStringLenEnc(data, value) 2004 } 2005 case defines.MYSQL_TYPE_DECIMAL: 2006 if value, err2 := mrs.GetString(ctx, r, i); err2 != nil { 2007 return nil, err2 2008 } else { 2009 data = mp.appendStringLenEnc(data, value) 2010 } 2011 case defines.MYSQL_TYPE_UUID: 2012 if value, err2 := mrs.GetString(ctx, r, i); err2 != nil { 2013 return nil, err2 2014 } else { 2015 data = mp.appendStringLenEnc(data, value) 2016 } 2017 case defines.MYSQL_TYPE_TINY, defines.MYSQL_TYPE_SHORT, defines.MYSQL_TYPE_INT24, defines.MYSQL_TYPE_LONG, defines.MYSQL_TYPE_YEAR: 2018 if value, err2 := mrs.GetInt64(ctx, r, i); err2 != nil { 2019 return nil, err2 2020 } else { 2021 if mysqlColumn.ColumnType() == defines.MYSQL_TYPE_YEAR { 2022 if value == 0 { 2023 data = mp.appendStringLenEnc(data, "0000") 2024 } else { 2025 data = mp.appendStringLenEncOfInt64(data, value) 2026 } 2027 } else { 2028 data = mp.appendStringLenEncOfInt64(data, value) 2029 } 2030 } 2031 case defines.MYSQL_TYPE_FLOAT: 2032 if value, err2 := mrs.GetFloat64(ctx, r, i); err2 != nil { 2033 return nil, err2 2034 } else { 2035 data = mp.appendStringLenEncOfFloat64(data, value, 32) 2036 } 2037 case defines.MYSQL_TYPE_DOUBLE: 2038 if value, err2 := mrs.GetFloat64(ctx, r, i); err2 != nil { 2039 return nil, err2 2040 } else { 2041 data = mp.appendStringLenEncOfFloat64(data, value, 64) 2042 } 2043 case defines.MYSQL_TYPE_LONGLONG: 2044 if uint32(mysqlColumn.Flag())&defines.UNSIGNED_FLAG != 0 { 2045 if value, err2 := mrs.GetUint64(ctx, r, i); err2 != nil { 2046 return nil, err2 2047 } else { 2048 data = mp.appendStringLenEncOfUint64(data, value) 2049 } 2050 } else { 2051 if value, err2 := mrs.GetInt64(ctx, r, i); err2 != nil { 2052 return nil, err2 2053 } else { 2054 data = mp.appendStringLenEncOfInt64(data, value) 2055 } 2056 } 2057 case defines.MYSQL_TYPE_VARCHAR, defines.MYSQL_TYPE_VAR_STRING, defines.MYSQL_TYPE_STRING, defines.MYSQL_TYPE_BLOB, defines.MYSQL_TYPE_TEXT: 2058 if value, err2 := mrs.GetString(ctx, r, i); err2 != nil { 2059 return nil, err2 2060 } else { 2061 data = mp.appendStringLenEnc(data, value) 2062 } 2063 case defines.MYSQL_TYPE_DATE: 2064 if value, err2 := mrs.GetValue(ctx, r, i); err2 != nil { 2065 return nil, err2 2066 } else { 2067 data = mp.appendStringLenEnc(data, value.(types.Date).String()) 2068 } 2069 case defines.MYSQL_TYPE_DATETIME: 2070 if value, err2 := mrs.GetString(ctx, r, i); err2 != nil { 2071 return nil, err2 2072 } else { 2073 data = mp.appendStringLenEnc(data, value) 2074 } 2075 case defines.MYSQL_TYPE_TIME: 2076 if value, err2 := mrs.GetString(ctx, r, i); err2 != nil { 2077 return nil, err2 2078 } else { 2079 data = mp.appendStringLenEnc(data, value) 2080 } 2081 case defines.MYSQL_TYPE_TIMESTAMP: 2082 if value, err2 := mrs.GetString(ctx, r, i); err2 != nil { 2083 return nil, err2 2084 } else { 2085 data = mp.appendStringLenEnc(data, value) 2086 } 2087 default: 2088 return nil, moerr.NewInternalError(mp.ses.requestCtx, "unsupported column type %d ", mysqlColumn.ColumnType()) 2089 } 2090 } 2091 return data, nil 2092 } 2093 2094 // the server send group row of the result set as an independent packet 2095 // thread safe 2096 func (mp *MysqlProtocolImpl) SendResultSetTextBatchRow(mrs *MysqlResultSet, cnt uint64) error { 2097 if cnt == 0 { 2098 return nil 2099 } 2100 2101 mp.GetLock().Lock() 2102 defer mp.GetLock().Unlock() 2103 var err error = nil 2104 2105 for i := uint64(0); i < cnt; i++ { 2106 if err = mp.sendResultSetTextRow(mrs, i); err != nil { 2107 return err 2108 } 2109 } 2110 return err 2111 } 2112 2113 func (mp *MysqlProtocolImpl) SendResultSetTextBatchRowSpeedup(mrs *MysqlResultSet, cnt uint64) error { 2114 if cnt == 0 { 2115 return nil 2116 } 2117 2118 cmd := mp.GetSession().GetCmd() 2119 mp.GetLock().Lock() 2120 defer mp.GetLock().Unlock() 2121 var err error = nil 2122 2123 binary := false 2124 // XXX now we known COM_QUERY will use textRow, COM_STMT_EXECUTE use binaryRow 2125 if CommandType(cmd) == COM_STMT_EXECUTE { 2126 binary = true 2127 } 2128 2129 //make rows into the batch 2130 for i := uint64(0); i < cnt; i++ { 2131 err = mp.openRow(nil) 2132 if err != nil { 2133 return err 2134 } 2135 //begin1 := time.Now() 2136 if binary { 2137 _, err = mp.makeResultSetBinaryRow(nil, mrs, i) 2138 } else { 2139 _, err = mp.makeResultSetTextRow(nil, mrs, i) 2140 } 2141 //mp.makeTime += time.Since(begin1) 2142 2143 if err != nil { 2144 //ERR_Packet in case of error 2145 err1 := mp.sendErrPacket(moerr.ER_UNKNOWN_ERROR, DefaultMySQLState, err.Error()) 2146 if err1 != nil { 2147 return err1 2148 } 2149 return err 2150 } 2151 2152 //output into outbuf 2153 err = mp.closeRow(nil) 2154 if err != nil { 2155 return err 2156 } 2157 } 2158 2159 return err 2160 } 2161 2162 // open a new row of the resultset 2163 func (mp *MysqlProtocolImpl) openRow(_ []byte) error { 2164 if mp.enableLog { 2165 logutil.Info("openRow") 2166 } 2167 return mp.openPacket() 2168 } 2169 2170 // close a finished row of the resultset 2171 func (mp *MysqlProtocolImpl) closeRow(_ []byte) error { 2172 if mp.enableLog { 2173 logutil.Info("closeRow") 2174 } 2175 2176 err := mp.closePacket(true) 2177 if err != nil { 2178 return err 2179 } 2180 2181 err = mp.flushOutBuffer() 2182 if err != nil { 2183 return err 2184 } 2185 return err 2186 } 2187 2188 // flushOutBuffer the data in the outbuf into the network 2189 func (mp *MysqlProtocolImpl) flushOutBuffer() error { 2190 if mp.enableLog { 2191 logutil.Info("flush") 2192 } 2193 2194 if mp.bytesInOutBuffer >= mp.untilBytesInOutbufToFlush { 2195 mp.flushCount++ 2196 mp.writeBytes += uint64(mp.bytesInOutBuffer) 2197 // FIXME: use a suitable timeout value 2198 err := mp.tcpConn.Flush(0) 2199 if err != nil { 2200 return err 2201 } 2202 mp.resetFlushOutBuffer() 2203 } 2204 return nil 2205 } 2206 2207 // open a new mysql protocol packet 2208 func (mp *MysqlProtocolImpl) openPacket() error { 2209 if mp.enableLog { 2210 logutil.Info("openPacket") 2211 } 2212 2213 outbuf := mp.tcpConn.OutBuf() 2214 n := 4 2215 outbuf.Grow(n) 2216 writeIdx := outbuf.GetWriteIndex() 2217 mp.beginWriteIndex = writeIdx 2218 writeIdx += n 2219 mp.bytesInOutBuffer += n 2220 outbuf.SetWriteIndex(writeIdx) 2221 if mp.enableLog { 2222 logutil.Infof("openPacket curWriteIdx %d", outbuf.GetWriteIndex()) 2223 } 2224 return nil 2225 } 2226 2227 // fill the packet with data 2228 func (mp *MysqlProtocolImpl) fillPacket(elems ...byte) error { 2229 if mp.enableLog { 2230 logutil.Infof("fillPacket len %d", len(elems)) 2231 } 2232 outbuf := mp.tcpConn.OutBuf() 2233 n := len(elems) 2234 i := 0 2235 curLen := 0 2236 hasDataLen := 0 2237 curDataLen := 0 2238 var err error 2239 var buf []byte 2240 for ; i < n; i += curLen { 2241 if !mp.isInPacket() { 2242 err = mp.openPacket() 2243 if err != nil { 2244 return err 2245 } 2246 } 2247 //length of data in the packet 2248 hasDataLen = outbuf.GetWriteIndex() - mp.beginWriteIndex - HeaderLengthOfTheProtocol 2249 curLen = int(MaxPayloadSize) - hasDataLen 2250 curLen = Min(curLen, n-i) 2251 if curLen < 0 { 2252 return moerr.NewInternalError(mp.ses.requestCtx, "needLen %d < 0. hasDataLen %d n - i %d", curLen, hasDataLen, n-i) 2253 } 2254 outbuf.Grow(curLen) 2255 buf = outbuf.RawBuf() 2256 writeIdx := outbuf.GetWriteIndex() 2257 copy(buf[writeIdx:], elems[i:i+curLen]) 2258 writeIdx += curLen 2259 mp.bytesInOutBuffer += curLen 2260 outbuf.SetWriteIndex(writeIdx) 2261 if mp.enableLog { 2262 logutil.Infof("fillPacket curWriteIdx %d", outbuf.GetWriteIndex()) 2263 } 2264 2265 //> 16MB, split it 2266 curDataLen = outbuf.GetWriteIndex() - mp.beginWriteIndex - HeaderLengthOfTheProtocol 2267 if curDataLen == int(MaxPayloadSize) { 2268 err = mp.closePacket(i+curLen == n) 2269 if err != nil { 2270 return err 2271 } 2272 2273 err = mp.flushOutBuffer() 2274 if err != nil { 2275 return err 2276 } 2277 } 2278 } 2279 2280 return nil 2281 } 2282 2283 // close a mysql protocol packet 2284 func (mp *MysqlProtocolImpl) closePacket(appendZeroPacket bool) error { 2285 if mp.enableLog { 2286 logutil.Info("closePacket") 2287 } 2288 if !mp.isInPacket() { 2289 return nil 2290 } 2291 outbuf := mp.tcpConn.OutBuf() 2292 payLoadLen := outbuf.GetWriteIndex() - mp.beginWriteIndex - 4 2293 if mp.enableLog { 2294 logutil.Infof("closePacket curWriteIdx %d", outbuf.GetWriteIndex()) 2295 } 2296 if payLoadLen < 0 || payLoadLen > int(MaxPayloadSize) { 2297 return moerr.NewInternalError(mp.ses.requestCtx, "invalid payload len :%d curWriteIdx %d beginWriteIdx %d ", 2298 payLoadLen, outbuf.GetWriteIndex(), mp.beginWriteIndex) 2299 } 2300 2301 buf := outbuf.RawBuf() 2302 binary.LittleEndian.PutUint32(buf[mp.beginWriteIndex:], uint32(payLoadLen)) 2303 buf[mp.beginWriteIndex+3] = mp.GetSequenceId() 2304 2305 mp.AddSequenceId(1) 2306 2307 if appendZeroPacket && payLoadLen == int(MaxPayloadSize) { //last 16MB packet,append a zero packet 2308 //if the size of the last packet is exactly MaxPayloadSize, a zero-size payload should be sent 2309 err := mp.openPacket() 2310 if err != nil { 2311 return err 2312 } 2313 buf = outbuf.RawBuf() 2314 binary.LittleEndian.PutUint32(buf[mp.beginWriteIndex:], uint32(0)) 2315 buf[mp.beginWriteIndex+3] = mp.GetSequenceId() 2316 mp.AddSequenceId(1) 2317 } 2318 2319 mp.resetPacket() 2320 return nil 2321 } 2322 2323 /* 2324 * 2325 append the elems into the outbuffer 2326 */ 2327 func (mp *MysqlProtocolImpl) append(_ []byte, elems ...byte) []byte { 2328 err := mp.fillPacket(elems...) 2329 if err != nil { 2330 panic(err) 2331 } 2332 return mp.tcpConn.OutBuf().RawBuf() 2333 } 2334 2335 func (mp *MysqlProtocolImpl) appendDatetime(data []byte, dt types.Datetime) []byte { 2336 if dt.MicroSec() != 0 { 2337 data = mp.append(data, 11) 2338 data = mp.appendUint16(data, uint16(dt.Year())) 2339 data = mp.append(data, dt.Month(), dt.Day(), byte(dt.Hour()), byte(dt.Minute()), byte(dt.Sec())) 2340 data = mp.appendUint32(data, uint32(dt.MicroSec())) 2341 } else if dt.Hour() != 0 || dt.Minute() != 0 || dt.Sec() != 0 { 2342 data = mp.append(data, 7) 2343 data = mp.appendUint16(data, uint16(dt.Year())) 2344 data = mp.append(data, dt.Month(), dt.Day(), byte(dt.Hour()), byte(dt.Minute()), byte(dt.Sec())) 2345 } else { 2346 data = mp.append(data, 4) 2347 data = mp.appendUint16(data, uint16(dt.Year())) 2348 data = mp.append(data, dt.Month(), dt.Day()) 2349 } 2350 return data 2351 } 2352 2353 func (mp *MysqlProtocolImpl) appendTime(data []byte, t types.Time) []byte { 2354 if int64(t) == 0 { 2355 data = mp.append(data, 0) 2356 } else { 2357 hour, minute, sec, msec, isNeg := t.ClockFormat() 2358 day := uint32(hour / 24) 2359 hour = hour % 24 2360 if msec != 0 { 2361 data = mp.append(data, 12) 2362 if isNeg { 2363 data = append(data, byte(1)) 2364 } else { 2365 data = append(data, byte(0)) 2366 } 2367 data = mp.appendUint32(data, day) 2368 data = mp.append(data, uint8(hour), minute, sec) 2369 data = mp.appendUint64(data, msec) 2370 } else { 2371 data = mp.append(data, 8) 2372 if isNeg { 2373 data = append(data, byte(1)) 2374 } else { 2375 data = append(data, byte(0)) 2376 } 2377 data = mp.appendUint32(data, day) 2378 data = mp.append(data, uint8(hour), minute, sec) 2379 } 2380 } 2381 return data 2382 } 2383 2384 func (mp *MysqlProtocolImpl) appendDate(data []byte, value types.Date) []byte { 2385 if int32(value) == 0 { 2386 data = mp.append(data, 0) 2387 } else { 2388 data = mp.append(data, 4) 2389 data = mp.appendUint16(data, value.Year()) 2390 data = mp.append(data, value.Month(), value.Day()) 2391 } 2392 return data 2393 } 2394 2395 // the server send every row of the result set as an independent packet 2396 // thread safe 2397 func (mp *MysqlProtocolImpl) SendResultSetTextRow(mrs *MysqlResultSet, r uint64) error { 2398 mp.GetLock().Lock() 2399 defer mp.GetLock().Unlock() 2400 2401 return mp.sendResultSetTextRow(mrs, r) 2402 } 2403 2404 // the server send every row of the result set as an independent packet 2405 func (mp *MysqlProtocolImpl) sendResultSetTextRow(mrs *MysqlResultSet, r uint64) error { 2406 var err error 2407 err = mp.openRow(nil) 2408 if err != nil { 2409 return err 2410 } 2411 if _, err = mp.makeResultSetTextRow(nil, mrs, r); err != nil { 2412 //ERR_Packet in case of error 2413 err1 := mp.sendErrPacket(moerr.ER_UNKNOWN_ERROR, DefaultMySQLState, err.Error()) 2414 if err1 != nil { 2415 return err1 2416 } 2417 return err 2418 } 2419 2420 err = mp.closeRow(nil) 2421 if err != nil { 2422 return err 2423 } 2424 2425 //begin2 := time.Now() 2426 //err = mp.writePackets(data) 2427 //if err != nil { 2428 // return moerr.NewInternalError("send result set text row failed. error: %v", err) 2429 //} 2430 //mp.sendTime += time.Since(begin2) 2431 2432 return nil 2433 } 2434 2435 // the server send the result set of execution the client 2436 // the routine follows the article: https://dev.mysql.com/doc/internals/en/com-query-response.html 2437 func (mp *MysqlProtocolImpl) sendResultSet(ctx context.Context, set ResultSet, cmd int, warnings, status uint16) error { 2438 mysqlRS, ok := set.(*MysqlResultSet) 2439 if !ok { 2440 return moerr.NewInternalError(ctx, "sendResultSet need MysqlResultSet") 2441 } 2442 2443 //A packet containing a Protocol::LengthEncodedInteger column_count 2444 err := mp.SendColumnCountPacket(mysqlRS.GetColumnCount()) 2445 if err != nil { 2446 return err 2447 } 2448 2449 if err = mp.sendColumns(ctx, mysqlRS, cmd, warnings, status); err != nil { 2450 return err 2451 } 2452 2453 //One or more ProtocolText::ResultsetRow packets, each containing column_count values 2454 for i := uint64(0); i < mysqlRS.GetRowCount(); i++ { 2455 if err = mp.sendResultSetTextRow(mysqlRS, i); err != nil { 2456 return err 2457 } 2458 } 2459 2460 //If the CLIENT_DEPRECATE_EOF client capabilities flag is set, OK_Packet; else EOF_Packet. 2461 if mp.capability&CLIENT_DEPRECATE_EOF != 0 { 2462 err := mp.sendOKPacketWithEof(0, 0, status, 0, "") 2463 if err != nil { 2464 return err 2465 } 2466 } else { 2467 err := mp.sendEOFPacket(warnings, status) 2468 if err != nil { 2469 return err 2470 } 2471 } 2472 2473 return nil 2474 } 2475 2476 // the server sends the payload to the client 2477 func (mp *MysqlProtocolImpl) writePackets(payload []byte) error { 2478 //protocol header length 2479 var headerLen = HeaderOffset 2480 var header [4]byte 2481 2482 //position of the first data byte 2483 var i = headerLen 2484 var length = len(payload) 2485 var curLen int 2486 for ; i < length; i += curLen { 2487 //var packet []byte = mp.packet[:0] 2488 curLen = Min(int(MaxPayloadSize), length-i) 2489 2490 //make mysql client protocol header 2491 //4 bytes 2492 //int<3> the length of payload 2493 mp.io.WriteUint32(header[:], 0, uint32(curLen)) 2494 2495 //int<1> sequence id 2496 mp.io.WriteUint8(header[:], 3, mp.GetSequenceId()) 2497 2498 //send packet 2499 var packet = append(header[:], payload[i:i+curLen]...) 2500 2501 err := mp.tcpConn.Write(packet, goetty.WriteOptions{Flush: true}) 2502 if err != nil { 2503 return err 2504 } 2505 mp.AddSequenceId(1) 2506 2507 if i+curLen == length && curLen == int(MaxPayloadSize) { 2508 //if the size of the last packet is exactly MaxPayloadSize, a zero-size payload should be sent 2509 header[0] = 0 2510 header[1] = 0 2511 header[2] = 0 2512 header[3] = mp.GetSequenceId() 2513 2514 //send header / zero-sized packet 2515 err := mp.tcpConn.Write(header[:], goetty.WriteOptions{Flush: true}) 2516 if err != nil { 2517 return err 2518 } 2519 2520 mp.AddSequenceId(1) 2521 } 2522 } 2523 return nil 2524 } 2525 2526 /* 2527 //ther server reads a part of payload from the connection 2528 //the part may be a whole payload 2529 func (mp *MysqlProtocolImpl) recvPartOfPayload() ([]byte, error) { 2530 //var length int 2531 //var header []byte 2532 //var err error 2533 //if header, err = mp.io.ReadPacket(4); err != nil { 2534 // return nil, moerr.NewInternalError("read header failed.error:%v", err) 2535 //} else if header[3] != mp.sequenceId { 2536 // return nil, moerr.NewInternalError("client sequence id %d != server sequence id %d", header[3], mp.sequenceId) 2537 //} 2538 2539 mp.sequenceId++ 2540 //length = int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) 2541 2542 var payload []byte 2543 //if payload, err = mp.io.ReadPacket(length); err != nil { 2544 // return nil, moerr.NewInternalError("read payload failed.error:%v", err) 2545 //} 2546 return payload, nil 2547 } 2548 2549 //the server read a payload from the connection 2550 func (mp *MysqlProtocolImpl) recvPayload() ([]byte, error) { 2551 payload, err := mp.recvPartOfPayload() 2552 if err != nil { 2553 return nil, err 2554 } 2555 2556 //only one part 2557 if len(payload) < int(MaxPayloadSize) { 2558 return payload, nil 2559 } 2560 2561 //payload has been split into many parts. 2562 //read them all together 2563 var part []byte 2564 for { 2565 part, err = mp.recvPartOfPayload() 2566 if err != nil { 2567 return nil, err 2568 } 2569 2570 payload = append(payload, part...) 2571 2572 //only one part 2573 if len(part) < int(MaxPayloadSize) { 2574 break 2575 } 2576 } 2577 return payload, nil 2578 } 2579 */ 2580 2581 /* 2582 generate random ascii string. 2583 Reference to :mysql 8.0.23 mysys/crypt_genhash_impl.cc generate_user_salt(char*,int) 2584 */ 2585 func generate_salt(n int) []byte { 2586 buf := make([]byte, n) 2587 r := rand.New(rand.NewSource(time.Now().UTC().UnixNano())) 2588 r.Read(buf) 2589 for i := 0; i < n; i++ { 2590 buf[i] &= 0x7f 2591 if buf[i] == 0 || buf[i] == '$' { 2592 buf[i]++ 2593 } 2594 } 2595 return buf 2596 } 2597 2598 func NewMysqlClientProtocol(connectionID uint32, tcp goetty.IOSession, maxBytesToFlush int, SV *config.FrontendParameters) *MysqlProtocolImpl { 2599 salt := generate_salt(20) 2600 mysql := &MysqlProtocolImpl{ 2601 ProtocolImpl: ProtocolImpl{ 2602 io: NewIOPackage(true), 2603 tcpConn: tcp, 2604 salt: salt, 2605 connectionID: connectionID, 2606 }, 2607 charset: "utf8mb4", 2608 capability: DefaultCapability, 2609 strconvBuffer: make([]byte, 0, 16*1024), 2610 lenEncBuffer: make([]byte, 0, 10), 2611 binaryNullBuffer: make([]byte, 0, 512), 2612 rowHandler: rowHandler{ 2613 beginWriteIndex: 0, 2614 bytesInOutBuffer: 0, 2615 untilBytesInOutbufToFlush: maxBytesToFlush * 1024, 2616 enableLog: false, 2617 }, 2618 SV: SV, 2619 } 2620 2621 mysql.MakeProfile() 2622 2623 if SV.EnableTls { 2624 mysql.capability = mysql.capability | CLIENT_SSL 2625 } 2626 2627 mysql.resetPacket() 2628 2629 return mysql 2630 }