github.com/badrootd/celestia-core@v0.0.0-20240305091328-aa4207a4b25d/p2p/conn/connection_test.go (about) 1 package conn 2 3 import ( 4 "bytes" 5 "encoding/hex" 6 "fmt" 7 "math" 8 "net" 9 "testing" 10 "time" 11 12 "github.com/fortytw2/leaktest" 13 "github.com/gogo/protobuf/proto" 14 "github.com/stretchr/testify/assert" 15 "github.com/stretchr/testify/require" 16 17 "github.com/badrootd/celestia-core/libs/log" 18 "github.com/badrootd/celestia-core/libs/protoio" 19 tmp2p "github.com/badrootd/celestia-core/proto/tendermint/p2p" 20 "github.com/badrootd/celestia-core/proto/tendermint/types" 21 ) 22 23 const maxPingPongPacketSize = 1024 // bytes 24 25 func createTestMConnection(conn net.Conn) *MConnection { 26 onReceive := func(chID byte, msgBytes []byte) { 27 } 28 onError := func(r interface{}) { 29 } 30 c := createMConnectionWithCallbacks(conn, onReceive, onError) 31 c.SetLogger(log.TestingLogger()) 32 return c 33 } 34 35 func createMConnectionWithCallbacks( 36 conn net.Conn, 37 onReceive func(chID byte, msgBytes []byte), 38 onError func(r interface{}), 39 ) *MConnection { 40 cfg := DefaultMConnConfig() 41 cfg.PingInterval = 90 * time.Millisecond 42 cfg.PongTimeout = 45 * time.Millisecond 43 chDescs := []*ChannelDescriptor{{ID: 0x01, Priority: 1, SendQueueCapacity: 1}} 44 c := NewMConnectionWithConfig(conn, chDescs, onReceive, onError, cfg) 45 c.SetLogger(log.TestingLogger()) 46 return c 47 } 48 49 func createMConnectionWithCallbacksConfigs( 50 conn net.Conn, 51 onReceive func(chID byte, msgBytes []byte), 52 onError func(r interface{}), 53 cfg MConnConfig, 54 ) *MConnection { 55 chDescs := []*ChannelDescriptor{{ID: 0x01, Priority: 1, SendQueueCapacity: 1}} 56 c := NewMConnectionWithConfig(conn, chDescs, onReceive, onError, cfg) 57 c.SetLogger(log.TestingLogger()) 58 return c 59 } 60 61 func TestMConnectionSendFlushStop(t *testing.T) { 62 server, client := NetPipe() 63 defer server.Close() 64 defer client.Close() 65 66 clientConn := createTestMConnection(client) 67 err := clientConn.Start() 68 require.Nil(t, err) 69 defer clientConn.Stop() //nolint:errcheck // ignore for tests 70 71 msg := []byte("abc") 72 assert.True(t, clientConn.Send(0x01, msg)) 73 74 msgLength := 14 75 76 // start the reader in a new routine, so we can flush 77 errCh := make(chan error) 78 go func() { 79 msgB := make([]byte, msgLength) 80 _, err := server.Read(msgB) 81 if err != nil { 82 t.Error(err) 83 return 84 } 85 errCh <- err 86 }() 87 88 // stop the conn - it should flush all conns 89 clientConn.FlushStop() 90 91 timer := time.NewTimer(3 * time.Second) 92 select { 93 case <-errCh: 94 case <-timer.C: 95 t.Error("timed out waiting for msgs to be read") 96 } 97 } 98 99 func TestMConnectionSend(t *testing.T) { 100 server, client := NetPipe() 101 defer server.Close() 102 defer client.Close() 103 104 mconn := createTestMConnection(client) 105 err := mconn.Start() 106 require.Nil(t, err) 107 defer mconn.Stop() //nolint:errcheck // ignore for tests 108 109 msg := []byte("Ant-Man") 110 assert.True(t, mconn.Send(0x01, msg)) 111 // Note: subsequent Send/TrySend calls could pass because we are reading from 112 // the send queue in a separate goroutine. 113 _, err = server.Read(make([]byte, len(msg))) 114 if err != nil { 115 t.Error(err) 116 } 117 assert.True(t, mconn.CanSend(0x01)) 118 119 msg = []byte("Spider-Man") 120 assert.True(t, mconn.TrySend(0x01, msg)) 121 _, err = server.Read(make([]byte, len(msg))) 122 if err != nil { 123 t.Error(err) 124 } 125 126 assert.False(t, mconn.CanSend(0x05), "CanSend should return false because channel is unknown") 127 assert.False(t, mconn.Send(0x05, []byte("Absorbing Man")), "Send should return false because channel is unknown") 128 } 129 130 func TestMConnectionSendRate(t *testing.T) { 131 server, client := NetPipe() 132 defer server.Close() 133 defer client.Close() 134 135 clientConn := createTestMConnection(client) 136 err := clientConn.Start() 137 require.Nil(t, err) 138 defer clientConn.Stop() //nolint:errcheck // ignore for tests 139 140 // prepare a message to send from client to the server 141 msg := bytes.Repeat([]byte{1}, 1000*1024) 142 143 // send the message and check if it was sent successfully 144 done := clientConn.Send(0x01, msg) 145 assert.True(t, done) 146 147 // read the message from the server 148 _, err = server.Read(make([]byte, len(msg))) 149 if err != nil { 150 t.Error(err) 151 } 152 153 // check if the peak send rate is within the expected range 154 peakSendRate := clientConn.Status().SendMonitor.PeakRate 155 // the peak send rate should be less than or equal to the max send rate 156 // the max send rate is calculated based on the configured SendRate and other configs 157 maxSendRate := clientConn.maxSendRate() 158 assert.True(t, peakSendRate <= clientConn.maxSendRate(), fmt.Sprintf("peakSendRate %d > maxSendRate %d", peakSendRate, maxSendRate)) 159 } 160 161 // maxSendRate returns the maximum send rate in bytes per second based on the MConnection's SendRate and other configs. It is used to calculate the highest expected value for the peak send rate. 162 // The returned value is slightly higher than the configured SendRate. 163 func (c *MConnection) maxSendRate() int64 { 164 sampleRate := c.sendMonitor.GetSampleRate().Seconds() 165 numberOfSamplePerSecond := 1 / sampleRate 166 sendRate := float64(round(float64(c.config.SendRate) * sampleRate)) 167 batchSizeBytes := float64(numBatchPacketMsgs * c._maxPacketMsgSize) 168 effectiveRatePerSample := math.Ceil(sendRate/batchSizeBytes) * batchSizeBytes 169 effectiveSendRate := round(numberOfSamplePerSecond * effectiveRatePerSample) 170 171 return effectiveSendRate 172 } 173 174 // round returns x rounded to the nearest int64 (non-negative values only). 175 func round(x float64) int64 { 176 if _, frac := math.Modf(x); frac >= 0.5 { 177 return int64(math.Ceil(x)) 178 } 179 return int64(math.Floor(x)) 180 } 181 182 func TestMConnectionReceiveRate(t *testing.T) { 183 server, client := NetPipe() 184 defer server.Close() 185 defer client.Close() 186 187 // prepare a client connection with callbacks to receive messages 188 receivedCh := make(chan []byte) 189 errorsCh := make(chan interface{}) 190 onReceive := func(chID byte, msgBytes []byte) { 191 receivedCh <- msgBytes 192 } 193 onError := func(r interface{}) { 194 errorsCh <- r 195 } 196 197 cnfg := DefaultMConnConfig() 198 cnfg.SendRate = 500_000 // 500 KB/s 199 cnfg.RecvRate = 500_000 // 500 KB/s 200 201 clientConn := createMConnectionWithCallbacksConfigs(client, onReceive, onError, cnfg) 202 err := clientConn.Start() 203 require.Nil(t, err) 204 defer clientConn.Stop() //nolint:errcheck // ignore for tests 205 206 serverConn := createMConnectionWithCallbacksConfigs(server, func(chID byte, msgBytes []byte) {}, func(r interface{}) {}, cnfg) 207 err = serverConn.Start() 208 require.Nil(t, err) 209 defer serverConn.Stop() //nolint:errcheck // ignore for tests 210 211 msgSize := int(2 * cnfg.RecvRate) 212 msg := bytes.Repeat([]byte{1}, msgSize) 213 assert.True(t, serverConn.Send(0x01, msg)) 214 215 // approximate the time it takes to receive the message given the configured RecvRate 216 approxDelay := time.Duration(int64(math.Ceil(float64(msgSize)/float64(cnfg.RecvRate))) * int64(time.Second) * 2) 217 218 select { 219 case receivedBytes := <-receivedCh: 220 assert.Equal(t, msg, receivedBytes) 221 case err := <-errorsCh: 222 t.Fatalf("Expected %s, got %+v", msg, err) 223 case <-time.After(approxDelay): 224 t.Fatalf("Did not receive the message in %fs", approxDelay.Seconds()) 225 } 226 227 peakRecvRate := clientConn.recvMonitor.Status().PeakRate 228 maxRecvRate := clientConn.maxRecvRate() 229 230 assert.True(t, peakRecvRate <= maxRecvRate, fmt.Sprintf("peakRecvRate %d > maxRecvRate %d", peakRecvRate, maxRecvRate)) 231 232 peakSendRate := clientConn.sendMonitor.Status().PeakRate 233 maxSendRate := clientConn.maxSendRate() 234 235 assert.True(t, peakSendRate <= maxSendRate, fmt.Sprintf("peakSendRate %d > maxSendRate %d", peakSendRate, maxSendRate)) 236 } 237 238 // maxRecvRate returns the maximum receive rate in bytes per second based on 239 // the MConnection's RecvRate and other configs. 240 // It is used to calculate the highest expected value for the peak receive rate. 241 // Note that the returned value is slightly higher than the configured RecvRate. 242 func (c *MConnection) maxRecvRate() int64 { 243 sampleRate := c.recvMonitor.GetSampleRate().Seconds() 244 numberOfSamplePerSeccond := 1 / sampleRate 245 recvRate := float64(round(float64(c.config.RecvRate) * sampleRate)) 246 batchSizeBytes := float64(c._maxPacketMsgSize) 247 effectiveRecvRatePerSample := math.Ceil(recvRate/batchSizeBytes) * batchSizeBytes 248 effectiveRecvRate := round(numberOfSamplePerSeccond * effectiveRecvRatePerSample) 249 250 return effectiveRecvRate 251 } 252 253 func TestMConnectionReceive(t *testing.T) { 254 server, client := NetPipe() 255 defer server.Close() 256 defer client.Close() 257 258 receivedCh := make(chan []byte) 259 errorsCh := make(chan interface{}) 260 onReceive := func(chID byte, msgBytes []byte) { 261 receivedCh <- msgBytes 262 } 263 onError := func(r interface{}) { 264 errorsCh <- r 265 } 266 mconn1 := createMConnectionWithCallbacks(client, onReceive, onError) 267 err := mconn1.Start() 268 require.Nil(t, err) 269 defer mconn1.Stop() //nolint:errcheck // ignore for tests 270 271 mconn2 := createTestMConnection(server) 272 err = mconn2.Start() 273 require.Nil(t, err) 274 defer mconn2.Stop() //nolint:errcheck // ignore for tests 275 276 msg := []byte("Cyclops") 277 assert.True(t, mconn2.Send(0x01, msg)) 278 279 select { 280 case receivedBytes := <-receivedCh: 281 assert.Equal(t, msg, receivedBytes) 282 case err := <-errorsCh: 283 t.Fatalf("Expected %s, got %+v", msg, err) 284 case <-time.After(500 * time.Millisecond): 285 t.Fatalf("Did not receive %s message in 500ms", msg) 286 } 287 } 288 289 func TestMConnectionStatus(t *testing.T) { 290 server, client := NetPipe() 291 defer server.Close() 292 defer client.Close() 293 294 mconn := createTestMConnection(client) 295 err := mconn.Start() 296 require.Nil(t, err) 297 defer mconn.Stop() //nolint:errcheck // ignore for tests 298 299 status := mconn.Status() 300 assert.NotNil(t, status) 301 assert.Zero(t, status.Channels[0].SendQueueSize) 302 } 303 304 func TestMConnectionPongTimeoutResultsInError(t *testing.T) { 305 server, client := net.Pipe() 306 defer server.Close() 307 defer client.Close() 308 309 receivedCh := make(chan []byte) 310 errorsCh := make(chan interface{}) 311 onReceive := func(chID byte, msgBytes []byte) { 312 receivedCh <- msgBytes 313 } 314 onError := func(r interface{}) { 315 errorsCh <- r 316 } 317 mconn := createMConnectionWithCallbacks(client, onReceive, onError) 318 err := mconn.Start() 319 require.Nil(t, err) 320 defer mconn.Stop() //nolint:errcheck // ignore for tests 321 322 serverGotPing := make(chan struct{}) 323 go func() { 324 // read ping 325 var pkt tmp2p.Packet 326 _, err := protoio.NewDelimitedReader(server, maxPingPongPacketSize).ReadMsg(&pkt) 327 require.NoError(t, err) 328 serverGotPing <- struct{}{} 329 }() 330 <-serverGotPing 331 332 pongTimerExpired := mconn.config.PongTimeout + 200*time.Millisecond 333 select { 334 case msgBytes := <-receivedCh: 335 t.Fatalf("Expected error, but got %v", msgBytes) 336 case err := <-errorsCh: 337 assert.NotNil(t, err) 338 case <-time.After(pongTimerExpired): 339 t.Fatalf("Expected to receive error after %v", pongTimerExpired) 340 } 341 } 342 343 func TestMConnectionMultiplePongsInTheBeginning(t *testing.T) { 344 server, client := net.Pipe() 345 defer server.Close() 346 defer client.Close() 347 348 receivedCh := make(chan []byte) 349 errorsCh := make(chan interface{}) 350 onReceive := func(chID byte, msgBytes []byte) { 351 receivedCh <- msgBytes 352 } 353 onError := func(r interface{}) { 354 errorsCh <- r 355 } 356 mconn := createMConnectionWithCallbacks(client, onReceive, onError) 357 err := mconn.Start() 358 require.Nil(t, err) 359 defer mconn.Stop() //nolint:errcheck // ignore for tests 360 361 // sending 3 pongs in a row (abuse) 362 protoWriter := protoio.NewDelimitedWriter(server) 363 364 _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{})) 365 require.NoError(t, err) 366 367 _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{})) 368 require.NoError(t, err) 369 370 _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{})) 371 require.NoError(t, err) 372 373 serverGotPing := make(chan struct{}) 374 go func() { 375 // read ping (one byte) 376 var packet tmp2p.Packet 377 _, err := protoio.NewDelimitedReader(server, maxPingPongPacketSize).ReadMsg(&packet) 378 require.NoError(t, err) 379 serverGotPing <- struct{}{} 380 381 // respond with pong 382 _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{})) 383 require.NoError(t, err) 384 }() 385 <-serverGotPing 386 387 pongTimerExpired := mconn.config.PongTimeout + 20*time.Millisecond 388 select { 389 case msgBytes := <-receivedCh: 390 t.Fatalf("Expected no data, but got %v", msgBytes) 391 case err := <-errorsCh: 392 t.Fatalf("Expected no error, but got %v", err) 393 case <-time.After(pongTimerExpired): 394 assert.True(t, mconn.IsRunning()) 395 } 396 } 397 398 func TestMConnectionMultiplePings(t *testing.T) { 399 server, client := net.Pipe() 400 defer server.Close() 401 defer client.Close() 402 403 receivedCh := make(chan []byte) 404 errorsCh := make(chan interface{}) 405 onReceive := func(chID byte, msgBytes []byte) { 406 receivedCh <- msgBytes 407 } 408 onError := func(r interface{}) { 409 errorsCh <- r 410 } 411 mconn := createMConnectionWithCallbacks(client, onReceive, onError) 412 err := mconn.Start() 413 require.Nil(t, err) 414 defer mconn.Stop() //nolint:errcheck // ignore for tests 415 416 // sending 3 pings in a row (abuse) 417 // see https://github.com/cometbft/cometbft/issues/1190 418 protoReader := protoio.NewDelimitedReader(server, maxPingPongPacketSize) 419 protoWriter := protoio.NewDelimitedWriter(server) 420 var pkt tmp2p.Packet 421 422 _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPing{})) 423 require.NoError(t, err) 424 425 _, err = protoReader.ReadMsg(&pkt) 426 require.NoError(t, err) 427 428 _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPing{})) 429 require.NoError(t, err) 430 431 _, err = protoReader.ReadMsg(&pkt) 432 require.NoError(t, err) 433 434 _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPing{})) 435 require.NoError(t, err) 436 437 _, err = protoReader.ReadMsg(&pkt) 438 require.NoError(t, err) 439 440 assert.True(t, mconn.IsRunning()) 441 } 442 443 func TestMConnectionPingPongs(t *testing.T) { 444 // check that we are not leaking any go-routines 445 defer leaktest.CheckTimeout(t, 10*time.Second)() 446 447 server, client := net.Pipe() 448 449 defer server.Close() 450 defer client.Close() 451 452 receivedCh := make(chan []byte) 453 errorsCh := make(chan interface{}) 454 onReceive := func(chID byte, msgBytes []byte) { 455 receivedCh <- msgBytes 456 } 457 onError := func(r interface{}) { 458 errorsCh <- r 459 } 460 mconn := createMConnectionWithCallbacks(client, onReceive, onError) 461 err := mconn.Start() 462 require.Nil(t, err) 463 defer mconn.Stop() //nolint:errcheck // ignore for tests 464 465 serverGotPing := make(chan struct{}) 466 go func() { 467 protoReader := protoio.NewDelimitedReader(server, maxPingPongPacketSize) 468 protoWriter := protoio.NewDelimitedWriter(server) 469 var pkt tmp2p.PacketPing 470 471 // read ping 472 _, err = protoReader.ReadMsg(&pkt) 473 require.NoError(t, err) 474 serverGotPing <- struct{}{} 475 476 // respond with pong 477 _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{})) 478 require.NoError(t, err) 479 480 time.Sleep(mconn.config.PingInterval) 481 482 // read ping 483 _, err = protoReader.ReadMsg(&pkt) 484 require.NoError(t, err) 485 serverGotPing <- struct{}{} 486 487 // respond with pong 488 _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{})) 489 require.NoError(t, err) 490 }() 491 <-serverGotPing 492 <-serverGotPing 493 494 pongTimerExpired := (mconn.config.PongTimeout + 20*time.Millisecond) * 2 495 select { 496 case msgBytes := <-receivedCh: 497 t.Fatalf("Expected no data, but got %v", msgBytes) 498 case err := <-errorsCh: 499 t.Fatalf("Expected no error, but got %v", err) 500 case <-time.After(2 * pongTimerExpired): 501 assert.True(t, mconn.IsRunning()) 502 } 503 } 504 505 func TestMConnectionStopsAndReturnsError(t *testing.T) { 506 server, client := NetPipe() 507 defer server.Close() 508 defer client.Close() 509 510 receivedCh := make(chan []byte) 511 errorsCh := make(chan interface{}) 512 onReceive := func(chID byte, msgBytes []byte) { 513 receivedCh <- msgBytes 514 } 515 onError := func(r interface{}) { 516 errorsCh <- r 517 } 518 mconn := createMConnectionWithCallbacks(client, onReceive, onError) 519 err := mconn.Start() 520 require.Nil(t, err) 521 defer mconn.Stop() //nolint:errcheck // ignore for tests 522 523 if err := client.Close(); err != nil { 524 t.Error(err) 525 } 526 527 select { 528 case receivedBytes := <-receivedCh: 529 t.Fatalf("Expected error, got %v", receivedBytes) 530 case err := <-errorsCh: 531 assert.NotNil(t, err) 532 assert.False(t, mconn.IsRunning()) 533 case <-time.After(500 * time.Millisecond): 534 t.Fatal("Did not receive error in 500ms") 535 } 536 } 537 538 func newClientAndServerConnsForReadErrors(t *testing.T, chOnErr chan struct{}) (*MConnection, *MConnection) { 539 server, client := NetPipe() 540 541 onReceive := func(chID byte, msgBytes []byte) {} 542 onError := func(r interface{}) {} 543 544 // create client conn with two channels 545 chDescs := []*ChannelDescriptor{ 546 {ID: 0x01, Priority: 1, SendQueueCapacity: 1}, 547 {ID: 0x02, Priority: 1, SendQueueCapacity: 1}, 548 } 549 mconnClient := NewMConnection(client, chDescs, onReceive, onError) 550 mconnClient.SetLogger(log.TestingLogger().With("module", "client")) 551 err := mconnClient.Start() 552 require.Nil(t, err) 553 554 // create server conn with 1 channel 555 // it fires on chOnErr when there's an error 556 serverLogger := log.TestingLogger().With("module", "server") 557 onError = func(r interface{}) { 558 chOnErr <- struct{}{} 559 } 560 mconnServer := createMConnectionWithCallbacks(server, onReceive, onError) 561 mconnServer.SetLogger(serverLogger) 562 err = mconnServer.Start() 563 require.Nil(t, err) 564 return mconnClient, mconnServer 565 } 566 567 func expectSend(ch chan struct{}) bool { 568 after := time.After(time.Second * 5) 569 select { 570 case <-ch: 571 return true 572 case <-after: 573 return false 574 } 575 } 576 577 func TestMConnectionReadErrorBadEncoding(t *testing.T) { 578 chOnErr := make(chan struct{}) 579 mconnClient, mconnServer := newClientAndServerConnsForReadErrors(t, chOnErr) 580 581 client := mconnClient.conn 582 583 // Write it. 584 _, err := client.Write([]byte{1, 2, 3, 4, 5}) 585 require.NoError(t, err) 586 assert.True(t, expectSend(chOnErr), "badly encoded msgPacket") 587 588 t.Cleanup(func() { 589 if err := mconnClient.Stop(); err != nil { 590 t.Log(err) 591 } 592 }) 593 594 t.Cleanup(func() { 595 if err := mconnServer.Stop(); err != nil { 596 t.Log(err) 597 } 598 }) 599 } 600 601 func TestMConnectionReadErrorUnknownChannel(t *testing.T) { 602 chOnErr := make(chan struct{}) 603 mconnClient, mconnServer := newClientAndServerConnsForReadErrors(t, chOnErr) 604 605 msg := []byte("Ant-Man") 606 607 // fail to send msg on channel unknown by client 608 assert.False(t, mconnClient.Send(0x03, msg)) 609 610 // send msg on channel unknown by the server. 611 // should cause an error 612 assert.True(t, mconnClient.Send(0x02, msg)) 613 assert.True(t, expectSend(chOnErr), "unknown channel") 614 615 t.Cleanup(func() { 616 if err := mconnClient.Stop(); err != nil { 617 t.Log(err) 618 } 619 }) 620 621 t.Cleanup(func() { 622 if err := mconnServer.Stop(); err != nil { 623 t.Log(err) 624 } 625 }) 626 } 627 628 func TestMConnectionReadErrorLongMessage(t *testing.T) { 629 chOnErr := make(chan struct{}) 630 chOnRcv := make(chan struct{}) 631 632 mconnClient, mconnServer := newClientAndServerConnsForReadErrors(t, chOnErr) 633 defer mconnClient.Stop() //nolint:errcheck // ignore for tests 634 defer mconnServer.Stop() //nolint:errcheck // ignore for tests 635 636 mconnServer.onReceive = func(chID byte, msgBytes []byte) { 637 chOnRcv <- struct{}{} 638 } 639 640 client := mconnClient.conn 641 protoWriter := protoio.NewDelimitedWriter(client) 642 643 // send msg that's just right 644 var packet = tmp2p.PacketMsg{ 645 ChannelID: 0x01, 646 EOF: true, 647 Data: make([]byte, mconnClient.config.MaxPacketMsgPayloadSize), 648 } 649 650 _, err := protoWriter.WriteMsg(mustWrapPacket(&packet)) 651 require.NoError(t, err) 652 assert.True(t, expectSend(chOnRcv), "msg just right") 653 654 // send msg that's too long 655 packet = tmp2p.PacketMsg{ 656 ChannelID: 0x01, 657 EOF: true, 658 Data: make([]byte, mconnClient.config.MaxPacketMsgPayloadSize+100), 659 } 660 661 _, err = protoWriter.WriteMsg(mustWrapPacket(&packet)) 662 require.Error(t, err) 663 assert.True(t, expectSend(chOnErr), "msg too long") 664 } 665 666 func TestMConnectionReadErrorUnknownMsgType(t *testing.T) { 667 chOnErr := make(chan struct{}) 668 mconnClient, mconnServer := newClientAndServerConnsForReadErrors(t, chOnErr) 669 defer mconnClient.Stop() //nolint:errcheck // ignore for tests 670 defer mconnServer.Stop() //nolint:errcheck // ignore for tests 671 672 // send msg with unknown msg type 673 _, err := protoio.NewDelimitedWriter(mconnClient.conn).WriteMsg(&types.Header{ChainID: "x"}) 674 require.NoError(t, err) 675 assert.True(t, expectSend(chOnErr), "unknown msg type") 676 } 677 678 func TestMConnectionTrySend(t *testing.T) { 679 server, client := NetPipe() 680 defer server.Close() 681 defer client.Close() 682 683 mconn := createTestMConnection(client) 684 err := mconn.Start() 685 require.Nil(t, err) 686 defer mconn.Stop() //nolint:errcheck // ignore for tests 687 688 msg := []byte("Semicolon-Woman") 689 resultCh := make(chan string, 2) 690 assert.True(t, mconn.TrySend(0x01, msg)) 691 _, err = server.Read(make([]byte, len(msg))) 692 require.NoError(t, err) 693 assert.True(t, mconn.CanSend(0x01)) 694 assert.True(t, mconn.TrySend(0x01, msg)) 695 assert.False(t, mconn.CanSend(0x01)) 696 go func() { 697 mconn.TrySend(0x01, msg) 698 resultCh <- "TrySend" 699 }() 700 assert.False(t, mconn.CanSend(0x01)) 701 assert.False(t, mconn.TrySend(0x01, msg)) 702 assert.Equal(t, "TrySend", <-resultCh) 703 } 704 705 //nolint:lll //ignore line length for tests 706 func TestConnVectors(t *testing.T) { 707 708 testCases := []struct { 709 testName string 710 msg proto.Message 711 expBytes string 712 }{ 713 {"PacketPing", &tmp2p.PacketPing{}, "0a00"}, 714 {"PacketPong", &tmp2p.PacketPong{}, "1200"}, 715 {"PacketMsg", &tmp2p.PacketMsg{ChannelID: 1, EOF: false, Data: []byte("data transmitted over the wire")}, "1a2208011a1e64617461207472616e736d6974746564206f766572207468652077697265"}, 716 } 717 718 for _, tc := range testCases { 719 tc := tc 720 721 pm := mustWrapPacket(tc.msg) 722 bz, err := pm.Marshal() 723 require.NoError(t, err, tc.testName) 724 725 require.Equal(t, tc.expBytes, hex.EncodeToString(bz), tc.testName) 726 } 727 } 728 729 func TestMConnectionChannelOverflow(t *testing.T) { 730 chOnErr := make(chan struct{}) 731 chOnRcv := make(chan struct{}) 732 733 mconnClient, mconnServer := newClientAndServerConnsForReadErrors(t, chOnErr) 734 t.Cleanup(stopAll(t, mconnClient, mconnServer)) 735 736 mconnServer.onReceive = func(chID byte, msgBytes []byte) { 737 chOnRcv <- struct{}{} 738 } 739 740 client := mconnClient.conn 741 protoWriter := protoio.NewDelimitedWriter(client) 742 743 var packet = tmp2p.PacketMsg{ 744 ChannelID: 0x01, 745 EOF: true, 746 Data: []byte(`42`), 747 } 748 _, err := protoWriter.WriteMsg(mustWrapPacket(&packet)) 749 require.NoError(t, err) 750 assert.True(t, expectSend(chOnRcv)) 751 752 packet.ChannelID = int32(1025) 753 _, err = protoWriter.WriteMsg(mustWrapPacket(&packet)) 754 require.NoError(t, err) 755 assert.False(t, expectSend(chOnRcv)) 756 757 } 758 759 type stopper interface { 760 Stop() error 761 } 762 763 func stopAll(t *testing.T, stoppers ...stopper) func() { 764 return func() { 765 for _, s := range stoppers { 766 if err := s.Stop(); err != nil { 767 t.Log(err) 768 } 769 } 770 } 771 }