github.com/KYVENetwork/cometbft/v38@v38.0.3/p2p/switch_test.go (about) 1 package p2p 2 3 import ( 4 "bytes" 5 "errors" 6 "fmt" 7 "io" 8 "net" 9 "net/http" 10 "net/http/httptest" 11 "regexp" 12 "strconv" 13 "sync/atomic" 14 "testing" 15 "time" 16 17 "github.com/cosmos/gogoproto/proto" 18 "github.com/prometheus/client_golang/prometheus/promhttp" 19 "github.com/stretchr/testify/assert" 20 "github.com/stretchr/testify/require" 21 22 "github.com/KYVENetwork/cometbft/v38/config" 23 "github.com/KYVENetwork/cometbft/v38/crypto/ed25519" 24 "github.com/KYVENetwork/cometbft/v38/libs/log" 25 cmtsync "github.com/KYVENetwork/cometbft/v38/libs/sync" 26 "github.com/KYVENetwork/cometbft/v38/p2p/conn" 27 p2pproto "github.com/KYVENetwork/cometbft/v38/proto/cometbft/v38/p2p" 28 ) 29 30 var cfg *config.P2PConfig 31 32 func init() { 33 cfg = config.DefaultP2PConfig() 34 cfg.PexReactor = true 35 cfg.AllowDuplicateIP = true 36 } 37 38 type PeerMessage struct { 39 Contents proto.Message 40 Counter int 41 } 42 43 type TestReactor struct { 44 BaseReactor 45 46 mtx cmtsync.Mutex 47 channels []*conn.ChannelDescriptor 48 logMessages bool 49 msgsCounter int 50 msgsReceived map[byte][]PeerMessage 51 } 52 53 func NewTestReactor(channels []*conn.ChannelDescriptor, logMessages bool) *TestReactor { 54 tr := &TestReactor{ 55 channels: channels, 56 logMessages: logMessages, 57 msgsReceived: make(map[byte][]PeerMessage), 58 } 59 tr.BaseReactor = *NewBaseReactor("TestReactor", tr) 60 tr.SetLogger(log.TestingLogger()) 61 return tr 62 } 63 64 func (tr *TestReactor) GetChannels() []*conn.ChannelDescriptor { 65 return tr.channels 66 } 67 68 func (tr *TestReactor) AddPeer(Peer) {} 69 70 func (tr *TestReactor) RemovePeer(Peer, interface{}) {} 71 72 func (tr *TestReactor) Receive(e Envelope) { 73 if tr.logMessages { 74 tr.mtx.Lock() 75 defer tr.mtx.Unlock() 76 fmt.Printf("Received: %X, %X\n", e.ChannelID, e.Message) 77 tr.msgsReceived[e.ChannelID] = append(tr.msgsReceived[e.ChannelID], PeerMessage{Contents: e.Message, Counter: tr.msgsCounter}) 78 tr.msgsCounter++ 79 } 80 } 81 82 func (tr *TestReactor) getMsgs(chID byte) []PeerMessage { 83 tr.mtx.Lock() 84 defer tr.mtx.Unlock() 85 return tr.msgsReceived[chID] 86 } 87 88 //----------------------------------------------------------------------------- 89 90 // convenience method for creating two switches connected to each other. 91 // XXX: note this uses net.Pipe and not a proper TCP conn 92 func MakeSwitchPair(initSwitch func(int, *Switch) *Switch) (*Switch, *Switch) { 93 // Create two switches that will be interconnected. 94 switches := MakeConnectedSwitches(cfg, 2, initSwitch, Connect2Switches) 95 return switches[0], switches[1] 96 } 97 98 func initSwitchFunc(_ int, sw *Switch) *Switch { 99 sw.SetAddrBook(&AddrBookMock{ 100 Addrs: make(map[string]struct{}), 101 OurAddrs: make(map[string]struct{}), 102 }) 103 104 // Make two reactors of two channels each 105 sw.AddReactor("foo", NewTestReactor([]*conn.ChannelDescriptor{ 106 {ID: byte(0x00), Priority: 10, MessageType: &p2pproto.Message{}}, 107 {ID: byte(0x01), Priority: 10, MessageType: &p2pproto.Message{}}, 108 }, true)) 109 sw.AddReactor("bar", NewTestReactor([]*conn.ChannelDescriptor{ 110 {ID: byte(0x02), Priority: 10, MessageType: &p2pproto.Message{}}, 111 {ID: byte(0x03), Priority: 10, MessageType: &p2pproto.Message{}}, 112 }, true)) 113 114 return sw 115 } 116 117 func TestSwitches(t *testing.T) { 118 s1, s2 := MakeSwitchPair(initSwitchFunc) 119 t.Cleanup(func() { 120 if err := s1.Stop(); err != nil { 121 t.Error(err) 122 } 123 }) 124 t.Cleanup(func() { 125 if err := s2.Stop(); err != nil { 126 t.Error(err) 127 } 128 }) 129 130 if s1.Peers().Size() != 1 { 131 t.Errorf("expected exactly 1 peer in s1, got %v", s1.Peers().Size()) 132 } 133 if s2.Peers().Size() != 1 { 134 t.Errorf("expected exactly 1 peer in s2, got %v", s2.Peers().Size()) 135 } 136 137 // Lets send some messages 138 ch0Msg := &p2pproto.PexAddrs{ 139 Addrs: []p2pproto.NetAddress{ 140 { 141 ID: "1", 142 }, 143 }, 144 } 145 ch1Msg := &p2pproto.PexAddrs{ 146 Addrs: []p2pproto.NetAddress{ 147 { 148 ID: "1", 149 }, 150 }, 151 } 152 ch2Msg := &p2pproto.PexAddrs{ 153 Addrs: []p2pproto.NetAddress{ 154 { 155 ID: "2", 156 }, 157 }, 158 } 159 s1.Broadcast(Envelope{ChannelID: byte(0x00), Message: ch0Msg}) 160 s1.Broadcast(Envelope{ChannelID: byte(0x01), Message: ch1Msg}) 161 s1.Broadcast(Envelope{ChannelID: byte(0x02), Message: ch2Msg}) 162 assertMsgReceivedWithTimeout(t, 163 ch0Msg, 164 byte(0x00), 165 s2.Reactor("foo").(*TestReactor), 200*time.Millisecond, 5*time.Second) 166 assertMsgReceivedWithTimeout(t, 167 ch1Msg, 168 byte(0x01), 169 s2.Reactor("foo").(*TestReactor), 200*time.Millisecond, 5*time.Second) 170 assertMsgReceivedWithTimeout(t, 171 ch2Msg, 172 byte(0x02), 173 s2.Reactor("bar").(*TestReactor), 200*time.Millisecond, 5*time.Second) 174 } 175 176 func assertMsgReceivedWithTimeout( 177 t *testing.T, 178 msg proto.Message, 179 channel byte, 180 reactor *TestReactor, 181 checkPeriod, 182 timeout time.Duration, 183 ) { 184 ticker := time.NewTicker(checkPeriod) 185 for { 186 select { 187 case <-ticker.C: 188 msgs := reactor.getMsgs(channel) 189 expectedBytes, err := proto.Marshal(msgs[0].Contents) 190 require.NoError(t, err) 191 gotBytes, err := proto.Marshal(msg) 192 require.NoError(t, err) 193 if len(msgs) > 0 { 194 if !bytes.Equal(expectedBytes, gotBytes) { 195 t.Fatalf("Unexpected message bytes. Wanted: %X, Got: %X", msg, msgs[0].Counter) 196 } 197 return 198 } 199 200 case <-time.After(timeout): 201 t.Fatalf("Expected to have received 1 message in channel #%v, got zero", channel) 202 } 203 } 204 } 205 206 func TestSwitchFiltersOutItself(t *testing.T) { 207 s1 := MakeSwitch(cfg, 1, initSwitchFunc) 208 209 // simulate s1 having a public IP by creating a remote peer with the same ID 210 rp := &remotePeer{PrivKey: s1.nodeKey.PrivKey, Config: cfg} 211 rp.Start() 212 213 // addr should be rejected in addPeer based on the same ID 214 err := s1.DialPeerWithAddress(rp.Addr()) 215 if assert.Error(t, err) { 216 if err, ok := err.(ErrRejected); ok { 217 if !err.IsSelf() { 218 t.Errorf("expected self to be rejected") 219 } 220 } else { 221 t.Errorf("expected ErrRejected") 222 } 223 } 224 225 assert.True(t, s1.addrBook.OurAddress(rp.Addr())) 226 assert.False(t, s1.addrBook.HasAddress(rp.Addr())) 227 228 rp.Stop() 229 230 assertNoPeersAfterTimeout(t, s1, 100*time.Millisecond) 231 } 232 233 func TestSwitchPeerFilter(t *testing.T) { 234 var ( 235 filters = []PeerFilterFunc{ 236 func(_ IPeerSet, _ Peer) error { return nil }, 237 func(_ IPeerSet, _ Peer) error { return fmt.Errorf("denied") }, 238 func(_ IPeerSet, _ Peer) error { return nil }, 239 } 240 sw = MakeSwitch( 241 cfg, 242 1, 243 initSwitchFunc, 244 SwitchPeerFilters(filters...), 245 ) 246 ) 247 err := sw.Start() 248 require.NoError(t, err) 249 t.Cleanup(func() { 250 if err := sw.Stop(); err != nil { 251 t.Error(err) 252 } 253 }) 254 255 // simulate remote peer 256 rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg} 257 rp.Start() 258 t.Cleanup(rp.Stop) 259 260 p, err := sw.transport.Dial(*rp.Addr(), peerConfig{ 261 chDescs: sw.chDescs, 262 onPeerError: sw.StopPeerForError, 263 isPersistent: sw.IsPeerPersistent, 264 reactorsByCh: sw.reactorsByCh, 265 }) 266 if err != nil { 267 t.Fatal(err) 268 } 269 270 err = sw.addPeer(p) 271 if err, ok := err.(ErrRejected); ok { 272 if !err.IsFiltered() { 273 t.Errorf("expected peer to be filtered") 274 } 275 } else { 276 t.Errorf("expected ErrRejected") 277 } 278 } 279 280 func TestSwitchPeerFilterTimeout(t *testing.T) { 281 var ( 282 filters = []PeerFilterFunc{ 283 func(_ IPeerSet, _ Peer) error { 284 time.Sleep(10 * time.Millisecond) 285 return nil 286 }, 287 } 288 sw = MakeSwitch( 289 cfg, 290 1, 291 initSwitchFunc, 292 SwitchFilterTimeout(5*time.Millisecond), 293 SwitchPeerFilters(filters...), 294 ) 295 ) 296 err := sw.Start() 297 require.NoError(t, err) 298 t.Cleanup(func() { 299 if err := sw.Stop(); err != nil { 300 t.Log(err) 301 } 302 }) 303 304 // simulate remote peer 305 rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg} 306 rp.Start() 307 defer rp.Stop() 308 309 p, err := sw.transport.Dial(*rp.Addr(), peerConfig{ 310 chDescs: sw.chDescs, 311 onPeerError: sw.StopPeerForError, 312 isPersistent: sw.IsPeerPersistent, 313 reactorsByCh: sw.reactorsByCh, 314 }) 315 if err != nil { 316 t.Fatal(err) 317 } 318 319 err = sw.addPeer(p) 320 if _, ok := err.(ErrFilterTimeout); !ok { 321 t.Errorf("expected ErrFilterTimeout") 322 } 323 } 324 325 func TestSwitchPeerFilterDuplicate(t *testing.T) { 326 sw := MakeSwitch(cfg, 1, initSwitchFunc) 327 err := sw.Start() 328 require.NoError(t, err) 329 t.Cleanup(func() { 330 if err := sw.Stop(); err != nil { 331 t.Error(err) 332 } 333 }) 334 335 // simulate remote peer 336 rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg} 337 rp.Start() 338 defer rp.Stop() 339 340 p, err := sw.transport.Dial(*rp.Addr(), peerConfig{ 341 chDescs: sw.chDescs, 342 onPeerError: sw.StopPeerForError, 343 isPersistent: sw.IsPeerPersistent, 344 reactorsByCh: sw.reactorsByCh, 345 }) 346 if err != nil { 347 t.Fatal(err) 348 } 349 350 if err := sw.addPeer(p); err != nil { 351 t.Fatal(err) 352 } 353 354 err = sw.addPeer(p) 355 if errRej, ok := err.(ErrRejected); ok { 356 if !errRej.IsDuplicate() { 357 t.Errorf("expected peer to be duplicate. got %v", errRej) 358 } 359 } else { 360 t.Errorf("expected ErrRejected, got %v", err) 361 } 362 } 363 364 func assertNoPeersAfterTimeout(t *testing.T, sw *Switch, timeout time.Duration) { 365 time.Sleep(timeout) 366 if sw.Peers().Size() != 0 { 367 t.Fatalf("Expected %v to not connect to some peers, got %d", sw, sw.Peers().Size()) 368 } 369 } 370 371 func TestSwitchStopsNonPersistentPeerOnError(t *testing.T) { 372 assert, require := assert.New(t), require.New(t) 373 374 sw := MakeSwitch(cfg, 1, initSwitchFunc) 375 err := sw.Start() 376 if err != nil { 377 t.Error(err) 378 } 379 t.Cleanup(func() { 380 if err := sw.Stop(); err != nil { 381 t.Error(err) 382 } 383 }) 384 385 // simulate remote peer 386 rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg} 387 rp.Start() 388 defer rp.Stop() 389 390 p, err := sw.transport.Dial(*rp.Addr(), peerConfig{ 391 chDescs: sw.chDescs, 392 onPeerError: sw.StopPeerForError, 393 isPersistent: sw.IsPeerPersistent, 394 reactorsByCh: sw.reactorsByCh, 395 }) 396 require.Nil(err) 397 398 err = sw.addPeer(p) 399 require.Nil(err) 400 401 require.NotNil(sw.Peers().Get(rp.ID())) 402 403 // simulate failure by closing connection 404 err = p.(*peer).CloseConn() 405 require.NoError(err) 406 407 assertNoPeersAfterTimeout(t, sw, 100*time.Millisecond) 408 assert.False(p.IsRunning()) 409 } 410 411 func TestSwitchStopPeerForError(t *testing.T) { 412 s := httptest.NewServer(promhttp.Handler()) 413 defer s.Close() 414 415 scrapeMetrics := func() string { 416 resp, err := http.Get(s.URL) 417 require.NoError(t, err) 418 defer resp.Body.Close() 419 buf, _ := io.ReadAll(resp.Body) 420 return string(buf) 421 } 422 423 namespace, subsystem, name := config.TestInstrumentationConfig().Namespace, MetricsSubsystem, "peers" 424 re := regexp.MustCompile(namespace + `_` + subsystem + `_` + name + ` ([0-9\.]+)`) 425 peersMetricValue := func() float64 { 426 matches := re.FindStringSubmatch(scrapeMetrics()) 427 f, _ := strconv.ParseFloat(matches[1], 64) 428 return f 429 } 430 431 p2pMetrics := PrometheusMetrics(namespace) 432 433 // make two connected switches 434 sw1, sw2 := MakeSwitchPair(func(i int, sw *Switch) *Switch { 435 // set metrics on sw1 436 if i == 0 { 437 opt := WithMetrics(p2pMetrics) 438 opt(sw) 439 } 440 return initSwitchFunc(i, sw) 441 }) 442 443 assert.Equal(t, len(sw1.Peers().List()), 1) 444 assert.EqualValues(t, 1, peersMetricValue()) 445 446 // send messages to the peer from sw1 447 p := sw1.Peers().List()[0] 448 p.Send(Envelope{ 449 ChannelID: 0x1, 450 Message: &p2pproto.Message{}, 451 }) 452 453 // stop sw2. this should cause the p to fail, 454 // which results in calling StopPeerForError internally 455 t.Cleanup(func() { 456 if err := sw2.Stop(); err != nil { 457 t.Error(err) 458 } 459 }) 460 461 // now call StopPeerForError explicitly, eg. from a reactor 462 sw1.StopPeerForError(p, fmt.Errorf("some err")) 463 464 assert.Equal(t, len(sw1.Peers().List()), 0) 465 assert.EqualValues(t, 0, peersMetricValue()) 466 } 467 468 func TestSwitchReconnectsToOutboundPersistentPeer(t *testing.T) { 469 sw := MakeSwitch(cfg, 1, initSwitchFunc) 470 err := sw.Start() 471 require.NoError(t, err) 472 t.Cleanup(func() { 473 if err := sw.Stop(); err != nil { 474 t.Error(err) 475 } 476 }) 477 478 // 1. simulate failure by closing connection 479 rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg} 480 rp.Start() 481 defer rp.Stop() 482 483 err = sw.AddPersistentPeers([]string{rp.Addr().String()}) 484 require.NoError(t, err) 485 486 err = sw.DialPeerWithAddress(rp.Addr()) 487 require.Nil(t, err) 488 require.NotNil(t, sw.Peers().Get(rp.ID())) 489 490 p := sw.Peers().List()[0] 491 err = p.(*peer).CloseConn() 492 require.NoError(t, err) 493 494 waitUntilSwitchHasAtLeastNPeers(sw, 1) 495 assert.False(t, p.IsRunning()) // old peer instance 496 assert.Equal(t, 1, sw.Peers().Size()) // new peer instance 497 498 // 2. simulate first time dial failure 499 rp = &remotePeer{ 500 PrivKey: ed25519.GenPrivKey(), 501 Config: cfg, 502 // Use different interface to prevent duplicate IP filter, this will break 503 // beyond two peers. 504 listenAddr: "127.0.0.1:0", 505 } 506 rp.Start() 507 defer rp.Stop() 508 509 conf := config.DefaultP2PConfig() 510 conf.TestDialFail = true // will trigger a reconnect 511 err = sw.addOutboundPeerWithConfig(rp.Addr(), conf) 512 require.NotNil(t, err) 513 // DialPeerWithAddres - sw.peerConfig resets the dialer 514 waitUntilSwitchHasAtLeastNPeers(sw, 2) 515 assert.Equal(t, 2, sw.Peers().Size()) 516 } 517 518 func TestSwitchReconnectsToInboundPersistentPeer(t *testing.T) { 519 sw := MakeSwitch(cfg, 1, initSwitchFunc) 520 err := sw.Start() 521 require.NoError(t, err) 522 t.Cleanup(func() { 523 if err := sw.Stop(); err != nil { 524 t.Error(err) 525 } 526 }) 527 528 // 1. simulate failure by closing the connection 529 rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg} 530 rp.Start() 531 defer rp.Stop() 532 533 err = sw.AddPersistentPeers([]string{rp.Addr().String()}) 534 require.NoError(t, err) 535 536 conn, err := rp.Dial(sw.NetAddress()) 537 require.NoError(t, err) 538 time.Sleep(50 * time.Millisecond) 539 require.NotNil(t, sw.Peers().Get(rp.ID())) 540 541 conn.Close() 542 543 waitUntilSwitchHasAtLeastNPeers(sw, 1) 544 assert.Equal(t, 1, sw.Peers().Size()) 545 } 546 547 func TestSwitchDialPeersAsync(t *testing.T) { 548 if testing.Short() { 549 return 550 } 551 552 sw := MakeSwitch(cfg, 1, initSwitchFunc) 553 err := sw.Start() 554 require.NoError(t, err) 555 t.Cleanup(func() { 556 if err := sw.Stop(); err != nil { 557 t.Error(err) 558 } 559 }) 560 561 rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg} 562 rp.Start() 563 defer rp.Stop() 564 565 err = sw.DialPeersAsync([]string{rp.Addr().String()}) 566 require.NoError(t, err) 567 time.Sleep(dialRandomizerIntervalMilliseconds * time.Millisecond) 568 require.NotNil(t, sw.Peers().Get(rp.ID())) 569 } 570 571 func waitUntilSwitchHasAtLeastNPeers(sw *Switch, n int) { 572 for i := 0; i < 20; i++ { 573 time.Sleep(250 * time.Millisecond) 574 has := sw.Peers().Size() 575 if has >= n { 576 break 577 } 578 } 579 } 580 581 func TestSwitchFullConnectivity(t *testing.T) { 582 switches := MakeConnectedSwitches(cfg, 3, initSwitchFunc, Connect2Switches) 583 defer func() { 584 for _, sw := range switches { 585 sw := sw 586 t.Cleanup(func() { 587 if err := sw.Stop(); err != nil { 588 t.Error(err) 589 } 590 }) 591 } 592 }() 593 594 for i, sw := range switches { 595 if sw.Peers().Size() != 2 { 596 t.Fatalf("Expected each switch to be connected to 2 other, but %d switch only connected to %d", sw.Peers().Size(), i) 597 } 598 } 599 } 600 601 func TestSwitchAcceptRoutine(t *testing.T) { 602 cfg.MaxNumInboundPeers = 5 603 604 // Create some unconditional peers. 605 const unconditionalPeersNum = 2 606 var ( 607 unconditionalPeers = make([]*remotePeer, unconditionalPeersNum) 608 unconditionalPeerIDs = make([]string, unconditionalPeersNum) 609 ) 610 for i := 0; i < unconditionalPeersNum; i++ { 611 peer := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg} 612 peer.Start() 613 unconditionalPeers[i] = peer 614 unconditionalPeerIDs[i] = string(peer.ID()) 615 } 616 617 // make switch 618 sw := MakeSwitch(cfg, 1, initSwitchFunc) 619 err := sw.AddUnconditionalPeerIDs(unconditionalPeerIDs) 620 require.NoError(t, err) 621 err = sw.Start() 622 require.NoError(t, err) 623 t.Cleanup(func() { 624 err := sw.Stop() 625 require.NoError(t, err) 626 }) 627 628 // 0. check there are no peers 629 assert.Equal(t, 0, sw.Peers().Size()) 630 631 // 1. check we connect up to MaxNumInboundPeers 632 peers := make([]*remotePeer, 0) 633 for i := 0; i < cfg.MaxNumInboundPeers; i++ { 634 peer := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg} 635 peers = append(peers, peer) 636 peer.Start() 637 c, err := peer.Dial(sw.NetAddress()) 638 require.NoError(t, err) 639 // spawn a reading routine to prevent connection from closing 640 go func(c net.Conn) { 641 for { 642 one := make([]byte, 1) 643 _, err := c.Read(one) 644 if err != nil { 645 return 646 } 647 } 648 }(c) 649 } 650 time.Sleep(100 * time.Millisecond) 651 assert.Equal(t, cfg.MaxNumInboundPeers, sw.Peers().Size()) 652 653 // 2. check we close new connections if we already have MaxNumInboundPeers peers 654 peer := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg} 655 peer.Start() 656 conn, err := peer.Dial(sw.NetAddress()) 657 require.NoError(t, err) 658 // check conn is closed 659 one := make([]byte, 1) 660 _ = conn.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) 661 _, err = conn.Read(one) 662 assert.Error(t, err) 663 assert.Equal(t, cfg.MaxNumInboundPeers, sw.Peers().Size()) 664 peer.Stop() 665 666 // 3. check we connect to unconditional peers despite the limit. 667 for _, peer := range unconditionalPeers { 668 c, err := peer.Dial(sw.NetAddress()) 669 require.NoError(t, err) 670 // spawn a reading routine to prevent connection from closing 671 go func(c net.Conn) { 672 for { 673 one := make([]byte, 1) 674 _, err := c.Read(one) 675 if err != nil { 676 return 677 } 678 } 679 }(c) 680 } 681 time.Sleep(10 * time.Millisecond) 682 assert.Equal(t, cfg.MaxNumInboundPeers+unconditionalPeersNum, sw.Peers().Size()) 683 684 for _, peer := range peers { 685 peer.Stop() 686 } 687 for _, peer := range unconditionalPeers { 688 peer.Stop() 689 } 690 } 691 692 type errorTransport struct { 693 acceptErr error 694 } 695 696 func (et errorTransport) NetAddress() NetAddress { 697 panic("not implemented") 698 } 699 700 func (et errorTransport) Accept(peerConfig) (Peer, error) { 701 return nil, et.acceptErr 702 } 703 704 func (errorTransport) Dial(NetAddress, peerConfig) (Peer, error) { 705 panic("not implemented") 706 } 707 708 func (errorTransport) Cleanup(Peer) { 709 panic("not implemented") 710 } 711 712 func TestSwitchAcceptRoutineErrorCases(t *testing.T) { 713 sw := NewSwitch(cfg, errorTransport{ErrFilterTimeout{}}) 714 assert.NotPanics(t, func() { 715 err := sw.Start() 716 require.NoError(t, err) 717 err = sw.Stop() 718 require.NoError(t, err) 719 }) 720 721 sw = NewSwitch(cfg, errorTransport{ErrRejected{conn: nil, err: errors.New("filtered"), isFiltered: true}}) 722 assert.NotPanics(t, func() { 723 err := sw.Start() 724 require.NoError(t, err) 725 err = sw.Stop() 726 require.NoError(t, err) 727 }) 728 // TODO(melekes) check we remove our address from addrBook 729 730 sw = NewSwitch(cfg, errorTransport{ErrTransportClosed{}}) 731 assert.NotPanics(t, func() { 732 err := sw.Start() 733 require.NoError(t, err) 734 err = sw.Stop() 735 require.NoError(t, err) 736 }) 737 } 738 739 // mockReactor checks that InitPeer never called before RemovePeer. If that's 740 // not true, InitCalledBeforeRemoveFinished will return true. 741 type mockReactor struct { 742 *BaseReactor 743 744 // atomic 745 removePeerInProgress uint32 746 initCalledBeforeRemoveFinished uint32 747 } 748 749 func (r *mockReactor) RemovePeer(Peer, interface{}) { 750 atomic.StoreUint32(&r.removePeerInProgress, 1) 751 defer atomic.StoreUint32(&r.removePeerInProgress, 0) 752 time.Sleep(100 * time.Millisecond) 753 } 754 755 func (r *mockReactor) InitPeer(peer Peer) Peer { 756 if atomic.LoadUint32(&r.removePeerInProgress) == 1 { 757 atomic.StoreUint32(&r.initCalledBeforeRemoveFinished, 1) 758 } 759 760 return peer 761 } 762 763 func (r *mockReactor) InitCalledBeforeRemoveFinished() bool { 764 return atomic.LoadUint32(&r.initCalledBeforeRemoveFinished) == 1 765 } 766 767 // see stopAndRemovePeer 768 func TestSwitchInitPeerIsNotCalledBeforeRemovePeer(t *testing.T) { 769 // make reactor 770 reactor := &mockReactor{} 771 reactor.BaseReactor = NewBaseReactor("mockReactor", reactor) 772 773 // make switch 774 sw := MakeSwitch(cfg, 1, func(i int, sw *Switch) *Switch { 775 sw.AddReactor("mock", reactor) 776 return sw 777 }) 778 err := sw.Start() 779 require.NoError(t, err) 780 t.Cleanup(func() { 781 if err := sw.Stop(); err != nil { 782 t.Error(err) 783 } 784 }) 785 786 // add peer 787 rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg} 788 rp.Start() 789 defer rp.Stop() 790 _, err = rp.Dial(sw.NetAddress()) 791 require.NoError(t, err) 792 793 // wait till the switch adds rp to the peer set, then stop the peer asynchronously 794 for { 795 time.Sleep(20 * time.Millisecond) 796 if peer := sw.Peers().Get(rp.ID()); peer != nil { 797 go sw.StopPeerForError(peer, "test") 798 break 799 } 800 } 801 802 // simulate peer reconnecting to us 803 _, err = rp.Dial(sw.NetAddress()) 804 require.NoError(t, err) 805 // wait till the switch adds rp to the peer set 806 time.Sleep(50 * time.Millisecond) 807 808 // make sure reactor.RemovePeer is finished before InitPeer is called 809 assert.False(t, reactor.InitCalledBeforeRemoveFinished()) 810 } 811 812 func BenchmarkSwitchBroadcast(b *testing.B) { 813 s1, s2 := MakeSwitchPair(func(i int, sw *Switch) *Switch { 814 // Make bar reactors of bar channels each 815 sw.AddReactor("foo", NewTestReactor([]*conn.ChannelDescriptor{ 816 {ID: byte(0x00), Priority: 10}, 817 {ID: byte(0x01), Priority: 10}, 818 }, false)) 819 sw.AddReactor("bar", NewTestReactor([]*conn.ChannelDescriptor{ 820 {ID: byte(0x02), Priority: 10}, 821 {ID: byte(0x03), Priority: 10}, 822 }, false)) 823 return sw 824 }) 825 826 b.Cleanup(func() { 827 if err := s1.Stop(); err != nil { 828 b.Error(err) 829 } 830 }) 831 832 b.Cleanup(func() { 833 if err := s2.Stop(); err != nil { 834 b.Error(err) 835 } 836 }) 837 838 // Allow time for goroutines to boot up 839 time.Sleep(1 * time.Second) 840 841 b.ResetTimer() 842 843 numSuccess, numFailure := 0, 0 844 845 // Send random message from foo channel to another 846 for i := 0; i < b.N; i++ { 847 chID := byte(i % 4) 848 successChan := s1.Broadcast(Envelope{ChannelID: chID}) 849 for s := range successChan { 850 if s { 851 numSuccess++ 852 } else { 853 numFailure++ 854 } 855 } 856 } 857 858 b.Logf("success: %v, failure: %v", numSuccess, numFailure) 859 } 860 861 func TestSwitchRemovalErr(t *testing.T) { 862 sw1, sw2 := MakeSwitchPair(func(i int, sw *Switch) *Switch { 863 return initSwitchFunc(i, sw) 864 }) 865 assert.Equal(t, len(sw1.Peers().List()), 1) 866 p := sw1.Peers().List()[0] 867 868 sw2.StopPeerForError(p, fmt.Errorf("peer should error")) 869 870 assert.Equal(t, sw2.peers.Add(p).Error(), ErrPeerRemoval{}.Error()) 871 }