github.com/matrixorigin/matrixone@v1.2.0/pkg/frontend/mysql_protocol.go (about) 1 // Copyright 2021 Matrix Origin 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package frontend 16 17 import ( 18 "bufio" 19 "bytes" 20 "context" 21 "crypto/sha1" 22 "encoding/binary" 23 "encoding/hex" 24 "fmt" 25 "math" 26 "math/rand" 27 "net" 28 "strconv" 29 "strings" 30 "sync/atomic" 31 "time" 32 "unicode" 33 34 "github.com/fagongzi/goetty/v2" 35 goetty_buf "github.com/fagongzi/goetty/v2/buf" 36 "go.uber.org/zap" 37 "golang.org/x/exp/slices" 38 39 "github.com/matrixorigin/matrixone/pkg/common/moerr" 40 "github.com/matrixorigin/matrixone/pkg/config" 41 "github.com/matrixorigin/matrixone/pkg/container/types" 42 "github.com/matrixorigin/matrixone/pkg/container/vector" 43 "github.com/matrixorigin/matrixone/pkg/defines" 44 "github.com/matrixorigin/matrixone/pkg/logutil" 45 planPb "github.com/matrixorigin/matrixone/pkg/pb/plan" 46 "github.com/matrixorigin/matrixone/pkg/pb/proxy" 47 plan2 "github.com/matrixorigin/matrixone/pkg/sql/plan" 48 "github.com/matrixorigin/matrixone/pkg/sql/util" 49 v2 "github.com/matrixorigin/matrixone/pkg/util/metric/v2" 50 "github.com/matrixorigin/matrixone/pkg/vm/process" 51 ) 52 53 // DefaultCapability means default capabilities of the server 54 var DefaultCapability = CLIENT_LONG_PASSWORD | 55 CLIENT_FOUND_ROWS | 56 CLIENT_LONG_FLAG | 57 CLIENT_CONNECT_WITH_DB | 58 CLIENT_LOCAL_FILES | 59 CLIENT_PROTOCOL_41 | 60 CLIENT_INTERACTIVE | 61 CLIENT_TRANSACTIONS | 62 CLIENT_SECURE_CONNECTION | 63 CLIENT_MULTI_STATEMENTS | 64 CLIENT_MULTI_RESULTS | 65 CLIENT_PLUGIN_AUTH | 66 CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA | 67 CLIENT_DEPRECATE_EOF | 68 CLIENT_CONNECT_ATTRS 69 70 // DefaultClientConnStatus default server status 71 var DefaultClientConnStatus = SERVER_STATUS_AUTOCOMMIT 72 73 var serverVersion atomic.Value 74 75 const defaultSaltReadTimeout = time.Millisecond * 200 76 77 const charsetBinary = 0x3f 78 const charsetVarchar = 0x21 79 const boolColumnLength = 12 80 81 func init() { 82 serverVersion.Store("0.5.0") 83 } 84 85 func InitServerVersion(v string) { 86 if len(v) > 0 { 87 switch v[0] { 88 case 'v': // format 'v1.1.1' 89 v = v[1:] 90 serverVersion.Store(v) 91 default: 92 vv := []byte(v) 93 for i := 0; i < len(vv); i++ { 94 if !unicode.IsDigit(rune(vv[i])) && vv[i] != '.' { 95 vv = append(vv[:i], vv[i+1:]...) 96 i-- 97 } 98 } 99 serverVersion.Store(string(vv)) 100 } 101 } else { 102 serverVersion.Store("0.5.0") 103 } 104 } 105 106 const ( 107 clientProtocolVersion uint8 = 10 108 109 /** 110 An answer talks about the charset utf8mb4. 111 https://stackoverflow.com/questions/766809/whats-the-difference-between-utf8-general-ci-and-utf8-unicode-ci 112 It recommends the charset utf8mb4_0900_ai_ci. 113 Maybe we can support utf8mb4_0900_ai_ci in the future. 114 115 A concise research in the Mysql 8.0.23. 116 117 the charset in sever level 118 ====================================== 119 120 mysql> show variables like 'character_set_server'; 121 +----------------------+---------+ 122 | Variable_name | Value | 123 +----------------------+---------+ 124 | character_set_server | utf8mb4 | 125 +----------------------+---------+ 126 127 mysql> show variables like 'collation_server'; 128 +------------------+--------------------+ 129 | Variable_name | Value | 130 +------------------+--------------------+ 131 | collation_server | utf8mb4_0900_ai_ci | 132 +------------------+--------------------+ 133 134 the charset in database level 135 ===================================== 136 mysql> show variables like 'character_set_database'; 137 +------------------------+---------+ 138 | Variable_name | Value | 139 +------------------------+---------+ 140 | character_set_database | utf8mb4 | 141 +------------------------+---------+ 142 143 mysql> show variables like 'collation_database'; 144 +--------------------+--------------------+ 145 | Variable_name | Value | 146 +--------------------+--------------------+ 147 | collation_database | utf8mb4_0900_ai_ci | 148 +--------------------+--------------------+ 149 150 */ 151 // DefaultCollationID is utf8mb4_bin(46) 152 utf8mb4BinCollationID uint8 = 46 153 154 Utf8mb4CollationID uint8 = 45 155 156 AuthNativePassword string = "mysql_native_password" 157 158 //the length of the mysql protocol header 159 HeaderLengthOfTheProtocol int = 4 160 HeaderOffset int = 0 161 162 // MaxPayloadSize If the payload is larger than or equal to 2^24−1 bytes the length is set to 2^24−1 (ff ff ff) 163 //and additional packets are sent with the rest of the payload until the payload of a packet 164 //is less than 2^24−1 bytes. 165 MaxPayloadSize uint32 = (1 << 24) - 1 166 167 // DefaultMySQLState is the default state of the mySQL 168 DefaultMySQLState string = "HY000" 169 ) 170 171 type MysqlProtocol interface { 172 Protocol 173 //the server send group row of the result set as an independent packet thread safe 174 SendResultSetTextBatchRow(mrs *MysqlResultSet, cnt uint64) error 175 176 SendResultSetTextBatchRowSpeedup(mrs *MysqlResultSet, cnt uint64) error 177 178 //SendColumnDefinitionPacket the server send the column definition to the client 179 SendColumnDefinitionPacket(ctx context.Context, column Column, cmd int) error 180 181 //SendColumnCountPacket makes the column count packet 182 SendColumnCountPacket(count uint64) error 183 184 SendResponse(ctx context.Context, resp *Response) error 185 186 SendEOFPacketIf(warnings uint16, status uint16) error 187 188 //send OK packet to the client 189 sendOKPacket(affectedRows uint64, lastInsertId uint64, status uint16, warnings uint16, message string) error 190 191 //the OK or EOF packet thread safe 192 sendEOFOrOkPacket(warnings uint16, status uint16) error 193 194 sendLocalInfileRequest(filename string) error 195 196 ResetStatistics() 197 198 GetStats() string 199 200 // CalculateOutTrafficBytes return bytes, mysql packet num send back to client 201 // reset marks Do reset counter after calculation or not. 202 CalculateOutTrafficBytes(reset bool) (int64, int64) 203 204 ParseExecuteData(ctx context.Context, proc *process.Process, stmt *PrepareStmt, data []byte, pos int) error 205 206 ParseSendLongData(ctx context.Context, proc *process.Process, stmt *PrepareStmt, data []byte, pos int) error 207 208 DisableAutoFlush() 209 EnableAutoFlush() 210 Flush() error 211 } 212 213 var _ MysqlProtocol = &MysqlProtocolImpl{} 214 215 type debugStats struct { 216 writeCount uint64 217 // record 2 cases data, all belong to flush op. 218 // 1) MysqlProtocolImpl.flushOutBuffer do flush buffer op 219 // 2) MysqlProtocolImpl.writePackets do write with goetty.WriteOptions{Flush: true} 220 writeBytes uint64 221 } 222 223 func (ds *debugStats) ResetStats() { 224 ds.writeCount = 0 225 ds.writeBytes = 0 226 } 227 228 func (ds *debugStats) String() string { 229 if ds.writeCount <= 0 { 230 ds.writeCount = 1 231 } 232 return fmt.Sprintf( 233 "writeCount %v \n"+ 234 "writeBytes %v %v MB\n", 235 ds.writeCount, 236 ds.writeBytes, ds.writeBytes/(1024*1024.0), 237 ) 238 } 239 240 func (ds *debugStats) AddFlushBytes(b uint64) { 241 ds.writeBytes += b 242 } 243 244 /* 245 rowHandler maintains the states in encoding the result row 246 */ 247 type rowHandler struct { 248 //the begin position of writing. 249 //the range [beginWriteIndex,beginWriteIndex+3] 250 //for the length and sequenceId of the mysql protocol packet 251 beginOffset int 252 //the bytes in the outbuffer 253 bytesInOutBuffer int 254 //when the number of bytes in the outbuffer exceeds the it, 255 //the outbuffer will be flushed. 256 untilBytesInOutbufToFlush int 257 //the count of the flush 258 flushCount int 259 //the bytes have been response 260 startOffsetInBuffer int 261 } 262 263 /* 264 isInPacket means it is compositing a packet now 265 */ 266 func (rh *rowHandler) isInPacket() bool { 267 return rh.beginOffset >= 0 268 } 269 270 /* 271 resetPacket reset the beginWriteIndex 272 */ 273 func (rh *rowHandler) resetPacket() { 274 rh.beginOffset = -1 275 } 276 277 /* 278 resetFlushOutBuffer clears the bytesInOutBuffer 279 */ 280 func (rh *rowHandler) resetFlushOutBuffer() { 281 rh.bytesInOutBuffer = 0 282 } 283 284 /* 285 resetFlushCount reset flushCount 286 */ 287 func (rh *rowHandler) resetFlushCount() { 288 rh.flushCount = 0 289 } 290 291 // resetStartOffset reset the startOffsetInBuffer 292 // How rowHandler.resetStartOffset, debugStats.writeBytes and MysqlProtocolImpl.CalculateOutTrafficBytes work together ? 293 // 0. init. call rowHandler.resetStartOffset at the beginning of query, record the beginning offset of the current buffer. 294 // 1. batch write. inc debugStats.writeBytes and do resetFlushOutBuffer() 295 // 2. last data. MysqlProtocolImpl.CalculateOutTrafficBytes() with debugStats.writeBytes, rowHandler.startOffsetInBuffer and rowHandler.bytesInOutBuffer 296 func (rh *rowHandler) resetStartOffset() { 297 rh.startOffsetInBuffer = rh.bytesInOutBuffer 298 } 299 300 type MysqlProtocolImpl struct { 301 ProtocolImpl 302 303 //joint capability shared by the server and the client 304 capability uint32 305 306 //collation id 307 collationID int 308 309 //collation name 310 collationName string 311 312 //character set 313 charset string 314 315 //max packet size of the client 316 maxClientPacketSize uint32 317 318 //the user of the client 319 username string 320 321 // opaque authentication response data generated by Authentication Method 322 // indicated by the plugin name field. 323 authResponse []byte 324 325 //the default database for the client 326 database string 327 328 // Connection attributes are key-value pairs that application programs 329 // can pass to the server at connect time. 330 connectAttrs map[string]string 331 332 //for debug 333 debugStats 334 335 //for converting the data into string 336 strconvBuffer []byte 337 338 //for encoding the length into bytes 339 lenEncBuffer []byte 340 341 //for encoding the null bytes in binary row 342 binaryNullBuffer []byte 343 344 rowHandler 345 346 SV *config.FrontendParameters 347 348 ses *Session 349 350 disableAutoFlush bool 351 } 352 353 func (mp *MysqlProtocolImpl) GetSession() *Session { 354 mp.m.Lock() 355 defer mp.m.Unlock() 356 return mp.ses 357 } 358 359 func (mp *MysqlProtocolImpl) GetCapability() uint32 { 360 mp.m.Lock() 361 defer mp.m.Unlock() 362 return mp.capability 363 } 364 365 func (mp *MysqlProtocolImpl) SetCapability(cap uint32) { 366 mp.m.Lock() 367 defer mp.m.Unlock() 368 mp.capability = cap 369 } 370 371 func (mp *MysqlProtocolImpl) AddSequenceId(a uint8) { 372 mp.ses.CountPacket(int64(a)) 373 mp.sequenceId.Add(uint32(a)) 374 } 375 376 func (mp *MysqlProtocolImpl) SetSequenceID(value uint8) { 377 mp.ses.CountPacket(1) 378 mp.sequenceId.Store(uint32(value)) 379 } 380 381 func (mp *MysqlProtocolImpl) GetDatabaseName() string { 382 mp.m.Lock() 383 defer mp.m.Unlock() 384 return mp.database 385 } 386 387 func (mp *MysqlProtocolImpl) SetDatabaseName(s string) { 388 mp.m.Lock() 389 defer mp.m.Unlock() 390 mp.database = s 391 } 392 393 func (mp *MysqlProtocolImpl) GetUserName() string { 394 mp.m.Lock() 395 defer mp.m.Unlock() 396 return mp.username 397 } 398 399 func (mp *MysqlProtocolImpl) SetUserName(s string) { 400 mp.m.Lock() 401 defer mp.m.Unlock() 402 mp.username = s 403 } 404 405 func (mp *MysqlProtocolImpl) GetStats() string { 406 return fmt.Sprintf("flushCount %d %s", 407 mp.flushCount, 408 mp.String()) 409 } 410 411 // CalculateOutTrafficBytes calculate the bytes of the last out traffic, the number of mysql packets 412 func (mp *MysqlProtocolImpl) CalculateOutTrafficBytes(reset bool) (bytes int64, packets int64) { 413 ses := mp.GetSession() 414 // Case 1: send data as ResultSet 415 bytes = int64(mp.writeBytes) + int64(mp.bytesInOutBuffer-mp.startOffsetInBuffer) + 416 // Case 2: send data as CSV 417 ses.writeCsvBytes.Load() 418 // mysql packet num + length(sql) / 16KiB + payload / 16 KiB 419 packets = ses.GetPacketCnt() + int64(len(ses.sql)>>14) + int64(ses.payloadCounter>>14) 420 if reset { 421 ses.ResetPacketCounter() 422 } 423 return 424 } 425 426 func (mp *MysqlProtocolImpl) ResetStatistics() { 427 mp.ResetStats() 428 mp.resetFlushCount() 429 mp.resetStartOffset() 430 } 431 432 func (mp *MysqlProtocolImpl) GetConnectAttrs() map[string]string { 433 mp.m.Lock() 434 defer mp.m.Unlock() 435 return mp.connectAttrs 436 } 437 438 func (mp *MysqlProtocolImpl) Quit() { 439 mp.m.Lock() 440 defer mp.m.Unlock() 441 mp.ProtocolImpl.Quit() 442 if mp.strconvBuffer != nil { 443 mp.strconvBuffer = nil 444 } 445 if mp.lenEncBuffer != nil { 446 mp.lenEncBuffer = nil 447 } 448 if mp.binaryNullBuffer != nil { 449 mp.binaryNullBuffer = nil 450 } 451 mp.ses = nil 452 } 453 454 func (mp *MysqlProtocolImpl) SetSession(ses *Session) { 455 mp.m.Lock() 456 defer mp.m.Unlock() 457 mp.ses = ses 458 } 459 460 // handshake response 41 461 type response41 struct { 462 capabilities uint32 463 maxPacketSize uint32 464 collationID uint8 465 username string 466 authResponse []byte 467 database string 468 clientPluginName string 469 isAskForTlsHeader bool 470 connectAttrs map[string]string 471 } 472 473 // handshake response 320 474 type response320 struct { 475 capabilities uint32 476 maxPacketSize uint32 477 username string 478 authResponse []byte 479 database string 480 isAskForTlsHeader bool 481 } 482 483 func (mp *MysqlProtocolImpl) SendPrepareResponse(ctx context.Context, stmt *PrepareStmt) error { 484 dcPrepare, ok := stmt.PreparePlan.GetDcl().Control.(*planPb.DataControl_Prepare) 485 if !ok { 486 return moerr.NewInternalError(ctx, "can not get Prepare plan in prepareStmt") 487 } 488 stmtID, err := GetPrepareStmtID(ctx, stmt.Name) 489 if err != nil { 490 return moerr.NewInternalError(ctx, "can not get Prepare stmtID") 491 } 492 paramTypes := dcPrepare.Prepare.ParamTypes 493 numParams := len(paramTypes) 494 columns := plan2.GetResultColumnsFromPlan(dcPrepare.Prepare.Plan) 495 numColumns := len(columns) 496 497 var data []byte 498 // status ok 499 data = append(data, 0) 500 // stmt id 501 data = mp.io.AppendUint32(data, uint32(stmtID)) 502 // number columns 503 data = mp.io.AppendUint16(data, uint16(numColumns)) 504 // number params 505 data = mp.io.AppendUint16(data, uint16(numParams)) 506 // filter [00] 507 data = append(data, 0) 508 // warning count 509 data = append(data, 0, 0) // TODO support warning count 510 if err := mp.writePackets(data, false); err != nil { 511 return err 512 } 513 514 cmd := int(COM_STMT_PREPARE) 515 for i := 0; i < numParams; i++ { 516 column := new(MysqlColumn) 517 column.SetName("?") 518 519 err = convertEngineTypeToMysqlType(ctx, types.T(paramTypes[i]), column) 520 if err != nil { 521 return err 522 } 523 524 err = mp.SendColumnDefinitionPacket(ctx, column, cmd) 525 if err != nil { 526 return err 527 } 528 } 529 if numParams > 0 { 530 if err := mp.SendEOFPacketIf(0, mp.GetSession().GetTxnHandler().GetServerStatus()); err != nil { 531 return err 532 } 533 } 534 535 for i := 0; i < numColumns; i++ { 536 column := new(MysqlColumn) 537 column.SetName(columns[i].Name) 538 539 err = convertEngineTypeToMysqlType(ctx, types.T(columns[i].Typ.Id), column) 540 if err != nil { 541 return err 542 } 543 544 err = mp.SendColumnDefinitionPacket(ctx, column, cmd) 545 if err != nil { 546 return err 547 } 548 } 549 if numColumns > 0 { 550 if err := mp.SendEOFPacketIf(0, mp.GetSession().GetTxnHandler().GetServerStatus()); err != nil { 551 return err 552 } 553 } 554 555 return nil 556 } 557 558 func (mp *MysqlProtocolImpl) ParseSendLongData(ctx context.Context, proc *process.Process, stmt *PrepareStmt, data []byte, pos int) error { 559 var err error 560 dcPrepare, ok := stmt.PreparePlan.GetDcl().Control.(*planPb.DataControl_Prepare) 561 if !ok { 562 return moerr.NewInternalError(ctx, "can not get Prepare plan in prepareStmt") 563 } 564 numParams := len(dcPrepare.Prepare.ParamTypes) 565 566 paramIdx, newPos, ok := mp.io.ReadUint16(data, pos) 567 if !ok { 568 return moerr.NewInvalidInput(ctx, "mysql protocol error, malformed packet") 569 } 570 pos = newPos 571 if int(paramIdx) >= numParams { 572 return moerr.NewInternalError(ctx, "get param index out of range. get %d, param length is %d", paramIdx, numParams) 573 } 574 575 if stmt.params == nil { 576 stmt.params = proc.GetVector(types.T_text.ToType()) 577 for i := 0; i < numParams; i++ { 578 err = vector.AppendBytes(stmt.params, []byte{}, false, proc.GetMPool()) 579 if err != nil { 580 return err 581 } 582 } 583 } 584 585 length := len(data) - pos 586 val, _, ok := mp.readCountOfBytes(data, pos, length) 587 if !ok { 588 return moerr.NewInvalidInput(ctx, "mysql protocol error, malformed packet") 589 } 590 stmt.getFromSendLongData[int(paramIdx)] = struct{}{} 591 return util.SetAnyToStringVector(proc, val, stmt.params, int(paramIdx)) 592 } 593 594 func (mp *MysqlProtocolImpl) ParseExecuteData(ctx context.Context, proc *process.Process, stmt *PrepareStmt, data []byte, pos int) error { 595 var err error 596 dcPrepare, ok := stmt.PreparePlan.GetDcl().Control.(*planPb.DataControl_Prepare) 597 if !ok { 598 return moerr.NewInternalError(ctx, "can not get Prepare plan in prepareStmt") 599 } 600 numParams := len(dcPrepare.Prepare.ParamTypes) 601 602 if stmt.params == nil { 603 stmt.params = proc.GetVector(types.T_text.ToType()) 604 for i := 0; i < numParams; i++ { 605 err = vector.AppendBytes(stmt.params, []byte{}, false, proc.GetMPool()) 606 if err != nil { 607 return err 608 } 609 } 610 } 611 612 var flag uint8 613 flag, pos, ok = mp.io.ReadUint8(data, pos) 614 if !ok { 615 return moerr.NewInternalError(ctx, "malform packet") 616 617 } 618 if flag != 0 { 619 // TODO only support CURSOR_TYPE_NO_CURSOR flag now 620 return moerr.NewInvalidInput(ctx, "unsupported Prepare flag '%v'", flag) 621 } 622 623 // skip iteration-count, always 1 624 pos += 4 625 626 if numParams > 0 { 627 var nullBitmaps []byte 628 nullBitmapLen := (numParams + 7) >> 3 629 nullBitmaps, pos, ok = mp.readCountOfBytes(data, pos, nullBitmapLen) 630 if !ok { 631 return moerr.NewInvalidInput(ctx, "mysql protocol error, malformed packet") 632 } 633 634 // new param bound flag 635 if data[pos] == 1 { 636 pos++ 637 638 // Just the first StmtExecute packet contain parameters type, 639 // we need save it for further use. 640 stmt.ParamTypes, pos, ok = mp.readCountOfBytes(data, pos, numParams<<1) 641 if !ok { 642 return moerr.NewInvalidInput(ctx, "mysql protocol error, malformed packet") 643 } 644 } else { 645 pos++ 646 } 647 648 // get paramters and set value to session variables 649 for i := 0; i < numParams; i++ { 650 // if params had received via COM_STMT_SEND_LONG_DATA, use them directly(we set the params when deal with COM_STMT_SEND_LONG_DATA). 651 // ref https://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html 652 if _, ok := stmt.getFromSendLongData[i]; ok { 653 continue 654 } 655 656 if nullBitmaps[i>>3]&(1<<(uint(i)%8)) > 0 { 657 err = util.SetAnyToStringVector(proc, nil, stmt.params, i) 658 if err != nil { 659 return err 660 } 661 continue 662 } 663 664 if (i<<1)+1 >= len(stmt.ParamTypes) { 665 return moerr.NewInvalidInput(ctx, "mysql protocol error, malformed packet") 666 667 } 668 tp := stmt.ParamTypes[i<<1] 669 isUnsigned := (stmt.ParamTypes[(i<<1)+1] & 0x80) > 0 670 671 switch defines.MysqlType(tp) { 672 case defines.MYSQL_TYPE_NULL: 673 err = util.SetAnyToStringVector(proc, nil, stmt.params, i) 674 675 case defines.MYSQL_TYPE_BIT: 676 val, newPos, ok := mp.io.ReadUint64(data, pos) 677 if !ok { 678 return moerr.NewInvalidInput(ctx, "mysql protocol error, malformed packet") 679 } 680 681 pos = newPos 682 err = util.SetAnyToStringVector(proc, val, stmt.params, i) 683 684 case defines.MYSQL_TYPE_TINY: 685 val, newPos, ok := mp.io.ReadUint8(data, pos) 686 if !ok { 687 return moerr.NewInvalidInput(ctx, "mysql protocol error, malformed packet") 688 689 } 690 691 pos = newPos 692 if isUnsigned { 693 // vars[i] = val 694 err = util.SetAnyToStringVector(proc, val, stmt.params, i) 695 } else { 696 // vars[i] = int8(val) 697 err = util.SetAnyToStringVector(proc, int8(val), stmt.params, i) 698 } 699 700 case defines.MYSQL_TYPE_SHORT, defines.MYSQL_TYPE_YEAR: 701 val, newPos, ok := mp.io.ReadUint16(data, pos) 702 if !ok { 703 return moerr.NewInvalidInput(ctx, "mysql protocol error, malformed packet") 704 } 705 706 pos = newPos 707 if isUnsigned { 708 err = util.SetAnyToStringVector(proc, val, stmt.params, i) 709 } else { 710 err = util.SetAnyToStringVector(proc, int16(val), stmt.params, i) 711 } 712 713 case defines.MYSQL_TYPE_INT24, defines.MYSQL_TYPE_LONG: 714 val, newPos, ok := mp.io.ReadUint32(data, pos) 715 if !ok { 716 return moerr.NewInvalidInput(ctx, "mysql protocol error, malformed packet") 717 } 718 719 pos = newPos 720 if isUnsigned { 721 err = util.SetAnyToStringVector(proc, val, stmt.params, i) 722 } else { 723 err = util.SetAnyToStringVector(proc, int32(val), stmt.params, i) 724 } 725 726 case defines.MYSQL_TYPE_LONGLONG: 727 val, newPos, ok := mp.io.ReadUint64(data, pos) 728 if !ok { 729 return moerr.NewInvalidInput(ctx, "mysql protocol error, malformed packet") 730 } 731 732 pos = newPos 733 if isUnsigned { 734 err = util.SetAnyToStringVector(proc, val, stmt.params, i) 735 } else { 736 err = util.SetAnyToStringVector(proc, int64(val), stmt.params, i) 737 } 738 739 case defines.MYSQL_TYPE_FLOAT: 740 val, newPos, ok := mp.io.ReadUint32(data, pos) 741 if !ok { 742 return moerr.NewInvalidInput(ctx, "mysql protocol error, malformed packet") 743 744 } 745 pos = newPos 746 err = util.SetAnyToStringVector(proc, math.Float32frombits(val), stmt.params, i) 747 748 case defines.MYSQL_TYPE_DOUBLE: 749 val, newPos, ok := mp.io.ReadUint64(data, pos) 750 if !ok { 751 return moerr.NewInvalidInput(ctx, "mysql protocol error, malformed packet") 752 } 753 pos = newPos 754 err = util.SetAnyToStringVector(proc, math.Float64frombits(val), stmt.params, i) 755 756 // Binary/varbinary has mysql_type_varchar. 757 case defines.MYSQL_TYPE_VARCHAR, defines.MYSQL_TYPE_VAR_STRING, defines.MYSQL_TYPE_STRING, defines.MYSQL_TYPE_DECIMAL, 758 defines.MYSQL_TYPE_ENUM, defines.MYSQL_TYPE_SET, defines.MYSQL_TYPE_GEOMETRY: 759 val, newPos, ok := mp.readStringLenEnc(data, pos) 760 if !ok { 761 return moerr.NewInvalidInput(ctx, "mysql protocol error, malformed packet") 762 } 763 pos = newPos 764 err = util.SetAnyToStringVector(proc, val, stmt.params, i) 765 766 case defines.MYSQL_TYPE_BLOB, defines.MYSQL_TYPE_TINY_BLOB, defines.MYSQL_TYPE_MEDIUM_BLOB, defines.MYSQL_TYPE_LONG_BLOB, defines.MYSQL_TYPE_TEXT: 767 val, newPos, ok := mp.readStringLenEnc(data, pos) 768 if !ok { 769 return moerr.NewInvalidInput(ctx, "mysql protocol error, malformed packet") 770 } 771 pos = newPos 772 // vars[i] = []byte(val) 773 err = util.SetAnyToStringVector(proc, val, stmt.params, i) 774 775 case defines.MYSQL_TYPE_TIME: 776 // See https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_binary_resultset.html 777 // for more details. 778 length, newPos, ok := mp.io.ReadUint8(data, pos) 779 if !ok { 780 return moerr.NewInvalidInput(ctx, "mysql protocol error, malformed packet") 781 } 782 pos = newPos 783 var val string 784 switch length { 785 case 0: 786 val = "0d 00:00:00" 787 case 8, 12: 788 pos, val = mp.readTime(data, pos, length) 789 default: 790 return moerr.NewInvalidInput(ctx, "mysql protocol error, malformed packet") 791 } 792 err = util.SetAnyToStringVector(proc, val, stmt.params, i) 793 794 case defines.MYSQL_TYPE_DATE, defines.MYSQL_TYPE_DATETIME, defines.MYSQL_TYPE_TIMESTAMP: 795 // See https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_binary_resultset.html 796 // for more details. 797 length, newPos, ok := mp.io.ReadUint8(data, pos) 798 if !ok { 799 return moerr.NewInvalidInput(ctx, "mysql protocol error, malformed packet") 800 801 } 802 pos = newPos 803 var val string 804 switch length { 805 case 0: 806 val = "0000-00-00 00:00:00" 807 case 4: 808 pos, val = mp.readDate(data, pos) 809 case 7: 810 pos, val = mp.readDateTime(data, pos) 811 case 11: 812 pos, val = mp.readTimestamp(data, pos) 813 default: 814 return moerr.NewInvalidInput(ctx, "mysql protocol error, malformed packet") 815 816 } 817 err = util.SetAnyToStringVector(proc, val, stmt.params, i) 818 819 case defines.MYSQL_TYPE_NEWDECIMAL: 820 // use string for decimal. Not tested 821 val, newPos, ok := mp.readStringLenEnc(data, pos) 822 if !ok { 823 return moerr.NewInvalidInput(ctx, "mysql protocol error, malformed packet") 824 825 } 826 pos = newPos 827 err = util.SetAnyToStringVector(proc, val, stmt.params, i) 828 829 default: 830 return moerr.NewInternalError(ctx, "unsupport parameter type") 831 832 } 833 834 if err != nil { 835 return err 836 } 837 } 838 } 839 840 return nil 841 } 842 843 func (mp *MysqlProtocolImpl) readDate(data []byte, pos int) (int, string) { 844 year, pos, _ := mp.io.ReadUint16(data, pos) 845 month := data[pos] 846 pos++ 847 day := data[pos] 848 pos++ 849 return pos, fmt.Sprintf("%04d-%02d-%02d", year, month, day) 850 } 851 852 func (mp *MysqlProtocolImpl) readTime(data []byte, pos int, len uint8) (int, string) { 853 var retStr string 854 negate := data[pos] 855 pos++ 856 if negate == 1 { 857 retStr += "-" 858 } 859 day, pos, _ := mp.io.ReadUint32(data, pos) 860 if day > 0 { 861 retStr += fmt.Sprintf("%dd ", day) 862 } 863 hour := data[pos] 864 pos++ 865 minute := data[pos] 866 pos++ 867 second := data[pos] 868 pos++ 869 870 if len == 12 { 871 ms, _, _ := mp.io.ReadUint32(data, pos) 872 retStr += fmt.Sprintf("%02d:%02d:%02d.%06d", hour, minute, second, ms) 873 } else { 874 retStr += fmt.Sprintf("%02d:%02d:%02d", hour, minute, second) 875 } 876 877 return pos, retStr 878 } 879 880 func (mp *MysqlProtocolImpl) readDateTime(data []byte, pos int) (int, string) { 881 pos, date := mp.readDate(data, pos) 882 hour := data[pos] 883 pos++ 884 minute := data[pos] 885 pos++ 886 second := data[pos] 887 pos++ 888 return pos, fmt.Sprintf("%s %02d:%02d:%02d", date, hour, minute, second) 889 } 890 891 func (mp *MysqlProtocolImpl) readTimestamp(data []byte, pos int) (int, string) { 892 pos, dateTime := mp.readDateTime(data, pos) 893 microSecond, pos, _ := mp.io.ReadUint32(data, pos) 894 return pos, fmt.Sprintf("%s.%06d", dateTime, microSecond) 895 } 896 897 // read an int with length encoded from the buffer at the position 898 // return the int ; position + the count of bytes for length encoded (1 or 3 or 4 or 9) 899 func (mp *MysqlProtocolImpl) readIntLenEnc(data []byte, pos int) (uint64, int, bool) { 900 if pos >= len(data) { 901 return 0, 0, false 902 } 903 switch data[pos] { 904 case 0xfb: 905 //zero, one byte 906 return 0, pos + 1, true 907 case 0xfc: 908 // int in two bytes 909 if pos+2 >= len(data) { 910 return 0, 0, false 911 } 912 value := uint64(data[pos+1]) | 913 uint64(data[pos+2])<<8 914 return value, pos + 3, true 915 case 0xfd: 916 // int in three bytes 917 if pos+3 >= len(data) { 918 return 0, 0, false 919 } 920 value := uint64(data[pos+1]) | 921 uint64(data[pos+2])<<8 | 922 uint64(data[pos+3])<<16 923 return value, pos + 4, true 924 case 0xfe: 925 // int in eight bytes 926 if pos+8 >= len(data) { 927 return 0, 0, false 928 } 929 value := uint64(data[pos+1]) | 930 uint64(data[pos+2])<<8 | 931 uint64(data[pos+3])<<16 | 932 uint64(data[pos+4])<<24 | 933 uint64(data[pos+5])<<32 | 934 uint64(data[pos+6])<<40 | 935 uint64(data[pos+7])<<48 | 936 uint64(data[pos+8])<<56 937 return value, pos + 9, true 938 } 939 // 0-250 940 return uint64(data[pos]), pos + 1, true 941 } 942 943 func (mp *MysqlProtocolImpl) ReadIntLenEnc(data []byte, pos int) (uint64, int, bool) { 944 return mp.readIntLenEnc(data, pos) 945 } 946 947 // write an int with length encoded into the buffer at the position 948 // return position + the count of bytes for length encoded (1 or 3 or 4 or 9) 949 func (mp *MysqlProtocolImpl) writeIntLenEnc(data []byte, pos int, value uint64) int { 950 switch { 951 case value < 251: 952 data[pos] = byte(value) 953 return pos + 1 954 case value < (1 << 16): 955 data[pos] = 0xfc 956 data[pos+1] = byte(value) 957 data[pos+2] = byte(value >> 8) 958 return pos + 3 959 case value < (1 << 24): 960 data[pos] = 0xfd 961 data[pos+1] = byte(value) 962 data[pos+2] = byte(value >> 8) 963 data[pos+3] = byte(value >> 16) 964 return pos + 4 965 default: 966 data[pos] = 0xfe 967 data[pos+1] = byte(value) 968 data[pos+2] = byte(value >> 8) 969 data[pos+3] = byte(value >> 16) 970 data[pos+4] = byte(value >> 24) 971 data[pos+5] = byte(value >> 32) 972 data[pos+6] = byte(value >> 40) 973 data[pos+7] = byte(value >> 48) 974 data[pos+8] = byte(value >> 56) 975 return pos + 9 976 } 977 } 978 979 // append an int with length encoded to the buffer 980 // return the buffer 981 func (mp *MysqlProtocolImpl) appendIntLenEnc(data []byte, value uint64) []byte { 982 mp.lenEncBuffer = mp.lenEncBuffer[:9] 983 pos := mp.writeIntLenEnc(mp.lenEncBuffer, 0, value) 984 return mp.append(data, mp.lenEncBuffer[:pos]...) 985 } 986 987 // read the count of bytes from the buffer at the position 988 // return bytes slice ; position + count ; true - succeeded or false - failed 989 func (mp *MysqlProtocolImpl) readCountOfBytes(data []byte, pos int, count int) ([]byte, int, bool) { 990 if pos+count-1 >= len(data) { 991 return nil, 0, false 992 } 993 return data[pos : pos+count], pos + count, true 994 } 995 996 // write the count of bytes into the buffer at the position 997 // return position + the number of bytes 998 func (mp *MysqlProtocolImpl) writeCountOfBytes(data []byte, pos int, value []byte) int { 999 pos += copy(data[pos:], value) 1000 return pos 1001 } 1002 1003 // append the count of bytes to the buffer 1004 // return the buffer 1005 func (mp *MysqlProtocolImpl) appendCountOfBytes(data []byte, value []byte) []byte { 1006 return mp.append(data, value...) 1007 } 1008 1009 // read a string with fixed length from the buffer at the position 1010 // return string ; position + length ; true - succeeded or false - failed 1011 func (mp *MysqlProtocolImpl) readStringFix(data []byte, pos int, length int) (string, int, bool) { 1012 var sdata []byte 1013 var ok bool 1014 sdata, pos, ok = mp.readCountOfBytes(data, pos, length) 1015 if !ok { 1016 return "", 0, false 1017 } 1018 return string(sdata), pos, true 1019 } 1020 1021 // write a string with fixed length into the buffer at the position 1022 // return pos + string.length 1023 func (mp *MysqlProtocolImpl) writeStringFix(data []byte, pos int, value string, length int) int { 1024 pos += copy(data[pos:], value[0:length]) 1025 return pos 1026 } 1027 1028 // append a string with fixed length to the buffer 1029 // return the buffer 1030 func (mp *MysqlProtocolImpl) appendStringFix(data []byte, value string, length int) []byte { 1031 return mp.append(data, []byte(value[:length])...) 1032 } 1033 1034 // read a string appended with zero from the buffer at the position 1035 // return string ; position + length of the string + 1; true - succeeded or false - failed 1036 func (mp *MysqlProtocolImpl) readStringNUL(data []byte, pos int) (string, int, bool) { 1037 zeroPos := bytes.IndexByte(data[pos:], 0) 1038 if zeroPos == -1 { 1039 return "", 0, false 1040 } 1041 return string(data[pos : pos+zeroPos]), pos + zeroPos + 1, true 1042 } 1043 1044 // write a string into the buffer at the position, then appended with 0 1045 // return pos + string.length + 1 1046 func (mp *MysqlProtocolImpl) writeStringNUL(data []byte, pos int, value string) int { 1047 pos = mp.writeStringFix(data, pos, value, len(value)) 1048 data[pos] = 0 1049 return pos + 1 1050 } 1051 1052 // read a string with length encoded from the buffer at the position 1053 // return string ; position + the count of bytes for length encoded (1 or 3 or 4 or 9) + length of the string; true - succeeded or false - failed 1054 func (mp *MysqlProtocolImpl) readStringLenEnc(data []byte, pos int) (string, int, bool) { 1055 var value uint64 1056 var ok bool 1057 value, pos, ok = mp.readIntLenEnc(data, pos) 1058 if !ok { 1059 return "", 0, false 1060 } 1061 sLength := int(value) 1062 if pos+sLength-1 >= len(data) { 1063 return "", 0, false 1064 } 1065 return string(data[pos : pos+sLength]), pos + sLength, true 1066 } 1067 1068 // write a string with length encoded into the buffer at the position 1069 // return position + the count of bytes for length encoded (1 or 3 or 4 or 9) + length of the string; 1070 func (mp *MysqlProtocolImpl) writeStringLenEnc(data []byte, pos int, value string) int { 1071 pos = mp.writeIntLenEnc(data, pos, uint64(len(value))) 1072 return mp.writeStringFix(data, pos, value, len(value)) 1073 } 1074 1075 // append a string with length encoded to the buffer 1076 // return the buffer 1077 func (mp *MysqlProtocolImpl) appendStringLenEnc(data []byte, value string) []byte { 1078 data = mp.appendIntLenEnc(data, uint64(len(value))) 1079 return mp.appendStringFix(data, value, len(value)) 1080 } 1081 1082 // append bytes with length encoded to the buffer 1083 // return the buffer 1084 func (mp *MysqlProtocolImpl) appendCountOfBytesLenEnc(data []byte, value []byte) []byte { 1085 data = mp.appendIntLenEnc(data, uint64(len(value))) 1086 return mp.appendCountOfBytes(data, value) 1087 } 1088 1089 // append an int64 value converted to string with length encoded to the buffer 1090 // return the buffer 1091 func (mp *MysqlProtocolImpl) appendStringLenEncOfInt64(data []byte, value int64) []byte { 1092 mp.strconvBuffer = mp.strconvBuffer[:0] 1093 mp.strconvBuffer = strconv.AppendInt(mp.strconvBuffer, value, 10) 1094 return mp.appendCountOfBytesLenEnc(data, mp.strconvBuffer) 1095 } 1096 1097 // append an uint64 value converted to string with length encoded to the buffer 1098 // return the buffer 1099 func (mp *MysqlProtocolImpl) appendStringLenEncOfUint64(data []byte, value uint64) []byte { 1100 mp.strconvBuffer = mp.strconvBuffer[:0] 1101 mp.strconvBuffer = strconv.AppendUint(mp.strconvBuffer, value, 10) 1102 return mp.appendCountOfBytesLenEnc(data, mp.strconvBuffer) 1103 } 1104 1105 // append an float32 value converted to string with length encoded to the buffer 1106 // return the buffer 1107 func (mp *MysqlProtocolImpl) appendStringLenEncOfFloat64(data []byte, value float64, bitSize int) []byte { 1108 mp.strconvBuffer = mp.strconvBuffer[:0] 1109 if !math.IsInf(value, 0) { 1110 mp.strconvBuffer = strconv.AppendFloat(mp.strconvBuffer, value, 'f', -1, bitSize) 1111 } else { 1112 if math.IsInf(value, 1) { 1113 mp.strconvBuffer = append(mp.strconvBuffer, []byte("+Infinity")...) 1114 } else { 1115 mp.strconvBuffer = append(mp.strconvBuffer, []byte("-Infinity")...) 1116 } 1117 } 1118 return mp.appendCountOfBytesLenEnc(data, mp.strconvBuffer) 1119 } 1120 1121 func (mp *MysqlProtocolImpl) appendUint8(data []byte, e uint8) []byte { 1122 return mp.append(data, e) 1123 } 1124 1125 func (mp *MysqlProtocolImpl) appendUint16(data []byte, e uint16) []byte { 1126 buf := mp.lenEncBuffer[:2] 1127 pos := mp.io.WriteUint16(buf, 0, e) 1128 return mp.append(data, buf[:pos]...) 1129 } 1130 1131 func (mp *MysqlProtocolImpl) appendUint32(data []byte, e uint32) []byte { 1132 buf := mp.lenEncBuffer[:4] 1133 pos := mp.io.WriteUint32(buf, 0, e) 1134 return mp.append(data, buf[:pos]...) 1135 } 1136 1137 func (mp *MysqlProtocolImpl) appendUint64(data []byte, e uint64) []byte { 1138 buf := mp.lenEncBuffer[:8] 1139 pos := mp.io.WriteUint64(buf, 0, e) 1140 return mp.append(data, buf[:pos]...) 1141 } 1142 1143 // write the count of zeros into the buffer at the position 1144 // return pos + count 1145 func (mp *MysqlProtocolImpl) writeZeros(data []byte, pos int, count int) int { 1146 for i := 0; i < count; i++ { 1147 data[pos+i] = 0 1148 } 1149 return pos + count 1150 } 1151 1152 // the server get the auth string from HandShakeResponse 1153 // pwd is SHA1(SHA1(password)), AUTH is from client 1154 // hash1 = AUTH XOR SHA1( slat + pwd) 1155 // hash2 = SHA1(hash1) 1156 // check(hash2, hpwd) 1157 func (mp *MysqlProtocolImpl) checkPassword(pwd, salt, auth []byte) bool { 1158 sha := sha1.New() 1159 _, err := sha.Write(salt) 1160 if err != nil { 1161 logutil.Errorf("SHA1(salt) failed.") 1162 return false 1163 } 1164 _, err = sha.Write(pwd) 1165 if err != nil { 1166 logutil.Errorf("SHA1(hpwd) failed.") 1167 return false 1168 } 1169 hash1 := sha.Sum(nil) 1170 1171 if len(auth) != len(hash1) { 1172 return false 1173 } 1174 1175 for i := range hash1 { 1176 hash1[i] ^= auth[i] 1177 } 1178 1179 hash2 := HashSha1(hash1) 1180 return bytes.Equal(pwd, hash2) 1181 } 1182 1183 // the server authenticate that the client can connect and use the database 1184 func (mp *MysqlProtocolImpl) authenticateUser(ctx context.Context, authResponse []byte) error { 1185 var psw []byte 1186 var err error 1187 var tenant *TenantInfo 1188 1189 ses := mp.GetSession() 1190 if !mp.SV.SkipCheckUser { 1191 logDebugf(mp.getDebugStringUnsafe(), "authenticate user 1") 1192 psw, err = ses.AuthenticateUser(ctx, mp.GetUserName(), mp.GetDatabaseName(), mp.authResponse, mp.GetSalt(), mp.checkPassword) 1193 if err != nil { 1194 return err 1195 } 1196 logDebugf(mp.getDebugStringUnsafe(), "authenticate user 2") 1197 1198 //TO Check password 1199 if mp.checkPassword(psw, mp.GetSalt(), authResponse) { 1200 logDebugf(mp.getDebugStringUnsafe(), "check password succeeded") 1201 ses.InitGlobalSystemVariables(ctx) 1202 } else { 1203 return moerr.NewInternalError(ctx, "check password failed") 1204 } 1205 } else { 1206 logDebugf(mp.getDebugStringUnsafe(), "skip authenticate user") 1207 //Get tenant info 1208 tenant, err = GetTenantInfo(ctx, mp.GetUserName()) 1209 if err != nil { 1210 return err 1211 } 1212 1213 if ses != nil { 1214 ses.SetTenantInfo(tenant) 1215 1216 //TO Check password 1217 if len(psw) == 0 || mp.checkPassword(psw, mp.GetSalt(), authResponse) { 1218 logInfo(mp.ses, mp.ses.GetDebugString(), "check password succeeded") 1219 } else { 1220 return moerr.NewInternalError(ctx, "check password failed") 1221 } 1222 } 1223 } 1224 mp.incDebugCount(1) 1225 return nil 1226 } 1227 1228 func (mp *MysqlProtocolImpl) HandleHandshake(ctx context.Context, payload []byte) (bool, error) { 1229 var err error 1230 if len(payload) < 2 { 1231 return false, moerr.NewInternalError(ctx, "received a broken response packet") 1232 } 1233 1234 if capabilities, _, ok := mp.io.ReadUint16(payload, 0); !ok { 1235 return false, moerr.NewInternalError(ctx, "read capabilities from response packet failed") 1236 } else if uint32(capabilities)&CLIENT_PROTOCOL_41 != 0 { 1237 var resp41 response41 1238 var ok2 bool 1239 logDebugf(mp.getDebugStringUnsafe(), "analyse handshake response") 1240 if ok2, resp41, err = mp.analyseHandshakeResponse41(ctx, payload); !ok2 { 1241 return false, err 1242 } 1243 1244 // client ask server to upgradeTls 1245 if resp41.isAskForTlsHeader { 1246 return true, nil 1247 } 1248 1249 mp.authResponse = resp41.authResponse 1250 mp.capability = mp.capability & resp41.capabilities 1251 1252 if nameAndCharset, ok3 := collationID2CharsetAndName[int(resp41.collationID)]; !ok3 { 1253 return false, moerr.NewInternalError(ctx, "get collationName and charset failed") 1254 } else { 1255 mp.collationID = int(resp41.collationID) 1256 mp.collationName = nameAndCharset.collationName 1257 mp.charset = nameAndCharset.charset 1258 } 1259 1260 mp.maxClientPacketSize = resp41.maxPacketSize 1261 mp.username = resp41.username 1262 mp.database = resp41.database 1263 mp.connectAttrs = resp41.connectAttrs 1264 } else { 1265 var resp320 response320 1266 var ok2 bool 1267 if ok2, resp320, err = mp.analyseHandshakeResponse320(ctx, payload); !ok2 { 1268 return false, err 1269 } 1270 1271 // client ask server to upgradeTls 1272 if resp320.isAskForTlsHeader { 1273 return true, nil 1274 } 1275 1276 mp.authResponse = resp320.authResponse 1277 mp.capability = mp.capability & resp320.capabilities 1278 mp.collationID = int(Utf8mb4CollationID) 1279 mp.collationName = "utf8mb4_general_ci" 1280 mp.charset = "utf8mb4" 1281 1282 mp.maxClientPacketSize = resp320.maxPacketSize 1283 mp.username = resp320.username 1284 mp.database = resp320.database 1285 } 1286 return false, nil 1287 } 1288 1289 func (mp *MysqlProtocolImpl) Authenticate(ctx context.Context) error { 1290 ses := mp.GetSession() 1291 ses.timestampMap[TSAuthenticateStart] = time.Now() 1292 defer func() { 1293 ses.timestampMap[TSAuthenticateEnd] = time.Now() 1294 v2.AuthenticateDurationHistogram.Observe(ses.timestampMap[TSAuthenticateEnd].Sub(ses.timestampMap[TSAuthenticateStart]).Seconds()) 1295 }() 1296 1297 logDebugf(mp.getDebugStringUnsafe(), "authenticate user") 1298 mp.incDebugCount(0) 1299 if err := mp.authenticateUser(ctx, mp.authResponse); err != nil { 1300 logutil.Errorf("authenticate user failed.error:%v", err) 1301 errorCode, sqlState, msg := RewriteError(err, mp.username) 1302 ses.timestampMap[TSSendErrPacketStart] = time.Now() 1303 err2 := mp.sendErrPacket(errorCode, sqlState, msg) 1304 ses.timestampMap[TSSendErrPacketEnd] = time.Now() 1305 v2.SendErrPacketDurationHistogram.Observe(ses.timestampMap[TSSendErrPacketEnd].Sub(ses.timestampMap[TSSendErrPacketStart]).Seconds()) 1306 if err2 != nil { 1307 logutil.Errorf("send err packet failed.error:%v", err2) 1308 return err2 1309 } 1310 return err 1311 } 1312 1313 mp.incDebugCount(2) 1314 logDebugf(mp.getDebugStringUnsafe(), "handle handshake end") 1315 ses.timestampMap[TSSendOKPacketStart] = time.Now() 1316 err := mp.sendOKPacket(0, 0, 0, 0, "") 1317 ses.timestampMap[TSSendOKPacketEnd] = time.Now() 1318 v2.SendOKPacketDurationHistogram.Observe(ses.timestampMap[TSSendOKPacketEnd].Sub(ses.timestampMap[TSSendOKPacketStart]).Seconds()) 1319 mp.incDebugCount(3) 1320 logDebugf(mp.getDebugStringUnsafe(), "handle handshake response ok") 1321 if err != nil { 1322 return err 1323 } 1324 return nil 1325 } 1326 1327 // the server makes a handshake v10 packet 1328 // return handshake packet 1329 func (mp *MysqlProtocolImpl) makeHandshakeV10Payload() []byte { 1330 var data = make([]byte, HeaderOffset+256) 1331 var pos = HeaderOffset 1332 //int<1> protocol version 1333 pos = mp.io.WriteUint8(data, pos, clientProtocolVersion) 1334 1335 pos = mp.writeStringNUL(data, pos, mp.SV.ServerVersionPrefix+serverVersion.Load().(string)) 1336 1337 //int<4> connection id 1338 pos = mp.io.WriteUint32(data, pos, mp.ConnectionID()) 1339 1340 //string[8] auth-plugin-data-part-1 1341 pos = mp.writeCountOfBytes(data, pos, mp.GetSalt()[0:8]) 1342 1343 //int<1> filler 0 1344 pos = mp.io.WriteUint8(data, pos, 0) 1345 1346 //int<2> capabilities flags (lower 2 bytes) 1347 pos = mp.io.WriteUint16(data, pos, uint16(mp.capability&0xFFFF)) 1348 1349 //int<1> character set 1350 pos = mp.io.WriteUint8(data, pos, utf8mb4BinCollationID) 1351 1352 //int<2> status flags 1353 pos = mp.io.WriteUint16(data, pos, DefaultClientConnStatus) 1354 1355 //int<2> capabilities flags (upper 2 bytes) 1356 pos = mp.io.WriteUint16(data, pos, uint16((DefaultCapability>>16)&0xFFFF)) 1357 1358 if (DefaultCapability & CLIENT_PLUGIN_AUTH) != 0 { 1359 //int<1> length of auth-plugin-data 1360 //set 21 always 1361 pos = mp.io.WriteUint8(data, pos, uint8(len(mp.GetSalt())+1)) 1362 } else { 1363 //int<1> [00] 1364 //set 0 always 1365 pos = mp.io.WriteUint8(data, pos, 0) 1366 } 1367 1368 //string[10] reserved (all [00]) 1369 pos = mp.writeZeros(data, pos, 10) 1370 1371 if (DefaultCapability & CLIENT_SECURE_CONNECTION) != 0 { 1372 //string[$len] auth-plugin-data-part-2 ($len=MAX(13, length of auth-plugin-data - 8)) 1373 pos = mp.writeCountOfBytes(data, pos, mp.GetSalt()[8:]) 1374 pos = mp.io.WriteUint8(data, pos, 0) 1375 } 1376 1377 if (DefaultCapability & CLIENT_PLUGIN_AUTH) != 0 { 1378 //string[NUL] auth-plugin name 1379 pos = mp.writeStringNUL(data, pos, AuthNativePassword) 1380 } 1381 1382 return data[:pos] 1383 } 1384 1385 // the server analyses handshake response41 info from the client 1386 // return true - analysed successfully / false - failed ; response41 ; error 1387 func (mp *MysqlProtocolImpl) analyseHandshakeResponse41(ctx context.Context, data []byte) (bool, response41, error) { 1388 var pos = 0 1389 var ok bool 1390 var info response41 1391 1392 //int<4> capabilities flags of the client, CLIENT_PROTOCOL_41 always set 1393 info.capabilities, pos, ok = mp.io.ReadUint32(data, pos) 1394 if !ok { 1395 return false, info, moerr.NewInternalError(ctx, "get capabilities failed") 1396 } 1397 1398 if (info.capabilities & CLIENT_PROTOCOL_41) == 0 { 1399 return false, info, moerr.NewInternalError(ctx, "capabilities does not have protocol 41") 1400 } 1401 1402 //int<4> max-packet size 1403 //max size of a command packet that the client wants to send to the server 1404 info.maxPacketSize, pos, ok = mp.io.ReadUint32(data, pos) 1405 if !ok { 1406 return false, info, moerr.NewInternalError(ctx, "get max packet size failed") 1407 } 1408 1409 //int<1> character set 1410 //connection's default character set 1411 info.collationID, pos, ok = mp.io.ReadUint8(data, pos) 1412 if !ok { 1413 return false, info, moerr.NewInternalError(ctx, "get character set failed") 1414 } 1415 1416 if pos+22 >= len(data) { 1417 return false, info, moerr.NewInternalError(ctx, "skip reserved failed") 1418 } 1419 //string[23] reserved (all [0]) 1420 //just skip it 1421 pos += 23 1422 1423 // if client reply for upgradeTls, then data will contains header only. 1424 if pos == len(data) && (info.capabilities&CLIENT_SSL) != 0 { 1425 info.isAskForTlsHeader = true 1426 return true, info, nil 1427 } 1428 1429 //string[NUL] username 1430 info.username, pos, ok = mp.readStringNUL(data, pos) 1431 if !ok { 1432 return false, info, moerr.NewInternalError(ctx, "get username failed") 1433 } 1434 1435 /* 1436 if capabilities & CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA { 1437 lenenc-int length of auth-response 1438 string[n] auth-response 1439 } else if capabilities & CLIENT_SECURE_CONNECTION { 1440 int<1> length of auth-response 1441 string[n] auth-response 1442 } else { 1443 string[NUL] auth-response 1444 } 1445 */ 1446 if (info.capabilities & CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA) != 0 { 1447 var l uint64 1448 l, pos, ok = mp.readIntLenEnc(data, pos) 1449 if !ok { 1450 return false, info, moerr.NewInternalError(ctx, "get length of auth-response failed") 1451 } 1452 info.authResponse, pos, ok = mp.readCountOfBytes(data, pos, int(l)) 1453 if !ok { 1454 return false, info, moerr.NewInternalError(ctx, "get auth-response failed") 1455 } 1456 } else if (info.capabilities & CLIENT_SECURE_CONNECTION) != 0 { 1457 var l uint8 1458 l, pos, ok = mp.io.ReadUint8(data, pos) 1459 if !ok { 1460 return false, info, moerr.NewInternalError(ctx, "get length of auth-response failed") 1461 } 1462 info.authResponse, pos, ok = mp.readCountOfBytes(data, pos, int(l)) 1463 if !ok { 1464 return false, info, moerr.NewInternalError(ctx, "get auth-response failed") 1465 } 1466 } else { 1467 var auth string 1468 auth, pos, ok = mp.readStringNUL(data, pos) 1469 if !ok { 1470 return false, info, moerr.NewInternalError(ctx, "get auth-response failed") 1471 } 1472 info.authResponse = []byte(auth) 1473 } 1474 1475 if (info.capabilities & CLIENT_CONNECT_WITH_DB) != 0 { 1476 info.database, pos, ok = mp.readStringNUL(data, pos) 1477 if !ok { 1478 return false, info, moerr.NewInternalError(ctx, "get database failed") 1479 } 1480 } 1481 1482 if (info.capabilities & CLIENT_PLUGIN_AUTH) != 0 { 1483 info.clientPluginName, pos, ok = mp.readStringNUL(data, pos) 1484 if !ok { 1485 return false, info, moerr.NewInternalError(ctx, "get auth plugin name failed") 1486 } 1487 1488 //to switch authenticate method 1489 if info.clientPluginName != AuthNativePassword { 1490 var err error 1491 if info.authResponse, err = mp.negotiateAuthenticationMethod(ctx); err != nil { 1492 return false, info, moerr.NewInternalError(ctx, "negotiate authentication method failed. error:%v", err) 1493 } 1494 info.clientPluginName = AuthNativePassword 1495 } 1496 } 1497 1498 // client connection attributes 1499 info.connectAttrs = make(map[string]string) 1500 if info.capabilities&CLIENT_CONNECT_ATTRS != 0 { 1501 var l uint64 1502 var ok bool 1503 l, pos, ok = mp.readIntLenEnc(data, pos) 1504 if !ok { 1505 return false, info, moerr.NewInternalError(ctx, "get length of client-connect-attrs failed") 1506 } 1507 endPos := pos + int(l) 1508 var key, value string 1509 for pos < endPos { 1510 key, pos, ok = mp.readStringLenEnc(data, pos) 1511 if !ok { 1512 return false, info, moerr.NewInternalError(ctx, "get connect-attrs key failed") 1513 } 1514 value, pos, ok = mp.readStringLenEnc(data, pos) 1515 if !ok { 1516 return false, info, moerr.NewInternalError(ctx, "get connect-attrs value failed") 1517 } 1518 info.connectAttrs[key] = value 1519 } 1520 } 1521 1522 return true, info, nil 1523 } 1524 1525 /* 1526 //the server does something after receiving a handshake response41 from the client 1527 //like check user and password 1528 //and other things 1529 func (mp *MysqlProtocolImpl) handleClientResponse41(resp41 response41) error { 1530 //to do something else 1531 //logutil.Infof("capabilities 0x%x\n", resp41.capabilities) 1532 //logutil.Infof("maxPacketSize %d\n", resp41.maxPacketSize) 1533 //logutil.Infof("collationID %d\n", resp41.collationID) 1534 //logutil.Infof("username %s\n", resp41.username) 1535 //logutil.Infof("authResponse: \n") 1536 //update the capabilities with client's capabilities 1537 mp.capability = DefaultCapability & resp41.capabilities 1538 1539 //character set 1540 if nameAndCharset, ok := collationID2CharsetAndName[int(resp41.collationID)]; !ok { 1541 return moerr.NewInternalError(requestCtx, "get collationName and charset failed") 1542 } else { 1543 mp.collationID = int(resp41.collationID) 1544 mp.collationName = nameAndCharset.collationName 1545 mp.charset = nameAndCharset.charset 1546 } 1547 1548 mp.maxClientPacketSize = resp41.maxPacketSize 1549 mp.username = resp41.username 1550 mp.database = resp41.database 1551 1552 //logutil.Infof("collationID %d collatonName %s charset %s \n", mp.collationID, mp.collationName, mp.charset) 1553 //logutil.Infof("database %s \n", resp41.database) 1554 //logutil.Infof("clientPluginName %s \n", resp41.clientPluginName) 1555 return nil 1556 } 1557 */ 1558 1559 // the server analyses handshake response320 info from the old client 1560 // return true - analysed successfully / false - failed ; response320 ; error 1561 func (mp *MysqlProtocolImpl) analyseHandshakeResponse320(ctx context.Context, data []byte) (bool, response320, error) { 1562 var pos = 0 1563 var ok bool 1564 var info response320 1565 var capa uint16 1566 1567 //int<2> capabilities flags, CLIENT_PROTOCOL_41 never set 1568 capa, pos, ok = mp.io.ReadUint16(data, pos) 1569 if !ok { 1570 return false, info, moerr.NewInternalError(ctx, "get capabilities failed") 1571 } 1572 info.capabilities = uint32(capa) 1573 1574 if pos+2 >= len(data) { 1575 return false, info, moerr.NewInternalError(ctx, "get max-packet-size failed") 1576 } 1577 1578 //int<3> max-packet size 1579 //max size of a command packet that the client wants to send to the server 1580 info.maxPacketSize = uint32(data[pos]) | uint32(data[pos+1])<<8 | uint32(data[pos+2])<<16 1581 pos += 3 1582 1583 // if client reply for upgradeTls, then data will contains header only. 1584 if pos == len(data) && (info.capabilities&CLIENT_SSL) != 0 { 1585 info.isAskForTlsHeader = true 1586 return true, info, nil 1587 } 1588 1589 //string[NUL] username 1590 info.username, pos, ok = mp.readStringNUL(data, pos) 1591 if !ok { 1592 return false, info, moerr.NewInternalError(ctx, "get username failed") 1593 } 1594 1595 if (info.capabilities & CLIENT_CONNECT_WITH_DB) != 0 { 1596 var auth string 1597 auth, pos, ok = mp.readStringNUL(data, pos) 1598 if !ok { 1599 return false, info, moerr.NewInternalError(ctx, "get auth-response failed") 1600 } 1601 info.authResponse = []byte(auth) 1602 1603 info.database, _, ok = mp.readStringNUL(data, pos) 1604 if !ok { 1605 return false, info, moerr.NewInternalError(ctx, "get database failed") 1606 } 1607 } else { 1608 info.authResponse, _, ok = mp.readCountOfBytes(data, pos, len(data)-pos) 1609 if !ok { 1610 return false, info, moerr.NewInternalError(ctx, "get auth-response failed") 1611 } 1612 } 1613 1614 return true, info, nil 1615 } 1616 1617 /* 1618 //the server does something after receiving a handshake response320 from the client 1619 //like check user and password 1620 //and other things 1621 func (mp *MysqlProtocolImpl) handleClientResponse320(resp320 response320) error { 1622 //to do something else 1623 //logutil.Infof("capabilities 0x%x\n", resp320.capabilities) 1624 //logutil.Infof("maxPacketSize %d\n", resp320.maxPacketSize) 1625 //logutil.Infof("username %s\n", resp320.username) 1626 //logutil.Infof("authResponse: \n") 1627 1628 //update the capabilities with client's capabilities 1629 mp.capability = DefaultCapability & resp320.capabilities 1630 1631 //if the client does not notice its default charset, the server gives a default charset. 1632 //Run the sql in mysql 8.0.23 to get the charset 1633 //the sql: select * from information_schema.collations where collation_name = 'utf8mb4_general_ci'; 1634 mp.collationID = int(Utf8mb4CollationID) 1635 mp.collationName = "utf8mb4_general_ci" 1636 mp.charset = "utf8mb4" 1637 1638 mp.maxClientPacketSize = resp320.maxPacketSize 1639 mp.username = resp320.username 1640 mp.database = resp320.database 1641 1642 //logutil.Infof("collationID %d collatonName %s charset %s \n", mp.collationID, mp.collationName, mp.charset) 1643 //logutil.Infof("database %s \n", resp320.database) 1644 return nil 1645 } 1646 */ 1647 1648 // the server makes a AuthSwitchRequest that asks the client to authenticate the data with new method 1649 func (mp *MysqlProtocolImpl) makeAuthSwitchRequestPayload(authMethodName string) []byte { 1650 data := make([]byte, HeaderOffset+1+len(authMethodName)+1+len(mp.GetSalt())+1) 1651 pos := HeaderOffset 1652 pos = mp.io.WriteUint8(data, pos, defines.EOFHeader) 1653 pos = mp.writeStringNUL(data, pos, authMethodName) 1654 pos = mp.writeCountOfBytes(data, pos, mp.GetSalt()) 1655 pos = mp.io.WriteUint8(data, pos, 0) 1656 return data[:pos] 1657 } 1658 1659 // the server can send AuthSwitchRequest to ask client to use designated authentication method, 1660 // if both server and client support CLIENT_PLUGIN_AUTH capability. 1661 // return data authenticated with new method 1662 func (mp *MysqlProtocolImpl) negotiateAuthenticationMethod(ctx context.Context) ([]byte, error) { 1663 var err error 1664 aswPkt := mp.makeAuthSwitchRequestPayload(AuthNativePassword) 1665 err = mp.writePackets(aswPkt, true) 1666 if err != nil { 1667 return nil, err 1668 } 1669 1670 read, err := mp.tcpConn.Read(goetty.ReadOptions{}) 1671 if err != nil { 1672 return nil, err 1673 } 1674 1675 if read == nil { 1676 return nil, moerr.NewInternalError(ctx, "read nil from tcp conn") 1677 } 1678 1679 pack, ok := read.(*Packet) 1680 if !ok { 1681 return nil, moerr.NewInternalError(ctx, "it is not the Packet") 1682 } 1683 1684 if pack == nil { 1685 return nil, moerr.NewInternalError(ctx, "packet is null") 1686 } 1687 1688 data := pack.Payload 1689 mp.AddSequenceId(1) 1690 return data, nil 1691 } 1692 1693 // extendStatus extends a flag to the status variable. 1694 // SERVER_QUERY_WAS_SLOW and SERVER_STATUS_NO_GOOD_INDEX_USED is not used in other modules, 1695 // so we use it to mark the packet MUST be OK/EOF. 1696 func extendStatus(old uint16) uint16 { 1697 return old | SERVER_QUERY_WAS_SLOW | SERVER_STATUS_NO_GOOD_INDEX_USED 1698 } 1699 1700 // make a OK packet 1701 func (mp *MysqlProtocolImpl) makeOKPayload(affectedRows, lastInsertId uint64, statusFlags, warnings uint16, message string) []byte { 1702 data := make([]byte, HeaderOffset+128+len(message)+10) 1703 var pos = HeaderOffset 1704 pos = mp.io.WriteUint8(data, pos, defines.OKHeader) 1705 pos = mp.writeIntLenEnc(data, pos, affectedRows) 1706 pos = mp.writeIntLenEnc(data, pos, lastInsertId) 1707 statusFlags = extendStatus(statusFlags) 1708 if (mp.capability & CLIENT_PROTOCOL_41) != 0 { 1709 pos = mp.io.WriteUint16(data, pos, statusFlags) 1710 pos = mp.io.WriteUint16(data, pos, warnings) 1711 } else if (mp.capability & CLIENT_TRANSACTIONS) != 0 { 1712 pos = mp.io.WriteUint16(data, pos, statusFlags) 1713 } 1714 1715 if mp.capability&CLIENT_SESSION_TRACK != 0 { 1716 //TODO:implement it 1717 } else { 1718 //string<lenenc> instead of string<EOF> in the manual of mysql 1719 pos = mp.writeStringLenEnc(data, pos, message) 1720 return data[:pos] 1721 } 1722 return data[:pos] 1723 } 1724 1725 func (mp *MysqlProtocolImpl) makeOKPayloadWithEof(affectedRows, lastInsertId uint64, statusFlags, warnings uint16, message string) []byte { 1726 data := make([]byte, HeaderOffset+128+len(message)+10) 1727 var pos = HeaderOffset 1728 pos = mp.io.WriteUint8(data, pos, defines.EOFHeader) 1729 pos = mp.writeIntLenEnc(data, pos, affectedRows) 1730 pos = mp.writeIntLenEnc(data, pos, lastInsertId) 1731 if (mp.capability & CLIENT_PROTOCOL_41) != 0 { 1732 pos = mp.io.WriteUint16(data, pos, statusFlags) 1733 pos = mp.io.WriteUint16(data, pos, warnings) 1734 } else if (mp.capability & CLIENT_TRANSACTIONS) != 0 { 1735 pos = mp.io.WriteUint16(data, pos, statusFlags) 1736 } 1737 1738 if mp.capability&CLIENT_SESSION_TRACK != 0 { 1739 //TODO:implement it 1740 } else { 1741 //string<lenenc> instead of string<EOF> in the manual of mysql 1742 pos = mp.writeStringLenEnc(data, pos, message) 1743 return data[:pos] 1744 } 1745 return data[:pos] 1746 } 1747 1748 func (mp *MysqlProtocolImpl) makeLocalInfileRequestPayload(filename string) []byte { 1749 data := make([]byte, HeaderOffset+1+len(filename)+1) 1750 pos := HeaderOffset 1751 pos = mp.io.WriteUint8(data, pos, defines.LocalInFileHeader) 1752 pos = mp.writeStringFix(data, pos, filename, len(filename)) 1753 return data[:pos] 1754 } 1755 1756 func (mp *MysqlProtocolImpl) sendLocalInfileRequest(filename string) error { 1757 req := mp.makeLocalInfileRequestPayload(filename) 1758 return mp.writePackets(req, true) 1759 } 1760 1761 func (mp *MysqlProtocolImpl) sendOKPacketWithEof(affectedRows, lastInsertId uint64, status, warnings uint16, message string) error { 1762 okPkt := mp.makeOKPayloadWithEof(affectedRows, lastInsertId, status, warnings, message) 1763 return mp.writePackets(okPkt, true) 1764 } 1765 1766 // send OK packet to the client 1767 func (mp *MysqlProtocolImpl) sendOKPacket(affectedRows, lastInsertId uint64, status, warnings uint16, message string) error { 1768 okPkt := mp.makeOKPayload(affectedRows, lastInsertId, status, warnings, message) 1769 return mp.writePackets(okPkt, true) 1770 } 1771 1772 // make Err packet 1773 func (mp *MysqlProtocolImpl) makeErrPayload(errorCode uint16, sqlState, errorMessage string) []byte { 1774 data := make([]byte, HeaderOffset+9+len(errorMessage)) 1775 pos := HeaderOffset 1776 pos = mp.io.WriteUint8(data, pos, defines.ErrHeader) 1777 pos = mp.io.WriteUint16(data, pos, errorCode) 1778 if mp.capability&CLIENT_PROTOCOL_41 != 0 { 1779 pos = mp.io.WriteUint8(data, pos, '#') 1780 if len(sqlState) < 5 { 1781 stuff := " " 1782 sqlState += stuff[:5-len(sqlState)] 1783 } 1784 pos = mp.writeStringFix(data, pos, sqlState, 5) 1785 } 1786 pos = mp.writeStringFix(data, pos, errorMessage, len(errorMessage)) 1787 return data[:pos] 1788 } 1789 1790 /* 1791 the server sends the Error packet 1792 1793 information from https://dev.mysql.com/doc/mysql-errors/8.0/en/server-error-reference.html 1794 mysql version 8.0.23 1795 usually it is in the directory /usr/local/include/mysql/mysqld_error.h 1796 1797 Error information includes several elements: an error code, SQLSTATE value, and message string. 1798 1799 Error code: This value is numeric. It is MySQL-specific and is not portable to other database systems. 1800 SQLSTATE value: This value is a five-character string (for example, '42S02'). SQLSTATE values are taken from ANSI SQL and ODBC and are more standardized than the numeric error codes. 1801 Message string: This string provides a textual description of the error. 1802 */ 1803 func (mp *MysqlProtocolImpl) sendErrPacket(errorCode uint16, sqlState, errorMessage string) error { 1804 if mp.ses != nil { 1805 mp.ses.GetErrInfo().push(errorCode, errorMessage) 1806 } 1807 errPkt := mp.makeErrPayload(errorCode, sqlState, errorMessage) 1808 return mp.writePackets(errPkt, true) 1809 } 1810 1811 func (mp *MysqlProtocolImpl) makeEOFPayload(warnings, status uint16) []byte { 1812 data := make([]byte, HeaderOffset+10) 1813 pos := HeaderOffset 1814 pos = mp.io.WriteUint8(data, pos, defines.EOFHeader) 1815 if mp.capability&CLIENT_PROTOCOL_41 != 0 { 1816 pos = mp.io.WriteUint16(data, pos, warnings) 1817 pos = mp.io.WriteUint16(data, pos, status) 1818 } 1819 return data[:pos] 1820 } 1821 1822 func (mp *MysqlProtocolImpl) sendEOFPacket(warnings, status uint16) error { 1823 data := mp.makeEOFPayload(warnings, status) 1824 return mp.writePackets(data, true) 1825 } 1826 1827 func (mp *MysqlProtocolImpl) SendEOFPacketIf(warnings, status uint16) error { 1828 //If the CLIENT_DEPRECATE_EOF client capabilities flag is not set, EOF_Packet 1829 if mp.capability&CLIENT_DEPRECATE_EOF == 0 { 1830 return mp.sendEOFPacket(warnings, status) 1831 } 1832 return nil 1833 } 1834 1835 // the OK or EOF packet 1836 // thread safe 1837 func (mp *MysqlProtocolImpl) sendEOFOrOkPacket(warnings, status uint16) error { 1838 //If the CLIENT_DEPRECATE_EOF client capabilities flag is set, OK_Packet; else EOF_Packet. 1839 if mp.capability&CLIENT_DEPRECATE_EOF != 0 { 1840 return mp.sendOKPacketWithEof(0, 0, status, 0, "") 1841 } else { 1842 return mp.sendEOFPacket(warnings, status) 1843 } 1844 } 1845 1846 func setColLength(column *MysqlColumn, width int32) { 1847 column.length = column.columnType.GetLength(width) 1848 } 1849 1850 func setColFlag(column *MysqlColumn) { 1851 if column.auto_incr { 1852 column.flag |= uint16(defines.AUTO_INCREMENT_FLAG) 1853 } 1854 } 1855 1856 func setCharacter(column *MysqlColumn) { 1857 switch column.columnType { 1858 // blob type should use 0x3f to show the binary data 1859 case defines.MYSQL_TYPE_VARCHAR, defines.MYSQL_TYPE_STRING, defines.MYSQL_TYPE_TEXT: 1860 column.SetCharset(charsetVarchar) 1861 case defines.MYSQL_TYPE_VAR_STRING: 1862 column.SetCharset(charsetVarchar) 1863 default: 1864 column.SetCharset(charsetBinary) 1865 } 1866 } 1867 1868 // make the column information with the format of column definition41 1869 func (mp *MysqlProtocolImpl) makeColumnDefinition41Payload(column *MysqlColumn, cmd int) []byte { 1870 space := HeaderOffset + 8*9 + //lenenc bytes of 8 fields 1871 21 + //fixed-length fields 1872 3 + // catalog "def" 1873 len(column.Schema()) + 1874 len(column.Table()) + 1875 len(column.OrgTable()) + 1876 len(column.Name()) + 1877 len(column.OrgName()) + 1878 len(column.DefaultValue()) + 1879 100 // for safe 1880 1881 data := make([]byte, space) 1882 pos := HeaderOffset 1883 1884 //lenenc_str catalog(always "def") 1885 pos = mp.writeStringLenEnc(data, pos, "def") 1886 1887 //lenenc_str schema 1888 pos = mp.writeStringLenEnc(data, pos, column.Schema()) 1889 1890 //lenenc_str table 1891 pos = mp.writeStringLenEnc(data, pos, column.Table()) 1892 1893 //lenenc_str org_table 1894 pos = mp.writeStringLenEnc(data, pos, column.OrgTable()) 1895 1896 //lenenc_str name 1897 pos = mp.writeStringLenEnc(data, pos, column.Name()) 1898 1899 //lenenc_str org_name 1900 pos = mp.writeStringLenEnc(data, pos, column.OrgName()) 1901 1902 //lenenc_int length of fixed-length fields [0c] 1903 pos = mp.io.WriteUint8(data, pos, 0x0c) 1904 1905 if column.ColumnType() == defines.MYSQL_TYPE_BOOL { 1906 //int<2> character set 1907 pos = mp.io.WriteUint16(data, pos, charsetVarchar) 1908 //int<4> column length 1909 pos = mp.io.WriteUint32(data, pos, boolColumnLength) 1910 //int<1> type 1911 pos = mp.io.WriteUint8(data, pos, uint8(defines.MYSQL_TYPE_VARCHAR)) 1912 } else { 1913 //int<2> character set 1914 pos = mp.io.WriteUint16(data, pos, column.Charset()) 1915 //int<4> column length 1916 pos = mp.io.WriteUint32(data, pos, column.Length()) 1917 //int<1> type 1918 pos = mp.io.WriteUint8(data, pos, uint8(column.ColumnType())) 1919 } 1920 1921 //int<2> flags 1922 pos = mp.io.WriteUint16(data, pos, column.Flag()) 1923 1924 //int<1> decimals 1925 pos = mp.io.WriteUint8(data, pos, column.Decimal()) 1926 1927 //int<2> filler [00] [00] 1928 pos = mp.io.WriteUint16(data, pos, 0) 1929 1930 if CommandType(cmd) == COM_FIELD_LIST { 1931 pos = mp.writeIntLenEnc(data, pos, uint64(len(column.DefaultValue()))) 1932 pos = mp.writeCountOfBytes(data, pos, column.DefaultValue()) 1933 } 1934 1935 return data[:pos] 1936 } 1937 1938 // SendColumnDefinitionPacket the server send the column definition to the client 1939 func (mp *MysqlProtocolImpl) SendColumnDefinitionPacket(ctx context.Context, column Column, cmd int) error { 1940 mysqlColumn, ok := column.(*MysqlColumn) 1941 if !ok { 1942 return moerr.NewInternalError(ctx, "sendColumn need MysqlColumn") 1943 } 1944 1945 var data []byte 1946 if mp.capability&CLIENT_PROTOCOL_41 != 0 { 1947 data = mp.makeColumnDefinition41Payload(mysqlColumn, cmd) 1948 } 1949 1950 return mp.writePackets(data, true) 1951 } 1952 1953 // SendColumnCountPacket makes the column count packet 1954 func (mp *MysqlProtocolImpl) SendColumnCountPacket(count uint64) error { 1955 data := make([]byte, HeaderOffset+20) 1956 pos := HeaderOffset 1957 pos = mp.writeIntLenEnc(data, pos, count) 1958 1959 return mp.writePackets(data[:pos], true) 1960 } 1961 1962 func (mp *MysqlProtocolImpl) sendColumns(ctx context.Context, mrs *MysqlResultSet, cmd int, warnings, status uint16) error { 1963 //column_count * Protocol::ColumnDefinition packets 1964 for i := uint64(0); i < mrs.GetColumnCount(); i++ { 1965 var col Column 1966 col, err := mrs.GetColumn(ctx, i) 1967 if err != nil { 1968 return err 1969 } 1970 1971 err = mp.SendColumnDefinitionPacket(ctx, col, cmd) 1972 if err != nil { 1973 return err 1974 } 1975 } 1976 1977 //If the CLIENT_DEPRECATE_EOF client capabilities flag is not set, EOF_Packet 1978 if mp.capability&CLIENT_DEPRECATE_EOF == 0 { 1979 err := mp.sendEOFPacket(warnings, status) 1980 if err != nil { 1981 return err 1982 } 1983 } 1984 return nil 1985 } 1986 1987 // the server convert every row of the result set into the format that mysql protocol needs 1988 func (mp *MysqlProtocolImpl) makeResultSetBinaryRow(data []byte, mrs *MysqlResultSet, rowIdx uint64) ([]byte, error) { 1989 data = mp.append(data, defines.OKHeader) // append OkHeader 1990 1991 // get null buffer 1992 buffer := mp.binaryNullBuffer[:0] 1993 columnsLength := mrs.GetColumnCount() 1994 numBytes4Null := (columnsLength + 7 + 2) / 8 1995 for i := uint64(0); i < numBytes4Null; i++ { 1996 buffer = append(buffer, 0) 1997 } 1998 for i := uint64(0); i < columnsLength; i++ { 1999 if isNil, err := mrs.ColumnIsNull(mp.ctx, rowIdx, i); err != nil { 2000 return nil, err 2001 } else if isNil { 2002 bytePos := (i + 2) / 8 2003 bitPos := byte((i + 2) % 8) 2004 idx := int(bytePos) 2005 buffer[idx] |= 1 << bitPos 2006 continue 2007 } 2008 } 2009 data = mp.append(data, buffer...) 2010 2011 for i := uint64(0); i < columnsLength; i++ { 2012 if isNil, err := mrs.ColumnIsNull(mp.ctx, rowIdx, i); err != nil { 2013 return nil, err 2014 } else if isNil { 2015 continue 2016 } 2017 2018 column, err := mrs.GetColumn(mp.ctx, uint64(i)) 2019 if err != nil { 2020 return nil, err 2021 } 2022 mysqlColumn, ok := column.(*MysqlColumn) 2023 if !ok { 2024 return nil, moerr.NewInternalError(mp.ctx, "sendColumn need MysqlColumn") 2025 } 2026 2027 switch mysqlColumn.ColumnType() { 2028 case defines.MYSQL_TYPE_BOOL: 2029 if value, err := mrs.GetString(mp.ctx, rowIdx, i); err != nil { 2030 return nil, err 2031 } else { 2032 data = mp.appendStringLenEnc(data, value) 2033 } 2034 case defines.MYSQL_TYPE_TINY: 2035 if value, err := mrs.GetInt64(mp.ctx, rowIdx, i); err != nil { 2036 return nil, err 2037 } else { 2038 data = mp.appendUint8(data, uint8(value)) 2039 } 2040 case defines.MYSQL_TYPE_SHORT, defines.MYSQL_TYPE_YEAR: 2041 if value, err := mrs.GetInt64(mp.ctx, rowIdx, i); err != nil { 2042 return nil, err 2043 } else { 2044 data = mp.appendUint16(data, uint16(value)) 2045 } 2046 case defines.MYSQL_TYPE_INT24, defines.MYSQL_TYPE_LONG: 2047 if value, err := mrs.GetInt64(mp.ctx, rowIdx, i); err != nil { 2048 return nil, err 2049 } else { 2050 data = mp.appendUint32(data, uint32(value)) 2051 } 2052 case defines.MYSQL_TYPE_LONGLONG: 2053 if value, err := mrs.GetUint64(mp.ctx, rowIdx, i); err != nil { 2054 return nil, err 2055 } else { 2056 data = mp.appendUint64(data, value) 2057 } 2058 case defines.MYSQL_TYPE_FLOAT: 2059 if value, err := mrs.GetValue(mp.ctx, rowIdx, i); err != nil { 2060 return nil, err 2061 } else { 2062 switch v := value.(type) { 2063 case float32: 2064 data = mp.appendUint32(data, math.Float32bits(v)) 2065 case float64: 2066 data = mp.appendUint32(data, math.Float32bits(float32(v))) 2067 default: 2068 } 2069 } 2070 case defines.MYSQL_TYPE_DOUBLE: 2071 if value, err := mrs.GetValue(mp.ctx, rowIdx, i); err != nil { 2072 return nil, err 2073 } else { 2074 switch v := value.(type) { 2075 case float32: 2076 data = mp.appendUint64(data, math.Float64bits(float64(v))) 2077 case float64: 2078 data = mp.appendUint64(data, math.Float64bits(v)) 2079 default: 2080 } 2081 } 2082 2083 // Binary/varbinary will be sent out as varchar type. 2084 case defines.MYSQL_TYPE_VARCHAR, defines.MYSQL_TYPE_VAR_STRING, defines.MYSQL_TYPE_STRING, 2085 defines.MYSQL_TYPE_BLOB, defines.MYSQL_TYPE_TEXT, defines.MYSQL_TYPE_JSON: 2086 if value, err := mrs.GetString(mp.ctx, rowIdx, i); err != nil { 2087 return nil, err 2088 } else { 2089 data = mp.appendStringLenEnc(data, value) 2090 } 2091 // TODO: some type, we use string now. someday need fix it 2092 case defines.MYSQL_TYPE_DECIMAL: 2093 if value, err := mrs.GetString(mp.ctx, rowIdx, i); err != nil { 2094 return nil, err 2095 } else { 2096 data = mp.appendStringLenEnc(data, value) 2097 } 2098 case defines.MYSQL_TYPE_UUID: 2099 if value, err := mrs.GetString(mp.ctx, rowIdx, i); err != nil { 2100 return nil, err 2101 } else { 2102 data = mp.appendStringLenEnc(data, value) 2103 } 2104 case defines.MYSQL_TYPE_DATE: 2105 if value, err := mrs.GetValue(mp.ctx, rowIdx, i); err != nil { 2106 return nil, err 2107 } else { 2108 data = mp.appendDate(data, value.(types.Date)) 2109 } 2110 case defines.MYSQL_TYPE_TIME: 2111 if value, err := mrs.GetString(mp.ctx, rowIdx, i); err != nil { 2112 return nil, err 2113 } else { 2114 var t types.Time 2115 var err error 2116 idx := strings.Index(value, ".") 2117 if idx == -1 { 2118 t, err = types.ParseTime(value, 0) 2119 } else { 2120 t, err = types.ParseTime(value, int32(len(value)-idx-1)) 2121 } 2122 if err != nil { 2123 data = mp.appendStringLenEnc(data, value) 2124 } else { 2125 data = mp.appendTime(data, t) 2126 } 2127 } 2128 case defines.MYSQL_TYPE_DATETIME, defines.MYSQL_TYPE_TIMESTAMP: 2129 if value, err := mrs.GetString(mp.ctx, rowIdx, i); err != nil { 2130 return nil, err 2131 } else { 2132 var dt types.Datetime 2133 var err error 2134 idx := strings.Index(value, ".") 2135 if idx == -1 { 2136 dt, err = types.ParseDatetime(value, 0) 2137 } else { 2138 dt, err = types.ParseDatetime(value, int32(len(value)-idx-1)) 2139 } 2140 if err != nil { 2141 data = mp.appendStringLenEnc(data, value) 2142 } else { 2143 data = mp.appendDatetime(data, dt) 2144 } 2145 } 2146 // case defines.MYSQL_TYPE_TIMESTAMP: 2147 // if value, err := mrs.GetString(rowIdx, i); err != nil { 2148 // return nil, err 2149 // } else { 2150 // data = mp.appendStringLenEnc(data, value) 2151 // } 2152 default: 2153 return nil, moerr.NewInternalError(mp.ctx, "type is not supported in binary text result row") 2154 } 2155 } 2156 2157 return data, nil 2158 } 2159 2160 // the server convert every row of the result set into the format that mysql protocol needs 2161 func (mp *MysqlProtocolImpl) makeResultSetTextRow(data []byte, mrs *MysqlResultSet, r uint64) ([]byte, error) { 2162 for i := uint64(0); i < mrs.GetColumnCount(); i++ { 2163 column, err := mrs.GetColumn(mp.ctx, i) 2164 if err != nil { 2165 return nil, err 2166 } 2167 mysqlColumn, ok := column.(*MysqlColumn) 2168 if !ok { 2169 return nil, moerr.NewInternalError(mp.ctx, "sendColumn need MysqlColumn") 2170 } 2171 2172 if isNil, err1 := mrs.ColumnIsNull(mp.ctx, r, i); err1 != nil { 2173 return nil, err1 2174 } else if isNil { 2175 //NULL is sent as 0xfb 2176 data = mp.appendUint8(data, 0xFB) 2177 continue 2178 } 2179 2180 switch mysqlColumn.ColumnType() { 2181 case defines.MYSQL_TYPE_JSON: 2182 if value, err2 := mrs.GetString(mp.ctx, r, i); err2 != nil { 2183 return nil, err2 2184 } else { 2185 data = mp.appendStringLenEnc(data, value) 2186 } 2187 case defines.MYSQL_TYPE_BOOL: 2188 if value, err2 := mrs.GetString(mp.ctx, r, i); err2 != nil { 2189 return nil, err2 2190 } else { 2191 data = mp.appendStringLenEnc(data, value) 2192 } 2193 case defines.MYSQL_TYPE_BIT: 2194 if value, err2 := mrs.GetUint64(mp.ctx, r, i); err2 != nil { 2195 return nil, err2 2196 } else { 2197 bitLength := mysqlColumn.ColumnImpl.Length() 2198 byteLength := (bitLength + 7) / 8 2199 b := types.EncodeUint64(&value)[:byteLength] 2200 slices.Reverse(b) 2201 data = mp.appendStringLenEnc(data, string(b)) 2202 } 2203 case defines.MYSQL_TYPE_DECIMAL: 2204 if value, err2 := mrs.GetString(mp.ctx, r, i); err2 != nil { 2205 return nil, err2 2206 } else { 2207 data = mp.appendStringLenEnc(data, value) 2208 } 2209 case defines.MYSQL_TYPE_UUID: 2210 if value, err2 := mrs.GetString(mp.ctx, r, i); err2 != nil { 2211 return nil, err2 2212 } else { 2213 data = mp.appendStringLenEnc(data, value) 2214 } 2215 case defines.MYSQL_TYPE_TINY, defines.MYSQL_TYPE_SHORT, defines.MYSQL_TYPE_INT24, defines.MYSQL_TYPE_LONG, defines.MYSQL_TYPE_YEAR: 2216 if value, err2 := mrs.GetInt64(mp.ctx, r, i); err2 != nil { 2217 return nil, err2 2218 } else { 2219 if mysqlColumn.ColumnType() == defines.MYSQL_TYPE_YEAR { 2220 if value == 0 { 2221 data = mp.appendStringLenEnc(data, "0000") 2222 } else { 2223 data = mp.appendStringLenEncOfInt64(data, value) 2224 } 2225 } else { 2226 data = mp.appendStringLenEncOfInt64(data, value) 2227 } 2228 } 2229 case defines.MYSQL_TYPE_FLOAT: 2230 if value, err := mrs.GetValue(mp.ctx, r, i); err != nil { 2231 return nil, err 2232 } else { 2233 switch v := value.(type) { 2234 case float32: 2235 data = mp.appendStringLenEncOfFloat64(data, float64(v), 32) 2236 case float64: 2237 data = mp.appendStringLenEncOfFloat64(data, v, 32) 2238 default: 2239 } 2240 } 2241 case defines.MYSQL_TYPE_DOUBLE: 2242 if value, err := mrs.GetValue(mp.ctx, r, i); err != nil { 2243 return nil, err 2244 } else { 2245 switch v := value.(type) { 2246 case float32: 2247 data = mp.appendStringLenEncOfFloat64(data, float64(v), 64) 2248 case float64: 2249 data = mp.appendStringLenEncOfFloat64(data, v, 64) 2250 default: 2251 } 2252 } 2253 case defines.MYSQL_TYPE_LONGLONG: 2254 if uint32(mysqlColumn.Flag())&defines.UNSIGNED_FLAG != 0 { 2255 if value, err2 := mrs.GetUint64(mp.ctx, r, i); err2 != nil { 2256 return nil, err2 2257 } else { 2258 data = mp.appendStringLenEncOfUint64(data, value) 2259 } 2260 } else { 2261 if value, err2 := mrs.GetInt64(mp.ctx, r, i); err2 != nil { 2262 return nil, err2 2263 } else { 2264 data = mp.appendStringLenEncOfInt64(data, value) 2265 } 2266 } 2267 // Binary/varbinary will be sent out as varchar type. 2268 case defines.MYSQL_TYPE_VARCHAR, defines.MYSQL_TYPE_VAR_STRING, defines.MYSQL_TYPE_STRING, 2269 defines.MYSQL_TYPE_BLOB, defines.MYSQL_TYPE_TEXT: 2270 if value, err2 := mrs.GetString(mp.ctx, r, i); err2 != nil { 2271 return nil, err2 2272 } else { 2273 data = mp.appendStringLenEnc(data, value) 2274 } 2275 case defines.MYSQL_TYPE_DATE: 2276 if value, err2 := mrs.GetValue(mp.ctx, r, i); err2 != nil { 2277 return nil, err2 2278 } else { 2279 data = mp.appendStringLenEnc(data, value.(types.Date).String()) 2280 } 2281 case defines.MYSQL_TYPE_DATETIME: 2282 if value, err2 := mrs.GetString(mp.ctx, r, i); err2 != nil { 2283 return nil, err2 2284 } else { 2285 data = mp.appendStringLenEnc(data, value) 2286 } 2287 case defines.MYSQL_TYPE_TIME: 2288 if value, err2 := mrs.GetString(mp.ctx, r, i); err2 != nil { 2289 return nil, err2 2290 } else { 2291 data = mp.appendStringLenEnc(data, value) 2292 } 2293 case defines.MYSQL_TYPE_TIMESTAMP: 2294 if value, err2 := mrs.GetString(mp.ctx, r, i); err2 != nil { 2295 return nil, err2 2296 } else { 2297 data = mp.appendStringLenEnc(data, value) 2298 } 2299 case defines.MYSQL_TYPE_ENUM: 2300 if value, err2 := mrs.GetString(mp.ctx, r, i); err2 != nil { 2301 return nil, err2 2302 } else { 2303 data = mp.appendStringLenEnc(data, value) 2304 } 2305 default: 2306 return nil, moerr.NewInternalError(mp.ctx, "unsupported column type %d ", mysqlColumn.ColumnType()) 2307 } 2308 } 2309 return data, nil 2310 } 2311 2312 // the server send group row of the result set as an independent packet 2313 // thread safe 2314 func (mp *MysqlProtocolImpl) SendResultSetTextBatchRow(mrs *MysqlResultSet, cnt uint64) error { 2315 if cnt == 0 { 2316 return nil 2317 } 2318 2319 mp.m.Lock() 2320 defer mp.m.Unlock() 2321 var err error = nil 2322 2323 for i := uint64(0); i < cnt; i++ { 2324 if err = mp.sendResultSetTextRow(mrs, i); err != nil { 2325 return err 2326 } 2327 } 2328 return err 2329 } 2330 2331 func (mp *MysqlProtocolImpl) SendResultSetTextBatchRowSpeedup(mrs *MysqlResultSet, cnt uint64) error { 2332 if cnt == 0 { 2333 return nil 2334 } 2335 2336 cmd := mp.GetSession().GetCmd() 2337 mp.m.Lock() 2338 defer mp.m.Unlock() 2339 var err error = nil 2340 2341 // XXX now we known COM_QUERY will use textRow, COM_STMT_EXECUTE use binaryRow 2342 useBinaryRow := cmd == COM_STMT_EXECUTE 2343 2344 //make rows into the batch 2345 for i := uint64(0); i < cnt; i++ { 2346 err = mp.openRow(nil) 2347 if err != nil { 2348 return err 2349 } 2350 //begin1 := time.Now() 2351 if useBinaryRow { 2352 _, err = mp.makeResultSetBinaryRow(nil, mrs, i) 2353 } else { 2354 _, err = mp.makeResultSetTextRow(nil, mrs, i) 2355 } 2356 //mp.makeTime += time.Since(begin1) 2357 2358 if err != nil { 2359 //ERR_Packet in case of error 2360 err1 := mp.sendErrPacket(moerr.ER_UNKNOWN_ERROR, DefaultMySQLState, err.Error()) 2361 if err1 != nil { 2362 return err1 2363 } 2364 return err 2365 } 2366 2367 //output into outbuf 2368 err = mp.closeRow(nil) 2369 if err != nil { 2370 return err 2371 } 2372 } 2373 2374 return err 2375 } 2376 2377 // open a new row of the resultset 2378 func (mp *MysqlProtocolImpl) openRow(_ []byte) error { 2379 return mp.openPacket() 2380 } 2381 2382 // close a finished row of the resultset 2383 func (mp *MysqlProtocolImpl) closeRow(_ []byte) error { 2384 err := mp.closePacket(true) 2385 if err != nil { 2386 return err 2387 } 2388 2389 err = mp.flushOutBuffer() 2390 if err != nil { 2391 return err 2392 } 2393 return err 2394 } 2395 2396 // flushOutBuffer the data in the outbuf into the network 2397 func (mp *MysqlProtocolImpl) flushOutBuffer() error { 2398 if mp.bytesInOutBuffer >= mp.untilBytesInOutbufToFlush { 2399 mp.flushCount++ 2400 mp.writeBytes += uint64(mp.bytesInOutBuffer) 2401 // FIXME: use a suitable timeout value 2402 mp.incDebugCount(8) 2403 err := mp.tcpConn.Flush(0) 2404 mp.incDebugCount(9) 2405 if err != nil { 2406 return err 2407 } 2408 mp.resetFlushOutBuffer() 2409 } 2410 return nil 2411 } 2412 2413 // open a new mysql protocol packet 2414 func (mp *MysqlProtocolImpl) openPacket() error { 2415 outbuf := mp.tcpConn.OutBuf() 2416 n := 4 2417 outbuf.Grow(n) 2418 /* 2419 offset = writerIndex - readerIndex 2420 writerIndex = GetReaderIndex() + offset 2421 */ 2422 offset := outbuf.GetWriteIndex() - outbuf.GetReadIndex() 2423 writeIdx := beginWriteIndex(outbuf, offset) 2424 mp.beginOffset = offset 2425 writeIdx += n 2426 mp.bytesInOutBuffer += n 2427 outbuf.SetWriteIndex(writeIdx) 2428 return nil 2429 } 2430 2431 func beginWriteIndex(outbuf *goetty_buf.ByteBuf, offset int) int { 2432 if offset < 0 { 2433 panic("invalid offset") 2434 } 2435 return outbuf.GetReadIndex() + offset 2436 } 2437 2438 // fill the packet with data 2439 func (mp *MysqlProtocolImpl) fillPacket(elems ...byte) error { 2440 outbuf := mp.tcpConn.OutBuf() 2441 n := len(elems) 2442 i := 0 2443 curLen := 0 2444 hasDataLen := 0 2445 curDataLen := 0 2446 var err error 2447 var buf []byte 2448 for ; i < n; i += curLen { 2449 if !mp.isInPacket() { 2450 err = mp.openPacket() 2451 if err != nil { 2452 return err 2453 } 2454 } 2455 //length of data in the packet 2456 hasDataLen = outbuf.GetWriteIndex() - beginWriteIndex(outbuf, mp.beginOffset) - HeaderLengthOfTheProtocol 2457 curLen = int(MaxPayloadSize) - hasDataLen 2458 curLen = Min(curLen, n-i) 2459 if curLen < 0 { 2460 return moerr.NewInternalError(mp.ctx, "needLen %d < 0. hasDataLen %d n - i %d", curLen, hasDataLen, n-i) 2461 } 2462 outbuf.Grow(curLen) 2463 buf = outbuf.RawBuf() 2464 writeIdx := outbuf.GetWriteIndex() 2465 copy(buf[writeIdx:], elems[i:i+curLen]) 2466 writeIdx += curLen 2467 mp.bytesInOutBuffer += curLen 2468 outbuf.SetWriteIndex(writeIdx) 2469 2470 //> 16MB, split it 2471 curDataLen = outbuf.GetWriteIndex() - beginWriteIndex(outbuf, mp.beginOffset) - HeaderLengthOfTheProtocol 2472 if curDataLen == int(MaxPayloadSize) { 2473 err = mp.closePacket(i+curLen == n) 2474 if err != nil { 2475 return err 2476 } 2477 2478 err = mp.flushOutBuffer() 2479 if err != nil { 2480 return err 2481 } 2482 } 2483 } 2484 2485 return nil 2486 } 2487 2488 // close a mysql protocol packet 2489 func (mp *MysqlProtocolImpl) closePacket(appendZeroPacket bool) error { 2490 if !mp.isInPacket() { 2491 return nil 2492 } 2493 outbuf := mp.tcpConn.OutBuf() 2494 payLoadLen := outbuf.GetWriteIndex() - beginWriteIndex(outbuf, mp.beginOffset) - 4 2495 if payLoadLen < 0 || payLoadLen > int(MaxPayloadSize) { 2496 return moerr.NewInternalError(mp.ctx, "invalid payload len :%d curWriteIdx %d beginWriteIdx %d ", 2497 payLoadLen, outbuf.GetWriteIndex(), beginWriteIndex(outbuf, mp.beginOffset)) 2498 } 2499 2500 buf := outbuf.RawBuf() 2501 binary.LittleEndian.PutUint32(buf[beginWriteIndex(outbuf, mp.beginOffset):], uint32(payLoadLen)) 2502 buf[beginWriteIndex(outbuf, mp.beginOffset)+3] = mp.GetSequenceId() 2503 2504 mp.AddSequenceId(1) 2505 2506 if appendZeroPacket && payLoadLen == int(MaxPayloadSize) { //last 16MB packet,append a zero packet 2507 //if the size of the last packet is exactly MaxPayloadSize, a zero-size payload should be sent 2508 err := mp.openPacket() 2509 if err != nil { 2510 return err 2511 } 2512 buf = outbuf.RawBuf() 2513 binary.LittleEndian.PutUint32(buf[beginWriteIndex(outbuf, mp.beginOffset):], uint32(0)) 2514 buf[beginWriteIndex(outbuf, mp.beginOffset)+3] = mp.GetSequenceId() 2515 mp.AddSequenceId(1) 2516 } 2517 2518 mp.resetPacket() 2519 return nil 2520 } 2521 2522 /* 2523 * 2524 append the elems into the outbuffer 2525 */ 2526 func (mp *MysqlProtocolImpl) append(_ []byte, elems ...byte) []byte { 2527 err := mp.fillPacket(elems...) 2528 if err != nil { 2529 panic(err) 2530 } 2531 return mp.tcpConn.OutBuf().RawBuf() 2532 } 2533 2534 func (mp *MysqlProtocolImpl) appendDatetime(data []byte, dt types.Datetime) []byte { 2535 if dt.MicroSec() != 0 { 2536 data = mp.append(data, 11) 2537 data = mp.appendUint16(data, uint16(dt.Year())) 2538 data = mp.append(data, dt.Month(), dt.Day(), byte(dt.Hour()), byte(dt.Minute()), byte(dt.Sec())) 2539 data = mp.appendUint32(data, uint32(dt.MicroSec())) 2540 } else if dt.Hour() != 0 || dt.Minute() != 0 || dt.Sec() != 0 { 2541 data = mp.append(data, 7) 2542 data = mp.appendUint16(data, uint16(dt.Year())) 2543 data = mp.append(data, dt.Month(), dt.Day(), byte(dt.Hour()), byte(dt.Minute()), byte(dt.Sec())) 2544 } else { 2545 data = mp.append(data, 4) 2546 data = mp.appendUint16(data, uint16(dt.Year())) 2547 data = mp.append(data, dt.Month(), dt.Day()) 2548 } 2549 return data 2550 } 2551 2552 func (mp *MysqlProtocolImpl) appendTime(data []byte, t types.Time) []byte { 2553 if int64(t) == 0 { 2554 data = mp.append(data, 0) 2555 } else { 2556 hour, minute, sec, msec, isNeg := t.ClockFormat() 2557 day := uint32(hour / 24) 2558 hour = hour % 24 2559 if msec != 0 { 2560 data = mp.append(data, 12) 2561 if isNeg { 2562 data = append(data, byte(1)) 2563 } else { 2564 data = append(data, byte(0)) 2565 } 2566 data = mp.appendUint32(data, day) 2567 data = mp.append(data, uint8(hour), minute, sec) 2568 data = mp.appendUint64(data, msec) 2569 } else { 2570 data = mp.append(data, 8) 2571 if isNeg { 2572 data = append(data, byte(1)) 2573 } else { 2574 data = append(data, byte(0)) 2575 } 2576 data = mp.appendUint32(data, day) 2577 data = mp.append(data, uint8(hour), minute, sec) 2578 } 2579 } 2580 return data 2581 } 2582 2583 func (mp *MysqlProtocolImpl) appendDate(data []byte, value types.Date) []byte { 2584 if int32(value) == 0 { 2585 data = mp.append(data, 0) 2586 } else { 2587 data = mp.append(data, 4) 2588 data = mp.appendUint16(data, value.Year()) 2589 data = mp.append(data, value.Month(), value.Day()) 2590 } 2591 return data 2592 } 2593 2594 // the server send every row of the result set as an independent packet 2595 // thread safe 2596 func (mp *MysqlProtocolImpl) SendResultSetTextRow(mrs *MysqlResultSet, r uint64) error { 2597 mp.m.Lock() 2598 defer mp.m.Unlock() 2599 2600 return mp.sendResultSetTextRow(mrs, r) 2601 } 2602 2603 // the server send every row of the result set as an independent packet 2604 func (mp *MysqlProtocolImpl) sendResultSetTextRow(mrs *MysqlResultSet, r uint64) error { 2605 var err error 2606 err = mp.openRow(nil) 2607 if err != nil { 2608 return err 2609 } 2610 if _, err = mp.makeResultSetTextRow(nil, mrs, r); err != nil { 2611 //ERR_Packet in case of error 2612 err1 := mp.sendErrPacket(moerr.ER_UNKNOWN_ERROR, DefaultMySQLState, err.Error()) 2613 if err1 != nil { 2614 return err1 2615 } 2616 return err 2617 } 2618 2619 err = mp.closeRow(nil) 2620 if err != nil { 2621 return err 2622 } 2623 2624 //begin2 := time.Now() 2625 //err = mp.writePackets(data) 2626 //if err != nil { 2627 // return moerr.NewInternalError("send result set text row failed. error: %v", err) 2628 //} 2629 //mp.sendTime += time.Since(begin2) 2630 2631 return nil 2632 } 2633 2634 // the server send the result set of execution the client 2635 // the routine follows the article: https://dev.mysql.com/doc/internals/en/com-query-response.html 2636 func (mp *MysqlProtocolImpl) sendResultSet(ctx context.Context, set ResultSet, cmd int, warnings, status uint16) error { 2637 mysqlRS, ok := set.(*MysqlResultSet) 2638 if !ok { 2639 return moerr.NewInternalError(ctx, "sendResultSet need MysqlResultSet") 2640 } 2641 2642 //A packet containing a Protocol::LengthEncodedInteger column_count 2643 err := mp.SendColumnCountPacket(mysqlRS.GetColumnCount()) 2644 if err != nil { 2645 return err 2646 } 2647 2648 if err = mp.sendColumns(ctx, mysqlRS, cmd, warnings, status); err != nil { 2649 return err 2650 } 2651 2652 //One or more ProtocolText::ResultsetRow packets, each containing column_count values 2653 for i := uint64(0); i < mysqlRS.GetRowCount(); i++ { 2654 if err = mp.sendResultSetTextRow(mysqlRS, i); err != nil { 2655 return err 2656 } 2657 } 2658 2659 //If the CLIENT_DEPRECATE_EOF client capabilities flag is set, OK_Packet; else EOF_Packet. 2660 if mp.capability&CLIENT_DEPRECATE_EOF != 0 { 2661 err := mp.sendOKPacketWithEof(0, 0, extendStatus(status), 0, "") 2662 if err != nil { 2663 return err 2664 } 2665 } else { 2666 err := mp.sendEOFPacket(warnings, extendStatus(status)) 2667 if err != nil { 2668 return err 2669 } 2670 } 2671 2672 return nil 2673 } 2674 2675 // the server sends the payload to the client 2676 func (mp *MysqlProtocolImpl) writePackets(payload []byte, _ bool) error { 2677 2678 //protocol header length 2679 var headerLen = HeaderOffset 2680 var header [4]byte 2681 2682 //position of the first data byte 2683 var i = headerLen 2684 var length = len(payload) 2685 var curLen int 2686 for ; i < length; i += curLen { 2687 //var packet []byte = mp.packet[:0] 2688 curLen = Min(int(MaxPayloadSize), length-i) 2689 2690 //make mysql client protocol header 2691 //4 bytes 2692 //int<3> the length of payload 2693 mp.io.WriteUint32(header[:], 0, uint32(curLen)) 2694 2695 //int<1> sequence id 2696 mp.io.WriteUint8(header[:], 3, mp.GetSequenceId()) 2697 2698 //send packet 2699 var packet = append(header[:], payload[i:i+curLen]...) 2700 2701 mp.incDebugCount(4) 2702 err := mp.tcpConn.Write(packet, goetty.WriteOptions{Flush: true}) 2703 mp.incDebugCount(5) 2704 if err != nil { 2705 return err 2706 } 2707 mp.AddFlushBytes(uint64(len(packet))) 2708 mp.AddSequenceId(1) 2709 2710 if i+curLen == length && curLen == int(MaxPayloadSize) { 2711 //if the size of the last packet is exactly MaxPayloadSize, a zero-size payload should be sent 2712 header[0] = 0 2713 header[1] = 0 2714 header[2] = 0 2715 header[3] = mp.GetSequenceId() 2716 mp.incDebugCount(6) 2717 //send header / zero-sized packet 2718 err := mp.tcpConn.Write(header[:], goetty.WriteOptions{Flush: true}) 2719 mp.AddFlushBytes(uint64(len(header))) 2720 mp.incDebugCount(7) 2721 if err != nil { 2722 return err 2723 } 2724 2725 mp.AddSequenceId(1) 2726 } 2727 } 2728 2729 return nil 2730 } 2731 2732 // MakeHandshakePayload exposes (*MysqlProtocolImpl).makeHandshakeV10Payload() function. 2733 func (mp *MysqlProtocolImpl) MakeHandshakePayload() []byte { 2734 return mp.makeHandshakeV10Payload() 2735 } 2736 2737 // WritePacket exposes (*MysqlProtocolImpl).writePackets() function. 2738 func (mp *MysqlProtocolImpl) WritePacket(payload []byte) error { 2739 return mp.writePackets(payload, true) 2740 } 2741 2742 // MakeOKPayload exposes (*MysqlProtocolImpl).makeOKPayload() function. 2743 func (mp *MysqlProtocolImpl) MakeOKPayload(affectedRows, lastInsertId uint64, statusFlags, warnings uint16, message string) []byte { 2744 return mp.makeOKPayload(affectedRows, lastInsertId, statusFlags, warnings, message) 2745 } 2746 2747 // MakeErrPayload exposes (*MysqlProtocolImpl).makeErrPayload() function. 2748 func (mp *MysqlProtocolImpl) MakeErrPayload(errCode uint16, sqlState, errorMessage string) []byte { 2749 return mp.makeErrPayload(errCode, sqlState, errorMessage) 2750 } 2751 2752 // MakeEOFPayload exposes (*MysqlProtocolImpl).makeEOFPayload() function. 2753 func (mp *MysqlProtocolImpl) MakeEOFPayload(warnings, status uint16) []byte { 2754 return mp.makeEOFPayload(warnings, status) 2755 } 2756 2757 // receiveExtraInfo tries to receive salt and labels read from proxy module. 2758 func (mp *MysqlProtocolImpl) receiveExtraInfo(rs goetty.IOSession) { 2759 // TODO(volgariver6): when proxy is stable, remove this deadline setting. 2760 if err := rs.RawConn().SetReadDeadline(time.Now().Add(defaultSaltReadTimeout)); err != nil { 2761 logDebugf(mp.GetDebugString(), "failed to set deadline for salt updating: %v", err) 2762 return 2763 } 2764 var i proxy.ExtraInfo 2765 reader := bufio.NewReader(rs.RawConn()) 2766 if err := i.Decode(reader); err != nil { 2767 // If the error is timeout, we treat it as normal case and do not update extra info. 2768 if err, ok := err.(net.Error); ok && err.Timeout() { 2769 logInfo(mp.ses, mp.GetDebugString(), "cannot get salt, maybe not use proxy", 2770 zap.Error(err)) 2771 } else { 2772 logError(mp.ses, mp.GetDebugString(), "failed to get extra info", 2773 zap.Error(err)) 2774 } 2775 return 2776 } 2777 2778 // must from proxy if extraInfo is received 2779 mp.GetSession().fromProxy = true 2780 mp.SetSalt(i.Salt) 2781 mp.connectionID = i.ConnectionID 2782 ses := mp.GetSession() 2783 ses.requestLabel = i.Label 2784 if i.InternalConn { 2785 ses.connType = ConnTypeInternal 2786 } else { 2787 ses.connType = ConnTypeExternal 2788 } 2789 ses.clientAddr = i.ClientAddr 2790 ses.proxyAddr = mp.Peer() 2791 } 2792 2793 /* 2794 //ther server reads a part of payload from the connection 2795 //the part may be a whole payload 2796 func (mp *MysqlProtocolImpl) recvPartOfPayload() ([]byte, error) { 2797 //var length int 2798 //var header []byte 2799 //var err error 2800 //if header, err = mp.io.ReadPacket(4); err != nil { 2801 // return nil, moerr.NewInternalError("read header failed.error:%v", err) 2802 //} else if header[3] != mp.sequenceId { 2803 // return nil, moerr.NewInternalError("client sequence id %d != server sequence id %d", header[3], mp.sequenceId) 2804 //} 2805 2806 mp.sequenceId++ 2807 //length = int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) 2808 2809 var payload []byte 2810 //if payload, err = mp.io.ReadPacket(length); err != nil { 2811 // return nil, moerr.NewInternalError("read payload failed.error:%v", err) 2812 //} 2813 return payload, nil 2814 } 2815 2816 //the server read a payload from the connection 2817 func (mp *MysqlProtocolImpl) recvPayload() ([]byte, error) { 2818 payload, err := mp.recvPartOfPayload() 2819 if err != nil { 2820 return nil, err 2821 } 2822 2823 //only one part 2824 if len(payload) < int(MaxPayloadSize) { 2825 return payload, nil 2826 } 2827 2828 //payload has been split into many parts. 2829 //read them all together 2830 var part []byte 2831 for { 2832 part, err = mp.recvPartOfPayload() 2833 if err != nil { 2834 return nil, err 2835 } 2836 2837 payload = append(payload, part...) 2838 2839 //only one part 2840 if len(part) < int(MaxPayloadSize) { 2841 break 2842 } 2843 } 2844 return payload, nil 2845 } 2846 */ 2847 2848 /* 2849 generate random ascii string. 2850 Reference to :mysql 8.0.23 mysys/crypt_genhash_impl.cc generate_user_salt(char*,int) 2851 */ 2852 func generate_salt(n int) []byte { 2853 buf := make([]byte, n) 2854 r := rand.New(rand.NewSource(time.Now().UTC().UnixNano())) 2855 r.Read(buf) 2856 for i := 0; i < n; i++ { 2857 buf[i] &= 0x7f 2858 if buf[i] == 0 || buf[i] == '$' { 2859 buf[i]++ 2860 } 2861 } 2862 return buf 2863 } 2864 2865 func NewMysqlClientProtocol(connectionID uint32, tcp goetty.IOSession, maxBytesToFlush int, SV *config.FrontendParameters) *MysqlProtocolImpl { 2866 salt := generate_salt(20) 2867 mysql := &MysqlProtocolImpl{ 2868 ProtocolImpl: ProtocolImpl{ 2869 io: NewIOPackage(true), 2870 tcpConn: tcp, 2871 salt: salt, 2872 connectionID: connectionID, 2873 }, 2874 charset: "utf8mb4", 2875 capability: DefaultCapability, 2876 strconvBuffer: make([]byte, 0, 16*1024), 2877 lenEncBuffer: make([]byte, 0, 10), 2878 binaryNullBuffer: make([]byte, 0, 512), 2879 rowHandler: rowHandler{ 2880 beginOffset: -1, 2881 bytesInOutBuffer: 0, 2882 untilBytesInOutbufToFlush: maxBytesToFlush * 1024, 2883 }, 2884 SV: SV, 2885 } 2886 2887 if SV.EnableTls { 2888 mysql.capability = mysql.capability | CLIENT_SSL 2889 } 2890 2891 mysql.resetPacket() 2892 2893 return mysql 2894 } 2895 2896 // HashSha1 is used to calcute a sha1 hash 2897 // SHA1() 2898 func HashSha1(toHash []byte) []byte { 2899 sha := sha1.New() 2900 _, err := sha.Write(toHash) 2901 if err != nil { 2902 logutil.Errorf("SHA1 failed.") 2903 return nil 2904 } 2905 return sha.Sum(nil) 2906 } 2907 2908 // HashPassWord is uesed to hash password 2909 // *SHA1(SHA1(password)) 2910 func HashPassWord(pwd string) string { 2911 if len(pwd) == 0 { 2912 return "" 2913 } 2914 hash1 := HashSha1([]byte(pwd)) 2915 hash2 := HashSha1(hash1) 2916 2917 return fmt.Sprintf("*%X", hash2) 2918 } 2919 2920 func HashPassWordWithByte(pwd []byte) string { 2921 if len(pwd) == 0 { 2922 return "" 2923 } 2924 hash1 := HashSha1(pwd) 2925 hash2 := HashSha1(hash1) 2926 2927 return fmt.Sprintf("*%X", hash2) 2928 } 2929 2930 // GetPassWord is used to get hash byte password 2931 // SHA1(SHA1(password)) 2932 func GetPassWord(pwd string) ([]byte, error) { 2933 pwdByte, err := hex.DecodeString(pwd[1:]) 2934 if err != nil { 2935 logutil.Errorf("GetPassWord failed.") 2936 } 2937 return pwdByte, nil 2938 }