vitess.io/vitess@v0.16.2/go/mysql/server_flaky_test.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package mysql 18 19 import ( 20 "context" 21 "crypto/tls" 22 "fmt" 23 "net" 24 "os" 25 "os/exec" 26 "path" 27 "strings" 28 "sync" 29 "testing" 30 "time" 31 32 "github.com/stretchr/testify/assert" 33 "github.com/stretchr/testify/require" 34 35 "vitess.io/vitess/go/sqltypes" 36 "vitess.io/vitess/go/test/utils" 37 vtenv "vitess.io/vitess/go/vt/env" 38 "vitess.io/vitess/go/vt/tlstest" 39 "vitess.io/vitess/go/vt/vterrors" 40 "vitess.io/vitess/go/vt/vttls" 41 42 querypb "vitess.io/vitess/go/vt/proto/query" 43 vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" 44 ) 45 46 var selectRowsResult = &sqltypes.Result{ 47 Fields: []*querypb.Field{ 48 { 49 Name: "id", 50 Type: querypb.Type_INT32, 51 }, 52 { 53 Name: "name", 54 Type: querypb.Type_VARCHAR, 55 }, 56 }, 57 Rows: [][]sqltypes.Value{ 58 { 59 sqltypes.MakeTrusted(querypb.Type_INT32, []byte("10")), 60 sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte("nice name")), 61 }, 62 { 63 sqltypes.MakeTrusted(querypb.Type_INT32, []byte("20")), 64 sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte("nicer name")), 65 }, 66 }, 67 } 68 69 type testHandler struct { 70 UnimplementedHandler 71 mu sync.Mutex 72 lastConn *Conn 73 result *sqltypes.Result 74 err error 75 warnings uint16 76 } 77 78 func (th *testHandler) LastConn() *Conn { 79 th.mu.Lock() 80 defer th.mu.Unlock() 81 return th.lastConn 82 } 83 84 func (th *testHandler) Result() *sqltypes.Result { 85 th.mu.Lock() 86 defer th.mu.Unlock() 87 return th.result 88 } 89 90 func (th *testHandler) SetErr(err error) { 91 th.mu.Lock() 92 defer th.mu.Unlock() 93 th.err = err 94 } 95 96 func (th *testHandler) Err() error { 97 th.mu.Lock() 98 defer th.mu.Unlock() 99 return th.err 100 } 101 102 func (th *testHandler) SetWarnings(count uint16) { 103 th.mu.Lock() 104 defer th.mu.Unlock() 105 th.warnings = count 106 } 107 108 func (th *testHandler) NewConnection(c *Conn) { 109 th.mu.Lock() 110 defer th.mu.Unlock() 111 th.lastConn = c 112 } 113 114 func (th *testHandler) ComQuery(c *Conn, query string, callback func(*sqltypes.Result) error) error { 115 if result := th.Result(); result != nil { 116 callback(result) 117 return nil 118 } 119 120 switch query { 121 case "error": 122 return th.Err() 123 case "panic": 124 panic("test panic attack!") 125 case "select rows": 126 callback(selectRowsResult) 127 case "error after send": 128 callback(selectRowsResult) 129 return th.Err() 130 case "insert": 131 callback(&sqltypes.Result{ 132 RowsAffected: 123, 133 InsertID: 123456789, 134 }) 135 case "schema echo": 136 callback(&sqltypes.Result{ 137 Fields: []*querypb.Field{ 138 { 139 Name: "schema_name", 140 Type: querypb.Type_VARCHAR, 141 }, 142 }, 143 Rows: [][]sqltypes.Value{ 144 { 145 sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte(c.schemaName)), 146 }, 147 }, 148 }) 149 case "ssl echo": 150 value := "OFF" 151 if c.Capabilities&CapabilityClientSSL > 0 { 152 value = "ON" 153 } 154 callback(&sqltypes.Result{ 155 Fields: []*querypb.Field{ 156 { 157 Name: "ssl_flag", 158 Type: querypb.Type_VARCHAR, 159 }, 160 }, 161 Rows: [][]sqltypes.Value{ 162 { 163 sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte(value)), 164 }, 165 }, 166 }) 167 case "userData echo": 168 callback(&sqltypes.Result{ 169 Fields: []*querypb.Field{ 170 { 171 Name: "user", 172 Type: querypb.Type_VARCHAR, 173 }, 174 { 175 Name: "user_data", 176 Type: querypb.Type_VARCHAR, 177 }, 178 }, 179 Rows: [][]sqltypes.Value{ 180 { 181 sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte(c.User)), 182 sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte(c.UserData.Get().Username)), 183 }, 184 }, 185 }) 186 case "50ms delay": 187 callback(&sqltypes.Result{ 188 Fields: []*querypb.Field{{ 189 Name: "result", 190 Type: querypb.Type_VARCHAR, 191 }}, 192 }) 193 time.Sleep(50 * time.Millisecond) 194 callback(&sqltypes.Result{ 195 Rows: [][]sqltypes.Value{{ 196 sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte("delayed")), 197 }}, 198 }) 199 default: 200 if strings.HasPrefix(query, benchmarkQueryPrefix) { 201 callback(&sqltypes.Result{ 202 Fields: []*querypb.Field{ 203 { 204 Name: "result", 205 Type: querypb.Type_VARCHAR, 206 }, 207 }, 208 Rows: [][]sqltypes.Value{ 209 { 210 sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte(query)), 211 }, 212 }, 213 }) 214 } 215 216 callback(&sqltypes.Result{}) 217 } 218 return nil 219 } 220 221 func (th *testHandler) ComPrepare(c *Conn, query string, bindVars map[string]*querypb.BindVariable) ([]*querypb.Field, error) { 222 return nil, nil 223 } 224 225 func (th *testHandler) ComStmtExecute(c *Conn, prepare *PrepareData, callback func(*sqltypes.Result) error) error { 226 return nil 227 } 228 229 func (th *testHandler) ComRegisterReplica(c *Conn, replicaHost string, replicaPort uint16, replicaUser string, replicaPassword string) error { 230 return nil 231 } 232 func (th *testHandler) ComBinlogDump(c *Conn, logFile string, binlogPos uint32) error { 233 return nil 234 } 235 func (th *testHandler) ComBinlogDumpGTID(c *Conn, logFile string, logPos uint64, gtidSet GTIDSet) error { 236 return nil 237 } 238 239 func (th *testHandler) WarningCount(c *Conn) uint16 { 240 th.mu.Lock() 241 defer th.mu.Unlock() 242 return th.warnings 243 } 244 245 func getHostPort(t *testing.T, a net.Addr) (string, int) { 246 host := a.(*net.TCPAddr).IP.String() 247 port := a.(*net.TCPAddr).Port 248 t.Logf("listening on address '%v' port %v", host, port) 249 return host, port 250 } 251 252 func TestConnectionFromListener(t *testing.T) { 253 th := &testHandler{} 254 255 authServer := NewAuthServerStatic("", "", 0) 256 authServer.entries["user1"] = []*AuthServerStaticEntry{{ 257 Password: "password1", 258 UserData: "userData1", 259 }} 260 defer authServer.close() 261 // Make sure we can create our own net.Listener for use with the mysql 262 // listener 263 listener, err := net.Listen("tcp", "127.0.0.1:") 264 require.NoError(t, err, "net.Listener failed") 265 266 l, err := NewFromListener(listener, authServer, th, 0, 0, false) 267 require.NoError(t, err, "NewListener failed") 268 defer l.Close() 269 go l.Accept() 270 271 host, port := getHostPort(t, l.Addr()) 272 fmt.Printf("host: %s, port: %d\n", host, port) 273 // Setup the right parameters. 274 params := &ConnParams{ 275 Host: host, 276 Port: port, 277 Uname: "user1", 278 Pass: "password1", 279 } 280 281 c, err := Connect(context.Background(), params) 282 require.NoError(t, err, "Should be able to connect to server") 283 c.Close() 284 } 285 286 func TestConnectionWithoutSourceHost(t *testing.T) { 287 th := &testHandler{} 288 289 authServer := NewAuthServerStatic("", "", 0) 290 authServer.entries["user1"] = []*AuthServerStaticEntry{{ 291 Password: "password1", 292 UserData: "userData1", 293 }} 294 defer authServer.close() 295 l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) 296 require.NoError(t, err, "NewListener failed") 297 defer l.Close() 298 go l.Accept() 299 300 host, port := getHostPort(t, l.Addr()) 301 302 // Setup the right parameters. 303 params := &ConnParams{ 304 Host: host, 305 Port: port, 306 Uname: "user1", 307 Pass: "password1", 308 } 309 310 c, err := Connect(context.Background(), params) 311 require.NoError(t, err, "Should be able to connect to server") 312 c.Close() 313 } 314 315 func TestConnectionWithSourceHost(t *testing.T) { 316 th := &testHandler{} 317 318 authServer := NewAuthServerStatic("", "", 0) 319 authServer.entries["user1"] = []*AuthServerStaticEntry{ 320 { 321 Password: "password1", 322 UserData: "userData1", 323 SourceHost: "localhost", 324 }, 325 } 326 defer authServer.close() 327 328 l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) 329 require.NoError(t, err, "NewListener failed") 330 defer l.Close() 331 go l.Accept() 332 333 host, port := getHostPort(t, l.Addr()) 334 335 // Setup the right parameters. 336 params := &ConnParams{ 337 Host: host, 338 Port: port, 339 Uname: "user1", 340 Pass: "password1", 341 } 342 343 _, err = Connect(context.Background(), params) 344 // target is localhost, should not work from tcp connection 345 require.EqualError(t, err, "Access denied for user 'user1' (errno 1045) (sqlstate 28000)", "Should not be able to connect to server") 346 } 347 348 func TestConnectionUseMysqlNativePasswordWithSourceHost(t *testing.T) { 349 th := &testHandler{} 350 351 authServer := NewAuthServerStatic("", "", 0) 352 authServer.entries["user1"] = []*AuthServerStaticEntry{ 353 { 354 MysqlNativePassword: "*9E128DA0C64A6FCCCDCFBDD0FC0A2C967C6DB36F", 355 UserData: "userData1", 356 SourceHost: "localhost", 357 }, 358 } 359 defer authServer.close() 360 361 l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) 362 require.NoError(t, err, "NewListener failed") 363 defer l.Close() 364 go l.Accept() 365 366 host, port := getHostPort(t, l.Addr()) 367 368 // Setup the right parameters. 369 params := &ConnParams{ 370 Host: host, 371 Port: port, 372 Uname: "user1", 373 Pass: "mysql_password", 374 } 375 376 _, err = Connect(context.Background(), params) 377 // target is localhost, should not work from tcp connection 378 require.EqualError(t, err, "Access denied for user 'user1' (errno 1045) (sqlstate 28000)", "Should not be able to connect to server") 379 } 380 381 func TestConnectionUnixSocket(t *testing.T) { 382 th := &testHandler{} 383 384 authServer := NewAuthServerStatic("", "", 0) 385 authServer.entries["user1"] = []*AuthServerStaticEntry{ 386 { 387 Password: "password1", 388 UserData: "userData1", 389 SourceHost: "localhost", 390 }, 391 } 392 defer authServer.close() 393 394 unixSocket, err := os.CreateTemp("", "mysql_vitess_test.sock") 395 require.NoError(t, err, "Failed to create temp file") 396 397 os.Remove(unixSocket.Name()) 398 399 l, err := NewListener("unix", unixSocket.Name(), authServer, th, 0, 0, false, false) 400 require.NoError(t, err, "NewListener failed") 401 defer l.Close() 402 go l.Accept() 403 404 // Setup the right parameters. 405 params := &ConnParams{ 406 UnixSocket: unixSocket.Name(), 407 Uname: "user1", 408 Pass: "password1", 409 } 410 411 c, err := Connect(context.Background(), params) 412 require.NoError(t, err, "Should be able to connect to server") 413 c.Close() 414 } 415 416 func TestClientFoundRows(t *testing.T) { 417 th := &testHandler{} 418 419 authServer := NewAuthServerStatic("", "", 0) 420 authServer.entries["user1"] = []*AuthServerStaticEntry{{ 421 Password: "password1", 422 UserData: "userData1", 423 }} 424 defer authServer.close() 425 l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) 426 require.NoError(t, err, "NewListener failed") 427 defer l.Close() 428 go l.Accept() 429 430 host, port := getHostPort(t, l.Addr()) 431 432 // Setup the right parameters. 433 params := &ConnParams{ 434 Host: host, 435 Port: port, 436 Uname: "user1", 437 Pass: "password1", 438 } 439 440 // Test without flag. 441 c, err := Connect(context.Background(), params) 442 require.NoError(t, err, "Connect failed") 443 foundRows := th.LastConn().Capabilities & CapabilityClientFoundRows 444 assert.Equal(t, uint32(0), foundRows, "FoundRows flag: %x, second bit must be 0", th.LastConn().Capabilities) 445 c.Close() 446 assert.True(t, c.IsClosed(), "IsClosed should be true on Close-d connection.") 447 448 // Test with flag. 449 params.Flags |= CapabilityClientFoundRows 450 c, err = Connect(context.Background(), params) 451 require.NoError(t, err, "Connect failed") 452 foundRows = th.LastConn().Capabilities & CapabilityClientFoundRows 453 assert.NotZero(t, foundRows, "FoundRows flag: %x, second bit must be set", th.LastConn().Capabilities) 454 c.Close() 455 } 456 457 func TestConnCounts(t *testing.T) { 458 th := &testHandler{} 459 460 initialNumUsers := len(connCountPerUser.Counts()) 461 462 // FIXME: we should be able to ResetAll counters instead of computing a delta, but it doesn't work for some reason 463 // connCountPerUser.ResetAll() 464 465 user := "anotherNotYetConnectedUser1" 466 passwd := "password1" 467 468 authServer := NewAuthServerStatic("", "", 0) 469 authServer.entries[user] = []*AuthServerStaticEntry{{ 470 Password: passwd, 471 UserData: "userData1", 472 }} 473 defer authServer.close() 474 l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) 475 require.NoError(t, err, "NewListener failed") 476 defer l.Close() 477 go l.Accept() 478 479 host, port := getHostPort(t, l.Addr()) 480 481 // Test with one new connection. 482 params := &ConnParams{ 483 Host: host, 484 Port: port, 485 Uname: user, 486 Pass: passwd, 487 } 488 489 c, err := Connect(context.Background(), params) 490 require.NoError(t, err, "Connect failed") 491 492 connCounts := connCountPerUser.Counts() 493 assert.Equal(t, 1, len(connCounts)-initialNumUsers) 494 checkCountsForUser(t, user, 1) 495 496 // Test with a second new connection. 497 c2, err := Connect(context.Background(), params) 498 require.NoError(t, err) 499 connCounts = connCountPerUser.Counts() 500 // There is still only one new user. 501 assert.Equal(t, 1, len(connCounts)-initialNumUsers) 502 checkCountsForUser(t, user, 2) 503 504 // Test after closing connections. time.Sleep lets it work, but seems flakey. 505 c.Close() 506 //time.Sleep(10 * time.Millisecond) 507 //checkCountsForUser(t, user, 1) 508 509 c2.Close() 510 //time.Sleep(10 * time.Millisecond) 511 //checkCountsForUser(t, user, 0) 512 } 513 514 func checkCountsForUser(t *testing.T, user string, expected int64) { 515 connCounts := connCountPerUser.Counts() 516 517 userCount, ok := connCounts[user] 518 assert.True(t, ok, "No count found for user %s", user) 519 assert.Equal(t, expected, userCount) 520 } 521 522 func TestServer(t *testing.T) { 523 th := &testHandler{} 524 525 authServer := NewAuthServerStatic("", "", 0) 526 authServer.entries["user1"] = []*AuthServerStaticEntry{{ 527 Password: "password1", 528 UserData: "userData1", 529 }} 530 defer authServer.close() 531 l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) 532 require.NoError(t, err) 533 l.SlowConnectWarnThreshold.Set(time.Nanosecond * 1) 534 defer l.Close() 535 go l.Accept() 536 537 host, port := getHostPort(t, l.Addr()) 538 539 // Setup the right parameters. 540 params := &ConnParams{ 541 Host: host, 542 Port: port, 543 Uname: "user1", 544 Pass: "password1", 545 } 546 547 // Run a 'select rows' command with results. 548 output, err := runMysqlWithErr(t, params, "select rows") 549 require.NoError(t, err) 550 551 assert.Contains(t, output, "nice name", "Unexpected output for 'select rows'") 552 assert.Contains(t, output, "nicer name", "Unexpected output for 'select rows'") 553 assert.Contains(t, output, "2 rows in set", "Unexpected output for 'select rows'") 554 assert.NotContains(t, output, "warnings") 555 556 // Run a 'select rows' command with warnings 557 th.SetWarnings(13) 558 output, err = runMysqlWithErr(t, params, "select rows") 559 require.NoError(t, err) 560 assert.Contains(t, output, "nice name", "Unexpected output for 'select rows'") 561 assert.Contains(t, output, "nicer name", "Unexpected output for 'select rows'") 562 assert.Contains(t, output, "2 rows in set", "Unexpected output for 'select rows'") 563 assert.Contains(t, output, "13 warnings", "Unexpected output for 'select rows'") 564 th.SetWarnings(0) 565 566 // If there's an error after streaming has started, 567 // we should get a 2013 568 th.SetErr(NewSQLError(ERUnknownComError, SSNetError, "forced error after send")) 569 output, err = runMysqlWithErr(t, params, "error after send") 570 require.Error(t, err) 571 assert.Contains(t, output, "ERROR 2013 (HY000)", "Unexpected output for 'panic'") 572 // MariaDB might not print the MySQL bit here 573 assert.Regexp(t, `Lost connection to( MySQL)? server during query`, output, "Unexpected output for 'panic': %v", output) 574 575 // Run an 'insert' command, no rows, but rows affected. 576 output, err = runMysqlWithErr(t, params, "insert") 577 require.NoError(t, err) 578 assert.Contains(t, output, "Query OK, 123 rows affected", "Unexpected output for 'insert'") 579 580 // Run a 'schema echo' command, to make sure db name is right. 581 params.DbName = "XXXfancyXXX" 582 output, err = runMysqlWithErr(t, params, "schema echo") 583 require.NoError(t, err) 584 assert.Contains(t, output, params.DbName, "Unexpected output for 'schema echo'") 585 586 // Sanity check: make sure this didn't go through SSL 587 output, err = runMysqlWithErr(t, params, "ssl echo") 588 require.NoError(t, err) 589 assert.Contains(t, output, "ssl_flag") 590 assert.Contains(t, output, "OFF") 591 assert.Contains(t, output, "1 row in set", "Unexpected output for 'ssl echo': %v", output) 592 593 // UserData check: checks the server user data is correct. 594 output, err = runMysqlWithErr(t, params, "userData echo") 595 require.NoError(t, err) 596 assert.Contains(t, output, "user1") 597 assert.Contains(t, output, "user_data") 598 assert.Contains(t, output, "userData1", "Unexpected output for 'userData echo': %v", output) 599 600 // Permissions check: check a bad password is rejected. 601 params.Pass = "bad" 602 output, err = runMysqlWithErr(t, params, "select rows") 603 require.Error(t, err) 604 assert.Contains(t, output, "1045") 605 assert.Contains(t, output, "28000") 606 assert.Contains(t, output, "Access denied", "Unexpected output for invalid password: %v", output) 607 608 // Permissions check: check an unknown user is rejected. 609 params.Pass = "password1" 610 params.Uname = "user2" 611 output, err = runMysqlWithErr(t, params, "select rows") 612 require.Error(t, err) 613 assert.Contains(t, output, "1045") 614 assert.Contains(t, output, "28000") 615 assert.Contains(t, output, "Access denied", "Unexpected output for invalid password: %v", output) 616 617 // Uncomment to leave setup up for a while, to run tests manually. 618 // fmt.Printf("Listening to server on host '%v' port '%v'.\n", host, port) 619 // time.Sleep(60 * time.Minute) 620 } 621 622 func TestServerStats(t *testing.T) { 623 th := &testHandler{} 624 625 authServer := NewAuthServerStatic("", "", 0) 626 authServer.entries["user1"] = []*AuthServerStaticEntry{{ 627 Password: "password1", 628 UserData: "userData1", 629 }} 630 defer authServer.close() 631 l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) 632 require.NoError(t, err) 633 l.SlowConnectWarnThreshold.Set(time.Nanosecond * 1) 634 defer l.Close() 635 go l.Accept() 636 637 host, port := getHostPort(t, l.Addr()) 638 639 // Setup the right parameters. 640 params := &ConnParams{ 641 Host: host, 642 Port: port, 643 Uname: "user1", 644 Pass: "password1", 645 } 646 647 timings.Reset() 648 connAccept.Reset() 649 connCount.Reset() 650 connSlow.Reset() 651 connRefuse.Reset() 652 653 // Run an 'error' command. 654 th.SetErr(NewSQLError(ERUnknownComError, SSNetError, "forced query error")) 655 output, ok := runMysql(t, params, "error") 656 require.False(t, ok, "mysql should have failed: %v", output) 657 658 assert.Contains(t, output, "ERROR 1047 (08S01)") 659 assert.Contains(t, output, "forced query error", "Unexpected output for 'error': %v", output) 660 661 assert.EqualValues(t, 0, connCount.Get(), "connCount") 662 assert.EqualValues(t, 1, connAccept.Get(), "connAccept") 663 assert.EqualValues(t, 1, connSlow.Get(), "connSlow") 664 assert.EqualValues(t, 0, connRefuse.Get(), "connRefuse") 665 666 expectedTimingDeltas := map[string]int64{ 667 "All": 2, 668 connectTimingKey: 1, 669 queryTimingKey: 1, 670 } 671 gotTimingCounts := timings.Counts() 672 for key, got := range gotTimingCounts { 673 expected := expectedTimingDeltas[key] 674 assert.GreaterOrEqual(t, got, expected, "Expected Timing count delta %s should be >= %d, got %d", key, expected, got) 675 } 676 677 // Set the slow connect threshold to something high that we don't expect to trigger 678 l.SlowConnectWarnThreshold.Set(time.Second * 1) 679 680 // Run a 'panic' command, other side should panic, recover and 681 // close the connection. 682 output, err = runMysqlWithErr(t, params, "panic") 683 require.Error(t, err) 684 assert.Contains(t, output, "ERROR 2013 (HY000)") 685 // MariaDB might not print the MySQL bit here 686 assert.Regexp(t, `Lost connection to( MySQL)? server during query`, output, "Unexpected output for 'panic': %v", output) 687 688 assert.EqualValues(t, 0, connCount.Get(), "connCount") 689 assert.EqualValues(t, 2, connAccept.Get(), "connAccept") 690 assert.EqualValues(t, 1, connSlow.Get(), "connSlow") 691 assert.EqualValues(t, 0, connRefuse.Get(), "connRefuse") 692 } 693 694 // TestClearTextServer creates a Server that needs clear text 695 // passwords from the client. 696 func TestClearTextServer(t *testing.T) { 697 th := &testHandler{} 698 699 authServer := NewAuthServerStaticWithAuthMethodDescription("", "", 0, MysqlClearPassword) 700 authServer.entries["user1"] = []*AuthServerStaticEntry{{ 701 Password: "password1", 702 UserData: "userData1", 703 }} 704 defer authServer.close() 705 l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) 706 require.NoError(t, err) 707 defer l.Close() 708 go l.Accept() 709 710 host, port := getHostPort(t, l.Addr()) 711 712 version, _ := runMysql(t, nil, "--version") 713 isMariaDB := strings.Contains(version, "MariaDB") 714 715 // Setup the right parameters. 716 params := &ConnParams{ 717 Host: host, 718 Port: port, 719 Uname: "user1", 720 Pass: "password1", 721 } 722 723 // Run a 'select rows' command with results. This should fail 724 // as clear text is not enabled by default on the client 725 // (except MariaDB). 726 l.AllowClearTextWithoutTLS.Set(true) 727 sql := "select rows" 728 output, ok := runMysql(t, params, sql) 729 if ok { 730 if isMariaDB { 731 t.Logf("mysql should have failed but returned: %v\nbut letting it go on MariaDB", output) 732 } else { 733 require.Fail(t, "mysql should have failed but returned: %v", output) 734 } 735 } else { 736 if strings.Contains(output, "No such file or directory") { 737 t.Logf("skipping mysql clear text tests, as the clear text plugin cannot be loaded: %v", err) 738 return 739 } 740 assert.Contains(t, output, "plugin not enabled", "Unexpected output for 'select rows': %v", output) 741 } 742 743 // Now enable clear text plugin in client, but server requires SSL. 744 l.AllowClearTextWithoutTLS.Set(false) 745 if !isMariaDB { 746 sql = enableCleartextPluginPrefix + sql 747 } 748 output, ok = runMysql(t, params, sql) 749 assert.False(t, ok, "mysql should have failed but returned: %v", output) 750 assert.Contains(t, output, "Cannot use clear text authentication over non-SSL connections", "Unexpected output for 'select rows': %v", output) 751 752 // Now enable clear text plugin, it should now work. 753 l.AllowClearTextWithoutTLS.Set(true) 754 output, ok = runMysql(t, params, sql) 755 require.True(t, ok, "mysql failed: %v", output) 756 757 assert.Contains(t, output, "nice name", "Unexpected output for 'select rows'") 758 assert.Contains(t, output, "nicer name", "Unexpected output for 'select rows'") 759 assert.Contains(t, output, "2 rows in set", "Unexpected output for 'select rows'") 760 761 // Change password, make sure server rejects us. 762 params.Pass = "bad" 763 output, ok = runMysql(t, params, sql) 764 assert.False(t, ok, "mysql should have failed but returned: %v", output) 765 assert.Contains(t, output, "Access denied for user 'user1'", "Unexpected output for 'select rows': %v", output) 766 } 767 768 // TestDialogServer creates a Server that uses the dialog plugin on the client. 769 func TestDialogServer(t *testing.T) { 770 th := &testHandler{} 771 772 authServer := NewAuthServerStaticWithAuthMethodDescription("", "", 0, MysqlDialog) 773 authServer.entries["user1"] = []*AuthServerStaticEntry{{ 774 Password: "password1", 775 UserData: "userData1", 776 }} 777 defer authServer.close() 778 l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) 779 require.NoError(t, err) 780 l.AllowClearTextWithoutTLS.Set(true) 781 defer l.Close() 782 go l.Accept() 783 784 host, port := getHostPort(t, l.Addr()) 785 786 // Setup the right parameters. 787 params := &ConnParams{ 788 Host: host, 789 Port: port, 790 Uname: "user1", 791 Pass: "password1", 792 SslMode: vttls.Disabled, 793 } 794 sql := "select rows" 795 output, ok := runMysql(t, params, sql) 796 if strings.Contains(output, "No such file or directory") || strings.Contains(output, "Authentication plugin 'dialog' cannot be loaded") { 797 t.Logf("skipping dialog plugin tests, as the dialog plugin cannot be loaded: %v", err) 798 return 799 } 800 require.True(t, ok, "mysql failed: %v", output) 801 assert.Contains(t, output, "nice name", "Unexpected output for 'select rows': %v", output) 802 assert.Contains(t, output, "nicer name", "Unexpected output for 'select rows': %v", output) 803 assert.Contains(t, output, "2 rows in set", "Unexpected output for 'select rows': %v", output) 804 } 805 806 // TestTLSServer creates a Server with TLS support, then uses mysql 807 // client to connect to it. 808 func TestTLSServer(t *testing.T) { 809 th := &testHandler{} 810 811 authServer := NewAuthServerStatic("", "", 0) 812 authServer.entries["user1"] = []*AuthServerStaticEntry{{ 813 Password: "password1", 814 }} 815 defer authServer.close() 816 817 // Create the listener, so we can get its host. 818 // Below, we are enabling --ssl-verify-server-cert, which adds 819 // a check that the common name of the certificate matches the 820 // server host name we connect to. 821 l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) 822 require.NoError(t, err) 823 defer l.Close() 824 825 host := l.Addr().(*net.TCPAddr).IP.String() 826 port := l.Addr().(*net.TCPAddr).Port 827 828 // Create the certs. 829 root := t.TempDir() 830 tlstest.CreateCA(root) 831 tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com") 832 tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "Client Cert") 833 834 // Create the server with TLS config. 835 serverConfig, err := vttls.ServerConfig( 836 path.Join(root, "server-cert.pem"), 837 path.Join(root, "server-key.pem"), 838 path.Join(root, "ca-cert.pem"), 839 "", 840 "", 841 tls.VersionTLS12) 842 require.NoError(t, err) 843 l.TLSConfig.Store(serverConfig) 844 845 var wg sync.WaitGroup 846 wg.Add(1) 847 go func(l *Listener) { 848 wg.Done() 849 l.Accept() 850 }(l) 851 // This is ensure the listener is called 852 wg.Wait() 853 // Sleep so that the Accept function is called as well.' 854 time.Sleep(3 * time.Second) 855 856 connCountByTLSVer.ResetAll() 857 // Setup the right parameters. 858 params := &ConnParams{ 859 Host: host, 860 Port: port, 861 Uname: "user1", 862 Pass: "password1", 863 // SSL flags. 864 SslMode: vttls.VerifyIdentity, 865 SslCa: path.Join(root, "ca-cert.pem"), 866 SslCert: path.Join(root, "client-cert.pem"), 867 SslKey: path.Join(root, "client-key.pem"), 868 ServerName: "server.example.com", 869 } 870 871 // Run a 'select rows' command with results. 872 conn, err := Connect(context.Background(), params) 873 //output, ok := runMysql(t, params, "select rows") 874 require.NoError(t, err) 875 results, err := conn.ExecuteFetch("select rows", 1000, true) 876 require.NoError(t, err) 877 output := "" 878 for _, row := range results.Rows { 879 r := make([]string, 0) 880 for _, col := range row { 881 r = append(r, col.String()) 882 } 883 output = output + strings.Join(r, ",") + "\n" 884 } 885 886 assert.Equal(t, "nice name", results.Rows[0][1].ToString()) 887 assert.Equal(t, "nicer name", results.Rows[1][1].ToString()) 888 assert.Equal(t, 2, len(results.Rows)) 889 890 // make sure this went through SSL 891 results, err = conn.ExecuteFetch("ssl echo", 1000, true) 892 require.NoError(t, err) 893 assert.Equal(t, "ON", results.Rows[0][0].ToString()) 894 895 // Find out which TLS version the connection actually used, 896 // so we can check that the corresponding counter was incremented. 897 tlsVersion := conn.conn.(*tls.Conn).ConnectionState().Version 898 899 checkCountForTLSVer(t, tlsVersionToString(tlsVersion), 1) 900 conn.Close() 901 902 } 903 904 // TestTLSRequired creates a Server with TLS required, then tests that an insecure mysql 905 // client is rejected 906 func TestTLSRequired(t *testing.T) { 907 th := &testHandler{} 908 909 authServer := NewAuthServerStatic("", "", 0) 910 authServer.entries["user1"] = []*AuthServerStaticEntry{{ 911 Password: "password1", 912 }} 913 defer authServer.close() 914 915 // Create the listener, so we can get its host. 916 // Below, we are enabling --ssl-verify-server-cert, which adds 917 // a check that the common name of the certificate matches the 918 // server host name we connect to. 919 l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) 920 require.NoError(t, err) 921 defer l.Close() 922 923 host := l.Addr().(*net.TCPAddr).IP.String() 924 port := l.Addr().(*net.TCPAddr).Port 925 926 // Create the certs. 927 root := t.TempDir() 928 tlstest.CreateCA(root) 929 tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com") 930 tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "Client Cert") 931 tlstest.CreateSignedCert(root, tlstest.CA, "03", "revoked-client", "Revoked Client Cert") 932 tlstest.RevokeCertAndRegenerateCRL(root, tlstest.CA, "revoked-client") 933 934 // Create the server with TLS config. 935 serverConfig, err := vttls.ServerConfig( 936 path.Join(root, "server-cert.pem"), 937 path.Join(root, "server-key.pem"), 938 path.Join(root, "ca-cert.pem"), 939 path.Join(root, "ca-crl.pem"), 940 "", 941 tls.VersionTLS12) 942 require.NoError(t, err) 943 l.TLSConfig.Store(serverConfig) 944 l.RequireSecureTransport = true 945 946 var wg sync.WaitGroup 947 wg.Add(1) 948 go func(l *Listener) { 949 wg.Done() 950 l.Accept() 951 }(l) 952 // This is ensure the listener is called 953 wg.Wait() 954 // Sleep so that the Accept function is called as well.' 955 time.Sleep(3 * time.Second) 956 957 // Setup conn params without SSL. 958 params := &ConnParams{ 959 Host: host, 960 Port: port, 961 Uname: "user1", 962 Pass: "password1", 963 SslMode: vttls.Disabled, 964 ServerName: "server.example.com", 965 } 966 conn, err := Connect(context.Background(), params) 967 require.NotNil(t, err) 968 require.Contains(t, err.Error(), "Code: UNAVAILABLE") 969 require.Contains(t, err.Error(), "server does not allow insecure connections, client must use SSL/TLS") 970 require.Contains(t, err.Error(), "(errno 1105) (sqlstate HY000)") 971 if conn != nil { 972 conn.Close() 973 } 974 975 // setup conn params with TLS 976 params.SslMode = vttls.VerifyIdentity 977 params.SslCa = path.Join(root, "ca-cert.pem") 978 params.SslCert = path.Join(root, "client-cert.pem") 979 params.SslKey = path.Join(root, "client-key.pem") 980 981 conn, err = Connect(context.Background(), params) 982 require.NoError(t, err) 983 if conn != nil { 984 conn.Close() 985 } 986 987 // setup conn params with TLS, but with a revoked client certificate 988 params.SslCert = path.Join(root, "revoked-client-cert.pem") 989 params.SslKey = path.Join(root, "revoked-client-key.pem") 990 conn, err = Connect(context.Background(), params) 991 require.NotNil(t, err) 992 require.Contains(t, err.Error(), "remote error: tls: bad certificate") 993 if conn != nil { 994 conn.Close() 995 } 996 } 997 998 func TestCachingSha2PasswordAuthWithTLS(t *testing.T) { 999 th := &testHandler{} 1000 1001 authServer := NewAuthServerStaticWithAuthMethodDescription("", "", 0, CachingSha2Password) 1002 authServer.entries["user1"] = []*AuthServerStaticEntry{ 1003 {Password: "password1"}, 1004 } 1005 defer authServer.close() 1006 1007 // Create the listener, so we can get its host. 1008 l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) 1009 require.NoError(t, err, "NewListener failed: %v", err) 1010 defer l.Close() 1011 host := l.Addr().(*net.TCPAddr).IP.String() 1012 port := l.Addr().(*net.TCPAddr).Port 1013 1014 // Create the certs. 1015 root := t.TempDir() 1016 tlstest.CreateCA(root) 1017 tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com") 1018 tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "Client Cert") 1019 1020 // Create the server with TLS config. 1021 serverConfig, err := vttls.ServerConfig( 1022 path.Join(root, "server-cert.pem"), 1023 path.Join(root, "server-key.pem"), 1024 path.Join(root, "ca-cert.pem"), 1025 "", 1026 "", 1027 tls.VersionTLS12) 1028 require.NoError(t, err, "TLSServerConfig failed: %v", err) 1029 1030 l.TLSConfig.Store(serverConfig) 1031 go func() { 1032 l.Accept() 1033 }() 1034 1035 // Setup the right parameters. 1036 params := &ConnParams{ 1037 Host: host, 1038 Port: port, 1039 Uname: "user1", 1040 Pass: "password1", 1041 // SSL flags. 1042 SslMode: vttls.VerifyIdentity, 1043 SslCa: path.Join(root, "ca-cert.pem"), 1044 SslCert: path.Join(root, "client-cert.pem"), 1045 SslKey: path.Join(root, "client-key.pem"), 1046 ServerName: "server.example.com", 1047 } 1048 1049 // Connection should fail, as server requires SSL for caching_sha2_password. 1050 ctx := context.Background() 1051 1052 conn, err := Connect(ctx, params) 1053 require.NoError(t, err, "unexpected connection error: %v", err) 1054 1055 defer conn.Close() 1056 1057 // Run a 'select rows' command with results. 1058 result, err := conn.ExecuteFetch("select rows", 10000, true) 1059 require.NoError(t, err, "ExecuteFetch failed: %v", err) 1060 1061 utils.MustMatch(t, result, selectRowsResult) 1062 1063 // Send a ComQuit to avoid the error message on the server side. 1064 conn.writeComQuit() 1065 } 1066 1067 type alwaysFallbackAuth struct{} 1068 1069 func (a *alwaysFallbackAuth) UserEntryWithCacheHash(conn *Conn, salt []byte, user string, authResponse []byte, remoteAddr net.Addr) (Getter, CacheState, error) { 1070 return &StaticUserData{}, AuthNeedMoreData, nil 1071 } 1072 1073 // newAuthServerAlwaysFallback returns a new empty AuthServerStatic 1074 // which will always request more data to trigger fallback auth path 1075 // for caching sha2. 1076 func newAuthServerAlwaysFallback(file, jsonConfig string, reloadInterval time.Duration) *AuthServerStatic { 1077 a := &AuthServerStatic{ 1078 file: file, 1079 jsonConfig: jsonConfig, 1080 reloadInterval: reloadInterval, 1081 entries: make(map[string][]*AuthServerStaticEntry), 1082 } 1083 1084 authMethod := NewSha2CachingAuthMethod(&alwaysFallbackAuth{}, a, a) 1085 a.methods = []AuthMethod{authMethod} 1086 1087 a.reload() 1088 a.installSignalHandlers() 1089 return a 1090 } 1091 1092 func TestCachingSha2PasswordAuthWithMoreData(t *testing.T) { 1093 th := &testHandler{} 1094 1095 authServer := newAuthServerAlwaysFallback("", "", 0) 1096 authServer.entries["user1"] = []*AuthServerStaticEntry{ 1097 {Password: "password1"}, 1098 } 1099 defer authServer.close() 1100 1101 // Create the listener, so we can get its host. 1102 l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) 1103 require.NoError(t, err, "NewListener failed: %v", err) 1104 defer l.Close() 1105 host := l.Addr().(*net.TCPAddr).IP.String() 1106 port := l.Addr().(*net.TCPAddr).Port 1107 1108 // Create the certs. 1109 root := t.TempDir() 1110 tlstest.CreateCA(root) 1111 tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com") 1112 tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "Client Cert") 1113 1114 // Create the server with TLS config. 1115 serverConfig, err := vttls.ServerConfig( 1116 path.Join(root, "server-cert.pem"), 1117 path.Join(root, "server-key.pem"), 1118 path.Join(root, "ca-cert.pem"), 1119 "", 1120 "", 1121 tls.VersionTLS12) 1122 require.NoError(t, err, "TLSServerConfig failed: %v", err) 1123 1124 l.TLSConfig.Store(serverConfig) 1125 go func() { 1126 l.Accept() 1127 }() 1128 1129 // Setup the right parameters. 1130 params := &ConnParams{ 1131 Host: host, 1132 Port: port, 1133 Uname: "user1", 1134 Pass: "password1", 1135 // SSL flags. 1136 SslMode: vttls.VerifyIdentity, 1137 SslCa: path.Join(root, "ca-cert.pem"), 1138 SslCert: path.Join(root, "client-cert.pem"), 1139 SslKey: path.Join(root, "client-key.pem"), 1140 ServerName: "server.example.com", 1141 } 1142 1143 // Connection should fail, as server requires SSL for caching_sha2_password. 1144 ctx := context.Background() 1145 1146 conn, err := Connect(ctx, params) 1147 require.NoError(t, err, "unexpected connection error: %v", err) 1148 1149 defer conn.Close() 1150 1151 // Run a 'select rows' command with results. 1152 result, err := conn.ExecuteFetch("select rows", 10000, true) 1153 require.NoError(t, err, "ExecuteFetch failed: %v", err) 1154 1155 utils.MustMatch(t, result, selectRowsResult) 1156 1157 // Send a ComQuit to avoid the error message on the server side. 1158 conn.writeComQuit() 1159 } 1160 1161 func TestCachingSha2PasswordAuthWithoutTLS(t *testing.T) { 1162 th := &testHandler{} 1163 1164 authServer := NewAuthServerStaticWithAuthMethodDescription("", "", 0, CachingSha2Password) 1165 authServer.entries["user1"] = []*AuthServerStaticEntry{ 1166 {Password: "password1"}, 1167 } 1168 defer authServer.close() 1169 1170 // Create the listener. 1171 l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) 1172 require.NoError(t, err, "NewListener failed: %v", err) 1173 defer l.Close() 1174 host := l.Addr().(*net.TCPAddr).IP.String() 1175 port := l.Addr().(*net.TCPAddr).Port 1176 go func() { 1177 l.Accept() 1178 }() 1179 1180 // Setup the right parameters. 1181 params := &ConnParams{ 1182 Host: host, 1183 Port: port, 1184 Uname: "user1", 1185 Pass: "password1", 1186 SslMode: vttls.Disabled, 1187 } 1188 1189 // Connection should fail, as server requires SSL for caching_sha2_password. 1190 ctx := context.Background() 1191 _, err = Connect(ctx, params) 1192 if err == nil || !strings.Contains(err.Error(), "No authentication methods available for authentication") { 1193 t.Fatalf("unexpected connection error: %v", err) 1194 } 1195 } 1196 1197 func checkCountForTLSVer(t *testing.T, version string, expected int64) { 1198 connCounts := connCountByTLSVer.Counts() 1199 count, ok := connCounts[version] 1200 assert.True(t, ok, "No count found for version %s", version) 1201 assert.Equal(t, expected, count, "Unexpected connection count for version %s", version) 1202 } 1203 1204 func TestErrorCodes(t *testing.T) { 1205 th := &testHandler{} 1206 1207 authServer := NewAuthServerStatic("", "", 0) 1208 authServer.entries["user1"] = []*AuthServerStaticEntry{{ 1209 Password: "password1", 1210 UserData: "userData1", 1211 }} 1212 defer authServer.close() 1213 l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) 1214 require.NoError(t, err) 1215 defer l.Close() 1216 go l.Accept() 1217 1218 host, port := getHostPort(t, l.Addr()) 1219 1220 // Setup the right parameters. 1221 params := &ConnParams{ 1222 Host: host, 1223 Port: port, 1224 Uname: "user1", 1225 Pass: "password1", 1226 } 1227 1228 ctx := context.Background() 1229 client, err := Connect(ctx, params) 1230 require.NoError(t, err) 1231 1232 // Test that the right mysql errno/sqlstate are returned for various 1233 // internal vitess errors 1234 tests := []struct { 1235 err error 1236 code int 1237 sqlState string 1238 text string 1239 }{ 1240 { 1241 err: vterrors.Errorf( 1242 vtrpcpb.Code_INVALID_ARGUMENT, 1243 "invalid argument"), 1244 code: ERUnknownError, 1245 sqlState: SSUnknownSQLState, 1246 text: "invalid argument", 1247 }, 1248 { 1249 err: vterrors.Errorf( 1250 vtrpcpb.Code_INVALID_ARGUMENT, 1251 "(errno %v) (sqlstate %v) invalid argument with errno", ERDupEntry, SSConstraintViolation), 1252 code: ERDupEntry, 1253 sqlState: SSConstraintViolation, 1254 text: "invalid argument with errno", 1255 }, 1256 { 1257 err: vterrors.Errorf( 1258 vtrpcpb.Code_DEADLINE_EXCEEDED, 1259 "connection deadline exceeded"), 1260 code: ERQueryInterrupted, 1261 sqlState: SSQueryInterrupted, 1262 text: "deadline exceeded", 1263 }, 1264 { 1265 err: vterrors.Errorf( 1266 vtrpcpb.Code_RESOURCE_EXHAUSTED, 1267 "query pool timeout"), 1268 code: ERTooManyUserConnections, 1269 sqlState: SSClientError, 1270 text: "resource exhausted", 1271 }, 1272 { 1273 err: vterrors.Wrap(vterrors.Errorf(vtrpcpb.Code_ABORTED, "Row count exceeded 10000"), "wrapped"), 1274 code: ERQueryInterrupted, 1275 sqlState: SSQueryInterrupted, 1276 text: "aborted", 1277 }, 1278 } 1279 1280 for _, test := range tests { 1281 t.Run(test.err.Error(), func(t *testing.T) { 1282 th.SetErr(NewSQLErrorFromError(test.err)) 1283 rs, err := client.ExecuteFetch("error", 100, false) 1284 require.Error(t, err, "mysql should have failed but returned: %v", rs) 1285 serr, ok := err.(*SQLError) 1286 require.True(t, ok, "mysql should have returned a SQLError") 1287 1288 assert.Equal(t, test.code, serr.Number(), "error in %s: want code %v got %v", test.text, test.code, serr.Number()) 1289 assert.Equal(t, test.sqlState, serr.SQLState(), "error in %s: want sqlState %v got %v", test.text, test.sqlState, serr.SQLState()) 1290 assert.Contains(t, serr.Error(), test.err.Error()) 1291 }) 1292 } 1293 } 1294 1295 const enableCleartextPluginPrefix = "enable-cleartext-plugin: " 1296 1297 // runMysql forks a mysql command line process connecting to the provided server. 1298 func runMysql(t *testing.T, params *ConnParams, command string) (string, bool) { 1299 output, err := runMysqlWithErr(t, params, command) 1300 if err != nil { 1301 return output, false 1302 } 1303 return output, true 1304 1305 } 1306 func runMysqlWithErr(t *testing.T, params *ConnParams, command string) (string, error) { 1307 dir, err := vtenv.VtMysqlRoot() 1308 require.NoError(t, err) 1309 name, err := binaryPath(dir, "mysql") 1310 require.NoError(t, err) 1311 // The args contain '-v' 3 times, to switch to very verbose output. 1312 // In particular, it has the message: 1313 // Query OK, 1 row affected (0.00 sec) 1314 args := []string{ 1315 "-v", "-v", "-v", 1316 } 1317 if strings.HasPrefix(command, enableCleartextPluginPrefix) { 1318 command = command[len(enableCleartextPluginPrefix):] 1319 args = append(args, "--enable-cleartext-plugin") 1320 } 1321 if command == "--version" { 1322 args = append(args, command) 1323 } else { 1324 args = append(args, "-e", command) 1325 if params.UnixSocket != "" { 1326 args = append(args, "-S", params.UnixSocket) 1327 } else { 1328 args = append(args, 1329 "-h", params.Host, 1330 "-P", fmt.Sprintf("%v", params.Port)) 1331 } 1332 if params.Uname != "" { 1333 args = append(args, "-u", params.Uname) 1334 } 1335 if params.Pass != "" { 1336 args = append(args, "-p"+params.Pass) 1337 } 1338 if params.DbName != "" { 1339 args = append(args, "-D", params.DbName) 1340 } 1341 if params.SslEnabled() { 1342 args = append(args, 1343 "--ssl", 1344 "--ssl-ca", params.SslCa, 1345 "--ssl-cert", params.SslCert, 1346 "--ssl-key", params.SslKey, 1347 "--ssl-verify-server-cert") 1348 } 1349 } 1350 env := []string{ 1351 "LD_LIBRARY_PATH=" + path.Join(dir, "lib/mysql"), 1352 } 1353 1354 t.Logf("Running mysql command: %v %v", name, args) 1355 cmd := exec.Command(name, args...) 1356 cmd.Env = env 1357 cmd.Dir = dir 1358 out, err := cmd.CombinedOutput() 1359 output := string(out) 1360 if err != nil { 1361 return output, err 1362 } 1363 return output, nil 1364 } 1365 1366 // binaryPath does a limited path lookup for a command, 1367 // searching only within sbin and bin in the given root. 1368 // 1369 // FIXME(alainjobart) move this to vt/env, and use it from 1370 // go/vt/mysqlctl too. 1371 func binaryPath(root, binary string) (string, error) { 1372 subdirs := []string{"sbin", "bin"} 1373 for _, subdir := range subdirs { 1374 binPath := path.Join(root, subdir, binary) 1375 if _, err := os.Stat(binPath); err == nil { 1376 return binPath, nil 1377 } 1378 } 1379 return "", fmt.Errorf("%s not found in any of %s/{%s}", 1380 binary, root, strings.Join(subdirs, ",")) 1381 } 1382 1383 func TestListenerShutdown(t *testing.T) { 1384 th := &testHandler{} 1385 authServer := NewAuthServerStatic("", "", 0) 1386 authServer.entries["user1"] = []*AuthServerStaticEntry{{ 1387 Password: "password1", 1388 UserData: "userData1", 1389 }} 1390 defer authServer.close() 1391 l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) 1392 require.NoError(t, err) 1393 defer l.Close() 1394 go l.Accept() 1395 1396 host, port := getHostPort(t, l.Addr()) 1397 1398 // Setup the right parameters. 1399 params := &ConnParams{ 1400 Host: host, 1401 Port: port, 1402 Uname: "user1", 1403 Pass: "password1", 1404 } 1405 connRefuse.Reset() 1406 1407 ctx, cancel := context.WithCancel(context.Background()) 1408 defer cancel() 1409 1410 conn, err := Connect(ctx, params) 1411 require.NoError(t, err) 1412 1413 err = conn.Ping() 1414 require.NoError(t, err) 1415 1416 l.Shutdown() 1417 1418 assert.EqualValues(t, 1, connRefuse.Get(), "connRefuse") 1419 1420 err = conn.Ping() 1421 require.EqualError(t, err, "Server shutdown in progress (errno 1053) (sqlstate 08S01)") 1422 sqlErr, ok := err.(*SQLError) 1423 require.True(t, ok, "Wrong error type: %T", err) 1424 1425 require.Equal(t, ERServerShutdown, sqlErr.Number()) 1426 require.Equal(t, SSNetError, sqlErr.SQLState()) 1427 require.Equal(t, "Server shutdown in progress", sqlErr.Message) 1428 } 1429 1430 func TestParseConnAttrs(t *testing.T) { 1431 expected := map[string]string{ 1432 "_client_version": "8.0.11", 1433 "program_name": "mysql", 1434 "_pid": "22850", 1435 "_platform": "x86_64", 1436 "_os": "linux-glibc2.12", 1437 "_client_name": "libmysql", 1438 } 1439 1440 data := []byte{0x70, 0x04, 0x5f, 0x70, 0x69, 0x64, 0x05, 0x32, 0x32, 0x38, 0x35, 0x30, 0x09, 0x5f, 0x70, 0x6c, 1441 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x06, 0x78, 0x38, 0x36, 0x5f, 0x36, 0x34, 0x03, 0x5f, 0x6f, 1442 0x73, 0x0f, 0x6c, 0x69, 0x6e, 0x75, 0x78, 0x2d, 0x67, 0x6c, 0x69, 0x62, 0x63, 0x32, 0x2e, 0x31, 1443 0x32, 0x0c, 0x5f, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x08, 0x6c, 1444 0x69, 0x62, 0x6d, 0x79, 0x73, 0x71, 0x6c, 0x0f, 0x5f, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 1445 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x06, 0x38, 0x2e, 0x30, 0x2e, 0x31, 0x31, 0x0c, 0x70, 1446 0x72, 0x6f, 0x67, 0x72, 0x61, 0x6d, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x05, 0x6d, 0x79, 0x73, 0x71, 0x6c} 1447 1448 attrs, pos, err := parseConnAttrs(data, 0) 1449 require.NoError(t, err) 1450 require.Equal(t, 113, pos) 1451 for k, v := range expected { 1452 val, ok := attrs[k] 1453 require.True(t, ok, "Error reading key %s from connection attributes: attrs: %-v", k, attrs) 1454 require.Equal(t, v, val, "Unexpected value found in attrs for key %s", k) 1455 } 1456 } 1457 1458 func TestServerFlush(t *testing.T) { 1459 defer func(saved time.Duration) { mysqlServerFlushDelay = saved }(mysqlServerFlushDelay) 1460 mysqlServerFlushDelay = 10 * time.Millisecond 1461 1462 th := &testHandler{} 1463 1464 l, err := NewListener("tcp", "127.0.0.1:", NewAuthServerNone(), th, 0, 0, false, false) 1465 require.NoError(t, err) 1466 defer l.Close() 1467 go l.Accept() 1468 1469 host, port := getHostPort(t, l.Addr()) 1470 params := &ConnParams{ 1471 Host: host, 1472 Port: port, 1473 } 1474 1475 c, err := Connect(context.Background(), params) 1476 require.NoError(t, err) 1477 defer c.Close() 1478 1479 start := time.Now() 1480 err = c.ExecuteStreamFetch("50ms delay") 1481 require.NoError(t, err) 1482 1483 flds, err := c.Fields() 1484 require.NoError(t, err) 1485 if duration, want := time.Since(start), 20*time.Millisecond; duration < mysqlServerFlushDelay || duration > want { 1486 assert.Fail(t, "duration out of expected range", "duration: %v, want between %v and %v", duration.String(), (mysqlServerFlushDelay).String(), want.String()) 1487 } 1488 want1 := []*querypb.Field{{ 1489 Name: "result", 1490 Type: querypb.Type_VARCHAR, 1491 }} 1492 assert.Equal(t, want1, flds) 1493 1494 row, err := c.FetchNext(nil) 1495 require.NoError(t, err) 1496 if duration, want := time.Since(start), 50*time.Millisecond; duration < want { 1497 assert.Fail(t, "duration is too low", "duration: %v, want > %v", duration, want) 1498 } 1499 want2 := []sqltypes.Value{sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte("delayed"))} 1500 assert.Equal(t, want2, row) 1501 1502 row, err = c.FetchNext(nil) 1503 require.NoError(t, err) 1504 assert.Nil(t, row) 1505 }