github.com/ari-anchor/sei-tendermint@v0.0.0-20230519144642-dc826b7b56bb/internal/p2p/conn/connection_test.go (about) 1 package conn 2 3 import ( 4 "context" 5 "encoding/hex" 6 "io" 7 "net" 8 "sync" 9 "testing" 10 "time" 11 12 "github.com/fortytw2/leaktest" 13 "github.com/gogo/protobuf/proto" 14 "github.com/stretchr/testify/assert" 15 "github.com/stretchr/testify/require" 16 17 "github.com/ari-anchor/sei-tendermint/internal/libs/protoio" 18 "github.com/ari-anchor/sei-tendermint/libs/log" 19 "github.com/ari-anchor/sei-tendermint/libs/service" 20 tmp2p "github.com/ari-anchor/sei-tendermint/proto/tendermint/p2p" 21 "github.com/ari-anchor/sei-tendermint/proto/tendermint/types" 22 ) 23 24 const maxPingPongPacketSize = 1024 // bytes 25 26 func createTestMConnection(logger log.Logger, conn net.Conn) *MConnection { 27 return createMConnectionWithCallbacks(logger, conn, 28 // onRecieve 29 func(ctx context.Context, chID ChannelID, msgBytes []byte) { 30 }, 31 // onError 32 func(ctx context.Context, r interface{}) { 33 }) 34 } 35 36 func createMConnectionWithCallbacks( 37 logger log.Logger, 38 conn net.Conn, 39 onReceive func(ctx context.Context, chID ChannelID, msgBytes []byte), 40 onError func(ctx context.Context, r interface{}), 41 ) *MConnection { 42 cfg := DefaultMConnConfig() 43 cfg.PingInterval = 250 * time.Millisecond 44 cfg.PongTimeout = 500 * time.Millisecond 45 chDescs := []*ChannelDescriptor{{ID: 0x01, Priority: 1, SendQueueCapacity: 1}} 46 c := NewMConnection(logger, conn, chDescs, onReceive, onError, cfg) 47 return c 48 } 49 50 func TestMConnectionSendFlushStop(t *testing.T) { 51 server, client := net.Pipe() 52 t.Cleanup(closeAll(t, client, server)) 53 54 ctx, cancel := context.WithCancel(context.Background()) 55 defer cancel() 56 57 clientConn := createTestMConnection(log.NewNopLogger(), client) 58 err := clientConn.Start(ctx) 59 require.NoError(t, err) 60 t.Cleanup(waitAll(clientConn)) 61 62 msg := []byte("abc") 63 assert.True(t, clientConn.Send(0x01, msg)) 64 65 msgLength := 14 66 67 // start the reader in a new routine, so we can flush 68 errCh := make(chan error) 69 go func() { 70 msgB := make([]byte, msgLength) 71 _, err := server.Read(msgB) 72 if err != nil { 73 t.Error(err) 74 return 75 } 76 errCh <- err 77 }() 78 79 timer := time.NewTimer(3 * time.Second) 80 select { 81 case <-errCh: 82 case <-timer.C: 83 t.Error("timed out waiting for msgs to be read") 84 } 85 } 86 87 func TestMConnectionSend(t *testing.T) { 88 server, client := net.Pipe() 89 t.Cleanup(closeAll(t, client, server)) 90 91 ctx, cancel := context.WithCancel(context.Background()) 92 defer cancel() 93 94 mconn := createTestMConnection(log.NewNopLogger(), client) 95 err := mconn.Start(ctx) 96 require.NoError(t, err) 97 t.Cleanup(waitAll(mconn)) 98 99 msg := []byte("Ant-Man") 100 assert.True(t, mconn.Send(0x01, msg)) 101 // Note: subsequent Send/TrySend calls could pass because we are reading from 102 // the send queue in a separate goroutine. 103 _, err = server.Read(make([]byte, len(msg))) 104 if err != nil { 105 t.Error(err) 106 } 107 108 msg = []byte("Spider-Man") 109 assert.True(t, mconn.Send(0x01, msg)) 110 _, err = server.Read(make([]byte, len(msg))) 111 if err != nil { 112 t.Error(err) 113 } 114 115 assert.False(t, mconn.Send(0x05, []byte("Absorbing Man")), "Send should return false because channel is unknown") 116 } 117 118 func TestMConnectionReceive(t *testing.T) { 119 server, client := net.Pipe() 120 t.Cleanup(closeAll(t, client, server)) 121 122 receivedCh := make(chan []byte) 123 errorsCh := make(chan interface{}) 124 onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) { 125 select { 126 case receivedCh <- msgBytes: 127 case <-ctx.Done(): 128 } 129 } 130 onError := func(ctx context.Context, r interface{}) { 131 select { 132 case errorsCh <- r: 133 case <-ctx.Done(): 134 } 135 } 136 logger := log.NewNopLogger() 137 138 ctx, cancel := context.WithCancel(context.Background()) 139 defer cancel() 140 141 mconn1 := createMConnectionWithCallbacks(logger, client, onReceive, onError) 142 err := mconn1.Start(ctx) 143 require.NoError(t, err) 144 t.Cleanup(waitAll(mconn1)) 145 146 mconn2 := createTestMConnection(logger, server) 147 err = mconn2.Start(ctx) 148 require.NoError(t, err) 149 t.Cleanup(waitAll(mconn2)) 150 151 msg := []byte("Cyclops") 152 assert.True(t, mconn2.Send(0x01, msg)) 153 154 select { 155 case receivedBytes := <-receivedCh: 156 assert.Equal(t, msg, receivedBytes) 157 case err := <-errorsCh: 158 t.Fatalf("Expected %s, got %+v", msg, err) 159 case <-time.After(500 * time.Millisecond): 160 t.Fatalf("Did not receive %s message in 500ms", msg) 161 } 162 } 163 164 func TestMConnectionWillEventuallyTimeout(t *testing.T) { 165 server, client := net.Pipe() 166 t.Cleanup(closeAll(t, client, server)) 167 168 ctx, cancel := context.WithCancel(context.Background()) 169 defer cancel() 170 171 mconn := createMConnectionWithCallbacks(log.NewNopLogger(), client, nil, nil) 172 err := mconn.Start(ctx) 173 require.NoError(t, err) 174 t.Cleanup(waitAll(mconn)) 175 require.True(t, mconn.IsRunning()) 176 177 go func() { 178 // read the send buffer so that the send receive 179 // doesn't get blocked. 180 ticker := time.NewTicker(10 * time.Millisecond) 181 defer ticker.Stop() 182 183 for { 184 select { 185 case <-ticker.C: 186 _, _ = io.ReadAll(server) 187 case <-ctx.Done(): 188 return 189 } 190 } 191 }() 192 193 // wait for the send routine to die because it doesn't 194 select { 195 case <-mconn.doneSendRoutine: 196 require.True(t, time.Since(mconn.getLastMessageAt()) > mconn.config.PongTimeout, 197 "the connection state reflects that we've passed the pong timeout") 198 // since we hit the timeout, things should be shutdown 199 require.False(t, mconn.IsRunning()) 200 case <-time.After(2 * mconn.config.PongTimeout): 201 t.Fatal("connection did not hit timeout", mconn.config.PongTimeout) 202 } 203 } 204 205 func TestMConnectionMultiplePongsInTheBeginning(t *testing.T) { 206 server, client := net.Pipe() 207 t.Cleanup(closeAll(t, client, server)) 208 209 receivedCh := make(chan []byte) 210 errorsCh := make(chan interface{}) 211 onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) { 212 select { 213 case receivedCh <- msgBytes: 214 case <-ctx.Done(): 215 } 216 } 217 onError := func(ctx context.Context, r interface{}) { 218 select { 219 case errorsCh <- r: 220 case <-ctx.Done(): 221 } 222 } 223 224 ctx, cancel := context.WithCancel(context.Background()) 225 defer cancel() 226 227 mconn := createMConnectionWithCallbacks(log.NewNopLogger(), client, onReceive, onError) 228 err := mconn.Start(ctx) 229 require.NoError(t, err) 230 t.Cleanup(waitAll(mconn)) 231 232 // sending 3 pongs in a row (abuse) 233 protoWriter := protoio.NewDelimitedWriter(server) 234 235 _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{})) 236 require.NoError(t, err) 237 238 _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{})) 239 require.NoError(t, err) 240 241 _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{})) 242 require.NoError(t, err) 243 244 // read ping (one byte) 245 var packet tmp2p.Packet 246 _, err = protoio.NewDelimitedReader(server, maxPingPongPacketSize).ReadMsg(&packet) 247 require.NoError(t, err) 248 249 // respond with pong 250 _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{})) 251 require.NoError(t, err) 252 253 pongTimerExpired := mconn.config.PongTimeout + 20*time.Millisecond 254 select { 255 case msgBytes := <-receivedCh: 256 t.Fatalf("Expected no data, but got %v", msgBytes) 257 case err := <-errorsCh: 258 t.Fatalf("Expected no error, but got %v", err) 259 case <-time.After(pongTimerExpired): 260 assert.True(t, mconn.IsRunning()) 261 } 262 } 263 264 func TestMConnectionMultiplePings(t *testing.T) { 265 server, client := net.Pipe() 266 t.Cleanup(closeAll(t, client, server)) 267 268 receivedCh := make(chan []byte) 269 errorsCh := make(chan interface{}) 270 onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) { 271 select { 272 case receivedCh <- msgBytes: 273 case <-ctx.Done(): 274 } 275 } 276 onError := func(ctx context.Context, r interface{}) { 277 select { 278 case errorsCh <- r: 279 case <-ctx.Done(): 280 } 281 } 282 ctx, cancel := context.WithCancel(context.Background()) 283 defer cancel() 284 285 mconn := createMConnectionWithCallbacks(log.NewNopLogger(), client, onReceive, onError) 286 err := mconn.Start(ctx) 287 require.NoError(t, err) 288 t.Cleanup(waitAll(mconn)) 289 290 // sending 3 pings in a row (abuse) 291 // see https://github.com/ari-anchor/sei-tendermint/issues/1190 292 protoReader := protoio.NewDelimitedReader(server, maxPingPongPacketSize) 293 protoWriter := protoio.NewDelimitedWriter(server) 294 var pkt tmp2p.Packet 295 296 _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPing{})) 297 require.NoError(t, err) 298 299 _, err = protoReader.ReadMsg(&pkt) 300 require.NoError(t, err) 301 302 _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPing{})) 303 require.NoError(t, err) 304 305 _, err = protoReader.ReadMsg(&pkt) 306 require.NoError(t, err) 307 308 _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPing{})) 309 require.NoError(t, err) 310 311 _, err = protoReader.ReadMsg(&pkt) 312 require.NoError(t, err) 313 314 assert.True(t, mconn.IsRunning()) 315 } 316 317 func TestMConnectionPingPongs(t *testing.T) { 318 // check that we are not leaking any go-routines 319 t.Cleanup(leaktest.CheckTimeout(t, 10*time.Second)) 320 321 server, client := net.Pipe() 322 t.Cleanup(closeAll(t, client, server)) 323 324 receivedCh := make(chan []byte) 325 errorsCh := make(chan interface{}) 326 onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) { 327 select { 328 case receivedCh <- msgBytes: 329 case <-ctx.Done(): 330 } 331 } 332 onError := func(ctx context.Context, r interface{}) { 333 select { 334 case errorsCh <- r: 335 case <-ctx.Done(): 336 } 337 } 338 339 ctx, cancel := context.WithCancel(context.Background()) 340 defer cancel() 341 342 mconn := createMConnectionWithCallbacks(log.NewNopLogger(), client, onReceive, onError) 343 err := mconn.Start(ctx) 344 require.NoError(t, err) 345 t.Cleanup(waitAll(mconn)) 346 347 protoReader := protoio.NewDelimitedReader(server, maxPingPongPacketSize) 348 protoWriter := protoio.NewDelimitedWriter(server) 349 var pkt tmp2p.PacketPing 350 351 // read ping 352 _, err = protoReader.ReadMsg(&pkt) 353 require.NoError(t, err) 354 355 // respond with pong 356 _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{})) 357 require.NoError(t, err) 358 359 time.Sleep(mconn.config.PingInterval) 360 361 // read ping 362 _, err = protoReader.ReadMsg(&pkt) 363 require.NoError(t, err) 364 365 // respond with pong 366 _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{})) 367 require.NoError(t, err) 368 369 pongTimerExpired := (mconn.config.PongTimeout + 20*time.Millisecond) * 4 370 select { 371 case msgBytes := <-receivedCh: 372 t.Fatalf("Expected no data, but got %v", msgBytes) 373 case err := <-errorsCh: 374 t.Fatalf("Expected no error, but got %v", err) 375 case <-time.After(2 * pongTimerExpired): 376 assert.True(t, mconn.IsRunning()) 377 } 378 } 379 380 func TestMConnectionStopsAndReturnsError(t *testing.T) { 381 server, client := net.Pipe() 382 t.Cleanup(closeAll(t, client, server)) 383 384 receivedCh := make(chan []byte) 385 errorsCh := make(chan interface{}) 386 onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) { 387 select { 388 case receivedCh <- msgBytes: 389 case <-ctx.Done(): 390 } 391 } 392 onError := func(ctx context.Context, r interface{}) { 393 select { 394 case errorsCh <- r: 395 case <-ctx.Done(): 396 } 397 } 398 ctx, cancel := context.WithCancel(context.Background()) 399 defer cancel() 400 401 mconn := createMConnectionWithCallbacks(log.NewNopLogger(), client, onReceive, onError) 402 err := mconn.Start(ctx) 403 require.NoError(t, err) 404 t.Cleanup(waitAll(mconn)) 405 406 if err := client.Close(); err != nil { 407 t.Error(err) 408 } 409 410 select { 411 case receivedBytes := <-receivedCh: 412 t.Fatalf("Expected error, got %v", receivedBytes) 413 case err := <-errorsCh: 414 assert.NotNil(t, err) 415 assert.False(t, mconn.IsRunning()) 416 case <-time.After(500 * time.Millisecond): 417 t.Fatal("Did not receive error in 500ms") 418 } 419 } 420 421 func newClientAndServerConnsForReadErrors( 422 ctx context.Context, 423 t *testing.T, 424 chOnErr chan struct{}, 425 ) (*MConnection, *MConnection) { 426 server, client := net.Pipe() 427 428 onReceive := func(context.Context, ChannelID, []byte) {} 429 onError := func(context.Context, interface{}) {} 430 431 // create client conn with two channels 432 chDescs := []*ChannelDescriptor{ 433 {ID: 0x01, Priority: 1, SendQueueCapacity: 1}, 434 {ID: 0x02, Priority: 1, SendQueueCapacity: 1}, 435 } 436 logger := log.NewNopLogger() 437 438 mconnClient := NewMConnection(logger.With("module", "client"), client, chDescs, onReceive, onError, DefaultMConnConfig()) 439 err := mconnClient.Start(ctx) 440 require.NoError(t, err) 441 442 // create server conn with 1 channel 443 // it fires on chOnErr when there's an error 444 serverLogger := logger.With("module", "server") 445 onError = func(ctx context.Context, r interface{}) { 446 select { 447 case <-ctx.Done(): 448 case chOnErr <- struct{}{}: 449 } 450 } 451 452 mconnServer := createMConnectionWithCallbacks(serverLogger, server, onReceive, onError) 453 err = mconnServer.Start(ctx) 454 require.NoError(t, err) 455 return mconnClient, mconnServer 456 } 457 458 func expectSend(ch chan struct{}) bool { 459 after := time.After(time.Second * 5) 460 select { 461 case <-ch: 462 return true 463 case <-after: 464 return false 465 } 466 } 467 468 func TestMConnectionReadErrorBadEncoding(t *testing.T) { 469 ctx, cancel := context.WithCancel(context.Background()) 470 defer cancel() 471 472 chOnErr := make(chan struct{}) 473 mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr) 474 475 client := mconnClient.conn 476 477 // Write it. 478 _, err := client.Write([]byte{1, 2, 3, 4, 5}) 479 require.NoError(t, err) 480 assert.True(t, expectSend(chOnErr), "badly encoded msgPacket") 481 t.Cleanup(waitAll(mconnClient, mconnServer)) 482 } 483 484 func TestMConnectionReadErrorUnknownChannel(t *testing.T) { 485 ctx, cancel := context.WithCancel(context.Background()) 486 defer cancel() 487 488 chOnErr := make(chan struct{}) 489 mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr) 490 491 msg := []byte("Ant-Man") 492 493 // fail to send msg on channel unknown by client 494 assert.False(t, mconnClient.Send(0x03, msg)) 495 496 // send msg on channel unknown by the server. 497 // should cause an error 498 assert.True(t, mconnClient.Send(0x02, msg)) 499 assert.True(t, expectSend(chOnErr), "unknown channel") 500 t.Cleanup(waitAll(mconnClient, mconnServer)) 501 } 502 503 func TestMConnectionReadErrorLongMessage(t *testing.T) { 504 chOnErr := make(chan struct{}) 505 chOnRcv := make(chan struct{}) 506 507 ctx, cancel := context.WithCancel(context.Background()) 508 defer cancel() 509 510 mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr) 511 t.Cleanup(waitAll(mconnClient, mconnServer)) 512 513 mconnServer.onReceive = func(ctx context.Context, chID ChannelID, msgBytes []byte) { 514 select { 515 case <-ctx.Done(): 516 case chOnRcv <- struct{}{}: 517 } 518 } 519 520 client := mconnClient.conn 521 protoWriter := protoio.NewDelimitedWriter(client) 522 523 // send msg thats just right 524 var packet = tmp2p.PacketMsg{ 525 ChannelID: 0x01, 526 EOF: true, 527 Data: make([]byte, mconnClient.config.MaxPacketMsgPayloadSize), 528 } 529 530 _, err := protoWriter.WriteMsg(mustWrapPacket(&packet)) 531 require.NoError(t, err) 532 assert.True(t, expectSend(chOnRcv), "msg just right") 533 534 // send msg thats too long 535 packet = tmp2p.PacketMsg{ 536 ChannelID: 0x01, 537 EOF: true, 538 Data: make([]byte, mconnClient.config.MaxPacketMsgPayloadSize+100), 539 } 540 541 _, err = protoWriter.WriteMsg(mustWrapPacket(&packet)) 542 require.Error(t, err) 543 assert.True(t, expectSend(chOnErr), "msg too long") 544 } 545 546 func TestMConnectionReadErrorUnknownMsgType(t *testing.T) { 547 ctx, cancel := context.WithCancel(context.Background()) 548 defer cancel() 549 550 chOnErr := make(chan struct{}) 551 mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr) 552 t.Cleanup(waitAll(mconnClient, mconnServer)) 553 554 // send msg with unknown msg type 555 _, err := protoio.NewDelimitedWriter(mconnClient.conn).WriteMsg(&types.Header{ChainID: "x"}) 556 require.NoError(t, err) 557 assert.True(t, expectSend(chOnErr), "unknown msg type") 558 } 559 560 func TestMConnectionTrySend(t *testing.T) { 561 server, client := net.Pipe() 562 t.Cleanup(closeAll(t, client, server)) 563 ctx, cancel := context.WithCancel(context.Background()) 564 defer cancel() 565 566 mconn := createTestMConnection(log.NewNopLogger(), client) 567 err := mconn.Start(ctx) 568 require.NoError(t, err) 569 t.Cleanup(waitAll(mconn)) 570 571 msg := []byte("Semicolon-Woman") 572 resultCh := make(chan string, 2) 573 assert.True(t, mconn.Send(0x01, msg)) 574 _, err = server.Read(make([]byte, len(msg))) 575 require.NoError(t, err) 576 assert.True(t, mconn.Send(0x01, msg)) 577 go func() { 578 mconn.Send(0x01, msg) 579 resultCh <- "TrySend" 580 }() 581 assert.False(t, mconn.Send(0x01, msg)) 582 assert.Equal(t, "TrySend", <-resultCh) 583 } 584 585 func TestConnVectors(t *testing.T) { 586 587 testCases := []struct { 588 testName string 589 msg proto.Message 590 expBytes string 591 }{ 592 {"PacketPing", &tmp2p.PacketPing{}, "0a00"}, 593 {"PacketPong", &tmp2p.PacketPong{}, "1200"}, 594 {"PacketMsg", &tmp2p.PacketMsg{ChannelID: 1, EOF: false, Data: []byte("data transmitted over the wire")}, "1a2208011a1e64617461207472616e736d6974746564206f766572207468652077697265"}, 595 } 596 597 for _, tc := range testCases { 598 tc := tc 599 600 pm := mustWrapPacket(tc.msg) 601 bz, err := pm.Marshal() 602 require.NoError(t, err, tc.testName) 603 604 require.Equal(t, tc.expBytes, hex.EncodeToString(bz), tc.testName) 605 } 606 } 607 608 func TestMConnectionChannelOverflow(t *testing.T) { 609 chOnErr := make(chan struct{}) 610 chOnRcv := make(chan struct{}) 611 612 ctx, cancel := context.WithCancel(context.Background()) 613 defer cancel() 614 615 mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr) 616 t.Cleanup(waitAll(mconnClient, mconnServer)) 617 618 mconnServer.onReceive = func(ctx context.Context, chID ChannelID, msgBytes []byte) { 619 select { 620 case <-ctx.Done(): 621 case chOnRcv <- struct{}{}: 622 } 623 } 624 625 client := mconnClient.conn 626 protoWriter := protoio.NewDelimitedWriter(client) 627 628 var packet = tmp2p.PacketMsg{ 629 ChannelID: 0x01, 630 EOF: true, 631 Data: []byte(`42`), 632 } 633 _, err := protoWriter.WriteMsg(mustWrapPacket(&packet)) 634 require.NoError(t, err) 635 assert.True(t, expectSend(chOnRcv)) 636 637 packet.ChannelID = int32(1025) 638 _, err = protoWriter.WriteMsg(mustWrapPacket(&packet)) 639 require.NoError(t, err) 640 assert.False(t, expectSend(chOnRcv)) 641 642 } 643 644 func waitAll(waiters ...service.Service) func() { 645 return func() { 646 switch len(waiters) { 647 case 0: 648 return 649 case 1: 650 waiters[0].Wait() 651 return 652 default: 653 wg := &sync.WaitGroup{} 654 655 for _, w := range waiters { 656 wg.Add(1) 657 go func(s service.Service) { 658 defer wg.Done() 659 s.Wait() 660 }(w) 661 } 662 663 wg.Wait() 664 } 665 } 666 } 667 668 type closer interface { 669 Close() error 670 } 671 672 func closeAll(t *testing.T, closers ...closer) func() { 673 return func() { 674 for _, s := range closers { 675 if err := s.Close(); err != nil { 676 t.Log(err) 677 } 678 } 679 } 680 }