github.com/matrixorigin/matrixone@v0.7.0/pkg/frontend/mysql_protocol_test.go (about) 1 // Copyright 2021 Matrix Origin 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package frontend 16 17 import ( 18 "bytes" 19 "context" 20 "database/sql" 21 "encoding/binary" 22 "fmt" 23 "github.com/stretchr/testify/assert" 24 "math" 25 "reflect" 26 "strconv" 27 "sync" 28 "testing" 29 "time" 30 31 // mysqlDriver "github.com/go-sql-driver/mysql" 32 "github.com/BurntSushi/toml" 33 "github.com/fagongzi/goetty/v2" 34 "github.com/fagongzi/goetty/v2/buf" 35 "github.com/golang/mock/gomock" 36 fuzz "github.com/google/gofuzz" 37 "github.com/matrixorigin/matrixone/pkg/common/moerr" 38 "github.com/matrixorigin/matrixone/pkg/config" 39 "github.com/matrixorigin/matrixone/pkg/container/types" 40 "github.com/matrixorigin/matrixone/pkg/defines" 41 mock_frontend "github.com/matrixorigin/matrixone/pkg/frontend/test" 42 "github.com/matrixorigin/matrixone/pkg/sql/parsers" 43 "github.com/matrixorigin/matrixone/pkg/sql/parsers/dialect" 44 "github.com/matrixorigin/matrixone/pkg/sql/parsers/dialect/mysql" 45 "github.com/matrixorigin/matrixone/pkg/sql/parsers/tree" 46 "github.com/matrixorigin/matrixone/pkg/sql/plan" 47 "github.com/matrixorigin/matrixone/pkg/vm/engine" 48 "github.com/matrixorigin/matrixone/pkg/vm/process" 49 "github.com/prashantv/gostub" 50 "github.com/smartystreets/goconvey/convey" 51 "github.com/stretchr/testify/require" 52 ) 53 54 type TestRoutineManager struct { 55 rwlock sync.Mutex 56 clients map[goetty.IOSession]*Routine 57 58 pu *config.ParameterUnit 59 } 60 61 func (tRM *TestRoutineManager) Created(rs goetty.IOSession) { 62 pro := NewMysqlClientProtocol(nextConnectionID(), rs, 1024, tRM.pu.SV) 63 pro.SetSkipCheckUser(true) 64 exe := NewMysqlCmdExecutor() 65 routine := NewRoutine(context.TODO(), pro, exe, tRM.pu.SV, rs) 66 67 hsV10pkt := pro.makeHandshakeV10Payload() 68 err := pro.writePackets(hsV10pkt) 69 if err != nil { 70 panic(err) 71 } 72 73 tRM.rwlock.Lock() 74 defer tRM.rwlock.Unlock() 75 tRM.clients[rs] = routine 76 } 77 78 func (tRM *TestRoutineManager) Closed(rs goetty.IOSession) { 79 tRM.rwlock.Lock() 80 defer tRM.rwlock.Unlock() 81 delete(tRM.clients, rs) 82 } 83 84 func NewTestRoutineManager(pu *config.ParameterUnit) *TestRoutineManager { 85 rm := &TestRoutineManager{ 86 clients: make(map[goetty.IOSession]*Routine), 87 pu: pu, 88 } 89 return rm 90 } 91 92 func TestMysqlClientProtocol_Handshake(t *testing.T) { 93 //client connection method: mysql -h 127.0.0.1 -P 6001 --default-auth=mysql_native_password -uroot -p 94 //client connect 95 //ion method: mysql -h 127.0.0.1 -P 6001 -udump -p 96 97 var db *sql.DB 98 var err error 99 //before anything using the configuration 100 pu := config.NewParameterUnit(&config.FrontendParameters{}, nil, nil, nil, nil) 101 _, err = toml.DecodeFile("test/system_vars_config.toml", pu.SV) 102 require.NoError(t, err) 103 104 ctx := context.WithValue(context.TODO(), config.ParameterUnitKey, pu) 105 rm, _ := NewRoutineManager(ctx, pu) 106 rm.SetSkipCheckUser(true) 107 108 wg := sync.WaitGroup{} 109 wg.Add(1) 110 111 //running server 112 go func() { 113 defer wg.Done() 114 echoServer(rm.Handler, rm, NewSqlCodec()) 115 }() 116 117 time.Sleep(time.Second * 2) 118 db, err = openDbConn(t, 6001) 119 require.NoError(t, err) 120 closeDbConn(t, db) 121 122 time.Sleep(time.Millisecond * 10) 123 //close server 124 setServer(1) 125 wg.Wait() 126 } 127 128 func newMrsForConnectionId(rows [][]interface{}) *MysqlResultSet { 129 mrs := &MysqlResultSet{} 130 131 col1 := &MysqlColumn{} 132 col1.SetName("connection_id") 133 col1.SetColumnType(defines.MYSQL_TYPE_LONGLONG) 134 135 mrs.AddColumn(col1) 136 137 for _, row := range rows { 138 mrs.AddRow(row) 139 } 140 141 return mrs 142 } 143 144 func TestKIll(t *testing.T) { 145 //client connection method: mysql -h 127.0.0.1 -P 6001 --default-auth=mysql_native_password -uroot -p 146 //client connect 147 //ion method: mysql -h 127.0.0.1 -P 6001 -udump -p 148 ctrl := gomock.NewController(t) 149 defer ctrl.Finish() 150 var conn1, conn2 *sql.DB 151 var err error 152 var connIdRow *sql.Row 153 154 //before anything using the configuration 155 eng := mock_frontend.NewMockEngine(ctrl) 156 eng.EXPECT().New(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() 157 eng.EXPECT().Commit(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() 158 eng.EXPECT().Rollback(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() 159 txnClient := mock_frontend.NewMockTxnClient(ctrl) 160 pu, err := getParameterUnit("test/system_vars_config.toml", eng, txnClient) 161 require.NoError(t, err) 162 163 sql1 := "select connection_id();" 164 var sql2, sql3, sql4 string 165 noResultSet := make(map[string]bool) 166 resultSet := make(map[string]genMrs) 167 resultSet[sql1] = func(ses *Session) *MysqlResultSet { 168 mrs := newMrsForConnectionId([][]interface{}{ 169 {ses.GetConnectionID()}, 170 }) 171 return mrs 172 } 173 174 var wrapperStubFunc = func(db, sql, user string, eng engine.Engine, proc *process.Process, ses *Session) ([]ComputationWrapper, error) { 175 var cw []ComputationWrapper = nil 176 var stmts []tree.Statement = nil 177 var cmdFieldStmt *InternalCmdFieldList 178 var err error 179 if isCmdFieldListSql(sql) { 180 cmdFieldStmt, err = parseCmdFieldList(proc.Ctx, sql) 181 if err != nil { 182 return nil, err 183 } 184 stmts = append(stmts, cmdFieldStmt) 185 } else { 186 stmts, err = parsers.Parse(proc.Ctx, dialect.MYSQL, sql) 187 if err != nil { 188 return nil, err 189 } 190 } 191 192 for _, stmt := range stmts { 193 cw = append(cw, newMockWrapper(ctrl, ses, resultSet, noResultSet, sql, stmt, proc)) 194 } 195 return cw, nil 196 } 197 198 bhStub := gostub.Stub(&GetComputationWrapper, wrapperStubFunc) 199 defer bhStub.Reset() 200 201 ctx := context.WithValue(context.TODO(), config.ParameterUnitKey, pu) 202 rm, _ := NewRoutineManager(ctx, pu) 203 rm.SetSkipCheckUser(true) 204 205 wg := sync.WaitGroup{} 206 wg.Add(1) 207 208 //running server 209 go func() { 210 defer wg.Done() 211 echoServer(rm.Handler, rm, NewSqlCodec()) 212 }() 213 214 time.Sleep(time.Second * 2) 215 conn1, err = openDbConn(t, 6001) 216 require.NoError(t, err) 217 218 time.Sleep(time.Second * 2) 219 conn2, err = openDbConn(t, 6001) 220 require.NoError(t, err) 221 222 //get the connection id of conn1 223 var conn1Id uint64 224 connIdRow = conn1.QueryRow(sql1) 225 err = connIdRow.Scan(&conn1Id) 226 require.NoError(t, err) 227 228 //get the connection id of conn2 229 var conn2Id uint64 230 connIdRow = conn2.QueryRow(sql1) 231 err = connIdRow.Scan(&conn2Id) 232 require.NoError(t, err) 233 234 //conn2 kills the query 235 sql3 = fmt.Sprintf("kill query %d;", conn1Id) 236 noResultSet[sql3] = true 237 _, err = conn2.Exec(sql3) 238 require.NoError(t, err) 239 240 //conn2 kills the connection 1 241 sql2 = fmt.Sprintf("kill %d;", conn1Id) 242 noResultSet[sql2] = true 243 _, err = conn2.Exec(sql2) 244 require.NoError(t, err) 245 246 //conn2 kills itself 247 sql4 = fmt.Sprintf("kill %d;", conn2Id) 248 noResultSet[sql4] = true 249 _, err = conn2.Exec(sql4) 250 require.NoError(t, err) 251 252 //close the connection 253 closeDbConn(t, conn1) 254 closeDbConn(t, conn2) 255 256 time.Sleep(time.Millisecond * 10) 257 //close server 258 setServer(1) 259 wg.Wait() 260 } 261 262 func TestReadIntLenEnc(t *testing.T) { 263 var intEnc MysqlProtocolImpl 264 var data = make([]byte, 24) 265 var cases = [][]uint64{ 266 {0, 123, 250}, 267 {251, 10000, 1<<16 - 1}, 268 {1 << 16, 1<<16 + 10000, 1<<24 - 1}, 269 {1 << 24, 1<<24 + 10000, 1<<64 - 1}, 270 } 271 var caseLens = []int{1, 3, 4, 9} 272 for j := 0; j < len(cases); j++ { 273 for i := 0; i < len(cases[j]); i++ { 274 value := cases[j][i] 275 p1 := intEnc.writeIntLenEnc(data, 0, value) 276 val, p2, ok := intEnc.readIntLenEnc(data, 0) 277 if !ok || p1 != caseLens[j] || p1 != p2 || val != value { 278 t.Errorf("IntLenEnc %d failed.", value) 279 break 280 } 281 _, _, ok = intEnc.readIntLenEnc(data[0:caseLens[j]-1], 0) 282 if ok { 283 t.Errorf("read IntLenEnc failed.") 284 break 285 } 286 } 287 } 288 } 289 290 func TestReadCountOfBytes(t *testing.T) { 291 var client MysqlProtocolImpl 292 var data = make([]byte, 24) 293 var length = 10 294 for i := 0; i < length; i++ { 295 data[i] = byte(length - i) 296 } 297 298 r, pos, ok := client.readCountOfBytes(data, 0, length) 299 if !ok || pos != length { 300 t.Error("read bytes failed.") 301 return 302 } 303 304 for i := 0; i < length; i++ { 305 if r[i] != data[i] { 306 t.Error("read != write") 307 break 308 } 309 } 310 311 _, _, ok = client.readCountOfBytes(data, 0, 100) 312 if ok { 313 t.Error("read bytes failed.") 314 return 315 } 316 317 _, pos, ok = client.readCountOfBytes(data, 0, 0) 318 if !ok || pos != 0 { 319 t.Error("read bytes failed.") 320 return 321 } 322 } 323 324 func TestReadStringFix(t *testing.T) { 325 var client MysqlProtocolImpl 326 var data = make([]byte, 24) 327 var length = 10 328 var s = "haha, test read string fix function" 329 pos := client.writeStringFix(data, 0, s, length) 330 if pos != length { 331 t.Error("write string fix failed.") 332 return 333 } 334 var x string 335 var ok bool 336 337 x, pos, ok = client.readStringFix(data, 0, length) 338 if !ok || pos != length || x != s[0:length] { 339 t.Error("read string fix failed.") 340 return 341 } 342 var sLen = []int{ 343 length + 10, 344 length + 20, 345 length + 30, 346 } 347 for i := 0; i < len(sLen); i++ { 348 x, pos, ok = client.readStringFix(data, 0, sLen[i]) 349 if ok && pos == sLen[i] && x == s[0:sLen[i]] { 350 t.Error("read string fix failed.") 351 return 352 } 353 } 354 355 //empty string 356 pos = client.writeStringFix(data, 0, s, 0) 357 if pos != 0 { 358 t.Error("write string fix failed.") 359 return 360 } 361 362 x, pos, ok = client.readStringFix(data, 0, 0) 363 if !ok || pos != 0 || x != "" { 364 t.Error("read string fix failed.") 365 return 366 } 367 } 368 369 func TestReadStringNUL(t *testing.T) { 370 var client MysqlProtocolImpl 371 var data = make([]byte, 24) 372 var length = 10 373 var s = "haha, test read string fix function" 374 pos := client.writeStringNUL(data, 0, s[0:length]) 375 if pos != length+1 { 376 t.Error("write string NUL failed.") 377 return 378 } 379 var x string 380 var ok bool 381 382 x, pos, ok = client.readStringNUL(data, 0) 383 if !ok || pos != length+1 || x != s[0:length] { 384 t.Error("read string NUL failed.") 385 return 386 } 387 var sLen = []int{ 388 length + 10, 389 length + 20, 390 length + 30, 391 } 392 for i := 0; i < len(sLen); i++ { 393 x, pos, ok = client.readStringNUL(data, 0) 394 if ok && pos == sLen[i]+1 && x == s[0:sLen[i]] { 395 t.Error("read string NUL failed.") 396 return 397 } 398 } 399 } 400 401 func TestReadStringLenEnc(t *testing.T) { 402 var client MysqlProtocolImpl 403 var data = make([]byte, 24) 404 var length = 10 405 var s = "haha, test read string fix function" 406 pos := client.writeStringLenEnc(data, 0, s[0:length]) 407 if pos != length+1 { 408 t.Error("write string lenenc failed.") 409 return 410 } 411 var x string 412 var ok bool 413 414 x, pos, ok = client.readStringLenEnc(data, 0) 415 if !ok || pos != length+1 || x != s[0:length] { 416 t.Error("read string lenenc failed.") 417 return 418 } 419 420 //empty string 421 pos = client.writeStringLenEnc(data, 0, s[0:0]) 422 if pos != 1 { 423 t.Error("write string lenenc failed.") 424 return 425 } 426 427 x, pos, ok = client.readStringLenEnc(data, 0) 428 if !ok || pos != 1 || x != s[0:0] { 429 t.Error("read string lenenc failed.") 430 return 431 } 432 } 433 434 // can not run this test case in ubuntu+golang1.9, let's add an issue(#4656) for that, I will fixed in someday. 435 // func TestMysqlClientProtocol_TlsHandshake(t *testing.T) { 436 // //before anything using the configuration 437 // pu := config.NewParameterUnit(&config.FrontendParameters{}, nil, nil, nil) 438 // _, err := toml.DecodeFile("test/system_vars_config.toml", pu.SV) 439 // if err != nil { 440 // panic(err) 441 // } 442 // pu.SV.EnableTls = true 443 // ctx := context.WithValue(context.TODO(), config.ParameterUnitKey, pu) 444 // rm, _ := NewRoutineManager(ctx, pu) 445 // rm.SetSkipCheckUser(true) 446 447 // wg := sync.WaitGroup{} 448 // wg.Add(1) 449 450 // // //running server 451 // go func() { 452 // defer wg.Done() 453 // echoServer(rm.Handler, rm, NewSqlCodec()) 454 // }() 455 456 // // to := NewTimeout(1*time.Minute, false) 457 // // for isClosed() && !to.isTimeout() { 458 // // } 459 460 // time.Sleep(time.Second * 2) 461 // db := open_tls_db(t, 6001) 462 // closeDbConn(t, db) 463 464 // time.Sleep(time.Millisecond * 10) 465 // //close server 466 // setServer(1) 467 // wg.Wait() 468 // } 469 470 func makeMysqlTinyIntResultSet(unsigned bool) *MysqlResultSet { 471 var rs = &MysqlResultSet{} 472 473 name := "Tiny" 474 if unsigned { 475 name = name + "Uint" 476 } else { 477 name = name + "Int" 478 } 479 480 mysqlCol := new(MysqlColumn) 481 mysqlCol.SetName(name) 482 mysqlCol.SetOrgName(name + "OrgName") 483 mysqlCol.SetColumnType(defines.MYSQL_TYPE_TINY) 484 mysqlCol.SetSchema(name + "Schema") 485 mysqlCol.SetTable(name + "Table") 486 mysqlCol.SetOrgTable(name + "Table") 487 mysqlCol.SetCharset(uint16(Utf8mb4CollationID)) 488 mysqlCol.SetSigned(!unsigned) 489 490 rs.AddColumn(mysqlCol) 491 if unsigned { 492 var cases = []uint8{0, 1, 254, 255} 493 for _, v := range cases { 494 var data = make([]interface{}, 1) 495 data[0] = v 496 rs.AddRow(data) 497 } 498 } else { 499 var cases = []int8{-128, -127, 127} 500 for _, v := range cases { 501 var data = make([]interface{}, 1) 502 data[0] = v 503 rs.AddRow(data) 504 } 505 } 506 507 return rs 508 } 509 510 func makeMysqlTinyResult(unsigned bool) *MysqlExecutionResult { 511 return NewMysqlExecutionResult(0, 0, 0, 0, makeMysqlTinyIntResultSet(unsigned)) 512 } 513 514 func makeMysqlShortResultSet(unsigned bool) *MysqlResultSet { 515 var rs = &MysqlResultSet{} 516 517 name := "Short" 518 if unsigned { 519 name = name + "Uint" 520 } else { 521 name = name + "Int" 522 } 523 mysqlCol := new(MysqlColumn) 524 mysqlCol.SetName(name) 525 mysqlCol.SetOrgName(name + "OrgName") 526 mysqlCol.SetColumnType(defines.MYSQL_TYPE_SHORT) 527 mysqlCol.SetSchema(name + "Schema") 528 mysqlCol.SetTable(name + "Table") 529 mysqlCol.SetOrgTable(name + "Table") 530 mysqlCol.SetCharset(uint16(Utf8mb4CollationID)) 531 mysqlCol.SetSigned(!unsigned) 532 533 rs.AddColumn(mysqlCol) 534 if unsigned { 535 var cases = []uint16{0, 1, 254, 255, 65535} 536 for _, v := range cases { 537 var data = make([]interface{}, 1) 538 data[0] = v 539 rs.AddRow(data) 540 } 541 } else { 542 var cases = []int16{-32768, 0, 32767} 543 for _, v := range cases { 544 var data = make([]interface{}, 1) 545 data[0] = v 546 rs.AddRow(data) 547 } 548 } 549 550 return rs 551 } 552 553 func makeMysqlShortResult(unsigned bool) *MysqlExecutionResult { 554 return NewMysqlExecutionResult(0, 0, 0, 0, makeMysqlShortResultSet(unsigned)) 555 } 556 557 func makeMysqlLongResultSet(unsigned bool) *MysqlResultSet { 558 var rs = &MysqlResultSet{} 559 560 name := "Long" 561 if unsigned { 562 name = name + "Uint" 563 } else { 564 name = name + "Int" 565 } 566 mysqlCol := new(MysqlColumn) 567 mysqlCol.SetName(name) 568 mysqlCol.SetOrgName(name + "OrgName") 569 mysqlCol.SetColumnType(defines.MYSQL_TYPE_LONG) 570 mysqlCol.SetSchema(name + "Schema") 571 mysqlCol.SetTable(name + "Table") 572 mysqlCol.SetOrgTable(name + "Table") 573 mysqlCol.SetCharset(uint16(Utf8mb4CollationID)) 574 mysqlCol.SetSigned(!unsigned) 575 576 rs.AddColumn(mysqlCol) 577 if unsigned { 578 var cases = []uint32{0, 4294967295} 579 for _, v := range cases { 580 var data = make([]interface{}, 1) 581 data[0] = v 582 rs.AddRow(data) 583 } 584 } else { 585 var cases = []int32{-2147483648, 0, 2147483647} 586 for _, v := range cases { 587 var data = make([]interface{}, 1) 588 data[0] = v 589 rs.AddRow(data) 590 } 591 } 592 593 return rs 594 } 595 596 func makeMysqlLongResult(unsigned bool) *MysqlExecutionResult { 597 return NewMysqlExecutionResult(0, 0, 0, 0, makeMysqlLongResultSet(unsigned)) 598 } 599 600 func makeMysqlLongLongResultSet(unsigned bool) *MysqlResultSet { 601 var rs = &MysqlResultSet{} 602 603 name := "LongLong" 604 if unsigned { 605 name = name + "Uint" 606 } else { 607 name = name + "Int" 608 } 609 mysqlCol := new(MysqlColumn) 610 mysqlCol.SetName(name) 611 mysqlCol.SetOrgName(name + "OrgName") 612 mysqlCol.SetColumnType(defines.MYSQL_TYPE_LONGLONG) 613 mysqlCol.SetSchema(name + "Schema") 614 mysqlCol.SetTable(name + "Table") 615 mysqlCol.SetOrgTable(name + "Table") 616 mysqlCol.SetCharset(uint16(Utf8mb4CollationID)) 617 mysqlCol.SetSigned(!unsigned) 618 619 rs.AddColumn(mysqlCol) 620 if unsigned { 621 var cases = []uint64{0, 4294967295, 18446744073709551615} 622 for _, v := range cases { 623 var data = make([]interface{}, 1) 624 data[0] = v 625 rs.AddRow(data) 626 } 627 } else { 628 var cases = []int64{-9223372036854775808, 0, 9223372036854775807} 629 for _, v := range cases { 630 var data = make([]interface{}, 1) 631 data[0] = v 632 rs.AddRow(data) 633 } 634 } 635 636 return rs 637 } 638 639 func makeMysqlLongLongResult(unsigned bool) *MysqlExecutionResult { 640 return NewMysqlExecutionResult(0, 0, 0, 0, makeMysqlLongLongResultSet(unsigned)) 641 } 642 643 func makeMysqlInt24ResultSet(unsigned bool) *MysqlResultSet { 644 var rs = &MysqlResultSet{} 645 646 name := "Int24" 647 if unsigned { 648 name = name + "Uint" 649 } else { 650 name = name + "Int" 651 } 652 mysqlCol := new(MysqlColumn) 653 mysqlCol.SetName(name) 654 mysqlCol.SetOrgName(name + "OrgName") 655 mysqlCol.SetColumnType(defines.MYSQL_TYPE_INT24) 656 mysqlCol.SetSchema(name + "Schema") 657 mysqlCol.SetTable(name + "Table") 658 mysqlCol.SetOrgTable(name + "Table") 659 mysqlCol.SetCharset(uint16(Utf8mb4CollationID)) 660 mysqlCol.SetSigned(!unsigned) 661 662 rs.AddColumn(mysqlCol) 663 if unsigned { 664 //[0,16777215] 665 var cases = []uint32{0, 16777215, 4294967295} 666 for _, v := range cases { 667 var data = make([]interface{}, 1) 668 data[0] = v 669 rs.AddRow(data) 670 } 671 } else { 672 //[-8388608,8388607] 673 var cases = []int32{-2147483648, -8388608, 0, 8388607, 2147483647} 674 for _, v := range cases { 675 var data = make([]interface{}, 1) 676 data[0] = v 677 rs.AddRow(data) 678 } 679 } 680 681 return rs 682 } 683 684 func makeMysqlInt24Result(unsigned bool) *MysqlExecutionResult { 685 return NewMysqlExecutionResult(0, 0, 0, 0, makeMysqlInt24ResultSet(unsigned)) 686 } 687 688 func makeMysqlYearResultSet(unsigned bool) *MysqlResultSet { 689 var rs = &MysqlResultSet{} 690 691 name := "Year" 692 if unsigned { 693 name = name + "Uint" 694 } else { 695 name = name + "Int" 696 } 697 mysqlCol := new(MysqlColumn) 698 mysqlCol.SetName(name) 699 mysqlCol.SetOrgName(name + "OrgName") 700 mysqlCol.SetColumnType(defines.MYSQL_TYPE_YEAR) 701 mysqlCol.SetSchema(name + "Schema") 702 mysqlCol.SetTable(name + "Table") 703 mysqlCol.SetOrgTable(name + "Table") 704 mysqlCol.SetCharset(uint16(Utf8mb4CollationID)) 705 mysqlCol.SetSigned(!unsigned) 706 707 rs.AddColumn(mysqlCol) 708 if unsigned { 709 var cases = []uint16{0, 1, 254, 255, 65535} 710 for _, v := range cases { 711 var data = make([]interface{}, 1) 712 data[0] = v 713 rs.AddRow(data) 714 } 715 } else { 716 var cases = []int16{-32768, 0, 32767} 717 for _, v := range cases { 718 var data = make([]interface{}, 1) 719 data[0] = v 720 rs.AddRow(data) 721 } 722 } 723 724 return rs 725 } 726 727 func makeMysqlYearResult(unsigned bool) *MysqlExecutionResult { 728 return NewMysqlExecutionResult(0, 0, 0, 0, makeMysqlYearResultSet(unsigned)) 729 } 730 731 func makeMysqlVarcharResultSet() *MysqlResultSet { 732 var rs = &MysqlResultSet{} 733 734 name := "Varchar" 735 736 mysqlCol := new(MysqlColumn) 737 mysqlCol.SetName(name) 738 mysqlCol.SetOrgName(name + "OrgName") 739 mysqlCol.SetColumnType(defines.MYSQL_TYPE_VARCHAR) 740 mysqlCol.SetSchema(name + "Schema") 741 mysqlCol.SetTable(name + "Table") 742 mysqlCol.SetOrgTable(name + "Table") 743 mysqlCol.SetCharset(uint16(Utf8mb4CollationID)) 744 745 rs.AddColumn(mysqlCol) 746 747 var cases = []string{"abc", "abcde", "", "x-", "xx"} 748 for _, v := range cases { 749 var data = make([]interface{}, 1) 750 data[0] = v 751 rs.AddRow(data) 752 } 753 754 return rs 755 } 756 757 func makeMysqlVarcharResult() *MysqlExecutionResult { 758 return NewMysqlExecutionResult(0, 0, 0, 0, makeMysqlVarcharResultSet()) 759 } 760 761 func makeMysqlVarStringResultSet() *MysqlResultSet { 762 var rs = &MysqlResultSet{} 763 764 name := "Varstring" 765 766 mysqlCol := new(MysqlColumn) 767 mysqlCol.SetName(name) 768 mysqlCol.SetOrgName(name + "OrgName") 769 mysqlCol.SetColumnType(defines.MYSQL_TYPE_VAR_STRING) 770 mysqlCol.SetSchema(name + "Schema") 771 mysqlCol.SetTable(name + "Table") 772 mysqlCol.SetOrgTable(name + "Table") 773 mysqlCol.SetCharset(uint16(Utf8mb4CollationID)) 774 775 rs.AddColumn(mysqlCol) 776 777 var cases = []string{"abc", "abcde", "", "x-", "xx"} 778 for _, v := range cases { 779 var data = make([]interface{}, 1) 780 data[0] = v 781 rs.AddRow(data) 782 } 783 784 return rs 785 } 786 787 func makeMysqlVarStringResult() *MysqlExecutionResult { 788 return NewMysqlExecutionResult(0, 0, 0, 0, makeMysqlVarStringResultSet()) 789 } 790 791 func makeMysqlStringResultSet() *MysqlResultSet { 792 var rs = &MysqlResultSet{} 793 794 name := "String" 795 796 mysqlCol := new(MysqlColumn) 797 mysqlCol.SetName(name) 798 mysqlCol.SetOrgName(name + "OrgName") 799 mysqlCol.SetColumnType(defines.MYSQL_TYPE_STRING) 800 mysqlCol.SetSchema(name + "Schema") 801 mysqlCol.SetTable(name + "Table") 802 mysqlCol.SetOrgTable(name + "Table") 803 mysqlCol.SetCharset(uint16(Utf8mb4CollationID)) 804 805 rs.AddColumn(mysqlCol) 806 807 var cases = []string{"abc", "abcde", "", "x-", "xx"} 808 for _, v := range cases { 809 var data = make([]interface{}, 1) 810 data[0] = v 811 rs.AddRow(data) 812 } 813 814 return rs 815 } 816 817 func makeMysqlStringResult() *MysqlExecutionResult { 818 return NewMysqlExecutionResult(0, 0, 0, 0, makeMysqlStringResultSet()) 819 } 820 821 func makeMysqlFloatResultSet() *MysqlResultSet { 822 var rs = &MysqlResultSet{} 823 824 name := "Float" 825 826 mysqlCol := new(MysqlColumn) 827 mysqlCol.SetName(name) 828 mysqlCol.SetOrgName(name + "OrgName") 829 mysqlCol.SetColumnType(defines.MYSQL_TYPE_FLOAT) 830 mysqlCol.SetSchema(name + "Schema") 831 mysqlCol.SetTable(name + "Table") 832 mysqlCol.SetOrgTable(name + "Table") 833 mysqlCol.SetCharset(uint16(Utf8mb4CollationID)) 834 835 rs.AddColumn(mysqlCol) 836 837 var cases = []float32{math.MaxFloat32, math.SmallestNonzeroFloat32, -math.MaxFloat32, -math.SmallestNonzeroFloat32} 838 for _, v := range cases { 839 var data = make([]interface{}, 1) 840 data[0] = v 841 rs.AddRow(data) 842 } 843 844 return rs 845 } 846 847 func makeMysqlFloatResult() *MysqlExecutionResult { 848 return NewMysqlExecutionResult(0, 0, 0, 0, makeMysqlFloatResultSet()) 849 } 850 851 func makeMysqlDoubleResultSet() *MysqlResultSet { 852 var rs = &MysqlResultSet{} 853 854 name := "Double" 855 856 mysqlCol := new(MysqlColumn) 857 mysqlCol.SetName(name) 858 mysqlCol.SetOrgName(name + "OrgName") 859 mysqlCol.SetColumnType(defines.MYSQL_TYPE_DOUBLE) 860 mysqlCol.SetSchema(name + "Schema") 861 mysqlCol.SetTable(name + "Table") 862 mysqlCol.SetOrgTable(name + "Table") 863 mysqlCol.SetCharset(uint16(Utf8mb4CollationID)) 864 865 rs.AddColumn(mysqlCol) 866 867 var cases = []float64{math.MaxFloat64, math.SmallestNonzeroFloat64, -math.MaxFloat64, -math.SmallestNonzeroFloat64} 868 for _, v := range cases { 869 var data = make([]interface{}, 1) 870 data[0] = v 871 rs.AddRow(data) 872 } 873 874 return rs 875 } 876 877 func makeMysqlDoubleResult() *MysqlExecutionResult { 878 return NewMysqlExecutionResult(0, 0, 0, 0, makeMysqlDoubleResultSet()) 879 } 880 881 func makeMysqlDateResultSet() *MysqlResultSet { 882 var rs = &MysqlResultSet{} 883 884 name := "Date" 885 886 mysqlCol := new(MysqlColumn) 887 mysqlCol.SetName(name) 888 mysqlCol.SetOrgName(name + "OrgName") 889 mysqlCol.SetColumnType(defines.MYSQL_TYPE_DATE) 890 mysqlCol.SetSchema(name + "Schema") 891 mysqlCol.SetTable(name + "Table") 892 mysqlCol.SetOrgTable(name + "Table") 893 mysqlCol.SetCharset(uint16(Utf8mb4CollationID)) 894 895 rs.AddColumn(mysqlCol) 896 897 d1, _ := types.ParseDateCast("1997-01-01") 898 d2, _ := types.ParseDateCast("2008-02-02") 899 var cases = []types.Date{ 900 d1, 901 d2, 902 } 903 for _, v := range cases { 904 var data = make([]interface{}, 1) 905 data[0] = v 906 rs.AddRow(data) 907 } 908 909 return rs 910 } 911 912 func makeMysqlDateResult() *MysqlExecutionResult { 913 return NewMysqlExecutionResult(0, 0, 0, 0, makeMysqlDateResultSet()) 914 } 915 916 func makeMysqlTimeResultSet() *MysqlResultSet { 917 var rs = &MysqlResultSet{} 918 919 name := "Time" 920 921 mysqlCol := new(MysqlColumn) 922 mysqlCol.SetName(name) 923 mysqlCol.SetOrgName(name + "OrgName") 924 mysqlCol.SetColumnType(defines.MYSQL_TYPE_TIME) 925 mysqlCol.SetSchema(name + "Schema") 926 mysqlCol.SetTable(name + "Table") 927 mysqlCol.SetOrgTable(name + "Table") 928 mysqlCol.SetCharset(uint16(Utf8mb4CollationID)) 929 930 rs.AddColumn(mysqlCol) 931 932 t1, _ := types.ParseTime("110:21:15", 0) 933 t2, _ := types.ParseTime("2018-04-28 10:21:15.123", 0) 934 t3, _ := types.ParseTime("-112:12:12", 0) 935 var cases = []types.Time{ 936 t1, 937 t2, 938 t3, 939 } 940 for _, v := range cases { 941 var data = make([]interface{}, 1) 942 data[0] = v 943 rs.AddRow(data) 944 } 945 946 return rs 947 } 948 949 func makeMysqlTimeResult() *MysqlExecutionResult { 950 return NewMysqlExecutionResult(0, 0, 0, 0, makeMysqlTimeResultSet()) 951 } 952 953 func makeMysqlDatetimeResultSet() *MysqlResultSet { 954 var rs = &MysqlResultSet{} 955 956 name := "Date" 957 958 mysqlCol := new(MysqlColumn) 959 mysqlCol.SetName(name) 960 mysqlCol.SetOrgName(name + "OrgName") 961 mysqlCol.SetColumnType(defines.MYSQL_TYPE_DATETIME) 962 mysqlCol.SetSchema(name + "Schema") 963 mysqlCol.SetTable(name + "Table") 964 mysqlCol.SetOrgTable(name + "Table") 965 mysqlCol.SetCharset(uint16(Utf8mb4CollationID)) 966 967 rs.AddColumn(mysqlCol) 968 969 d1, _ := types.ParseDatetime("2018-04-28 10:21:15", 0) 970 d2, _ := types.ParseDatetime("2018-04-28 10:21:15.123", 0) 971 d3, _ := types.ParseDatetime("2015-03-03 12:12:12", 0) 972 var cases = []types.Datetime{ 973 d1, 974 d2, 975 d3, 976 } 977 for _, v := range cases { 978 var data = make([]interface{}, 1) 979 data[0] = v 980 rs.AddRow(data) 981 } 982 983 return rs 984 } 985 986 func makeMysqlDatetimeResult() *MysqlExecutionResult { 987 return NewMysqlExecutionResult(0, 0, 0, 0, makeMysqlDatetimeResultSet()) 988 } 989 990 func make9ColumnsResultSet() *MysqlResultSet { 991 var rs = &MysqlResultSet{} 992 993 var columnTypes = []defines.MysqlType{ 994 defines.MYSQL_TYPE_TINY, 995 defines.MYSQL_TYPE_SHORT, 996 defines.MYSQL_TYPE_LONG, 997 defines.MYSQL_TYPE_LONGLONG, 998 defines.MYSQL_TYPE_VARCHAR, 999 defines.MYSQL_TYPE_FLOAT, 1000 defines.MYSQL_TYPE_DATE, 1001 defines.MYSQL_TYPE_TIME, 1002 defines.MYSQL_TYPE_DATETIME, 1003 defines.MYSQL_TYPE_DOUBLE, 1004 } 1005 1006 var names = []string{ 1007 "Tiny", 1008 "Short", 1009 "Long", 1010 "Longlong", 1011 "Varchar", 1012 "Float", 1013 "Date", 1014 "Time", 1015 "Datetime", 1016 "Double", 1017 } 1018 1019 d1, _ := types.ParseDateCast("1997-01-01") 1020 d2, _ := types.ParseDateCast("2008-02-02") 1021 1022 dt1, _ := types.ParseDatetime("2018-04-28 10:21:15", 0) 1023 dt2, _ := types.ParseDatetime("2018-04-28 10:21:15.123", 0) 1024 dt3, _ := types.ParseDatetime("2015-03-03 12:12:12", 0) 1025 1026 t1, _ := types.ParseTime("2018-04-28 10:21:15", 0) 1027 t2, _ := types.ParseTime("2018-04-28 10:21:15.123", 0) 1028 t3, _ := types.ParseTime("2015-03-03 12:12:12", 0) 1029 1030 var cases = [][]interface{}{ 1031 {int8(-128), int16(-32768), int32(-2147483648), int64(-9223372036854775808), "abc", float32(math.MaxFloat32), d1, t1, dt1, float64(0.01)}, 1032 {int8(-127), int16(0), int32(0), int64(0), "abcde", float32(math.SmallestNonzeroFloat32), d2, t2, dt2, float64(0.01)}, 1033 {int8(127), int16(32767), int32(2147483647), int64(9223372036854775807), "", float32(-math.MaxFloat32), d1, t3, dt3, float64(0.01)}, 1034 {int8(126), int16(32766), int32(2147483646), int64(9223372036854775806), "x-", float32(-math.SmallestNonzeroFloat32), d2, t1, dt1, float64(0.01)}, 1035 } 1036 1037 for i, ct := range columnTypes { 1038 name := names[i] 1039 mysqlCol := new(MysqlColumn) 1040 mysqlCol.SetName(name) 1041 mysqlCol.SetOrgName(name + "OrgName") 1042 mysqlCol.SetColumnType(ct) 1043 mysqlCol.SetSchema(name + "Schema") 1044 mysqlCol.SetTable(name + "Table") 1045 mysqlCol.SetOrgTable(name + "Table") 1046 mysqlCol.SetCharset(uint16(Utf8mb4CollationID)) 1047 1048 rs.AddColumn(mysqlCol) 1049 } 1050 1051 for _, v := range cases { 1052 rs.AddRow(v) 1053 } 1054 1055 return rs 1056 } 1057 1058 func makeMysql9ColumnsResult() *MysqlExecutionResult { 1059 return NewMysqlExecutionResult(0, 0, 0, 0, make9ColumnsResultSet()) 1060 } 1061 1062 func makeMoreThan16MBResultSet() *MysqlResultSet { 1063 var rs = &MysqlResultSet{} 1064 1065 var columnTypes = []defines.MysqlType{ 1066 defines.MYSQL_TYPE_LONGLONG, 1067 defines.MYSQL_TYPE_DOUBLE, 1068 defines.MYSQL_TYPE_VARCHAR, 1069 } 1070 1071 var names = []string{ 1072 "Longlong", 1073 "Double", 1074 "Varchar", 1075 } 1076 1077 var rowCase = []interface{}{int64(9223372036854775807), math.MaxFloat64, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"} 1078 1079 for i, ct := range columnTypes { 1080 name := names[i] 1081 mysqlCol := new(MysqlColumn) 1082 mysqlCol.SetName(name) 1083 mysqlCol.SetOrgName(name + "OrgName") 1084 mysqlCol.SetColumnType(ct) 1085 mysqlCol.SetSchema(name + "Schema") 1086 mysqlCol.SetTable(name + "Table") 1087 mysqlCol.SetOrgTable(name + "Table") 1088 mysqlCol.SetCharset(uint16(Utf8mb4CollationID)) 1089 1090 rs.AddColumn(mysqlCol) 1091 } 1092 1093 //the size of the total result set will be more than 16MB 1094 for i := 0; i < 40000; i++ { 1095 rs.AddRow(rowCase) 1096 } 1097 1098 return rs 1099 } 1100 1101 // the size of resultset will be morethan 16MB 1102 func makeMoreThan16MBResult() *MysqlExecutionResult { 1103 return NewMysqlExecutionResult(0, 0, 0, 0, makeMoreThan16MBResultSet()) 1104 } 1105 1106 func make16MBRowResultSet() *MysqlResultSet { 1107 var rs = &MysqlResultSet{} 1108 1109 name := "Varstring" 1110 1111 mysqlCol := new(MysqlColumn) 1112 mysqlCol.SetName(name) 1113 mysqlCol.SetOrgName(name + "OrgName") 1114 mysqlCol.SetColumnType(defines.MYSQL_TYPE_VAR_STRING) 1115 mysqlCol.SetSchema(name + "Schema") 1116 mysqlCol.SetTable(name + "Table") 1117 mysqlCol.SetOrgTable(name + "Table") 1118 mysqlCol.SetCharset(uint16(Utf8mb4CollationID)) 1119 1120 rs.AddColumn(mysqlCol) 1121 1122 /* 1123 How to test the max size of the data in one packet that the client can received ? 1124 Environment: Mysql Version 8.0.23 1125 1. shell: mysql --help | grep allowed-packet 1126 something like: 1127 " 1128 --max-allowed-packet=# 1129 max-allowed-packet 16777216 1130 " 1131 so, we get: 1132 max-allowed-packet means : The maximum packet length to send to or receive from server. 1133 default value : 16777216 (16MB) 1134 2. shell execution: mysql -uroot -e "select repeat('a',16*1024*1024-4);" > 16MB-mysql.txt 1135 we get: ERROR 2020 (HY000) at line 1: Got packet bigger than 'max_allowed_packet' bytes 1136 3. shell execution: mysql -uroot -e "select repeat('a',16*1024*1024-5);" > 16MB-mysql.txt 1137 execution succeeded 1138 4. so, the max size of the data in one packet is (max-allowed-packet - 5). 1139 5. To change max-allowed-packet. 1140 shell execution: mysql max-allowed-packet=xxxxx .... 1141 */ 1142 1143 //test in shell : mysql -h 127.0.0.1 -P 6001 -udump -p111 -e "16mbrow" > 16mbrow.txt 1144 //max data size : 16 * 1024 * 1024 - 5 1145 var stuff = make([]byte, 16*1024*1024-5) 1146 for i := range stuff { 1147 stuff[i] = 'a' 1148 } 1149 1150 var rowCase = []interface{}{string(stuff)} 1151 for i := 0; i < 1; i++ { 1152 rs.AddRow(rowCase) 1153 } 1154 1155 return rs 1156 } 1157 1158 // the size of resultset row will be more than 16MB 1159 func make16MBRowResult() *MysqlExecutionResult { 1160 return NewMysqlExecutionResult(0, 0, 0, 0, make16MBRowResultSet()) 1161 } 1162 1163 func (tRM *TestRoutineManager) resultsetHandler(rs goetty.IOSession, msg interface{}, _ uint64) error { 1164 tRM.rwlock.Lock() 1165 routine := tRM.clients[rs] 1166 tRM.rwlock.Unlock() 1167 ctx := context.TODO() 1168 1169 pu, err := getParameterUnit("test/system_vars_config.toml", nil, nil) 1170 if err != nil { 1171 return err 1172 } 1173 1174 pro := routine.getProtocol().(*MysqlProtocolImpl) 1175 packet, ok := msg.(*Packet) 1176 pro.SetSequenceID(uint8(packet.SequenceID + 1)) 1177 if !ok { 1178 return moerr.NewInternalError(ctx, "message is not Packet") 1179 } 1180 1181 ses := NewSession(pro, nil, pu, nil, false) 1182 ses.SetRequestContext(ctx) 1183 pro.SetSession(ses) 1184 1185 length := packet.Length 1186 payload := packet.Payload 1187 for uint32(length) == MaxPayloadSize { 1188 var err error 1189 msg, err = pro.GetTcpConnection().Read(goetty.ReadOptions{}) 1190 if err != nil { 1191 return moerr.NewInternalError(ctx, "read msg error") 1192 } 1193 1194 packet, ok = msg.(*Packet) 1195 if !ok { 1196 return moerr.NewInternalError(ctx, "message is not Packet") 1197 } 1198 1199 pro.SetSequenceID(uint8(packet.SequenceID + 1)) 1200 payload = append(payload, packet.Payload...) 1201 length = packet.Length 1202 } 1203 1204 // finish handshake process 1205 if !pro.IsEstablished() { 1206 _, err := pro.HandleHandshake(ctx, payload) 1207 if err != nil { 1208 return err 1209 } 1210 pro.SetEstablished() 1211 return nil 1212 } 1213 1214 var req *Request 1215 var resp *Response 1216 req = pro.GetRequest(payload) 1217 switch req.GetCmd() { 1218 case COM_QUIT: 1219 resp = &Response{ 1220 category: OkResponse, 1221 status: 0, 1222 data: nil, 1223 } 1224 if err := pro.SendResponse(ctx, resp); err != nil { 1225 fmt.Printf("send response failed. error:%v", err) 1226 break 1227 } 1228 case COM_QUERY: 1229 var query = string(req.GetData().([]byte)) 1230 1231 switch query { 1232 case "tiny": 1233 resp = &Response{ 1234 category: ResultResponse, 1235 status: 0, 1236 cmd: 0, 1237 data: makeMysqlTinyResult(false), 1238 } 1239 case "tinyu": 1240 resp = &Response{ 1241 category: ResultResponse, 1242 status: 0, 1243 data: makeMysqlTinyResult(true), 1244 } 1245 case "short": 1246 resp = &Response{ 1247 category: ResultResponse, 1248 status: 0, 1249 data: makeMysqlShortResult(false), 1250 } 1251 case "shortu": 1252 resp = &Response{ 1253 category: ResultResponse, 1254 status: 0, 1255 data: makeMysqlShortResult(true), 1256 } 1257 case "long": 1258 resp = &Response{ 1259 category: ResultResponse, 1260 status: 0, 1261 data: makeMysqlLongResult(false), 1262 } 1263 case "longu": 1264 resp = &Response{ 1265 category: ResultResponse, 1266 status: 0, 1267 data: makeMysqlLongResult(true), 1268 } 1269 case "longlong": 1270 resp = &Response{ 1271 category: ResultResponse, 1272 status: 0, 1273 data: makeMysqlLongLongResult(false), 1274 } 1275 case "longlongu": 1276 resp = &Response{ 1277 category: ResultResponse, 1278 status: 0, 1279 data: makeMysqlLongLongResult(true), 1280 } 1281 case "int24": 1282 resp = &Response{ 1283 category: ResultResponse, 1284 status: 0, 1285 data: makeMysqlInt24Result(false), 1286 } 1287 case "int24u": 1288 resp = &Response{ 1289 category: ResultResponse, 1290 status: 0, 1291 data: makeMysqlInt24Result(true), 1292 } 1293 case "year": 1294 resp = &Response{ 1295 category: ResultResponse, 1296 status: 0, 1297 data: makeMysqlYearResult(false), 1298 } 1299 case "yearu": 1300 resp = &Response{ 1301 category: ResultResponse, 1302 status: 0, 1303 data: makeMysqlYearResult(true), 1304 } 1305 case "varchar": 1306 resp = &Response{ 1307 category: ResultResponse, 1308 status: 0, 1309 data: makeMysqlVarcharResult(), 1310 } 1311 case "varstring": 1312 resp = &Response{ 1313 category: ResultResponse, 1314 status: 0, 1315 data: makeMysqlVarStringResult(), 1316 } 1317 case "string": 1318 resp = &Response{ 1319 category: ResultResponse, 1320 status: 0, 1321 data: makeMysqlStringResult(), 1322 } 1323 case "float": 1324 resp = &Response{ 1325 category: ResultResponse, 1326 status: 0, 1327 data: makeMysqlFloatResult(), 1328 } 1329 case "double": 1330 resp = &Response{ 1331 category: ResultResponse, 1332 status: 0, 1333 data: makeMysqlDoubleResult(), 1334 } 1335 case "date": 1336 resp = &Response{ 1337 category: ResultResponse, 1338 status: 0, 1339 data: makeMysqlDateResult(), 1340 } 1341 case "time": 1342 resp = &Response{ 1343 category: ResultResponse, 1344 status: 0, 1345 data: makeMysqlTimeResult(), 1346 } 1347 case "datetime": 1348 resp = &Response{ 1349 category: ResultResponse, 1350 status: 0, 1351 data: makeMysqlDatetimeResult(), 1352 } 1353 case "9columns": 1354 resp = &Response{ 1355 category: ResultResponse, 1356 status: 0, 1357 data: makeMysql9ColumnsResult(), 1358 } 1359 case "16mb": 1360 resp = &Response{ 1361 category: ResultResponse, 1362 status: 0, 1363 data: makeMoreThan16MBResult(), 1364 } 1365 case "16mbrow": 1366 resp = &Response{ 1367 category: ResultResponse, 1368 status: 0, 1369 data: make16MBRowResult(), 1370 } 1371 default: 1372 resp = &Response{ 1373 category: OkResponse, 1374 status: 0, 1375 data: nil, 1376 } 1377 } 1378 1379 if err := pro.SendResponse(ctx, resp); err != nil { 1380 fmt.Printf("send response failed. error:%v", err) 1381 break 1382 } 1383 case COM_PING: 1384 resp = NewResponse( 1385 OkResponse, 1386 0, 1387 int(COM_PING), 1388 nil, 1389 ) 1390 if err := pro.SendResponse(ctx, resp); err != nil { 1391 fmt.Printf("send response failed. error:%v", err) 1392 break 1393 } 1394 1395 default: 1396 fmt.Printf("unsupported command. 0x%x \n", req.cmd) 1397 } 1398 if req.cmd == COM_QUIT { 1399 return nil 1400 } 1401 return nil 1402 } 1403 1404 func TestMysqlResultSet(t *testing.T) { 1405 //client connection method: mysql -h 127.0.0.1 -P 6001 -udump -p 1406 //pwd: mysql-server-mysql-8.0.23/mysql-test 1407 //with mysqltest: mysqltest --test-file=t/1st.test --result-file=r/1st.result --user=dump -p111 -P 6001 --host=127.0.0.1 1408 1409 //test: 1410 //./mysql-test-run 1st --extern user=root --extern port=3306 --extern host=127.0.0.1 1411 // mysql5.7 failed 1412 // mysql-8.0.23 success 1413 //./mysql-test-run 1st --extern user=root --extern port=6001 --extern host=127.0.0.1 1414 // matrixone failed: mysql-test-run: *** ERROR: Could not connect to extern server using command: '/Users/pengzhen/Documents/mysql-server-mysql-8.0.23/bld/runtime_output_directory//mysql --no-defaults --user=root --user=root --port=6001 --host=127.0.0.1 --silent --database=mysql --execute="SHOW GLOBAL VARIABLES"' 1415 pu := config.NewParameterUnit(&config.FrontendParameters{}, nil, nil, nil, nil) 1416 _, err := toml.DecodeFile("test/system_vars_config.toml", pu.SV) 1417 if err != nil { 1418 panic(err) 1419 } 1420 1421 trm := NewTestRoutineManager(pu) 1422 1423 wg := sync.WaitGroup{} 1424 wg.Add(1) 1425 1426 go func() { 1427 defer wg.Done() 1428 echoServer(trm.resultsetHandler, trm, NewSqlCodec()) 1429 }() 1430 1431 // to := NewTimeout(1*time.Minute, false) 1432 // for isClosed() && !to.isTimeout() { 1433 // } 1434 1435 time.Sleep(time.Second * 2) 1436 db, err := openDbConn(t, 6001) 1437 require.NoError(t, err) 1438 1439 do_query_resp_resultset(t, db, false, false, "tiny", makeMysqlTinyIntResultSet(false)) 1440 do_query_resp_resultset(t, db, false, false, "tinyu", makeMysqlTinyIntResultSet(true)) 1441 do_query_resp_resultset(t, db, false, false, "short", makeMysqlShortResultSet(false)) 1442 do_query_resp_resultset(t, db, false, false, "shortu", makeMysqlShortResultSet(true)) 1443 do_query_resp_resultset(t, db, false, false, "long", makeMysqlLongResultSet(false)) 1444 do_query_resp_resultset(t, db, false, false, "longu", makeMysqlLongResultSet(true)) 1445 do_query_resp_resultset(t, db, false, false, "longlong", makeMysqlLongLongResultSet(false)) 1446 do_query_resp_resultset(t, db, false, false, "longlongu", makeMysqlLongLongResultSet(true)) 1447 do_query_resp_resultset(t, db, false, false, "int24", makeMysqlInt24ResultSet(false)) 1448 do_query_resp_resultset(t, db, false, false, "int24u", makeMysqlInt24ResultSet(true)) 1449 do_query_resp_resultset(t, db, false, false, "year", makeMysqlYearResultSet(false)) 1450 do_query_resp_resultset(t, db, false, false, "yearu", makeMysqlYearResultSet(true)) 1451 do_query_resp_resultset(t, db, false, false, "varchar", makeMysqlVarcharResultSet()) 1452 do_query_resp_resultset(t, db, false, false, "varstring", makeMysqlVarStringResultSet()) 1453 do_query_resp_resultset(t, db, false, false, "string", makeMysqlStringResultSet()) 1454 do_query_resp_resultset(t, db, false, false, "float", makeMysqlFloatResultSet()) 1455 do_query_resp_resultset(t, db, false, false, "double", makeMysqlDoubleResultSet()) 1456 do_query_resp_resultset(t, db, false, false, "date", makeMysqlDateResultSet()) 1457 do_query_resp_resultset(t, db, false, false, "time", makeMysqlTimeResultSet()) 1458 do_query_resp_resultset(t, db, false, false, "datetime", makeMysqlDatetimeResultSet()) 1459 do_query_resp_resultset(t, db, false, false, "9columns", make9ColumnsResultSet()) 1460 do_query_resp_resultset(t, db, false, false, "16mbrow", make16MBRowResultSet()) 1461 do_query_resp_resultset(t, db, false, false, "16mb", makeMoreThan16MBResultSet()) 1462 1463 closeDbConn(t, db) 1464 1465 time.Sleep(time.Millisecond * 10) 1466 //close server 1467 setServer(1) 1468 wg.Wait() 1469 } 1470 1471 // func open_tls_db(t *testing.T, port int) *sql.DB { 1472 // tlsName := "custom" 1473 // rootCertPool := x509.NewCertPool() 1474 // pem, err := os.ReadFile("test/ca.pem") 1475 // if err != nil { 1476 // setServer(1) 1477 // require.NoError(t, err) 1478 // } 1479 // if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { 1480 // log.Fatal("Failed to append PEM.") 1481 // } 1482 // clientCert := make([]tls.Certificate, 0, 1) 1483 // certs, err := tls.LoadX509KeyPair("test/client-cert2.pem", "test/client-key2.pem") 1484 // if err != nil { 1485 // setServer(1) 1486 // require.NoError(t, err) 1487 // } 1488 // clientCert = append(clientCert, certs) 1489 // err = mysqlDriver.RegisterTLSConfig(tlsName, &tls.Config{ 1490 // RootCAs: rootCertPool, 1491 // Certificates: clientCert, 1492 // MinVersion: tls.VersionTLS12, 1493 // InsecureSkipVerify: true, 1494 // }) 1495 // if err != nil { 1496 // setServer(1) 1497 // require.NoError(t, err) 1498 // } 1499 1500 // dsn := fmt.Sprintf("dump:111@tcp(127.0.0.1:%d)/?readTimeout=5s&timeout=5s&writeTimeout=5s&tls=%s", port, tlsName) 1501 // db, err := sql.Open("mysql", dsn) 1502 // if err != nil { 1503 // require.NoError(t, err) 1504 // } else { 1505 // db.SetConnMaxLifetime(time.Minute * 3) 1506 // db.SetMaxOpenConns(1) 1507 // db.SetMaxIdleConns(1) 1508 // time.Sleep(time.Millisecond * 100) 1509 1510 // // ping opens the connection 1511 // logutil.Info("start ping") 1512 // err = db.Ping() 1513 // if err != nil { 1514 // setServer(1) 1515 // require.NoError(t, err) 1516 // } 1517 // } 1518 // return db 1519 // } 1520 1521 func openDbConn(t *testing.T, port int) (*sql.DB, error) { 1522 dsn := fmt.Sprintf("dump:111@tcp(127.0.0.1:%d)/?readTimeout=10s&timeout=10s&writeTimeout=10s", port) 1523 db, err := sql.Open("mysql", dsn) 1524 if err != nil { 1525 return nil, err 1526 } else { 1527 db.SetConnMaxLifetime(time.Minute * 3) 1528 db.SetMaxOpenConns(1) 1529 db.SetMaxIdleConns(1) 1530 time.Sleep(time.Millisecond * 100) 1531 1532 //ping opens the connection 1533 err = db.Ping() 1534 if err != nil { 1535 return nil, err 1536 } 1537 } 1538 return db, err 1539 } 1540 1541 func closeDbConn(t *testing.T, db *sql.DB) { 1542 err := db.Close() 1543 require.NoError(t, err) 1544 } 1545 1546 func do_query_resp_resultset(t *testing.T, db *sql.DB, wantErr bool, skipResultsetCheck bool, query string, mrs *MysqlResultSet) { 1547 rows, err := db.Query(query) 1548 if wantErr { 1549 require.Error(t, err) 1550 require.True(t, rows == nil) 1551 return 1552 } 1553 require.NoError(t, err) 1554 1555 //column check 1556 columns, err := rows.Columns() 1557 require.NoError(t, err) 1558 require.True(t, len(columns) == len(mrs.Columns)) 1559 1560 //colType, err := rows.ColumnTypes() 1561 //require.NoError(t, err) 1562 //for i, ct := range colType { 1563 // fmt.Printf("column %d\n",i) 1564 // fmt.Printf("name %v \n",ct.Name()) 1565 // l,o := ct.Length() 1566 // fmt.Printf("length %v %v \n",l,o) 1567 // p,s,o := ct.DecimalSize() 1568 // fmt.Printf("decimalsize %v %v %v \n",p,s,o) 1569 // fmt.Printf("scantype %v \n",ct.ScanType()) 1570 // n,o := ct.Nullable() 1571 // fmt.Printf("nullable %v %v \n",n,o) 1572 // fmt.Printf("databaseTypeName %s \n",ct.DatabaseTypeName()) 1573 //} 1574 1575 values := make([][]byte, len(columns)) 1576 1577 // rows.Scan wants '[]interface{}' as an argument, so we must copy the 1578 // references into such a slice 1579 // See http://code.google.com/p/go-wiki/wiki/InterfaceSlice for details 1580 scanArgs := make([]interface{}, len(columns)) 1581 for i := uint64(0); i < mrs.GetColumnCount(); i++ { 1582 scanArgs[i] = &values[i] 1583 } 1584 1585 rowIdx := uint64(0) 1586 for rows.Next() { 1587 err = rows.Scan(scanArgs...) 1588 require.NoError(t, err) 1589 1590 //fmt.Println(rowIdx) 1591 //fmt.Println(mrs.GetRow(rowIdx)) 1592 // 1593 //for i := uint64(0); i < mrs.GetColumnCount(); i++ { 1594 // arg := scanArgs[i] 1595 // val := *(arg.(*[]byte)) 1596 // fmt.Printf("%v ",val) 1597 //} 1598 //fmt.Println() 1599 1600 if !skipResultsetCheck { 1601 for i := uint64(0); i < mrs.GetColumnCount(); i++ { 1602 arg := scanArgs[i] 1603 val := *(arg.(*[]byte)) 1604 1605 column, err := mrs.GetColumn(context.TODO(), i) 1606 require.NoError(t, err) 1607 1608 col, ok := column.(*MysqlColumn) 1609 require.True(t, ok) 1610 1611 isNUll, err := mrs.ColumnIsNull(context.TODO(), rowIdx, i) 1612 require.NoError(t, err) 1613 1614 if isNUll { 1615 require.True(t, val == nil) 1616 } else { 1617 var data []byte = nil 1618 switch col.ColumnType() { 1619 case defines.MYSQL_TYPE_TINY, defines.MYSQL_TYPE_SHORT, defines.MYSQL_TYPE_INT24, defines.MYSQL_TYPE_LONG, defines.MYSQL_TYPE_YEAR: 1620 value, err := mrs.GetInt64(context.TODO(), rowIdx, i) 1621 require.NoError(t, err) 1622 if col.ColumnType() == defines.MYSQL_TYPE_YEAR { 1623 if value == 0 { 1624 data = append(data, []byte("0000")...) 1625 } else { 1626 data = strconv.AppendInt(data, value, 10) 1627 } 1628 } else { 1629 data = strconv.AppendInt(data, value, 10) 1630 } 1631 1632 case defines.MYSQL_TYPE_LONGLONG: 1633 if uint32(col.Flag())&defines.UNSIGNED_FLAG != 0 { 1634 value, err := mrs.GetUint64(context.TODO(), rowIdx, i) 1635 require.NoError(t, err) 1636 data = strconv.AppendUint(data, value, 10) 1637 } else { 1638 value, err := mrs.GetInt64(context.TODO(), rowIdx, i) 1639 require.NoError(t, err) 1640 data = strconv.AppendInt(data, value, 10) 1641 } 1642 case defines.MYSQL_TYPE_VARCHAR, defines.MYSQL_TYPE_VAR_STRING, defines.MYSQL_TYPE_STRING: 1643 value, err := mrs.GetString(context.TODO(), rowIdx, i) 1644 require.NoError(t, err) 1645 data = []byte(value) 1646 case defines.MYSQL_TYPE_FLOAT: 1647 value, err := mrs.GetFloat64(context.TODO(), rowIdx, i) 1648 require.NoError(t, err) 1649 data = strconv.AppendFloat(data, value, 'f', -1, 32) 1650 case defines.MYSQL_TYPE_DOUBLE: 1651 value, err := mrs.GetFloat64(context.TODO(), rowIdx, i) 1652 require.NoError(t, err) 1653 data = strconv.AppendFloat(data, value, 'f', -1, 64) 1654 case defines.MYSQL_TYPE_DATE: 1655 value, err := mrs.GetValue(context.TODO(), rowIdx, i) 1656 require.NoError(t, err) 1657 x := value.(types.Date).String() 1658 data = []byte(x) 1659 case defines.MYSQL_TYPE_TIME: 1660 value, err := mrs.GetValue(context.TODO(), rowIdx, i) 1661 require.NoError(t, err) 1662 x := value.(types.Time).String() 1663 data = []byte(x) 1664 case defines.MYSQL_TYPE_DATETIME: 1665 value, err := mrs.GetValue(context.TODO(), rowIdx, i) 1666 require.NoError(t, err) 1667 x := value.(types.Datetime).String() 1668 data = []byte(x) 1669 default: 1670 require.NoError(t, moerr.NewInternalError(context.TODO(), "unsupported type %v", col.ColumnType())) 1671 } 1672 //check 1673 ret := reflect.DeepEqual(data, val) 1674 //fmt.Println(i) 1675 //fmt.Println(data) 1676 //fmt.Println(val) 1677 require.True(t, ret) 1678 } 1679 } 1680 } 1681 1682 rowIdx++ 1683 } 1684 1685 require.True(t, rowIdx == mrs.GetRowCount()) 1686 1687 err = rows.Err() 1688 require.NoError(t, err) 1689 } 1690 1691 func Test_writePackets(t *testing.T) { 1692 ctx := context.TODO() 1693 convey.Convey("writepackets 16MB succ", t, func() { 1694 ctrl := gomock.NewController(t) 1695 defer ctrl.Finish() 1696 ioses := mock_frontend.NewMockIOSession(ctrl) 1697 ioses.EXPECT().Write(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() 1698 ioses.EXPECT().RemoteAddress().Return("").AnyTimes() 1699 ioses.EXPECT().Ref().AnyTimes() 1700 sv, err := getSystemVariables("test/system_vars_config.toml") 1701 if err != nil { 1702 t.Error(err) 1703 } 1704 1705 proto := NewMysqlClientProtocol(0, ioses, 1024, sv) 1706 err = proto.writePackets(make([]byte, MaxPayloadSize)) 1707 convey.So(err, convey.ShouldBeNil) 1708 }) 1709 convey.Convey("writepackets 16MB failed", t, func() { 1710 ctrl := gomock.NewController(t) 1711 defer ctrl.Finish() 1712 ioses := mock_frontend.NewMockIOSession(ctrl) 1713 1714 cnt := 0 1715 ioses.EXPECT().Write(gomock.Any(), gomock.Any()).DoAndReturn(func(msg interface{}, opts goetty.WriteOptions) error { 1716 if cnt == 0 { 1717 cnt++ 1718 return nil 1719 } else { 1720 cnt++ 1721 return moerr.NewInternalError(ctx, "write and flush failed.") 1722 } 1723 }).AnyTimes() 1724 ioses.EXPECT().RemoteAddress().Return("").AnyTimes() 1725 ioses.EXPECT().Ref().AnyTimes() 1726 sv, err := getSystemVariables("test/system_vars_config.toml") 1727 if err != nil { 1728 t.Error(err) 1729 } 1730 1731 proto := NewMysqlClientProtocol(0, ioses, 1024, sv) 1732 err = proto.writePackets(make([]byte, MaxPayloadSize)) 1733 convey.So(err, convey.ShouldBeError) 1734 }) 1735 1736 convey.Convey("writepackets 16MB failed 2", t, func() { 1737 ctrl := gomock.NewController(t) 1738 defer ctrl.Finish() 1739 ioses := mock_frontend.NewMockIOSession(ctrl) 1740 1741 ioses.EXPECT().Write(gomock.Any(), gomock.Any()).DoAndReturn(func(msg interface{}, opts goetty.WriteOptions) error { 1742 return moerr.NewInternalError(ctx, "write and flush failed.") 1743 }).AnyTimes() 1744 ioses.EXPECT().RemoteAddress().Return("").AnyTimes() 1745 ioses.EXPECT().Ref().AnyTimes() 1746 sv, err := getSystemVariables("test/system_vars_config.toml") 1747 if err != nil { 1748 t.Error(err) 1749 } 1750 1751 proto := NewMysqlClientProtocol(0, ioses, 1024, sv) 1752 err = proto.writePackets(make([]byte, MaxPayloadSize)) 1753 convey.So(err, convey.ShouldBeError) 1754 }) 1755 } 1756 1757 func Test_openpacket(t *testing.T) { 1758 convey.Convey("openpacket succ", t, func() { 1759 ctrl := gomock.NewController(t) 1760 defer ctrl.Finish() 1761 ioses := mock_frontend.NewMockIOSession(ctrl) 1762 1763 ioses.EXPECT().OutBuf().Return(buf.NewByteBuf(1024)).AnyTimes() 1764 ioses.EXPECT().RemoteAddress().Return("").AnyTimes() 1765 ioses.EXPECT().Ref().AnyTimes() 1766 sv, err := getSystemVariables("test/system_vars_config.toml") 1767 if err != nil { 1768 t.Error(err) 1769 } 1770 1771 proto := NewMysqlClientProtocol(0, ioses, 1024, sv) 1772 1773 err = proto.openPacket() 1774 convey.So(err, convey.ShouldBeNil) 1775 headLen := proto.tcpConn.OutBuf().GetWriteIndex() - proto.beginWriteIndex 1776 convey.So(headLen, convey.ShouldEqual, HeaderLengthOfTheProtocol) 1777 }) 1778 1779 convey.Convey("fillpacket succ", t, func() { 1780 ctrl := gomock.NewController(t) 1781 defer ctrl.Finish() 1782 ioses := mock_frontend.NewMockIOSession(ctrl) 1783 1784 ioses.EXPECT().OutBuf().Return(buf.NewByteBuf(1024)).AnyTimes() 1785 ioses.EXPECT().Flush(gomock.Any()).Return(nil).AnyTimes() 1786 ioses.EXPECT().RemoteAddress().Return("").AnyTimes() 1787 ioses.EXPECT().Ref().AnyTimes() 1788 pu, err := getParameterUnit("test/system_vars_config.toml", nil, nil) 1789 if err != nil { 1790 t.Error(err) 1791 } 1792 1793 proto := NewMysqlClientProtocol(0, ioses, 1024, pu.SV) 1794 // fill proto.ses 1795 ses := NewSession(proto, nil, pu, nil, false) 1796 ses.SetRequestContext(context.TODO()) 1797 proto.ses = ses 1798 1799 err = proto.fillPacket(make([]byte, MaxPayloadSize)...) 1800 convey.So(err, convey.ShouldBeNil) 1801 1802 err = proto.closePacket(true) 1803 convey.So(err, convey.ShouldBeNil) 1804 1805 proto.append(nil, make([]byte, 1024)...) 1806 }) 1807 1808 convey.Convey("closepacket falied.", t, func() { 1809 ctrl := gomock.NewController(t) 1810 defer ctrl.Finish() 1811 ioses := mock_frontend.NewMockIOSession(ctrl) 1812 1813 ioses.EXPECT().OutBuf().Return(buf.NewByteBuf(1024)).AnyTimes() 1814 ioses.EXPECT().RemoteAddress().Return("").AnyTimes() 1815 ioses.EXPECT().Ref().AnyTimes() 1816 pu, err := getParameterUnit("test/system_vars_config.toml", nil, nil) 1817 if err != nil { 1818 t.Error(err) 1819 } 1820 1821 proto := NewMysqlClientProtocol(0, ioses, 1024, pu.SV) 1822 // fill proto.ses 1823 ses := NewSession(proto, nil, pu, nil, false) 1824 ses.SetRequestContext(context.TODO()) 1825 proto.ses = ses 1826 1827 err = proto.openPacket() 1828 convey.So(err, convey.ShouldBeNil) 1829 1830 proto.beginWriteIndex = proto.tcpConn.OutBuf().GetWriteIndex() 1831 err = proto.closePacket(true) 1832 convey.So(err, convey.ShouldBeError) 1833 }) 1834 1835 convey.Convey("append -- data checks", t, func() { 1836 ctrl := gomock.NewController(t) 1837 defer ctrl.Finish() 1838 ioses := mock_frontend.NewMockIOSession(ctrl) 1839 1840 ioses.EXPECT().OutBuf().Return(buf.NewByteBuf(1024)).AnyTimes() 1841 ioses.EXPECT().Flush(gomock.Any()).Return(nil).AnyTimes() 1842 ioses.EXPECT().RemoteAddress().Return("").AnyTimes() 1843 ioses.EXPECT().Ref().AnyTimes() 1844 sv, err := getSystemVariables("test/system_vars_config.toml") 1845 if err != nil { 1846 t.Error(err) 1847 } 1848 1849 proto := NewMysqlClientProtocol(0, ioses, 1024, sv) 1850 1851 mysqlPack := func(payload []byte) []byte { 1852 n := len(payload) 1853 var curLen int 1854 var header [4]byte 1855 var data []byte = nil 1856 var sequenceId byte = 0 1857 for i := 0; i < n; i += curLen { 1858 curLen = Min(int(MaxPayloadSize), n-i) 1859 binary.LittleEndian.PutUint32(header[:], uint32(curLen)) 1860 header[3] = sequenceId 1861 sequenceId++ 1862 data = append(data, header[:]...) 1863 data = append(data, payload[i:i+curLen]...) 1864 if i+curLen == n && curLen == int(MaxPayloadSize) { 1865 binary.LittleEndian.PutUint32(header[:], uint32(0)) 1866 header[3] = sequenceId 1867 sequenceId++ 1868 data = append(data, header[:]...) 1869 } 1870 } 1871 return data 1872 } 1873 1874 data16MB := func(cnt int) []byte { 1875 data := make([]byte, cnt*int(MaxPayloadSize)) 1876 return data 1877 } 1878 1879 type kase struct { 1880 data []byte 1881 len int 1882 } 1883 1884 kases := []kase{ 1885 { 1886 data: []byte{1, 2, 3, 4}, 1887 len: HeaderLengthOfTheProtocol + 4, 1888 }, 1889 { 1890 data: data16MB(1), 1891 len: HeaderLengthOfTheProtocol + int(MaxPayloadSize) + HeaderLengthOfTheProtocol, 1892 }, 1893 { 1894 data: data16MB(2), 1895 len: HeaderLengthOfTheProtocol + int(MaxPayloadSize) + HeaderLengthOfTheProtocol + int(MaxPayloadSize) + HeaderLengthOfTheProtocol, 1896 }, 1897 { 1898 data: data16MB(3), 1899 len: HeaderLengthOfTheProtocol + int(MaxPayloadSize) + 1900 HeaderLengthOfTheProtocol + int(MaxPayloadSize) + 1901 HeaderLengthOfTheProtocol + int(MaxPayloadSize) + 1902 HeaderLengthOfTheProtocol, 1903 }, 1904 { 1905 data: data16MB(4), 1906 len: HeaderLengthOfTheProtocol + int(MaxPayloadSize) + 1907 HeaderLengthOfTheProtocol + int(MaxPayloadSize) + 1908 HeaderLengthOfTheProtocol + int(MaxPayloadSize) + 1909 HeaderLengthOfTheProtocol + int(MaxPayloadSize) + 1910 HeaderLengthOfTheProtocol, 1911 }, 1912 } 1913 1914 for _, c := range kases { 1915 proto.SetSequenceID(0) 1916 1917 err = proto.openRow(nil) 1918 convey.So(err, convey.ShouldBeNil) 1919 beginIdx := proto.beginWriteIndex 1920 1921 rawBuf := proto.append(nil, c.data...) 1922 1923 err = proto.closeRow(nil) 1924 convey.So(err, convey.ShouldBeNil) 1925 1926 want := mysqlPack(c.data) 1927 1928 convey.So(c.len, convey.ShouldEqual, len(want)) 1929 1930 buf := proto.tcpConn.OutBuf() 1931 widx := buf.GetWriteIndex() 1932 res := rawBuf[beginIdx:widx] 1933 1934 convey.So(bytes.Equal(res, want), convey.ShouldBeTrue) 1935 } 1936 }) 1937 } 1938 1939 func TestSendPrepareResponse(t *testing.T) { 1940 ctx := context.TODO() 1941 convey.Convey("send Prepare response succ", t, func() { 1942 ctrl := gomock.NewController(t) 1943 defer ctrl.Finish() 1944 ioses := mock_frontend.NewMockIOSession(ctrl) 1945 1946 ioses.EXPECT().OutBuf().Return(buf.NewByteBuf(1024)).AnyTimes() 1947 ioses.EXPECT().Write(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() 1948 ioses.EXPECT().RemoteAddress().Return("").AnyTimes() 1949 ioses.EXPECT().Ref().AnyTimes() 1950 sv, err := getSystemVariables("test/system_vars_config.toml") 1951 if err != nil { 1952 t.Error(err) 1953 } 1954 1955 proto := NewMysqlClientProtocol(0, ioses, 1024, sv) 1956 1957 st := tree.NewPrepareString(tree.Identifier(getPrepareStmtName(1)), "select ?, 1") 1958 stmts, err := mysql.Parse(ctx, st.Sql) 1959 if err != nil { 1960 t.Error(err) 1961 } 1962 compCtx := plan.NewEmptyCompilerContext() 1963 preparePlan, err := buildPlan(context.TODO(), nil, compCtx, st) 1964 if err != nil { 1965 t.Error(err) 1966 } 1967 prepareStmt := &PrepareStmt{ 1968 Name: preparePlan.GetDcl().GetPrepare().GetName(), 1969 PreparePlan: preparePlan, 1970 PrepareStmt: stmts[0], 1971 } 1972 err = proto.SendPrepareResponse(ctx, prepareStmt) 1973 1974 convey.So(err, convey.ShouldBeNil) 1975 }) 1976 1977 convey.Convey("send Prepare response error", t, func() { 1978 ctrl := gomock.NewController(t) 1979 defer ctrl.Finish() 1980 ioses := mock_frontend.NewMockIOSession(ctrl) 1981 1982 ioses.EXPECT().OutBuf().Return(buf.NewByteBuf(1024)).AnyTimes() 1983 ioses.EXPECT().Write(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() 1984 ioses.EXPECT().RemoteAddress().Return("").AnyTimes() 1985 ioses.EXPECT().Ref().AnyTimes() 1986 sv, err := getSystemVariables("test/system_vars_config.toml") 1987 if err != nil { 1988 t.Error(err) 1989 } 1990 1991 proto := NewMysqlClientProtocol(0, ioses, 1024, sv) 1992 1993 st := tree.NewPrepareString("stmt1", "select ?, 1") 1994 stmts, err := mysql.Parse(ctx, st.Sql) 1995 if err != nil { 1996 t.Error(err) 1997 } 1998 compCtx := plan.NewEmptyCompilerContext() 1999 preparePlan, err := buildPlan(context.TODO(), nil, compCtx, st) 2000 if err != nil { 2001 t.Error(err) 2002 } 2003 prepareStmt := &PrepareStmt{ 2004 Name: preparePlan.GetDcl().GetPrepare().GetName(), 2005 PreparePlan: preparePlan, 2006 PrepareStmt: stmts[0], 2007 } 2008 err = proto.SendPrepareResponse(ctx, prepareStmt) 2009 2010 convey.So(err, convey.ShouldBeError) 2011 }) 2012 } 2013 2014 func FuzzParseExecuteData(f *testing.F) { 2015 ctx := context.TODO() 2016 2017 ctrl := gomock.NewController(f) 2018 defer ctrl.Finish() 2019 ioses := mock_frontend.NewMockIOSession(ctrl) 2020 2021 ioses.EXPECT().OutBuf().Return(buf.NewByteBuf(1024)).AnyTimes() 2022 ioses.EXPECT().Write(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() 2023 ioses.EXPECT().RemoteAddress().Return("").AnyTimes() 2024 ioses.EXPECT().Ref().AnyTimes() 2025 sv, err := getSystemVariables("test/system_vars_config.toml") 2026 if err != nil { 2027 f.Error(err) 2028 } 2029 2030 proto := NewMysqlClientProtocol(0, ioses, 1024, sv) 2031 2032 st := tree.NewPrepareString(tree.Identifier(getPrepareStmtName(1)), "select ?, 1") 2033 stmts, err := mysql.Parse(ctx, st.Sql) 2034 if err != nil { 2035 f.Error(err) 2036 } 2037 compCtx := plan.NewEmptyCompilerContext() 2038 preparePlan, err := buildPlan(context.TODO(), nil, compCtx, st) 2039 if err != nil { 2040 f.Error(err) 2041 } 2042 prepareStmt := &PrepareStmt{ 2043 Name: preparePlan.GetDcl().GetPrepare().GetName(), 2044 PreparePlan: preparePlan, 2045 PrepareStmt: stmts[0], 2046 } 2047 2048 var testData []byte 2049 testData = append(testData, 0) //flag 2050 testData = append(testData, 0, 0, 0, 0) // skip iteration-count 2051 nullBitmapLen := (1 + 7) >> 3 2052 //nullBitmapLen 2053 for i := 0; i < nullBitmapLen; i++ { 2054 testData = append(testData, 0) 2055 } 2056 testData = append(testData, 1) // new param bound flag 2057 testData = append(testData, uint8(defines.MYSQL_TYPE_TINY)) // type 2058 testData = append(testData, 0) //is unsigned 2059 testData = append(testData, 10) //tiny value 2060 2061 f.Add(testData) 2062 2063 testData = []byte{} 2064 testData = append(testData, 0) //flag 2065 testData = append(testData, 0, 0, 0, 0) // skip iteration-count 2066 nullBitmapLen = (1 + 7) >> 3 2067 //nullBitmapLen 2068 for i := 0; i < nullBitmapLen; i++ { 2069 testData = append(testData, 0) 2070 } 2071 testData = append(testData, 1) // new param bound flag 2072 testData = append(testData, uint8(defines.MYSQL_TYPE_TINY)) // type 2073 testData = append(testData, 0) //is unsigned 2074 testData = append(testData, 4) //tiny value 2075 f.Add(testData) 2076 2077 f.Fuzz(func(t *testing.T, data []byte) { 2078 proto.ParseExecuteData(ctx, prepareStmt, data, 0) 2079 }) 2080 } 2081 2082 func TestParseExecuteData(t *testing.T) { 2083 ctx := context.TODO() 2084 convey.Convey("parseExecuteData succ", t, func() { 2085 ctrl := gomock.NewController(t) 2086 defer ctrl.Finish() 2087 ioses := mock_frontend.NewMockIOSession(ctrl) 2088 2089 ioses.EXPECT().OutBuf().Return(buf.NewByteBuf(1024)).AnyTimes() 2090 ioses.EXPECT().Write(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() 2091 ioses.EXPECT().RemoteAddress().Return("").AnyTimes() 2092 ioses.EXPECT().Ref().AnyTimes() 2093 sv, err := getSystemVariables("test/system_vars_config.toml") 2094 if err != nil { 2095 t.Error(err) 2096 } 2097 2098 proto := NewMysqlClientProtocol(0, ioses, 1024, sv) 2099 2100 st := tree.NewPrepareString(tree.Identifier(getPrepareStmtName(1)), "select ?, 1") 2101 stmts, err := mysql.Parse(ctx, st.Sql) 2102 if err != nil { 2103 t.Error(err) 2104 } 2105 compCtx := plan.NewEmptyCompilerContext() 2106 preparePlan, err := buildPlan(context.TODO(), nil, compCtx, st) 2107 if err != nil { 2108 t.Error(err) 2109 } 2110 prepareStmt := &PrepareStmt{ 2111 Name: preparePlan.GetDcl().GetPrepare().GetName(), 2112 PreparePlan: preparePlan, 2113 PrepareStmt: stmts[0], 2114 } 2115 2116 var testData []byte 2117 testData = append(testData, 0) //flag 2118 testData = append(testData, 0, 0, 0, 0) // skip iteration-count 2119 nullBitmapLen := (1 + 7) >> 3 2120 //nullBitmapLen 2121 for i := 0; i < nullBitmapLen; i++ { 2122 testData = append(testData, 0) 2123 } 2124 testData = append(testData, 1) // new param bound flag 2125 testData = append(testData, uint8(defines.MYSQL_TYPE_TINY)) // type 2126 testData = append(testData, 0) //is unsigned 2127 testData = append(testData, 10) //tiny value 2128 2129 names, vars, err := proto.ParseExecuteData(ctx, prepareStmt, testData, 0) 2130 convey.So(err, convey.ShouldBeNil) 2131 convey.ShouldEqual(len(names), 1) 2132 convey.ShouldEqual(len(vars), 1) 2133 convey.ShouldEqual(vars[0], 10) 2134 }) 2135 2136 } 2137 2138 func Test_resultset(t *testing.T) { 2139 ctx := context.TODO() 2140 convey.Convey("send result set batch row succ", t, func() { 2141 ctrl := gomock.NewController(t) 2142 defer ctrl.Finish() 2143 ioses := mock_frontend.NewMockIOSession(ctrl) 2144 2145 ioses.EXPECT().OutBuf().Return(buf.NewByteBuf(1024)).AnyTimes() 2146 ioses.EXPECT().Write(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() 2147 ioses.EXPECT().RemoteAddress().Return("").AnyTimes() 2148 ioses.EXPECT().Ref().AnyTimes() 2149 sv, err := getSystemVariables("test/system_vars_config.toml") 2150 if err != nil { 2151 t.Error(err) 2152 } 2153 2154 proto := NewMysqlClientProtocol(0, ioses, 1024, sv) 2155 eng := mock_frontend.NewMockEngine(ctrl) 2156 txnClient := mock_frontend.NewMockTxnClient(ctrl) 2157 pu, err := getParameterUnit("test/system_vars_config.toml", eng, txnClient) 2158 if err != nil { 2159 t.Error(err) 2160 } 2161 var gSys GlobalSystemVariables 2162 InitGlobalSystemVariables(&gSys) 2163 ses := NewSession(proto, nil, pu, &gSys, false) 2164 ses.SetRequestContext(ctx) 2165 proto.ses = ses 2166 2167 res := make9ColumnsResultSet() 2168 2169 err = proto.SendResultSetTextBatchRow(res, uint64(len(res.Data))) 2170 convey.So(err, convey.ShouldBeNil) 2171 }) 2172 2173 convey.Convey("send result set batch row speedup succ", t, func() { 2174 ctrl := gomock.NewController(t) 2175 defer ctrl.Finish() 2176 ioses := mock_frontend.NewMockIOSession(ctrl) 2177 2178 ioses.EXPECT().OutBuf().Return(buf.NewByteBuf(1024)).AnyTimes() 2179 ioses.EXPECT().Write(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() 2180 ioses.EXPECT().RemoteAddress().Return("").AnyTimes() 2181 ioses.EXPECT().Ref().AnyTimes() 2182 sv, err := getSystemVariables("test/system_vars_config.toml") 2183 if err != nil { 2184 t.Error(err) 2185 } 2186 2187 proto := NewMysqlClientProtocol(0, ioses, 1024, sv) 2188 eng := mock_frontend.NewMockEngine(ctrl) 2189 txnClient := mock_frontend.NewMockTxnClient(ctrl) 2190 pu, err := getParameterUnit("test/system_vars_config.toml", eng, txnClient) 2191 if err != nil { 2192 t.Error(err) 2193 } 2194 var gSys GlobalSystemVariables 2195 InitGlobalSystemVariables(&gSys) 2196 ses := NewSession(proto, nil, pu, &gSys, false) 2197 ses.SetRequestContext(ctx) 2198 proto.ses = ses 2199 2200 res := make9ColumnsResultSet() 2201 2202 err = proto.SendResultSetTextBatchRowSpeedup(res, uint64(len(res.Data))) 2203 convey.So(err, convey.ShouldBeNil) 2204 }) 2205 2206 convey.Convey("send result set succ", t, func() { 2207 ctrl := gomock.NewController(t) 2208 defer ctrl.Finish() 2209 ioses := mock_frontend.NewMockIOSession(ctrl) 2210 2211 ioses.EXPECT().OutBuf().Return(buf.NewByteBuf(1024)).AnyTimes() 2212 ioses.EXPECT().Write(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() 2213 ioses.EXPECT().RemoteAddress().Return("").AnyTimes() 2214 ioses.EXPECT().Ref().AnyTimes() 2215 sv, err := getSystemVariables("test/system_vars_config.toml") 2216 if err != nil { 2217 t.Error(err) 2218 } 2219 2220 proto := NewMysqlClientProtocol(0, ioses, 1024, sv) 2221 eng := mock_frontend.NewMockEngine(ctrl) 2222 txnClient := mock_frontend.NewMockTxnClient(ctrl) 2223 pu, err := getParameterUnit("test/system_vars_config.toml", eng, txnClient) 2224 if err != nil { 2225 t.Error(err) 2226 } 2227 var gSys GlobalSystemVariables 2228 InitGlobalSystemVariables(&gSys) 2229 ses := NewSession(proto, nil, pu, &gSys, false) 2230 ses.SetRequestContext(ctx) 2231 proto.ses = ses 2232 2233 res := make9ColumnsResultSet() 2234 2235 err = proto.sendResultSet(ctx, res, int(COM_QUERY), 0, 0) 2236 convey.So(err, convey.ShouldBeNil) 2237 2238 err = proto.SendResultSetTextRow(res, 0) 2239 convey.So(err, convey.ShouldBeNil) 2240 }) 2241 2242 convey.Convey("send binary result set succ", t, func() { 2243 ctrl := gomock.NewController(t) 2244 defer ctrl.Finish() 2245 ioses := mock_frontend.NewMockIOSession(ctrl) 2246 2247 ioses.EXPECT().OutBuf().Return(buf.NewByteBuf(1024)).AnyTimes() 2248 ioses.EXPECT().Write(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() 2249 ioses.EXPECT().RemoteAddress().Return("").AnyTimes() 2250 ioses.EXPECT().Ref().AnyTimes() 2251 sv, err := getSystemVariables("test/system_vars_config.toml") 2252 if err != nil { 2253 t.Error(err) 2254 } 2255 2256 proto := NewMysqlClientProtocol(0, ioses, 1024, sv) 2257 eng := mock_frontend.NewMockEngine(ctrl) 2258 txnClient := mock_frontend.NewMockTxnClient(ctrl) 2259 pu, err := getParameterUnit("test/system_vars_config.toml", eng, txnClient) 2260 if err != nil { 2261 t.Error(err) 2262 } 2263 var gSys GlobalSystemVariables 2264 InitGlobalSystemVariables(&gSys) 2265 ses := NewSession(proto, nil, pu, &gSys, false) 2266 ses.SetRequestContext(ctx) 2267 ses.cmd = COM_STMT_EXECUTE 2268 proto.ses = ses 2269 2270 res := make9ColumnsResultSet() 2271 2272 err = proto.SendResultSetTextBatchRowSpeedup(res, 0) 2273 convey.So(err, convey.ShouldBeNil) 2274 }) 2275 } 2276 2277 func Test_send_packet(t *testing.T) { 2278 convey.Convey("send err packet", t, func() { 2279 ctrl := gomock.NewController(t) 2280 defer ctrl.Finish() 2281 ioses := mock_frontend.NewMockIOSession(ctrl) 2282 2283 ioses.EXPECT().OutBuf().Return(buf.NewByteBuf(1024)).AnyTimes() 2284 ioses.EXPECT().Write(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() 2285 ioses.EXPECT().RemoteAddress().Return("").AnyTimes() 2286 ioses.EXPECT().Ref().AnyTimes() 2287 sv, err := getSystemVariables("test/system_vars_config.toml") 2288 if err != nil { 2289 t.Error(err) 2290 } 2291 2292 proto := NewMysqlClientProtocol(0, ioses, 1024, sv) 2293 2294 err = proto.sendErrPacket(1, "fake state", "fake error") 2295 convey.So(err, convey.ShouldBeNil) 2296 }) 2297 convey.Convey("send eof packet", t, func() { 2298 ctrl := gomock.NewController(t) 2299 defer ctrl.Finish() 2300 ioses := mock_frontend.NewMockIOSession(ctrl) 2301 2302 ioses.EXPECT().OutBuf().Return(buf.NewByteBuf(1024)).AnyTimes() 2303 ioses.EXPECT().Write(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() 2304 ioses.EXPECT().RemoteAddress().Return("").AnyTimes() 2305 ioses.EXPECT().Ref().AnyTimes() 2306 sv, err := getSystemVariables("test/system_vars_config.toml") 2307 if err != nil { 2308 t.Error(err) 2309 } 2310 2311 proto := NewMysqlClientProtocol(0, ioses, 1024, sv) 2312 2313 err = proto.sendEOFPacket(1, 0) 2314 convey.So(err, convey.ShouldBeNil) 2315 2316 err = proto.SendEOFPacketIf(1, 0) 2317 convey.So(err, convey.ShouldBeNil) 2318 2319 err = proto.sendEOFOrOkPacket(1, 0) 2320 convey.So(err, convey.ShouldBeNil) 2321 }) 2322 } 2323 2324 func Test_analyse320resp(t *testing.T) { 2325 convey.Convey("analyse 320 resp succ", t, func() { 2326 ctrl := gomock.NewController(t) 2327 defer ctrl.Finish() 2328 ioses := mock_frontend.NewMockIOSession(ctrl) 2329 ioses.EXPECT().RemoteAddress().Return("").AnyTimes() 2330 ioses.EXPECT().Ref().AnyTimes() 2331 sv, err := getSystemVariables("test/system_vars_config.toml") 2332 if err != nil { 2333 t.Error(err) 2334 } 2335 2336 proto := NewMysqlClientProtocol(0, ioses, 1024, sv) 2337 2338 var data []byte = nil 2339 var cap uint16 = 0 2340 cap |= uint16(CLIENT_CONNECT_WITH_DB) 2341 var header [2]byte 2342 proto.io.WriteUint16(header[:], 0, cap) 2343 //int<2> capabilities flags, CLIENT_PROTOCOL_41 never set 2344 data = append(data, header[:]...) 2345 //int<3> max-packet size 2346 data = append(data, 0xff, 0xff, 0xff) 2347 //string[NUL] username 2348 username := "abc" 2349 data = append(data, []byte(username)...) 2350 data = append(data, 0x0) 2351 //auth response 2352 authResp := []byte{0x1, 0x2, 0x3, 0x4} 2353 data = append(data, authResp...) 2354 data = append(data, 0x0) 2355 //database 2356 dbName := "T" 2357 data = append(data, []byte(dbName)...) 2358 data = append(data, 0x0) 2359 2360 ok, resp320, err := proto.analyseHandshakeResponse320(context.TODO(), data) 2361 convey.So(err, convey.ShouldBeNil) 2362 convey.So(ok, convey.ShouldBeTrue) 2363 2364 convey.So(resp320.username, convey.ShouldEqual, username) 2365 convey.So(bytes.Equal(resp320.authResponse, authResp), convey.ShouldBeTrue) 2366 convey.So(resp320.database, convey.ShouldEqual, dbName) 2367 }) 2368 2369 convey.Convey("analyse 320 resp failed", t, func() { 2370 ctrl := gomock.NewController(t) 2371 defer ctrl.Finish() 2372 ioses := mock_frontend.NewMockIOSession(ctrl) 2373 ioses.EXPECT().RemoteAddress().Return("").AnyTimes() 2374 ioses.EXPECT().Ref().AnyTimes() 2375 sv, err := getSystemVariables("test/system_vars_config.toml") 2376 if err != nil { 2377 t.Error(err) 2378 } 2379 2380 proto := NewMysqlClientProtocol(0, ioses, 1024, sv) 2381 2382 type kase struct { 2383 data []byte 2384 res bool 2385 } 2386 2387 kases := []kase{ 2388 {data: []byte{0}, res: false}, 2389 {data: []byte{0, 0, 0, 0}, res: false}, 2390 {data: []byte{0, 0, 1, 2, 3}, res: false}, 2391 {data: []byte{0, 0, 1, 2, 3, 'a', 0}, res: true}, 2392 {data: []byte{0, 0, 1, 2, 3, 'a', 0, 1, 2, 3}, res: true}, 2393 {data: []byte{uint8(CLIENT_CONNECT_WITH_DB), 0, 1, 2, 3, 'a', 0}, res: false}, 2394 {data: []byte{uint8(CLIENT_CONNECT_WITH_DB), 0, 1, 2, 3, 'a', 0, 'b', 'c'}, res: false}, 2395 {data: []byte{uint8(CLIENT_CONNECT_WITH_DB), 0, 1, 2, 3, 'a', 0, 'b', 'c', 0}, res: false}, 2396 {data: []byte{uint8(CLIENT_CONNECT_WITH_DB), 0, 1, 2, 3, 'a', 0, 'b', 'c', 0, 'd', 'e'}, res: false}, 2397 {data: []byte{uint8(CLIENT_CONNECT_WITH_DB), 0, 1, 2, 3, 'a', 0, 'b', 'c', 0, 'd', 'e', 0}, res: true}, 2398 } 2399 2400 for _, c := range kases { 2401 ok, _, _ := proto.analyseHandshakeResponse320(context.TODO(), c.data) 2402 convey.So(ok, convey.ShouldEqual, c.res) 2403 } 2404 }) 2405 } 2406 2407 func Test_analyse41resp(t *testing.T) { 2408 convey.Convey("analyse 41 resp succ", t, func() { 2409 ctrl := gomock.NewController(t) 2410 defer ctrl.Finish() 2411 ioses := mock_frontend.NewMockIOSession(ctrl) 2412 ioses.EXPECT().RemoteAddress().Return("").AnyTimes() 2413 ioses.EXPECT().Ref().AnyTimes() 2414 sv, err := getSystemVariables("test/system_vars_config.toml") 2415 if err != nil { 2416 t.Error(err) 2417 } 2418 2419 proto := NewMysqlClientProtocol(0, ioses, 1024, sv) 2420 2421 var data []byte = nil 2422 var cap uint32 = 0 2423 cap |= CLIENT_PROTOCOL_41 | CLIENT_CONNECT_WITH_DB 2424 var header [4]byte 2425 proto.io.WriteUint32(header[:], 0, cap) 2426 //int<4> capabilities flags of the client, CLIENT_PROTOCOL_41 always set 2427 data = append(data, header[:]...) 2428 //int<4> max-packet size 2429 data = append(data, 0xff, 0xff, 0xff, 0xff) 2430 //int<1> character set 2431 data = append(data, 0x1) 2432 //string[23] reserved (all [0]) 2433 data = append(data, make([]byte, 23)...) 2434 //string[NUL] username 2435 username := "abc" 2436 data = append(data, []byte(username)...) 2437 data = append(data, 0x0) 2438 //auth response 2439 authResp := []byte{0x1, 0x2, 0x3, 0x4} 2440 data = append(data, authResp...) 2441 data = append(data, 0x0) 2442 //database 2443 dbName := "T" 2444 data = append(data, []byte(dbName)...) 2445 data = append(data, 0x0) 2446 2447 ok, resp41, err := proto.analyseHandshakeResponse41(context.TODO(), data) 2448 convey.So(err, convey.ShouldBeNil) 2449 convey.So(ok, convey.ShouldBeTrue) 2450 2451 convey.So(resp41.username, convey.ShouldEqual, username) 2452 convey.So(bytes.Equal(resp41.authResponse, authResp), convey.ShouldBeTrue) 2453 convey.So(resp41.database, convey.ShouldEqual, dbName) 2454 }) 2455 2456 convey.Convey("analyse 41 resp failed", t, func() { 2457 ctrl := gomock.NewController(t) 2458 defer ctrl.Finish() 2459 ioses := mock_frontend.NewMockIOSession(ctrl) 2460 ioses.EXPECT().Write(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() 2461 ioses.EXPECT().Read(gomock.Any()).Return(new(Packet), nil).AnyTimes() 2462 ioses.EXPECT().RemoteAddress().Return("").AnyTimes() 2463 ioses.EXPECT().Ref().AnyTimes() 2464 sv, err := getSystemVariables("test/system_vars_config.toml") 2465 if err != nil { 2466 t.Error(err) 2467 } 2468 2469 proto := NewMysqlClientProtocol(0, ioses, 1024, sv) 2470 2471 type kase struct { 2472 data []byte 2473 res bool 2474 } 2475 2476 var cap uint32 = 0 2477 cap |= CLIENT_PROTOCOL_41 | CLIENT_CONNECT_WITH_DB | CLIENT_PLUGIN_AUTH 2478 var header [4]byte 2479 proto.io.WriteUint32(header[:], 0, cap) 2480 2481 kases := []kase{ 2482 {data: []byte{0}, res: false}, 2483 {data: []byte{0, 0, 0, 0}, res: false}, 2484 {data: append(header[:], []byte{ 2485 0, 0, 0, 2486 }...), res: false}, 2487 {data: append(header[:], []byte{ 2488 0, 0, 0, 0, 2489 }...), res: false}, 2490 {data: append(header[:], []byte{ 2491 0, 0, 0, 0, 2492 0, 2493 0, 0, 0, 2494 }...), res: false}, 2495 {data: append(header[:], []byte{ 2496 0, 0, 0, 0, 2497 0, 2498 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2499 }...), res: false}, 2500 {data: append(header[:], []byte{ 2501 0, 0, 0, 0, 2502 0, 2503 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2504 'a', 'b', 'c', 2505 }...), res: false}, 2506 {data: append(header[:], []byte{ 2507 0, 0, 0, 0, 2508 0, 2509 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2510 'a', 'b', 'c', 0, 2511 'd', 'e', 'f', 2512 }...), res: false}, 2513 {data: append(header[:], []byte{ 2514 0, 0, 0, 0, 2515 0, 2516 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2517 'a', 'b', 'c', 0, 2518 'd', 'e', 'f', 0, 2519 'T', 2520 }...), res: false}, 2521 {data: append(header[:], []byte{ 2522 0, 0, 0, 0, 2523 0, 2524 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2525 'a', 'b', 'c', 0, 2526 'd', 'e', 'f', 0, 2527 'T', 0, 2528 'm', 'y', 's', 2529 }...), res: false}, 2530 {data: append(header[:], []byte{ 2531 0, 0, 0, 0, 2532 0, 2533 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2534 'a', 'b', 'c', 0, 2535 'd', 'e', 'f', 0, 2536 'T', 0, 2537 'm', 'y', 's', 'q', 'l', '_', 'n', 'a', 't', 'i', 'v', 'e', '_', 'p', 'a', 's', 's', 'w', 'o', 'r', 'x', 0, 2538 }...), res: true}, 2539 {data: append(header[:], []byte{ 2540 0, 0, 0, 0, 2541 0, 2542 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2543 'a', 'b', 'c', 0, 2544 'd', 'e', 'f', 0, 2545 'T', 0, 2546 'm', 'y', 's', 'q', 'l', '_', 'n', 'a', 't', 'i', 'v', 'e', '_', 'p', 'a', 's', 's', 'w', 'o', 'r', 'd', 0, 2547 }...), res: true}, 2548 } 2549 2550 for _, c := range kases { 2551 ok, _, _ := proto.analyseHandshakeResponse41(context.TODO(), c.data) 2552 convey.So(ok, convey.ShouldEqual, c.res) 2553 } 2554 }) 2555 } 2556 2557 func Test_handleHandshake(t *testing.T) { 2558 ctx := context.TODO() 2559 convey.Convey("handleHandshake succ", t, func() { 2560 ctrl := gomock.NewController(t) 2561 defer ctrl.Finish() 2562 ioses := mock_frontend.NewMockIOSession(ctrl) 2563 ioses.EXPECT().Write(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() 2564 ioses.EXPECT().RemoteAddress().Return("").AnyTimes() 2565 ioses.EXPECT().Ref().AnyTimes() 2566 var IO IOPackageImpl 2567 var SV = &config.FrontendParameters{} 2568 mp := &MysqlProtocolImpl{SV: SV} 2569 mp.io = &IO 2570 mp.tcpConn = ioses 2571 mp.SetSkipCheckUser(true) 2572 payload := []byte{'a'} 2573 _, err := mp.HandleHandshake(ctx, payload) 2574 convey.So(err, convey.ShouldNotBeNil) 2575 2576 payload = append(payload, []byte{'b', 'c'}...) 2577 _, err = mp.HandleHandshake(ctx, payload) 2578 convey.So(err, convey.ShouldNotBeNil) 2579 2580 payload = append(payload, []byte{'c', 'd', 0}...) 2581 _, err = mp.HandleHandshake(ctx, payload) 2582 convey.So(err, convey.ShouldBeNil) 2583 }) 2584 } 2585 2586 func Test_handleHandshake_Recover(t *testing.T) { 2587 f := fuzz.New() 2588 count := 10000 2589 maxLen := 0 2590 2591 ctx := context.TODO() 2592 ctrl := gomock.NewController(t) 2593 defer ctrl.Finish() 2594 ioses := mock_frontend.NewMockIOSession(ctrl) 2595 ioses.EXPECT().Write(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() 2596 ioses.EXPECT().RemoteAddress().Return("").AnyTimes() 2597 ioses.EXPECT().Ref().AnyTimes() 2598 convey.Convey("handleHandshake succ", t, func() { 2599 var IO IOPackageImpl 2600 var SV = &config.FrontendParameters{} 2601 mp := &MysqlProtocolImpl{SV: SV} 2602 mp.io = &IO 2603 mp.tcpConn = ioses 2604 mp.SetSkipCheckUser(true) 2605 var payload []byte 2606 for i := 0; i < count; i++ { 2607 f.Fuzz(&payload) 2608 _, _ = mp.HandleHandshake(ctx, payload) 2609 maxLen = Max(maxLen, len(payload)) 2610 } 2611 maxLen = 0 2612 var payload2 string 2613 for i := 0; i < count; i++ { 2614 f.Fuzz(&payload2) 2615 _, _ = mp.HandleHandshake(ctx, []byte(payload2)) 2616 maxLen = Max(maxLen, len(payload2)) 2617 } 2618 }) 2619 } 2620 2621 func TestMysqlProtocolImpl_Close(t *testing.T) { 2622 ctrl := gomock.NewController(t) 2623 defer ctrl.Finish() 2624 ioses := mock_frontend.NewMockIOSession(ctrl) 2625 ioses.EXPECT().Write(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() 2626 ioses.EXPECT().Read(gomock.Any()).Return(new(Packet), nil).AnyTimes() 2627 ioses.EXPECT().RemoteAddress().Return("").AnyTimes() 2628 ioses.EXPECT().Ref().AnyTimes() 2629 ioses.EXPECT().Disconnect().AnyTimes() 2630 sv, err := getSystemVariables("test/system_vars_config.toml") 2631 if err != nil { 2632 t.Error(err) 2633 } 2634 2635 proto := NewMysqlClientProtocol(0, ioses, 1024, sv) 2636 proto.Quit() 2637 assert.Nil(t, proto.GetSalt()) 2638 assert.Nil(t, proto.strconvBuffer) 2639 assert.Nil(t, proto.lenEncBuffer) 2640 assert.Nil(t, proto.binaryNullBuffer) 2641 }