github.com/dolthub/go-mysql-server@v0.18.0/server/handler_test.go (about) 1 // Copyright 2020-2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package server 16 17 import ( 18 "context" 19 "fmt" 20 "io" 21 "net" 22 "strconv" 23 "testing" 24 "time" 25 26 "github.com/dolthub/vitess/go/mysql" 27 "github.com/dolthub/vitess/go/sqltypes" 28 "github.com/dolthub/vitess/go/vt/proto/query" 29 "github.com/stretchr/testify/assert" 30 "github.com/stretchr/testify/require" 31 32 sqle "github.com/dolthub/go-mysql-server" 33 "github.com/dolthub/go-mysql-server/memory" 34 "github.com/dolthub/go-mysql-server/sql" 35 "github.com/dolthub/go-mysql-server/sql/analyzer" 36 "github.com/dolthub/go-mysql-server/sql/types" 37 "github.com/dolthub/go-mysql-server/sql/variables" 38 ) 39 40 var samplePrepareData = &mysql.PrepareData{ 41 StatementID: 42, 42 ParamsCount: 1, 43 } 44 45 func TestHandlerOutput(t *testing.T) { 46 e, pro := setupMemDB(require.New(t)) 47 dbFunc := pro.Database 48 49 dummyConn := newConn(1) 50 handler := &Handler{ 51 e: e, 52 sm: NewSessionManager( 53 testSessionBuilder(pro), 54 sql.NoopTracer, 55 dbFunc, 56 sql.NewMemoryManager(nil), 57 sqle.NewProcessList(), 58 "foo", 59 ), 60 readTimeout: time.Second, 61 } 62 handler.NewConnection(dummyConn) 63 64 type expectedValues struct { 65 callsToCallback int 66 lenLastBatch int 67 lastRowsAffected uint64 68 } 69 70 tests := []struct { 71 name string 72 handler *Handler 73 conn *mysql.Conn 74 query string 75 expected expectedValues 76 }{ 77 { 78 name: "select all without limit", 79 handler: handler, 80 conn: dummyConn, 81 query: "SELECT * FROM test", 82 expected: expectedValues{ 83 callsToCallback: 8, 84 lenLastBatch: 114, 85 lastRowsAffected: uint64(114), 86 }, 87 }, 88 { 89 name: "with limit equal to batch capacity", 90 handler: handler, 91 conn: dummyConn, 92 query: "SELECT * FROM test limit 100", 93 expected: expectedValues{ 94 callsToCallback: 1, 95 lenLastBatch: 100, 96 lastRowsAffected: uint64(100), 97 }, 98 }, 99 { 100 name: "with limit less than batch capacity", 101 handler: handler, 102 conn: dummyConn, 103 query: "SELECT * FROM test limit 60", 104 expected: expectedValues{ 105 callsToCallback: 1, 106 lenLastBatch: 60, 107 lastRowsAffected: uint64(60), 108 }, 109 }, 110 { 111 name: "with limit greater than batch capacity", 112 handler: handler, 113 conn: dummyConn, 114 query: "SELECT * FROM test limit 200", 115 expected: expectedValues{ 116 callsToCallback: 2, 117 lenLastBatch: 72, 118 lastRowsAffected: uint64(72), 119 }, 120 }, 121 { 122 name: "with limit set to a number not multiple of the batch capacity", 123 handler: handler, 124 conn: dummyConn, 125 query: "SELECT * FROM test limit 530", 126 expected: expectedValues{ 127 callsToCallback: 5, 128 lenLastBatch: 18, 129 lastRowsAffected: uint64(18), 130 }, 131 }, 132 { 133 name: "with limit zero", 134 handler: handler, 135 conn: dummyConn, 136 query: "SELECT * FROM test limit 0", 137 expected: expectedValues{ 138 callsToCallback: 1, 139 lenLastBatch: 0, 140 lastRowsAffected: uint64(0), 141 }, 142 }, 143 } 144 145 for _, test := range tests { 146 t.Run(test.name, func(t *testing.T) { 147 var callsToCallback int 148 var lenLastBatch int 149 var lastRowsAffected uint64 150 handler.ComInitDB(test.conn, "test") 151 err := handler.ComQuery(test.conn, test.query, func(res *sqltypes.Result, more bool) error { 152 callsToCallback++ 153 lenLastBatch = len(res.Rows) 154 lastRowsAffected = res.RowsAffected 155 return nil 156 }) 157 158 require.NoError(t, err) 159 assert.Equal(t, test.expected.callsToCallback, callsToCallback) 160 assert.Equal(t, test.expected.lenLastBatch, lenLastBatch) 161 assert.Equal(t, test.expected.lastRowsAffected, lastRowsAffected) 162 163 }) 164 } 165 166 t.Run("sum aggregation type is correct", func(t *testing.T) { 167 handler.ComInitDB(dummyConn, "test") 168 var result *sqltypes.Result 169 err := handler.ComQuery(dummyConn, "select sum(1) from test", func(res *sqltypes.Result, more bool) error { 170 result = res 171 return nil 172 }) 173 require.NoError(t, err) 174 require.Equal(t, 1, len(result.Rows)) 175 require.Equal(t, sqltypes.Float64, result.Rows[0][0].Type()) 176 require.Equal(t, []byte("1010"), result.Rows[0][0].ToBytes()) 177 }) 178 179 t.Run("avg aggregation type is correct", func(t *testing.T) { 180 handler.ComInitDB(dummyConn, "test") 181 var result *sqltypes.Result 182 err := handler.ComQuery(dummyConn, "select avg(1) from test", func(res *sqltypes.Result, more bool) error { 183 result = res 184 return nil 185 }) 186 require.NoError(t, err) 187 require.Equal(t, 1, len(result.Rows)) 188 require.Equal(t, sqltypes.Float64, result.Rows[0][0].Type()) 189 require.Equal(t, []byte("1"), result.Rows[0][0].ToBytes()) 190 }) 191 192 t.Run("if() type is correct", func(t *testing.T) { 193 handler.ComInitDB(dummyConn, "test") 194 var result *sqltypes.Result 195 err := handler.ComQuery(dummyConn, "select if(1, 123, 'def')", func(res *sqltypes.Result, more bool) error { 196 result = res 197 return nil 198 }) 199 require.NoError(t, err) 200 require.Equal(t, 1, len(result.Rows)) 201 require.Equal(t, sqltypes.Text, result.Rows[0][0].Type()) 202 require.Equal(t, []byte("123"), result.Rows[0][0].ToBytes()) 203 204 err = handler.ComQuery(dummyConn, "select if(0, 123, 456)", func(res *sqltypes.Result, more bool) error { 205 result = res 206 return nil 207 }) 208 require.NoError(t, err) 209 require.Equal(t, 1, len(result.Rows)) 210 require.Equal(t, sqltypes.Int64, result.Rows[0][0].Type()) 211 require.Equal(t, []byte("456"), result.Rows[0][0].ToBytes()) 212 }) 213 } 214 215 func TestHandlerErrors(t *testing.T) { 216 e, pro := setupMemDB(require.New(t)) 217 dbFunc := pro.Database 218 219 dummyConn := newConn(1) 220 handler := &Handler{ 221 e: e, 222 sm: NewSessionManager( 223 testSessionBuilder(pro), 224 sql.NoopTracer, 225 dbFunc, 226 sql.NewMemoryManager(nil), 227 sqle.NewProcessList(), 228 "foo", 229 ), 230 readTimeout: time.Second, 231 } 232 handler.NewConnection(dummyConn) 233 234 type expectedValues struct { 235 callsToCallback int 236 lenLastBatch int 237 lastRowsAffected uint64 238 } 239 240 setupCommands := []string{"CREATE TABLE `test_table` ( `id` INT NOT NULL PRIMARY KEY, `v` INT );"} 241 242 tests := []struct { 243 name string 244 query string 245 expectedErrorCode int 246 }{ 247 { 248 name: "insert with nonexistent field name", 249 query: "INSERT INTO `test_table` (`id`, `v_`) VALUES (1, 2)", 250 expectedErrorCode: mysql.ERBadFieldError, 251 }, 252 { 253 name: "insert into nonexistent table", 254 query: "INSERT INTO `test`.`no_such_table` (`id`, `v`) VALUES (1, 2)", 255 expectedErrorCode: mysql.ERNoSuchTable, 256 }, 257 { 258 name: "insert into same column twice", 259 query: "INSERT INTO `test`.`test_table` (`id`, `id`, `v`) VALUES (1, 2, 3)", 260 expectedErrorCode: mysql.ERFieldSpecifiedTwice, 261 }, 262 } 263 264 handler.ComInitDB(dummyConn, "test") 265 for _, setupCommand := range setupCommands { 266 err := handler.ComQuery(dummyConn, setupCommand, func(res *sqltypes.Result, more bool) error { 267 return nil 268 }) 269 require.NoError(t, err) 270 } 271 272 for _, test := range tests { 273 t.Run(test.name, func(t *testing.T) { 274 err := handler.ComQuery(dummyConn, test.query, func(res *sqltypes.Result, more bool) error { 275 return nil 276 }) 277 require.NotNil(t, err) 278 sqlErr, isSqlError := err.(*mysql.SQLError) 279 require.True(t, isSqlError) 280 require.Equal(t, test.expectedErrorCode, sqlErr.Number()) 281 }) 282 } 283 } 284 285 // TestHandlerComReset asserts that the Handler.ComResetConnection method correctly clears all session 286 // state (e.g. table locks, prepared statements, user variables, session variables), and keeps the current 287 // database selected. 288 func TestHandlerComResetConnection(t *testing.T) { 289 e, pro := setupMemDB(require.New(t)) 290 dummyConn := newConn(1) 291 dbFunc := pro.Database 292 293 handler := &Handler{ 294 e: e, 295 sm: NewSessionManager( 296 testSessionBuilder(pro), 297 sql.NoopTracer, 298 dbFunc, 299 sql.NewMemoryManager(nil), 300 sqle.NewProcessList(), 301 "foo", 302 ), 303 } 304 handler.NewConnection(dummyConn) 305 handler.ComInitDB(dummyConn, "test") 306 307 prepareData := &mysql.PrepareData{ 308 StatementID: 0, 309 PrepareStmt: "select 42 + ? from dual", 310 ParamsCount: 0, 311 ParamsType: nil, 312 ColumnNames: nil, 313 BindVars: map[string]*query.BindVariable{ 314 "v1": {Type: query.Type_INT8, Value: []byte("5")}, 315 }, 316 } 317 318 // Create a prepared statement, a table lock, and a user var in the current session 319 _, err := handler.ComPrepare(dummyConn, prepareData.PrepareStmt, prepareData) 320 require.NoError(t, err) 321 _, cached := e.PreparedDataCache.GetCachedStmt(dummyConn.ConnectionID, prepareData.PrepareStmt) 322 require.True(t, cached) 323 err = handler.ComQuery(dummyConn, "SET @userVar = 42;", func(res *sqltypes.Result, more bool) error { 324 return nil 325 }) 326 require.NoError(t, err) 327 328 // Reset the connection to clear all session state 329 err = handler.ComResetConnection(dummyConn) 330 require.NoError(t, err) 331 332 // Assert that the session is clean – the selected database should not change, and all session state 333 // such as user vars, session vars, prepared statements, table locks, and temporary tables should be cleared. 334 err = handler.ComQuery(dummyConn, "SELECT database()", func(res *sqltypes.Result, more bool) error { 335 require.Equal(t, "test", res.Rows[0][0].ToString()) 336 return nil 337 }) 338 require.NoError(t, err) 339 _, cached = e.PreparedDataCache.GetCachedStmt(dummyConn.ConnectionID, prepareData.PrepareStmt) 340 require.False(t, cached) 341 err = handler.ComQuery(dummyConn, "SELECT @userVar;", func(res *sqltypes.Result, more bool) error { 342 require.True(t, res.Rows[0][0].IsNull()) 343 return nil 344 }) 345 require.NoError(t, err) 346 } 347 348 func TestHandlerComPrepare(t *testing.T) { 349 e, pro := setupMemDB(require.New(t)) 350 dummyConn := newConn(1) 351 dbFunc := pro.Database 352 353 handler := &Handler{ 354 e: e, 355 sm: NewSessionManager( 356 testSessionBuilder(pro), 357 sql.NoopTracer, 358 dbFunc, 359 sql.NewMemoryManager(nil), 360 sqle.NewProcessList(), 361 "foo", 362 ), 363 } 364 handler.NewConnection(dummyConn) 365 366 type testcase struct { 367 name string 368 statement string 369 expected []*query.Field 370 expectedErr *mysql.SQLError 371 } 372 373 for _, test := range []testcase{ 374 { 375 name: "insert statement returns nil schema", 376 statement: "insert into test (c1) values (?)", 377 expected: nil, 378 }, 379 { 380 name: "update statement returns nil schema", 381 statement: "update test set c1 = ?", 382 expected: nil, 383 }, 384 { 385 name: "delete statement returns nil schema", 386 statement: "delete from test where c1 = ?", 387 expected: nil, 388 }, 389 { 390 name: "select statement returns non-nil schema", 391 statement: "select c1 from test where c1 > ?", 392 expected: []*query.Field{ 393 {Name: "c1", OrgName: "c1", Table: "test", OrgTable: "test", Database: "test", Type: query.Type_INT32, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 11, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, 394 }, 395 }, 396 { 397 name: "errors are cast to SQLError", 398 statement: "SELECT * from doesnotexist LIMIT ?", 399 expectedErr: mysql.NewSQLError(mysql.ERNoSuchTable, "", "table not found: %s", "doesnotexist"), 400 }, 401 } { 402 t.Run(test.name, func(t *testing.T) { 403 handler.ComInitDB(dummyConn, "test") 404 schema, err := handler.ComPrepare(dummyConn, test.statement, samplePrepareData) 405 if test.expectedErr == nil { 406 require.NoError(t, err) 407 require.Equal(t, test.expected, schema) 408 } else { 409 require.NotNil(t, err) 410 sqlErr, isSqlError := err.(*mysql.SQLError) 411 require.True(t, isSqlError) 412 require.Equal(t, test.expectedErr.Number(), sqlErr.Number()) 413 require.Equal(t, test.expectedErr.SQLState(), sqlErr.SQLState()) 414 require.Equal(t, test.expectedErr.Error(), sqlErr.Error()) 415 } 416 }) 417 } 418 } 419 420 func TestHandlerComPrepareExecute(t *testing.T) { 421 e, pro := setupMemDB(require.New(t)) 422 dummyConn := newConn(1) 423 dbFunc := pro.Database 424 425 handler := &Handler{ 426 e: e, 427 sm: NewSessionManager( 428 testSessionBuilder(pro), 429 sql.NoopTracer, 430 dbFunc, 431 sql.NewMemoryManager(nil), 432 sqle.NewProcessList(), 433 "foo", 434 ), 435 } 436 handler.NewConnection(dummyConn) 437 438 type testcase struct { 439 name string 440 prepare *mysql.PrepareData 441 execute map[string]*query.BindVariable 442 schema []*query.Field 443 expected []sql.Row 444 } 445 446 for _, test := range []testcase{ 447 { 448 name: "select statement returns nil schema", 449 prepare: &mysql.PrepareData{ 450 StatementID: 0, 451 PrepareStmt: "select c1 from test where c1 < ?", 452 ParamsCount: 0, 453 ParamsType: nil, 454 ColumnNames: nil, 455 BindVars: map[string]*query.BindVariable{ 456 "v1": {Type: query.Type_INT8, Value: []byte("5")}, 457 }, 458 }, 459 schema: []*query.Field{ 460 {Name: "c1", OrgName: "c1", Table: "test", OrgTable: "test", Database: "test", Type: query.Type_INT32, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 11, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, 461 }, 462 expected: []sql.Row{ 463 {0}, {1}, {2}, {3}, {4}, 464 }, 465 }, 466 } { 467 t.Run(test.name, func(t *testing.T) { 468 handler.ComInitDB(dummyConn, "test") 469 schema, err := handler.ComPrepare(dummyConn, test.prepare.PrepareStmt, samplePrepareData) 470 require.NoError(t, err) 471 require.Equal(t, test.schema, schema) 472 473 var res []sql.Row 474 callback := func(r *sqltypes.Result) error { 475 for _, r := range r.Rows { 476 var vals []interface{} 477 for _, v := range r { 478 val, err := strconv.ParseInt(string(v.Raw()), 0, 64) 479 if err != nil { 480 return err 481 } 482 vals = append(vals, int(val)) 483 } 484 res = append(res, sql.NewRow(vals...)) 485 } 486 return nil 487 } 488 err = handler.ComStmtExecute(dummyConn, test.prepare, callback) 489 require.NoError(t, err) 490 require.Equal(t, test.expected, res) 491 }) 492 } 493 } 494 495 func TestHandlerComPrepareExecuteWithPreparedDisabled(t *testing.T) { 496 e, pro := setupMemDB(require.New(t)) 497 dummyConn := newConn(1) 498 dbFunc := pro.Database 499 500 handler := &Handler{ 501 e: e, 502 sm: NewSessionManager( 503 testSessionBuilder(pro), 504 sql.NoopTracer, 505 dbFunc, 506 sql.NewMemoryManager(nil), 507 sqle.NewProcessList(), 508 "foo", 509 ), 510 } 511 handler.NewConnection(dummyConn) 512 analyzer.SetPreparedStmts(true) 513 defer func() { 514 analyzer.SetPreparedStmts(false) 515 }() 516 type testcase struct { 517 name string 518 prepare *mysql.PrepareData 519 execute map[string]*query.BindVariable 520 schema []*query.Field 521 expected []sql.Row 522 } 523 524 for _, test := range []testcase{ 525 { 526 name: "select statement returns nil schema", 527 prepare: &mysql.PrepareData{ 528 StatementID: 0, 529 PrepareStmt: "select c1 from test where c1 < ?", 530 ParamsCount: 0, 531 ParamsType: nil, 532 ColumnNames: nil, 533 BindVars: map[string]*query.BindVariable{ 534 "v1": {Type: query.Type_INT8, Value: []byte("5")}, 535 }, 536 }, 537 schema: []*query.Field{ 538 {Name: "c1", OrgName: "c1", Table: "test", OrgTable: "test", Database: "test", Type: query.Type_INT32, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 11, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, 539 }, 540 expected: []sql.Row{ 541 {0}, {1}, {2}, {3}, {4}, 542 }, 543 }, 544 } { 545 t.Run(test.name, func(t *testing.T) { 546 handler.ComInitDB(dummyConn, "test") 547 schema, err := handler.ComPrepare(dummyConn, test.prepare.PrepareStmt, samplePrepareData) 548 require.NoError(t, err) 549 require.Equal(t, test.schema, schema) 550 551 var res []sql.Row 552 callback := func(r *sqltypes.Result) error { 553 for _, r := range r.Rows { 554 var vals []interface{} 555 for _, v := range r { 556 val, err := strconv.ParseInt(string(v.Raw()), 0, 64) 557 if err != nil { 558 return err 559 } 560 vals = append(vals, int(val)) 561 } 562 res = append(res, sql.NewRow(vals...)) 563 } 564 return nil 565 } 566 err = handler.ComStmtExecute(dummyConn, test.prepare, callback) 567 require.NoError(t, err) 568 require.Equal(t, test.expected, res) 569 }) 570 } 571 } 572 573 type TestListener struct { 574 Connections int 575 Queries int 576 Disconnects int 577 Successes int 578 Failures int 579 } 580 581 func (tl *TestListener) ClientConnected() { 582 tl.Connections++ 583 } 584 585 func (tl *TestListener) ClientDisconnected() { 586 tl.Disconnects++ 587 } 588 589 func (tl *TestListener) QueryStarted() { 590 tl.Queries++ 591 } 592 593 func (tl *TestListener) QueryCompleted(success bool, duration time.Duration) { 594 if success { 595 tl.Successes++ 596 } else { 597 tl.Failures++ 598 } 599 } 600 601 func TestServerEventListener(t *testing.T) { 602 require := require.New(t) 603 e, pro := setupMemDB(require) 604 dbFunc := pro.Database 605 606 listener := &TestListener{} 607 handler := &Handler{ 608 e: e, 609 sm: NewSessionManager( 610 func(ctx context.Context, conn *mysql.Conn, addr string) (sql.Session, error) { 611 return sql.NewBaseSessionWithClientServer(addr, sql.Client{Capabilities: conn.Capabilities}, conn.ConnectionID), nil 612 }, 613 sql.NoopTracer, 614 dbFunc, 615 e.MemoryManager, 616 e.ProcessList, 617 "foo", 618 ), 619 sel: listener, 620 } 621 622 cb := func(res *sqltypes.Result, more bool) error { 623 return nil 624 } 625 626 require.Equal(listener.Connections, 0) 627 require.Equal(listener.Disconnects, 0) 628 require.Equal(listener.Queries, 0) 629 require.Equal(listener.Successes, 0) 630 require.Equal(listener.Failures, 0) 631 632 conn1 := newConn(1) 633 handler.NewConnection(conn1) 634 require.Equal(listener.Connections, 1) 635 require.Equal(listener.Disconnects, 0) 636 637 err := handler.sm.SetDB(conn1, "test") 638 require.NoError(err) 639 640 err = handler.ComQuery(conn1, "SELECT 1", cb) 641 require.NoError(err) 642 require.Equal(listener.Queries, 1) 643 require.Equal(listener.Successes, 1) 644 require.Equal(listener.Failures, 0) 645 646 conn2 := newConn(2) 647 handler.NewConnection(conn2) 648 require.Equal(listener.Connections, 2) 649 require.Equal(listener.Disconnects, 0) 650 651 handler.ComInitDB(conn2, "test") 652 err = handler.ComQuery(conn2, "select 1", cb) 653 require.NoError(err) 654 require.Equal(listener.Queries, 2) 655 require.Equal(listener.Successes, 2) 656 require.Equal(listener.Failures, 0) 657 658 err = handler.ComQuery(conn1, "select bad_col from bad_table with illegal syntax", cb) 659 require.Error(err) 660 require.Equal(listener.Queries, 3) 661 require.Equal(listener.Successes, 2) 662 require.Equal(listener.Failures, 1) 663 664 handler.ConnectionClosed(conn1) 665 require.Equal(listener.Connections, 2) 666 require.Equal(listener.Disconnects, 1) 667 668 handler.ConnectionClosed(conn2) 669 require.Equal(listener.Connections, 2) 670 require.Equal(listener.Disconnects, 2) 671 672 conn3 := newConn(3) 673 query := "SELECT ?" 674 _, err = handler.ComPrepare(conn3, query, samplePrepareData) 675 require.NoError(err) 676 require.Equal(1, len(e.PreparedDataCache.CachedStatementsForSession(conn3.ConnectionID))) 677 require.NotNil(e.PreparedDataCache.GetCachedStmt(conn3.ConnectionID, query)) 678 679 handler.ConnectionClosed(conn3) 680 require.Equal(0, len(e.PreparedDataCache.CachedStatementsForSession(conn3.ConnectionID))) 681 } 682 683 func TestHandlerKill(t *testing.T) { 684 require := require.New(t) 685 e, pro := setupMemDB(require) 686 dbFunc := pro.Database 687 688 handler := &Handler{ 689 e: e, 690 sm: NewSessionManager( 691 func(ctx context.Context, conn *mysql.Conn, addr string) (sql.Session, error) { 692 return sql.NewBaseSessionWithClientServer(addr, sql.Client{Capabilities: conn.Capabilities}, conn.ConnectionID), nil 693 }, 694 sql.NoopTracer, 695 dbFunc, 696 e.MemoryManager, 697 e.ProcessList, 698 "foo", 699 ), 700 } 701 702 conn1 := newConn(1) 703 handler.NewConnection(conn1) 704 705 conn2 := newConn(2) 706 handler.NewConnection(conn2) 707 708 require.Len(handler.sm.connections, 2) 709 require.Len(handler.sm.sessions, 0) 710 711 handler.ComInitDB(conn2, "test") 712 err := handler.ComQuery(conn2, "KILL QUERY 1", func(res *sqltypes.Result, more bool) error { 713 return nil 714 }) 715 require.NoError(err) 716 717 require.False(conn1.Conn.(*mockConn).closed) 718 require.Len(handler.sm.connections, 2) 719 require.Len(handler.sm.sessions, 1) 720 721 err = handler.sm.SetDB(conn1, "test") 722 require.NoError(err) 723 ctx1, err := handler.sm.NewContextWithQuery(conn1, "SELECT 1") 724 require.NoError(err) 725 ctx1, err = handler.e.ProcessList.BeginQuery(ctx1, "SELECT 1") 726 require.NoError(err) 727 728 err = handler.ComQuery(conn2, "KILL "+fmt.Sprint(ctx1.ID()), func(res *sqltypes.Result, more bool) error { 729 return nil 730 }) 731 require.NoError(err) 732 733 require.Error(ctx1.Err()) 734 require.True(conn1.Conn.(*mockConn).closed) 735 handler.ConnectionClosed(conn1) 736 require.Len(handler.sm.sessions, 1) 737 } 738 739 func TestSchemaToFields(t *testing.T) { 740 require := require.New(t) 741 742 schema := sql.Schema{ 743 // Blob, Text, and JSON Types 744 {Name: "tinyblob", Source: "table1", DatabaseSource: "db1", Type: types.TinyBlob}, 745 {Name: "blob", Source: "table1", DatabaseSource: "db1", Type: types.Blob}, 746 {Name: "mediumblob", Source: "table1", DatabaseSource: "db1", Type: types.MediumBlob}, 747 {Name: "longblob", Source: "table1", DatabaseSource: "db1", Type: types.LongBlob}, 748 {Name: "tinytext", Source: "table1", DatabaseSource: "db1", Type: types.TinyText}, 749 {Name: "text", Source: "table1", DatabaseSource: "db1", Type: types.Text}, 750 {Name: "mediumtext", Source: "table1", DatabaseSource: "db1", Type: types.MediumText}, 751 {Name: "longtext", Source: "table1", DatabaseSource: "db1", Type: types.LongText}, 752 {Name: "json", Source: "table1", DatabaseSource: "db1", Type: types.JSON}, 753 754 // Geometry Types 755 {Name: "geometry", Source: "table1", DatabaseSource: "db1", Type: types.GeometryType{}}, 756 {Name: "point", Source: "table1", DatabaseSource: "db1", Type: types.PointType{}}, 757 {Name: "polygon", Source: "table1", DatabaseSource: "db1", Type: types.PolygonType{}}, 758 {Name: "linestring", Source: "table1", DatabaseSource: "db1", Type: types.LineStringType{}}, 759 760 // Integer Types 761 {Name: "uint8", Source: "table1", DatabaseSource: "db1", Type: types.Uint8}, 762 {Name: "int8", Source: "table1", DatabaseSource: "db1", Type: types.Int8}, 763 {Name: "uint16", Source: "table1", DatabaseSource: "db1", Type: types.Uint16}, 764 {Name: "int16", Source: "table1", DatabaseSource: "db1", Type: types.Int16}, 765 {Name: "uint24", Source: "table1", DatabaseSource: "db1", Type: types.Uint24}, 766 {Name: "int24", Source: "table1", DatabaseSource: "db1", Type: types.Int24}, 767 {Name: "uint32", Source: "table1", DatabaseSource: "db1", Type: types.Uint32}, 768 {Name: "int32", Source: "table1", DatabaseSource: "db1", Type: types.Int32}, 769 {Name: "uint64", Source: "table1", DatabaseSource: "db1", Type: types.Uint64}, 770 {Name: "int64", Source: "table1", DatabaseSource: "db1", Type: types.Int64}, 771 772 // Floating Point and Decimal Types 773 {Name: "float32", Source: "table1", DatabaseSource: "db1", Type: types.Float32}, 774 {Name: "float64", Source: "table1", DatabaseSource: "db1", Type: types.Float64}, 775 {Name: "decimal10_0", Source: "table1", DatabaseSource: "db1", Type: types.MustCreateDecimalType(10, 0)}, 776 {Name: "decimal60_30", Source: "table1", DatabaseSource: "db1", Type: types.MustCreateDecimalType(60, 30)}, 777 778 // Char, Binary, and Bit Types 779 {Name: "varchar50", Source: "table1", DatabaseSource: "db1", Type: types.MustCreateString(sqltypes.VarChar, 50, sql.Collation_Default)}, 780 {Name: "varbinary12345", Source: "table1", DatabaseSource: "db1", Type: types.MustCreateBinary(sqltypes.VarBinary, 12345)}, 781 {Name: "binary123", Source: "table1", DatabaseSource: "db1", Type: types.MustCreateBinary(sqltypes.Binary, 123)}, 782 {Name: "char123", Source: "table1", DatabaseSource: "db1", Type: types.MustCreateString(sqltypes.Char, 123, sql.Collation_Default)}, 783 {Name: "bit12", Source: "table1", DatabaseSource: "db1", Type: types.MustCreateBitType(12)}, 784 785 // Dates 786 {Name: "datetime", Source: "table1", DatabaseSource: "db1", Type: types.MustCreateDatetimeType(sqltypes.Datetime, 0)}, 787 {Name: "timestamp", Source: "table1", DatabaseSource: "db1", Type: types.MustCreateDatetimeType(sqltypes.Timestamp, 0)}, 788 {Name: "date", Source: "table1", DatabaseSource: "db1", Type: types.MustCreateDatetimeType(sqltypes.Date, 0)}, 789 {Name: "time", Source: "table1", DatabaseSource: "db1", Type: types.Time}, 790 {Name: "year", Source: "table1", DatabaseSource: "db1", Type: types.Year}, 791 792 // Set and Enum Types 793 {Name: "set", Source: "table1", DatabaseSource: "db1", Type: types.MustCreateSetType([]string{"one", "two", "three", "four"}, sql.Collation_Default)}, 794 {Name: "enum", Source: "table1", DatabaseSource: "db1", Type: types.MustCreateEnumType([]string{"one", "two", "three", "four"}, sql.Collation_Default)}, 795 } 796 797 expected := []*query.Field{ 798 // Blob, Text, and JSON Types 799 {Name: "tinyblob", OrgName: "tinyblob", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_BLOB, Charset: mysql.CharacterSetBinary, ColumnLength: 255, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, 800 {Name: "blob", OrgName: "blob", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_BLOB, Charset: mysql.CharacterSetBinary, ColumnLength: 65_535, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, 801 {Name: "mediumblob", OrgName: "mediumblob", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_BLOB, Charset: mysql.CharacterSetBinary, ColumnLength: 16_777_215, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, 802 {Name: "longblob", OrgName: "longblob", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_BLOB, Charset: mysql.CharacterSetBinary, ColumnLength: 4_294_967_295, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, 803 {Name: "tinytext", OrgName: "tinytext", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_TEXT, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 1020, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, 804 {Name: "text", OrgName: "text", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_TEXT, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 262_140, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, 805 {Name: "mediumtext", OrgName: "mediumtext", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_TEXT, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 67_108_860, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, 806 {Name: "longtext", OrgName: "longtext", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_TEXT, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 4_294_967_295, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, 807 {Name: "json", OrgName: "json", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_JSON, Charset: mysql.CharacterSetBinary, ColumnLength: 4_294_967_295, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, 808 809 // Geometry Types 810 {Name: "geometry", OrgName: "geometry", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_GEOMETRY, Charset: mysql.CharacterSetBinary, ColumnLength: 4_294_967_295, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, 811 {Name: "point", OrgName: "point", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_GEOMETRY, Charset: mysql.CharacterSetBinary, ColumnLength: 4_294_967_295, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, 812 {Name: "polygon", OrgName: "polygon", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_GEOMETRY, Charset: mysql.CharacterSetBinary, ColumnLength: 4_294_967_295, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, 813 {Name: "linestring", OrgName: "linestring", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_GEOMETRY, Charset: mysql.CharacterSetBinary, ColumnLength: 4_294_967_295, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, 814 815 // Integer Types 816 {Name: "uint8", OrgName: "uint8", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_UINT8, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 3, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG | query.MySqlFlag_UNSIGNED_FLAG)}, 817 {Name: "int8", OrgName: "int8", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_INT8, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 4, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, 818 {Name: "uint16", OrgName: "uint16", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_UINT16, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 5, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG | query.MySqlFlag_UNSIGNED_FLAG)}, 819 {Name: "int16", OrgName: "int16", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_INT16, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 6, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, 820 {Name: "uint24", OrgName: "uint24", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_UINT24, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 8, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG | query.MySqlFlag_UNSIGNED_FLAG)}, 821 {Name: "int24", OrgName: "int24", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_INT24, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 9, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, 822 {Name: "uint32", OrgName: "uint32", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_UINT32, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 10, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG | query.MySqlFlag_UNSIGNED_FLAG)}, 823 {Name: "int32", OrgName: "int32", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_INT32, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 11, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, 824 {Name: "uint64", OrgName: "uint64", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_UINT64, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 20, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG | query.MySqlFlag_UNSIGNED_FLAG)}, 825 {Name: "int64", OrgName: "int64", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_INT64, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 20, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, 826 827 // Floating Point and Decimal Types 828 {Name: "float32", OrgName: "float32", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_FLOAT32, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 12, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, 829 {Name: "float64", OrgName: "float64", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_FLOAT64, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 22, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, 830 {Name: "decimal10_0", OrgName: "decimal10_0", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_DECIMAL, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 11, Decimals: 0, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, 831 {Name: "decimal60_30", OrgName: "decimal60_30", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_DECIMAL, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 62, Decimals: 30, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, 832 833 // Char, Binary, and Bit Types 834 {Name: "varchar50", OrgName: "varchar50", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_VARCHAR, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 50 * 4, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, 835 {Name: "varbinary12345", OrgName: "varbinary12345", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_VARBINARY, Charset: mysql.CharacterSetBinary, ColumnLength: 12345, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, 836 {Name: "binary123", OrgName: "binary123", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_BINARY, Charset: mysql.CharacterSetBinary, ColumnLength: 123, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, 837 {Name: "char123", OrgName: "char123", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_CHAR, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 123 * 4, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, 838 {Name: "bit12", OrgName: "bit12", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_BIT, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 12, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, 839 840 // Dates 841 {Name: "datetime", OrgName: "datetime", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_DATETIME, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 26, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, 842 {Name: "timestamp", OrgName: "timestamp", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_TIMESTAMP, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 26, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, 843 {Name: "date", OrgName: "date", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_DATE, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 10, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, 844 {Name: "time", OrgName: "time", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_TIME, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 17, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, 845 {Name: "year", OrgName: "year", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_YEAR, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 4, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, 846 847 // Set and Enum Types 848 {Name: "set", OrgName: "set", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_SET, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 72, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, 849 {Name: "enum", OrgName: "enum", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_ENUM, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 20, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, 850 } 851 852 require.Equal(len(schema), len(expected)) 853 854 e, pro := setupMemDB(require) 855 dbFunc := pro.Database 856 857 handler := &Handler{ 858 e: e, 859 sm: NewSessionManager( 860 testSessionBuilder(pro), 861 sql.NoopTracer, 862 dbFunc, 863 sql.NewMemoryManager(nil), 864 sqle.NewProcessList(), 865 "foo", 866 ), 867 readTimeout: time.Second, 868 } 869 870 conn := newConn(1) 871 handler.NewConnection(conn) 872 873 ctx, err := handler.sm.NewContextWithQuery(conn, "SELECT 1") 874 require.NoError(err) 875 876 fields := schemaToFields(ctx, schema) 877 for i := 0; i < len(fields); i++ { 878 t.Run(schema[i].Name, func(t *testing.T) { 879 assert.Equal(t, expected[i], fields[i]) 880 }) 881 } 882 } 883 884 // TestHandlerMaxTextResponseBytes tests that the handler calculates the correct max text response byte 885 // metadata for TEXT types, including honoring the character_set_results session variable. This is tested 886 // here, instead of in string type unit tests, because of the dependency on system variables being loaded. 887 func TestHandlerMaxTextResponseBytes(t *testing.T) { 888 variables.InitSystemVariables() 889 session := sql.NewBaseSession() 890 ctx := sql.NewContext( 891 context.Background(), 892 sql.WithSession(session), 893 ) 894 895 tinyTextUtf8mb4 := types.MustCreateString(sqltypes.Text, types.TinyTextBlobMax, sql.Collation_Default) 896 textUtf8mb4 := types.MustCreateString(sqltypes.Text, types.TextBlobMax, sql.Collation_Default) 897 mediumTextUtf8mb4 := types.MustCreateString(sqltypes.Text, types.MediumTextBlobMax, sql.Collation_Default) 898 longTextUtf8mb4 := types.MustCreateString(sqltypes.Text, types.LongTextBlobMax, sql.Collation_Default) 899 900 // When character_set_results is set to utf8mb4, the multibyte character multiplier is 4 901 require.NoError(t, session.SetSessionVariable(ctx, "character_set_results", "utf8mb4")) 902 require.EqualValues(t, types.TinyTextBlobMax*4, tinyTextUtf8mb4.MaxTextResponseByteLength(ctx)) 903 require.EqualValues(t, types.TextBlobMax*4, textUtf8mb4.MaxTextResponseByteLength(ctx)) 904 require.EqualValues(t, types.MediumTextBlobMax*4, mediumTextUtf8mb4.MaxTextResponseByteLength(ctx)) 905 require.EqualValues(t, types.LongTextBlobMax, longTextUtf8mb4.MaxTextResponseByteLength(ctx)) 906 907 // When character_set_results is set to utf8mb3, the multibyte character multiplier is 3 908 require.NoError(t, session.SetSessionVariable(ctx, "character_set_results", "utf8mb3")) 909 require.EqualValues(t, types.TinyTextBlobMax*3, tinyTextUtf8mb4.MaxTextResponseByteLength(ctx)) 910 require.EqualValues(t, types.TextBlobMax*3, textUtf8mb4.MaxTextResponseByteLength(ctx)) 911 require.EqualValues(t, types.MediumTextBlobMax*3, mediumTextUtf8mb4.MaxTextResponseByteLength(ctx)) 912 require.EqualValues(t, types.LongTextBlobMax, longTextUtf8mb4.MaxTextResponseByteLength(ctx)) 913 914 // When character_set_results is set to utf8, the multibyte character multiplier is 3 915 require.NoError(t, session.SetSessionVariable(ctx, "character_set_results", "utf8")) 916 require.EqualValues(t, types.TinyTextBlobMax*3, tinyTextUtf8mb4.MaxTextResponseByteLength(ctx)) 917 require.EqualValues(t, types.TextBlobMax*3, textUtf8mb4.MaxTextResponseByteLength(ctx)) 918 require.EqualValues(t, types.MediumTextBlobMax*3, mediumTextUtf8mb4.MaxTextResponseByteLength(ctx)) 919 require.EqualValues(t, types.LongTextBlobMax, longTextUtf8mb4.MaxTextResponseByteLength(ctx)) 920 921 // When character_set_results is set to NULL, the multibyte character multiplier is taken from 922 // the type's charset (4 in this case) 923 require.NoError(t, session.SetSessionVariable(ctx, "character_set_results", nil)) 924 require.EqualValues(t, types.TinyTextBlobMax*4, tinyTextUtf8mb4.MaxTextResponseByteLength(ctx)) 925 require.EqualValues(t, types.TextBlobMax*4, textUtf8mb4.MaxTextResponseByteLength(ctx)) 926 require.EqualValues(t, types.MediumTextBlobMax*4, mediumTextUtf8mb4.MaxTextResponseByteLength(ctx)) 927 require.EqualValues(t, types.LongTextBlobMax, longTextUtf8mb4.MaxTextResponseByteLength(ctx)) 928 } 929 930 func TestHandlerTimeout(t *testing.T) { 931 require := require.New(t) 932 933 e, pro := setupMemDB(require) 934 dbFunc := pro.Database 935 936 e2, pro2 := setupMemDB(require) 937 dbFunc2 := pro2.Database 938 939 timeOutHandler := &Handler{ 940 e: e, 941 sm: NewSessionManager(testSessionBuilder(pro), 942 sql.NoopTracer, 943 dbFunc, 944 sql.NewMemoryManager(nil), 945 sqle.NewProcessList(), 946 "foo"), 947 readTimeout: 1 * time.Second, 948 } 949 950 noTimeOutHandler := &Handler{ 951 e: e2, 952 sm: NewSessionManager(testSessionBuilder(pro2), 953 sql.NoopTracer, 954 dbFunc2, 955 sql.NewMemoryManager(nil), 956 sqle.NewProcessList(), 957 "foo"), 958 } 959 require.Equal(1*time.Second, timeOutHandler.readTimeout) 960 require.Equal(0*time.Second, noTimeOutHandler.readTimeout) 961 962 connTimeout := newConn(1) 963 timeOutHandler.NewConnection(connTimeout) 964 965 connNoTimeout := newConn(2) 966 noTimeOutHandler.NewConnection(connNoTimeout) 967 968 timeOutHandler.ComInitDB(connTimeout, "test") 969 err := timeOutHandler.ComQuery(connTimeout, "SELECT SLEEP(2)", func(res *sqltypes.Result, more bool) error { 970 return nil 971 }) 972 require.EqualError(err, "row read wait bigger than connection timeout (errno 1105) (sqlstate HY000)") 973 974 err = timeOutHandler.ComQuery(connTimeout, "SELECT SLEEP(0.5)", func(res *sqltypes.Result, more bool) error { 975 return nil 976 }) 977 require.NoError(err) 978 979 noTimeOutHandler.ComInitDB(connNoTimeout, "test") 980 err = noTimeOutHandler.ComQuery(connNoTimeout, "SELECT SLEEP(2)", func(res *sqltypes.Result, more bool) error { 981 return nil 982 }) 983 require.NoError(err) 984 } 985 986 func TestOkClosedConnection(t *testing.T) { 987 require := require.New(t) 988 e, pro := setupMemDB(require) 989 dbFunc := pro.Database 990 991 port, err := getFreePort() 992 require.NoError(err) 993 994 ready := make(chan struct{}) 995 go okTestServer(t, ready, port) 996 <-ready 997 conn, err := net.Dial("tcp", "localhost:"+port) 998 require.NoError(err) 999 defer func() { 1000 _ = conn.Close() 1001 }() 1002 1003 h := &Handler{ 1004 e: e, 1005 sm: NewSessionManager( 1006 testSessionBuilder(pro), 1007 sql.NoopTracer, 1008 dbFunc, 1009 sql.NewMemoryManager(nil), 1010 sqle.NewProcessList(), 1011 "foo", 1012 ), 1013 } 1014 c := newConn(1) 1015 h.NewConnection(c) 1016 1017 q := fmt.Sprintf("SELECT SLEEP(%d)", (tcpCheckerSleepDuration * 4 / time.Second)) 1018 h.ComInitDB(c, "test") 1019 err = h.ComQuery(c, q, func(res *sqltypes.Result, more bool) error { 1020 return nil 1021 }) 1022 require.NoError(err) 1023 } 1024 1025 // Tests the CLIENT_FOUND_ROWS capabilities flag 1026 func TestHandlerFoundRowsCapabilities(t *testing.T) { 1027 e, pro := setupMemDB(require.New(t)) 1028 dbFunc := pro.Database 1029 dummyConn := newConn(1) 1030 1031 // Set the capabilities to include found rows 1032 dummyConn.Capabilities = mysql.CapabilityClientFoundRows 1033 1034 // Setup the handler 1035 handler := &Handler{ 1036 e: e, 1037 sm: NewSessionManager( 1038 testSessionBuilder(pro), 1039 sql.NoopTracer, 1040 dbFunc, 1041 sql.NewMemoryManager(nil), 1042 sqle.NewProcessList(), 1043 "foo", 1044 ), 1045 } 1046 1047 tests := []struct { 1048 name string 1049 handler *Handler 1050 conn *mysql.Conn 1051 query string 1052 expectedRowsAffected uint64 1053 }{ 1054 { 1055 name: "Update query should return number of rows matched instead of rows affected", 1056 handler: handler, 1057 conn: dummyConn, 1058 query: "UPDATE test set c1 = c1 where c1 < 10", 1059 expectedRowsAffected: uint64(10), 1060 }, 1061 { 1062 name: "INSERT ON UPDATE returns +1 for every row that already exists", 1063 handler: handler, 1064 conn: dummyConn, 1065 query: "INSERT INTO test VALUES (1), (2), (3) ON DUPLICATE KEY UPDATE c1=c1", 1066 expectedRowsAffected: uint64(3), 1067 }, 1068 { 1069 name: "SQL_CALC_ROWS should not affect CLIENT_FOUND_ROWS output", 1070 handler: handler, 1071 conn: dummyConn, 1072 query: "SELECT SQL_CALC_FOUND_ROWS * FROM test LIMIT 5", 1073 expectedRowsAffected: uint64(5), 1074 }, 1075 { 1076 name: "INSERT returns rows affected", 1077 handler: handler, 1078 conn: dummyConn, 1079 query: "INSERT into test VALUES (10000),(10001),(10002)", 1080 expectedRowsAffected: uint64(3), 1081 }, 1082 } 1083 1084 for _, test := range tests { 1085 t.Run(test.name, func(t *testing.T) { 1086 handler.ComInitDB(test.conn, "test") 1087 var rowsAffected uint64 1088 err := handler.ComQuery(test.conn, test.query, func(res *sqltypes.Result, more bool) error { 1089 rowsAffected = uint64(res.RowsAffected) 1090 return nil 1091 }) 1092 1093 require.NoError(t, err) 1094 require.Equal(t, test.expectedRowsAffected, rowsAffected) 1095 }) 1096 } 1097 } 1098 1099 func setupMemDB(require *require.Assertions) (*sqle.Engine, *memory.DbProvider) { 1100 db := memory.NewDatabase("test") 1101 pro := memory.NewDBProvider(db) 1102 e := sqle.NewDefault(pro) 1103 ctx := sql.NewContext(context.Background(), sql.WithSession(memory.NewSession(sql.NewBaseSession(), pro))) 1104 1105 tableTest := memory.NewTable(db, "test", sql.NewPrimaryKeySchema(sql.Schema{{Name: "c1", Type: types.Int32, Source: "test"}}), nil) 1106 tableTest.EnablePrimaryKeyIndexes() 1107 1108 for i := 0; i < 1010; i++ { 1109 require.NoError(tableTest.Insert( 1110 ctx, 1111 sql.NewRow(int32(i)), 1112 )) 1113 } 1114 1115 db.AddTable("test", tableTest) 1116 1117 return e, pro 1118 } 1119 1120 func getFreePort() (string, error) { 1121 addr, err := net.ResolveTCPAddr("tcp", "localhost:0") 1122 if err != nil { 1123 return "", err 1124 } 1125 1126 l, err := net.ListenTCP("tcp", addr) 1127 if err != nil { 1128 return "", err 1129 } 1130 defer l.Close() 1131 return strconv.Itoa(l.Addr().(*net.TCPAddr).Port), nil 1132 } 1133 1134 func testServer(t *testing.T, ready chan struct{}, port string, breakConn bool) { 1135 l, err := net.Listen("tcp", ":"+port) 1136 defer func() { 1137 _ = l.Close() 1138 }() 1139 if err != nil { 1140 t.Fatal(err) 1141 } 1142 close(ready) 1143 conn, err := l.Accept() 1144 if err != nil { 1145 return 1146 } 1147 1148 if !breakConn { 1149 defer func() { 1150 _ = conn.Close() 1151 }() 1152 1153 _, err = io.ReadAll(conn) 1154 if err != nil { 1155 t.Fatal(err) 1156 } 1157 } // else: dirty return without closing or reading to force the socket into TIME_WAIT 1158 } 1159 func okTestServer(t *testing.T, ready chan struct{}, port string) { 1160 testServer(t, ready, port, false) 1161 } 1162 1163 // This session builder is used as dummy mysql Conn is not complete and 1164 // causes panic when accessing remote address. 1165 func testSessionBuilder(pro *memory.DbProvider) func(ctx context.Context, c *mysql.Conn, addr string) (sql.Session, error) { 1166 return func(ctx context.Context, c *mysql.Conn, addr string) (sql.Session, error) { 1167 base := sql.NewBaseSessionWithClientServer(addr, sql.Client{Address: "127.0.0.1:34567", User: c.User, Capabilities: c.Capabilities}, c.ConnectionID) 1168 return memory.NewSession(base, pro), nil 1169 } 1170 } 1171 1172 type mockConn struct { 1173 net.Conn 1174 closed bool 1175 } 1176 1177 func (c *mockConn) Close() error { 1178 c.closed = true 1179 return nil 1180 } 1181 1182 func (c *mockConn) RemoteAddr() net.Addr { 1183 return mockAddr{} 1184 } 1185 1186 type mockAddr struct{} 1187 1188 func (mockAddr) Network() string { 1189 return "tcp" 1190 } 1191 1192 func (mockAddr) String() string { 1193 return "localhost" 1194 } 1195 1196 func newConn(id uint32) *mysql.Conn { 1197 return &mysql.Conn{ 1198 ConnectionID: id, 1199 Conn: new(mockConn), 1200 } 1201 }