github.com/matrixorigin/matrixone@v1.2.0/pkg/proxy/tunnel_test.go (about) 1 // Copyright 2021 - 2023 Matrix Origin 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package proxy 16 17 import ( 18 "bytes" 19 "context" 20 "fmt" 21 "io" 22 "math/rand" 23 "net" 24 "testing" 25 "testing/iotest" 26 "time" 27 28 "github.com/lni/goutils/leaktest" 29 "github.com/matrixorigin/matrixone/pkg/common/runtime" 30 "github.com/stretchr/testify/require" 31 ) 32 33 func TestTunnelClientToServer(t *testing.T) { 34 defer leaktest.AfterTest(t)() 35 36 runtime.SetupProcessLevelRuntime(runtime.DefaultRuntime()) 37 baseCtx := context.Background() 38 rt := runtime.DefaultRuntime() 39 logger := rt.Logger() 40 41 tu := newTunnel(baseCtx, logger, nil) 42 defer func() { _ = tu.Close() }() 43 44 clientProxy, client := net.Pipe() 45 serverProxy, server := net.Pipe() 46 47 cc := newMockClientConn(clientProxy, "t1", clientInfo{}, nil, tu) 48 require.NotNil(t, cc) 49 50 sc := newMockServerConn(serverProxy) 51 require.NotNil(t, sc) 52 53 err := tu.run(cc, sc) 54 require.NoError(t, err) 55 require.Nil(t, tu.ctx.Err()) 56 57 func() { 58 tu.mu.Lock() 59 defer tu.mu.Unlock() 60 require.True(t, tu.mu.started) 61 }() 62 63 tu.mu.Lock() 64 csp := tu.mu.csp 65 scp := tu.mu.scp 66 tu.mu.Unlock() 67 68 barrierStart, barrierEnd := make(chan struct{}), make(chan struct{}) 69 csp.testHelper.beforeSend = func() { 70 <-barrierStart 71 <-barrierEnd 72 } 73 74 csp.mu.Lock() 75 require.True(t, csp.mu.started) 76 csp.mu.Unlock() 77 78 scp.mu.Lock() 79 require.True(t, scp.mu.started) 80 scp.mu.Unlock() 81 82 // Client writes some MySQL packets. 83 sendEventCh := make(chan struct{}, 1) 84 errChan := make(chan error, 1) 85 go func() { 86 <-sendEventCh 87 if _, err := client.Write(makeSimplePacket("select 1")); err != nil { 88 errChan <- err 89 return 90 } 91 92 <-sendEventCh 93 if _, err := client.Write(makeSimplePacket("begin")); err != nil { 94 errChan <- err 95 return 96 } 97 98 <-sendEventCh 99 if _, err := client.Write(makeSimplePacket("select 1")); err != nil { 100 errChan <- err 101 return 102 } 103 104 <-sendEventCh 105 if _, err := client.Write(makeSimplePacket("commit")); err != nil { 106 errChan <- err 107 return 108 } 109 }() 110 111 sendEventCh <- struct{}{} 112 barrierStart <- struct{}{} 113 scp.mu.Lock() 114 require.Equal(t, true, scp.safeToTransferLocked()) 115 scp.mu.Unlock() 116 barrierEnd <- struct{}{} 117 118 ret := make([]byte, 30) 119 n, err := server.Read(ret) 120 require.NoError(t, err) 121 require.Equal(t, 13, n) 122 l, err := packetLen(ret) 123 require.NoError(t, err) 124 require.Equal(t, 9, int(l)) 125 require.Equal(t, comQuery, int(ret[4])) 126 require.Equal(t, "select 1", string(ret[5:13])) 127 128 sendEventCh <- struct{}{} 129 barrierStart <- struct{}{} 130 scp.mu.Lock() 131 require.Equal(t, true, scp.safeToTransferLocked()) 132 scp.mu.Unlock() 133 barrierEnd <- struct{}{} 134 135 ret = make([]byte, 30) 136 n, err = server.Read(ret) 137 require.NoError(t, err) 138 require.Equal(t, 10, n) 139 l, err = packetLen(ret) 140 require.NoError(t, err) 141 require.Equal(t, 6, int(l)) 142 require.Equal(t, comQuery, int(ret[4])) 143 require.Equal(t, "begin", string(ret[5:10])) 144 145 sendEventCh <- struct{}{} 146 barrierStart <- struct{}{} 147 scp.mu.Lock() 148 // in txn 149 require.Equal(t, true, scp.safeToTransferLocked()) 150 scp.mu.Unlock() 151 barrierEnd <- struct{}{} 152 153 ret = make([]byte, 30) 154 n, err = server.Read(ret) 155 require.NoError(t, err) 156 require.Equal(t, 13, n) 157 l, err = packetLen(ret) 158 require.NoError(t, err) 159 require.Equal(t, 9, int(l)) 160 require.Equal(t, comQuery, int(ret[4])) 161 require.Equal(t, "select 1", string(ret[5:13])) 162 163 sendEventCh <- struct{}{} 164 barrierStart <- struct{}{} 165 scp.mu.Lock() 166 // out of txn 167 require.Equal(t, true, scp.safeToTransferLocked()) 168 scp.mu.Unlock() 169 barrierEnd <- struct{}{} 170 171 ret = make([]byte, 30) 172 n, err = server.Read(ret) 173 require.NoError(t, err) 174 require.Equal(t, 11, n) 175 l, err = packetLen(ret) 176 require.NoError(t, err) 177 require.Equal(t, 7, int(l)) 178 require.Equal(t, comQuery, int(ret[4])) 179 require.Equal(t, "commit", string(ret[5:11])) 180 select { 181 case err = <-errChan: 182 t.Fatalf("require no error, but got %v", err) 183 default: 184 } 185 } 186 187 func TestTunnelServerClient(t *testing.T) { 188 defer leaktest.AfterTest(t)() 189 190 runtime.SetupProcessLevelRuntime(runtime.DefaultRuntime()) 191 baseCtx := context.Background() 192 193 rt := runtime.DefaultRuntime() 194 logger := rt.Logger() 195 196 tu := newTunnel(baseCtx, logger, nil) 197 defer func() { _ = tu.Close() }() 198 199 clientProxy, client := net.Pipe() 200 serverProxy, server := net.Pipe() 201 202 cc := newMockClientConn(clientProxy, "t1", clientInfo{}, nil, tu) 203 require.NotNil(t, cc) 204 205 sc := newMockServerConn(serverProxy) 206 require.NotNil(t, sc) 207 208 err := tu.run(cc, sc) 209 require.NoError(t, err) 210 require.Nil(t, tu.ctx.Err()) 211 212 func() { 213 tu.mu.Lock() 214 defer tu.mu.Unlock() 215 require.True(t, tu.mu.started) 216 }() 217 218 tu.mu.Lock() 219 csp := tu.mu.csp 220 scp := tu.mu.scp 221 tu.mu.Unlock() 222 223 csp.mu.Lock() 224 require.True(t, csp.mu.started) 225 csp.mu.Unlock() 226 227 scp.mu.Lock() 228 require.True(t, scp.mu.started) 229 scp.mu.Unlock() 230 231 // Client writes some MySQL packets. 232 recvEventCh := make(chan struct{}, 1) 233 errChan := make(chan error, 1) 234 go func() { 235 <-recvEventCh 236 if _, err := server.Write(makeSimplePacket("123456")); err != nil { 237 errChan <- err 238 return 239 } 240 }() 241 242 recvEventCh <- struct{}{} 243 ret := make([]byte, 30) 244 n, err := client.Read(ret) 245 require.NoError(t, err) 246 require.Equal(t, 11, n) 247 l, err := packetLen(ret) 248 require.NoError(t, err) 249 require.Equal(t, 7, int(l)) 250 require.Equal(t, comQuery, int(ret[4])) 251 require.Equal(t, "123456", string(ret[5:11])) 252 select { 253 case err = <-errChan: 254 t.Fatalf("require no error, but got %v", err) 255 default: 256 } 257 } 258 259 func TestTunnelClose(t *testing.T) { 260 defer leaktest.AfterTest(t)() 261 262 ctx := context.Background() 263 rt := runtime.DefaultRuntime() 264 for _, withRun := range []bool{true, false} { 265 t.Run(fmt.Sprintf("withRun=%t", withRun), func(t *testing.T) { 266 f := newTunnel(ctx, rt.Logger(), nil) 267 defer f.Close() 268 269 if withRun { 270 p1, p2 := net.Pipe() 271 272 cc := newMockClientConn(p1, "t1", clientInfo{}, nil, nil) 273 require.NotNil(t, cc) 274 275 sc := newMockServerConn(p2) 276 require.NotNil(t, sc) 277 278 err := f.run(cc, sc) 279 require.NoError(t, err) 280 } 281 282 require.Nil(t, f.ctx.Err()) 283 f.Close() 284 require.EqualError(t, f.ctx.Err(), context.Canceled.Error()) 285 }) 286 } 287 } 288 289 func TestTunnelReplaceConn(t *testing.T) { 290 defer leaktest.AfterTest(t)() 291 292 rt := runtime.DefaultRuntime() 293 ctx := context.Background() 294 clientProxy, client := net.Pipe() 295 serverProxy, server := net.Pipe() 296 297 tu := newTunnel(ctx, rt.Logger(), nil) 298 defer tu.Close() 299 300 cc := newMockClientConn(clientProxy, "t1", clientInfo{}, nil, tu) 301 require.NotNil(t, cc) 302 303 sc := newMockServerConn(serverProxy) 304 require.NotNil(t, sc) 305 306 err := tu.run(cc, sc) 307 require.NoError(t, err) 308 309 c, s := tu.getConns() 310 require.Equal(t, clientProxy, c.Conn) 311 require.Equal(t, serverProxy, s.Conn) 312 313 csp, scp := tu.getPipes() 314 require.NoError(t, csp.pause(ctx)) 315 require.NoError(t, scp.pause(ctx)) 316 317 newServerProxy, newServer := net.Pipe() 318 tu.replaceServerConn(newMySQLConn("server", newServerProxy, 0, nil, nil, 0), false) 319 require.NoError(t, tu.kickoff()) 320 321 go func() { 322 _, _ = newServer.Write(makeSimplePacket("123456")) 323 }() 324 ret := make([]byte, 30) 325 n, err := client.Read(ret) 326 require.NoError(t, err) 327 l, err := packetLen(ret) 328 require.NoError(t, err) 329 require.Equal(t, 11, n) 330 require.Equal(t, 7, int(l)) 331 require.Equal(t, comQuery, int(ret[4])) 332 require.Equal(t, "123456", string(ret[5:11])) 333 334 _, err = server.Write([]byte("closed error")) 335 require.Regexp(t, "closed pipe", err) 336 } 337 338 func TestPipeCancelError(t *testing.T) { 339 defer leaktest.AfterTest(t)() 340 341 ctx, cancel := context.WithCancel(context.Background()) 342 // cancel the context immediately 343 cancel() 344 345 clientProxy, serverProxy := net.Pipe() 346 defer clientProxy.Close() 347 defer serverProxy.Close() 348 349 rt := runtime.DefaultRuntime() 350 runtime.SetupProcessLevelRuntime(rt) 351 logger := rt.Logger() 352 tun := newTunnel(ctx, logger, newCounterSet()) 353 354 cc := newMySQLConn("client", clientProxy, 0, nil, nil, 0) 355 sc := newMySQLConn("server", serverProxy, 0, nil, nil, 0) 356 p := tun.newPipe(pipeClientToServer, cc, sc) 357 err := p.kickoff(ctx, nil) 358 require.EqualError(t, err, context.Canceled.Error()) 359 p.mu.Lock() 360 require.True(t, p.mu.closed) 361 p.mu.Unlock() 362 363 p.mu.Lock() 364 p.mu.started = true 365 p.mu.Unlock() 366 require.EqualError(t, p.pause(ctx), errPipeClosed.Error()) 367 } 368 369 func TestPipeStart(t *testing.T) { 370 defer leaktest.AfterTest(t)() 371 372 ctx, cancel := context.WithCancel(context.Background()) 373 defer cancel() 374 375 clientProxy, serverProxy := net.Pipe() 376 defer clientProxy.Close() 377 defer serverProxy.Close() 378 379 rt := runtime.DefaultRuntime() 380 runtime.SetupProcessLevelRuntime(rt) 381 logger := rt.Logger() 382 tun := newTunnel(ctx, logger, newCounterSet()) 383 384 cc := newMySQLConn("client", clientProxy, 0, nil, nil, 0) 385 sc := newMySQLConn("server", serverProxy, 0, nil, nil, 0) 386 p := tun.newPipe(pipeClientToServer, cc, sc) 387 388 errCh := make(chan error) 389 go func() { 390 errCh <- p.waitReady(ctx) 391 }() 392 393 select { 394 case <-errCh: 395 t.Fatal("expected not started") 396 default: 397 } 398 399 go func() { _ = p.kickoff(ctx, nil) }() 400 401 var lastErr error 402 require.Eventually(t, func() bool { 403 select { 404 case lastErr = <-errCh: 405 return true 406 default: 407 return false 408 } 409 }, 10*time.Second, 100*time.Millisecond) 410 require.NoError(t, lastErr) 411 } 412 413 func TestPipeStartAndPause(t *testing.T) { 414 defer leaktest.AfterTest(t)() 415 416 ctx, cancel := context.WithCancel(context.Background()) 417 defer cancel() 418 419 clientProxy, serverProxy := net.Pipe() 420 defer clientProxy.Close() 421 defer serverProxy.Close() 422 423 rt := runtime.DefaultRuntime() 424 runtime.SetupProcessLevelRuntime(rt) 425 logger := rt.Logger() 426 tun := newTunnel(ctx, logger, newCounterSet()) 427 428 cc := newMySQLConn("client", clientProxy, 0, nil, nil, 0) 429 sc := newMySQLConn("server", serverProxy, 0, nil, nil, 0) 430 p := tun.newPipe(pipeClientToServer, cc, sc) 431 432 errCh := make(chan error, 2) 433 go func() { errCh <- p.kickoff(ctx, nil) }() 434 go func() { errCh <- p.kickoff(ctx, nil) }() 435 go func() { errCh <- p.kickoff(ctx, nil) }() 436 err := <-errCh 437 require.NoError(t, err) 438 err = <-errCh 439 require.NoError(t, err) 440 441 err = p.waitReady(ctx) 442 require.NoError(t, err) 443 err = p.pause(ctx) 444 require.NoError(t, err) 445 446 err = <-errCh 447 require.NoError(t, err) 448 p.mu.Lock() 449 require.False(t, p.mu.started) 450 require.False(t, p.mu.inPreRecv) 451 require.False(t, p.mu.paused) 452 p.mu.Unlock() 453 454 err = p.pause(ctx) 455 require.NoError(t, err) 456 457 p.mu.Lock() 458 require.False(t, p.mu.started) 459 require.False(t, p.mu.inPreRecv) 460 require.False(t, p.mu.paused) 461 p.mu.Unlock() 462 } 463 464 func TestPipeMultipleStartAndPause(t *testing.T) { 465 ctx, cancel := context.WithCancel(context.Background()) 466 defer cancel() 467 468 clientProxy, client := net.Pipe() 469 defer clientProxy.Close() 470 defer client.Close() 471 serverProxy, server := net.Pipe() 472 defer serverProxy.Close() 473 defer server.Close() 474 475 rt := runtime.DefaultRuntime() 476 runtime.SetupProcessLevelRuntime(rt) 477 logger := rt.Logger() 478 tun := newTunnel(ctx, logger, newCounterSet()) 479 480 cc := newMySQLConn("client", clientProxy, 0, nil, nil, 0) 481 sc := newMySQLConn("server", serverProxy, 0, nil, nil, 0) 482 p := tun.newPipe(pipeClientToServer, cc, sc) 483 484 const ( 485 queryCount = 100 486 concurrency = 200 487 ) 488 489 buf := new(bytes.Buffer) 490 pack := makeSimplePacket("select 1") 491 492 for i := 0; i < queryCount; i++ { 493 if i%2 == 0 { 494 pack[12] = '1' 495 } else { 496 pack[12] = '2' 497 } 498 _, _ = buf.Write(pack) 499 } 500 go func() { 501 _, _ = io.Copy(client, iotest.OneByteReader(buf)) 502 }() 503 504 packetCh := make(chan []byte, queryCount) 505 go func() { 506 receiver := newMySQLConn("receiver", server, 0, nil, nil, 0) 507 for { 508 ret, err := receiver.receive() 509 if err != nil { 510 return 511 } 512 packetCh <- ret 513 } 514 }() 515 516 errKickoffCh := make(chan error, concurrency) 517 errPauseCh := make(chan error, concurrency) 518 for i := 1; i <= concurrency; i++ { 519 go func(p *pipe, i int) { 520 time.Sleep(jitteredInterval(time.Duration((i*2)+500) * time.Millisecond)) 521 errKickoffCh <- p.kickoff(ctx, nil) 522 }(p, i) 523 go func(p *pipe, i int) { 524 time.Sleep(jitteredInterval(time.Duration((i*2)+500) * time.Millisecond)) 525 errPauseCh <- p.pause(ctx) 526 }(p, i) 527 } 528 529 for i := 0; i < concurrency-1; i++ { 530 err := <-errKickoffCh 531 require.NoError(t, err) 532 } 533 534 var lastErr error 535 require.Eventually(t, func() bool { 536 select { 537 case lastErr = <-errKickoffCh: 538 return true 539 default: 540 _ = p.pause(ctx) 541 return false 542 } 543 }, 10*time.Second, 100*time.Millisecond) 544 if lastErr != nil { 545 require.EqualError(t, lastErr, errPipeClosed.Error()) 546 } 547 548 for i := 0; i < concurrency; i++ { 549 err := <-errPauseCh 550 require.NoError(t, err) 551 } 552 553 go func(p *pipe) { _ = p.kickoff(ctx, nil) }(p) 554 555 err := p.waitReady(ctx) 556 require.NoError(t, err) 557 558 for i := 0; i < queryCount; i++ { 559 p := <-packetCh 560 561 expectedStr := "select 1" 562 if i%2 == 1 { 563 expectedStr = "select 2" 564 } 565 require.Equal(t, expectedStr, string(p[5:])) 566 } 567 568 err = p.pause(ctx) 569 require.NoError(t, err) 570 } 571 572 func jitteredInterval(interval time.Duration) time.Duration { 573 return time.Duration(float64(interval) * (0.5 + 0.5*rand.Float64())) 574 } 575 576 func TestCanStartTransfer(t *testing.T) { 577 rt := runtime.DefaultRuntime() 578 runtime.SetupProcessLevelRuntime(rt) 579 logger := rt.Logger() 580 581 t.Run("not_started", func(t *testing.T) { 582 tu := &tunnel{ 583 logger: logger, 584 } 585 can := tu.canStartTransfer(false) 586 require.False(t, can) 587 }) 588 589 t.Run("inTransfer", func(t *testing.T) { 590 tu := &tunnel{ 591 logger: logger, 592 } 593 tu.mu.inTransfer = true 594 can := tu.canStartTransfer(false) 595 require.False(t, can) 596 }) 597 598 t.Run("lastCmd", func(t *testing.T) { 599 tu := &tunnel{ 600 logger: logger, 601 } 602 tu.mu.csp = &pipe{} 603 tu.mu.scp = &pipe{} 604 tu.mu.started = true 605 csp, scp := tu.getPipes() 606 now := time.Now() 607 csp.mu.lastCmdTime = now.Add(time.Second) 608 scp.mu.lastCmdTime = now 609 can := tu.canStartTransfer(false) 610 require.False(t, can) 611 }) 612 613 t.Run("inTxn", func(t *testing.T) { 614 tu := &tunnel{ 615 logger: logger, 616 } 617 tu.mu.scp = &pipe{} 618 tu.mu.scp.src = newMySQLConn("", nil, 0, nil, nil, 0) 619 tu.mu.scp.mu.inTxn = true 620 can := tu.canStartTransfer(false) 621 require.False(t, can) 622 }) 623 624 t.Run("ok", func(t *testing.T) { 625 tu := &tunnel{ 626 logger: logger, 627 } 628 tu.mu.csp = &pipe{} 629 tu.mu.scp = &pipe{} 630 tu.mu.scp.src = newMySQLConn("", nil, 0, nil, nil, 0) 631 tu.mu.started = true 632 csp, scp := tu.getPipes() 633 now := time.Now() 634 csp.mu.lastCmdTime = now 635 scp.mu.lastCmdTime = now.Add(time.Second) 636 can := tu.canStartTransfer(false) 637 require.True(t, can) 638 }) 639 } 640 641 func TestReplaceServerConn(t *testing.T) { 642 defer leaktest.AfterTest(t)() 643 644 ctx := context.TODO() 645 clientProxy, client := net.Pipe() 646 serverProxy, _ := net.Pipe() 647 648 rt := runtime.DefaultRuntime() 649 tu := newTunnel(ctx, rt.Logger(), nil) 650 defer func() { 651 require.NoError(t, tu.Close()) 652 }() 653 654 cc := newMockClientConn(clientProxy, "t1", clientInfo{}, nil, tu) 655 require.NotNil(t, cc) 656 657 sc := newMockServerConn(serverProxy) 658 require.NotNil(t, sc) 659 660 err := tu.run(cc, sc) 661 require.NoError(t, err) 662 663 mysqlCC, mysqlSC := tu.getConns() 664 require.Equal(t, clientProxy, mysqlCC.src) 665 require.Equal(t, serverProxy, mysqlSC.src) 666 667 csp, scp := tu.getPipes() 668 require.NoError(t, csp.pause(ctx)) 669 require.NoError(t, scp.pause(ctx)) 670 671 newServerProxy, newServer := net.Pipe() 672 newSC := newMockServerConn(newServerProxy) 673 require.NotNil(t, sc) 674 newServerC := newMySQLConn("new-server", newSC.RawConn(), 0, nil, nil, 0) 675 tu.replaceServerConn(newServerC, false) 676 _, newMysqlSC := tu.getConns() 677 require.Equal(t, newServerC, newMysqlSC) 678 require.NoError(t, tu.kickoff()) 679 680 go func() { 681 _, err := client.Write(makeSimplePacket("select 1")) 682 require.NoError(t, err) 683 }() 684 685 buf := make([]byte, 30) 686 n, err := newServer.Read(buf) 687 require.NoError(t, err) 688 require.Equal(t, "select 1", string(buf[5:n])) 689 }