github.com/badrootd/nibiru-cometbft@v0.37.5-0.20240307173500-2a75559eee9b/p2p/switch_test.go (about) 1 package p2p 2 3 import ( 4 "bytes" 5 "errors" 6 "fmt" 7 "io" 8 "net" 9 "net/http" 10 "net/http/httptest" 11 "regexp" 12 "strconv" 13 "sync/atomic" 14 "testing" 15 "time" 16 17 "github.com/cosmos/gogoproto/proto" 18 "github.com/prometheus/client_golang/prometheus/promhttp" 19 "github.com/stretchr/testify/assert" 20 "github.com/stretchr/testify/require" 21 22 "github.com/badrootd/nibiru-cometbft/config" 23 "github.com/badrootd/nibiru-cometbft/crypto/ed25519" 24 "github.com/badrootd/nibiru-cometbft/libs/log" 25 cmtsync "github.com/badrootd/nibiru-cometbft/libs/sync" 26 "github.com/badrootd/nibiru-cometbft/p2p/conn" 27 p2pproto "github.com/badrootd/nibiru-cometbft/proto/tendermint/p2p" 28 ) 29 30 var ( 31 cfg *config.P2PConfig 32 ) 33 34 func init() { 35 cfg = config.DefaultP2PConfig() 36 cfg.PexReactor = true 37 cfg.AllowDuplicateIP = true 38 } 39 40 type PeerMessage struct { 41 Contents proto.Message 42 Counter int 43 } 44 45 type TestReactor struct { 46 BaseReactor 47 48 mtx cmtsync.Mutex 49 channels []*conn.ChannelDescriptor 50 logMessages bool 51 msgsCounter int 52 msgsReceived map[byte][]PeerMessage 53 } 54 55 func NewTestReactor(channels []*conn.ChannelDescriptor, logMessages bool) *TestReactor { 56 tr := &TestReactor{ 57 channels: channels, 58 logMessages: logMessages, 59 msgsReceived: make(map[byte][]PeerMessage), 60 } 61 tr.BaseReactor = *NewBaseReactor("TestReactor", tr) 62 tr.SetLogger(log.TestingLogger()) 63 return tr 64 } 65 66 func (tr *TestReactor) GetChannels() []*conn.ChannelDescriptor { 67 return tr.channels 68 } 69 70 func (tr *TestReactor) AddPeer(peer Peer) {} 71 72 func (tr *TestReactor) RemovePeer(peer Peer, reason interface{}) {} 73 74 func (tr *TestReactor) ReceiveEnvelope(e Envelope) { 75 if tr.logMessages { 76 tr.mtx.Lock() 77 defer tr.mtx.Unlock() 78 // fmt.Printf("Received: %X, %X\n", e.ChannelID, e.Message) 79 tr.msgsReceived[e.ChannelID] = append(tr.msgsReceived[e.ChannelID], PeerMessage{Contents: e.Message, Counter: tr.msgsCounter}) 80 tr.msgsCounter++ 81 } 82 } 83 84 func (tr *TestReactor) getMsgs(chID byte) []PeerMessage { 85 tr.mtx.Lock() 86 defer tr.mtx.Unlock() 87 return tr.msgsReceived[chID] 88 } 89 90 //----------------------------------------------------------------------------- 91 92 // convenience method for creating two switches connected to each other. 93 // XXX: note this uses net.Pipe and not a proper TCP conn 94 func MakeSwitchPair(t testing.TB, initSwitch func(int, *Switch) *Switch) (*Switch, *Switch) { 95 // Create two switches that will be interconnected. 96 switches := MakeConnectedSwitches(cfg, 2, initSwitch, Connect2Switches) 97 return switches[0], switches[1] 98 } 99 100 func initSwitchFunc(i int, sw *Switch) *Switch { 101 sw.SetAddrBook(&AddrBookMock{ 102 Addrs: make(map[string]struct{}), 103 OurAddrs: make(map[string]struct{})}) 104 105 // Make two reactors of two channels each 106 sw.AddReactor("foo", NewTestReactor([]*conn.ChannelDescriptor{ 107 {ID: byte(0x00), Priority: 10, MessageType: &p2pproto.Message{}}, 108 {ID: byte(0x01), Priority: 10, MessageType: &p2pproto.Message{}}, 109 }, true)) 110 sw.AddReactor("bar", NewTestReactor([]*conn.ChannelDescriptor{ 111 {ID: byte(0x02), Priority: 10, MessageType: &p2pproto.Message{}}, 112 {ID: byte(0x03), Priority: 10, MessageType: &p2pproto.Message{}}, 113 }, true)) 114 115 return sw 116 } 117 118 func TestSwitches(t *testing.T) { 119 s1, s2 := MakeSwitchPair(t, initSwitchFunc) 120 t.Cleanup(func() { 121 if err := s1.Stop(); err != nil { 122 t.Error(err) 123 } 124 }) 125 t.Cleanup(func() { 126 if err := s2.Stop(); err != nil { 127 t.Error(err) 128 } 129 }) 130 131 if s1.Peers().Size() != 1 { 132 t.Errorf("expected exactly 1 peer in s1, got %v", s1.Peers().Size()) 133 } 134 if s2.Peers().Size() != 1 { 135 t.Errorf("expected exactly 1 peer in s2, got %v", s2.Peers().Size()) 136 } 137 138 // Lets send some messages 139 ch0Msg := &p2pproto.PexAddrs{ 140 Addrs: []p2pproto.NetAddress{ 141 { 142 ID: "1", 143 }, 144 }, 145 } 146 ch1Msg := &p2pproto.PexAddrs{ 147 Addrs: []p2pproto.NetAddress{ 148 { 149 ID: "1", 150 }, 151 }, 152 } 153 ch2Msg := &p2pproto.PexAddrs{ 154 Addrs: []p2pproto.NetAddress{ 155 { 156 ID: "2", 157 }, 158 }, 159 } 160 s1.BroadcastEnvelope(Envelope{ChannelID: byte(0x00), Message: ch0Msg}) 161 s1.BroadcastEnvelope(Envelope{ChannelID: byte(0x01), Message: ch1Msg}) 162 s1.BroadcastEnvelope(Envelope{ChannelID: byte(0x02), Message: ch2Msg}) 163 assertMsgReceivedWithTimeout(t, 164 ch0Msg, 165 byte(0x00), 166 s2.Reactor("foo").(*TestReactor), 200*time.Millisecond, 5*time.Second) 167 assertMsgReceivedWithTimeout(t, 168 ch1Msg, 169 byte(0x01), 170 s2.Reactor("foo").(*TestReactor), 200*time.Millisecond, 5*time.Second) 171 assertMsgReceivedWithTimeout(t, 172 ch2Msg, 173 byte(0x02), 174 s2.Reactor("bar").(*TestReactor), 200*time.Millisecond, 5*time.Second) 175 } 176 177 func assertMsgReceivedWithTimeout( 178 t *testing.T, 179 msg proto.Message, 180 channel byte, 181 reactor *TestReactor, 182 checkPeriod, 183 timeout time.Duration, 184 ) { 185 ticker := time.NewTicker(checkPeriod) 186 for { 187 select { 188 case <-ticker.C: 189 msgs := reactor.getMsgs(channel) 190 expectedBytes, err := proto.Marshal(msgs[0].Contents) 191 require.NoError(t, err) 192 gotBytes, err := proto.Marshal(msg) 193 require.NoError(t, err) 194 if len(msgs) > 0 { 195 if !bytes.Equal(expectedBytes, gotBytes) { 196 t.Fatalf("Unexpected message bytes. Wanted: %X, Got: %X", msg, msgs[0].Counter) 197 } 198 return 199 } 200 201 case <-time.After(timeout): 202 t.Fatalf("Expected to have received 1 message in channel #%v, got zero", channel) 203 } 204 } 205 } 206 207 func TestSwitchFiltersOutItself(t *testing.T) { 208 s1 := MakeSwitch(cfg, 1, "127.0.0.1", "123.123.123", initSwitchFunc) 209 210 // simulate s1 having a public IP by creating a remote peer with the same ID 211 rp := &remotePeer{PrivKey: s1.nodeKey.PrivKey, Config: cfg} 212 rp.Start() 213 214 // addr should be rejected in addPeer based on the same ID 215 err := s1.DialPeerWithAddress(rp.Addr()) 216 if assert.Error(t, err) { 217 if err, ok := err.(ErrRejected); ok { 218 if !err.IsSelf() { 219 t.Errorf("expected self to be rejected") 220 } 221 } else { 222 t.Errorf("expected ErrRejected") 223 } 224 } 225 226 assert.True(t, s1.addrBook.OurAddress(rp.Addr())) 227 assert.False(t, s1.addrBook.HasAddress(rp.Addr())) 228 229 rp.Stop() 230 231 assertNoPeersAfterTimeout(t, s1, 100*time.Millisecond) 232 } 233 234 func TestSwitchPeerFilter(t *testing.T) { 235 var ( 236 filters = []PeerFilterFunc{ 237 func(_ IPeerSet, _ Peer) error { return nil }, 238 func(_ IPeerSet, _ Peer) error { return fmt.Errorf("denied") }, 239 func(_ IPeerSet, _ Peer) error { return nil }, 240 } 241 sw = MakeSwitch( 242 cfg, 243 1, 244 "testing", 245 "123.123.123", 246 initSwitchFunc, 247 SwitchPeerFilters(filters...), 248 ) 249 ) 250 err := sw.Start() 251 require.NoError(t, err) 252 t.Cleanup(func() { 253 if err := sw.Stop(); err != nil { 254 t.Error(err) 255 } 256 }) 257 258 // simulate remote peer 259 rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg} 260 rp.Start() 261 t.Cleanup(rp.Stop) 262 263 p, err := sw.transport.Dial(*rp.Addr(), peerConfig{ 264 chDescs: sw.chDescs, 265 onPeerError: sw.StopPeerForError, 266 isPersistent: sw.IsPeerPersistent, 267 reactorsByCh: sw.reactorsByCh, 268 }) 269 if err != nil { 270 t.Fatal(err) 271 } 272 273 err = sw.addPeer(p) 274 if err, ok := err.(ErrRejected); ok { 275 if !err.IsFiltered() { 276 t.Errorf("expected peer to be filtered") 277 } 278 } else { 279 t.Errorf("expected ErrRejected") 280 } 281 } 282 283 func TestSwitchPeerFilterTimeout(t *testing.T) { 284 var ( 285 filters = []PeerFilterFunc{ 286 func(_ IPeerSet, _ Peer) error { 287 time.Sleep(10 * time.Millisecond) 288 return nil 289 }, 290 } 291 sw = MakeSwitch( 292 cfg, 293 1, 294 "testing", 295 "123.123.123", 296 initSwitchFunc, 297 SwitchFilterTimeout(5*time.Millisecond), 298 SwitchPeerFilters(filters...), 299 ) 300 ) 301 err := sw.Start() 302 require.NoError(t, err) 303 t.Cleanup(func() { 304 if err := sw.Stop(); err != nil { 305 t.Log(err) 306 } 307 }) 308 309 // simulate remote peer 310 rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg} 311 rp.Start() 312 defer rp.Stop() 313 314 p, err := sw.transport.Dial(*rp.Addr(), peerConfig{ 315 chDescs: sw.chDescs, 316 onPeerError: sw.StopPeerForError, 317 isPersistent: sw.IsPeerPersistent, 318 reactorsByCh: sw.reactorsByCh, 319 }) 320 if err != nil { 321 t.Fatal(err) 322 } 323 324 err = sw.addPeer(p) 325 if _, ok := err.(ErrFilterTimeout); !ok { 326 t.Errorf("expected ErrFilterTimeout") 327 } 328 } 329 330 func TestSwitchPeerFilterDuplicate(t *testing.T) { 331 sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc) 332 err := sw.Start() 333 require.NoError(t, err) 334 t.Cleanup(func() { 335 if err := sw.Stop(); err != nil { 336 t.Error(err) 337 } 338 }) 339 340 // simulate remote peer 341 rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg} 342 rp.Start() 343 defer rp.Stop() 344 345 p, err := sw.transport.Dial(*rp.Addr(), peerConfig{ 346 chDescs: sw.chDescs, 347 onPeerError: sw.StopPeerForError, 348 isPersistent: sw.IsPeerPersistent, 349 reactorsByCh: sw.reactorsByCh, 350 }) 351 if err != nil { 352 t.Fatal(err) 353 } 354 355 if err := sw.addPeer(p); err != nil { 356 t.Fatal(err) 357 } 358 359 err = sw.addPeer(p) 360 if errRej, ok := err.(ErrRejected); ok { 361 if !errRej.IsDuplicate() { 362 t.Errorf("expected peer to be duplicate. got %v", errRej) 363 } 364 } else { 365 t.Errorf("expected ErrRejected, got %v", err) 366 } 367 } 368 369 func assertNoPeersAfterTimeout(t *testing.T, sw *Switch, timeout time.Duration) { 370 time.Sleep(timeout) 371 if sw.Peers().Size() != 0 { 372 t.Fatalf("Expected %v to not connect to some peers, got %d", sw, sw.Peers().Size()) 373 } 374 } 375 376 func TestSwitchStopsNonPersistentPeerOnError(t *testing.T) { 377 assert, require := assert.New(t), require.New(t) 378 379 sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc) 380 err := sw.Start() 381 if err != nil { 382 t.Error(err) 383 } 384 t.Cleanup(func() { 385 if err := sw.Stop(); err != nil { 386 t.Error(err) 387 } 388 }) 389 390 // simulate remote peer 391 rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg} 392 rp.Start() 393 defer rp.Stop() 394 395 p, err := sw.transport.Dial(*rp.Addr(), peerConfig{ 396 chDescs: sw.chDescs, 397 onPeerError: sw.StopPeerForError, 398 isPersistent: sw.IsPeerPersistent, 399 reactorsByCh: sw.reactorsByCh, 400 }) 401 require.Nil(err) 402 403 err = sw.addPeer(p) 404 require.Nil(err) 405 406 require.NotNil(sw.Peers().Get(rp.ID())) 407 408 // simulate failure by closing connection 409 err = p.(*peer).CloseConn() 410 require.NoError(err) 411 412 assertNoPeersAfterTimeout(t, sw, 100*time.Millisecond) 413 assert.False(p.IsRunning()) 414 } 415 416 func TestSwitchStopPeerForError(t *testing.T) { 417 s := httptest.NewServer(promhttp.Handler()) 418 defer s.Close() 419 420 scrapeMetrics := func() string { 421 resp, err := http.Get(s.URL) 422 require.NoError(t, err) 423 defer resp.Body.Close() 424 buf, _ := io.ReadAll(resp.Body) 425 return string(buf) 426 } 427 428 namespace, subsystem, name := config.TestInstrumentationConfig().Namespace, MetricsSubsystem, "peers" 429 re := regexp.MustCompile(namespace + `_` + subsystem + `_` + name + ` ([0-9\.]+)`) 430 peersMetricValue := func() float64 { 431 matches := re.FindStringSubmatch(scrapeMetrics()) 432 f, _ := strconv.ParseFloat(matches[1], 64) 433 return f 434 } 435 436 p2pMetrics := PrometheusMetrics(namespace) 437 438 // make two connected switches 439 sw1, sw2 := MakeSwitchPair(t, func(i int, sw *Switch) *Switch { 440 // set metrics on sw1 441 if i == 0 { 442 opt := WithMetrics(p2pMetrics) 443 opt(sw) 444 } 445 return initSwitchFunc(i, sw) 446 }) 447 448 assert.Equal(t, len(sw1.Peers().List()), 1) 449 assert.EqualValues(t, 1, peersMetricValue()) 450 451 // send messages to the peer from sw1 452 p := sw1.Peers().List()[0] 453 p.SendEnvelope(Envelope{ 454 ChannelID: 0x1, 455 Message: &p2pproto.Message{}, 456 }) 457 458 // stop sw2. this should cause the p to fail, 459 // which results in calling StopPeerForError internally 460 t.Cleanup(func() { 461 if err := sw2.Stop(); err != nil { 462 t.Error(err) 463 } 464 }) 465 466 // now call StopPeerForError explicitly, eg. from a reactor 467 sw1.StopPeerForError(p, fmt.Errorf("some err")) 468 469 assert.Equal(t, len(sw1.Peers().List()), 0) 470 assert.EqualValues(t, 0, peersMetricValue()) 471 } 472 473 func TestSwitchReconnectsToOutboundPersistentPeer(t *testing.T) { 474 sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc) 475 err := sw.Start() 476 require.NoError(t, err) 477 t.Cleanup(func() { 478 if err := sw.Stop(); err != nil { 479 t.Error(err) 480 } 481 }) 482 483 // 1. simulate failure by closing connection 484 rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg} 485 rp.Start() 486 defer rp.Stop() 487 488 err = sw.AddPersistentPeers([]string{rp.Addr().String()}) 489 require.NoError(t, err) 490 491 err = sw.DialPeerWithAddress(rp.Addr()) 492 require.Nil(t, err) 493 require.NotNil(t, sw.Peers().Get(rp.ID())) 494 495 p := sw.Peers().List()[0] 496 err = p.(*peer).CloseConn() 497 require.NoError(t, err) 498 499 waitUntilSwitchHasAtLeastNPeers(sw, 1) 500 assert.False(t, p.IsRunning()) // old peer instance 501 assert.Equal(t, 1, sw.Peers().Size()) // new peer instance 502 503 // 2. simulate first time dial failure 504 rp = &remotePeer{ 505 PrivKey: ed25519.GenPrivKey(), 506 Config: cfg, 507 // Use different interface to prevent duplicate IP filter, this will break 508 // beyond two peers. 509 listenAddr: "127.0.0.1:0", 510 } 511 rp.Start() 512 defer rp.Stop() 513 514 conf := config.DefaultP2PConfig() 515 conf.TestDialFail = true // will trigger a reconnect 516 err = sw.addOutboundPeerWithConfig(rp.Addr(), conf) 517 require.NotNil(t, err) 518 // DialPeerWithAddres - sw.peerConfig resets the dialer 519 waitUntilSwitchHasAtLeastNPeers(sw, 2) 520 assert.Equal(t, 2, sw.Peers().Size()) 521 } 522 523 func TestSwitchReconnectsToInboundPersistentPeer(t *testing.T) { 524 sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc) 525 err := sw.Start() 526 require.NoError(t, err) 527 t.Cleanup(func() { 528 if err := sw.Stop(); err != nil { 529 t.Error(err) 530 } 531 }) 532 533 // 1. simulate failure by closing the connection 534 rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg} 535 rp.Start() 536 defer rp.Stop() 537 538 err = sw.AddPersistentPeers([]string{rp.Addr().String()}) 539 require.NoError(t, err) 540 541 conn, err := rp.Dial(sw.NetAddress()) 542 require.NoError(t, err) 543 time.Sleep(50 * time.Millisecond) 544 require.NotNil(t, sw.Peers().Get(rp.ID())) 545 546 conn.Close() 547 548 waitUntilSwitchHasAtLeastNPeers(sw, 1) 549 assert.Equal(t, 1, sw.Peers().Size()) 550 } 551 552 func TestSwitchDialPeersAsync(t *testing.T) { 553 if testing.Short() { 554 return 555 } 556 557 sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc) 558 err := sw.Start() 559 require.NoError(t, err) 560 t.Cleanup(func() { 561 if err := sw.Stop(); err != nil { 562 t.Error(err) 563 } 564 }) 565 566 rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg} 567 rp.Start() 568 defer rp.Stop() 569 570 err = sw.DialPeersAsync([]string{rp.Addr().String()}) 571 require.NoError(t, err) 572 time.Sleep(dialRandomizerIntervalMilliseconds * time.Millisecond) 573 require.NotNil(t, sw.Peers().Get(rp.ID())) 574 } 575 576 func waitUntilSwitchHasAtLeastNPeers(sw *Switch, n int) { 577 for i := 0; i < 20; i++ { 578 time.Sleep(250 * time.Millisecond) 579 has := sw.Peers().Size() 580 if has >= n { 581 break 582 } 583 } 584 } 585 586 func TestSwitchFullConnectivity(t *testing.T) { 587 switches := MakeConnectedSwitches(cfg, 3, initSwitchFunc, Connect2Switches) 588 defer func() { 589 for _, sw := range switches { 590 sw := sw 591 t.Cleanup(func() { 592 if err := sw.Stop(); err != nil { 593 t.Error(err) 594 } 595 }) 596 } 597 }() 598 599 for i, sw := range switches { 600 if sw.Peers().Size() != 2 { 601 t.Fatalf("Expected each switch to be connected to 2 other, but %d switch only connected to %d", sw.Peers().Size(), i) 602 } 603 } 604 } 605 606 func TestSwitchAcceptRoutine(t *testing.T) { 607 cfg.MaxNumInboundPeers = 5 608 609 // Create some unconditional peers. 610 const unconditionalPeersNum = 2 611 var ( 612 unconditionalPeers = make([]*remotePeer, unconditionalPeersNum) 613 unconditionalPeerIDs = make([]string, unconditionalPeersNum) 614 ) 615 for i := 0; i < unconditionalPeersNum; i++ { 616 peer := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg} 617 peer.Start() 618 unconditionalPeers[i] = peer 619 unconditionalPeerIDs[i] = string(peer.ID()) 620 } 621 622 // make switch 623 sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc) 624 err := sw.AddUnconditionalPeerIDs(unconditionalPeerIDs) 625 require.NoError(t, err) 626 err = sw.Start() 627 require.NoError(t, err) 628 t.Cleanup(func() { 629 err := sw.Stop() 630 require.NoError(t, err) 631 }) 632 633 // 0. check there are no peers 634 assert.Equal(t, 0, sw.Peers().Size()) 635 636 // 1. check we connect up to MaxNumInboundPeers 637 peers := make([]*remotePeer, 0) 638 for i := 0; i < cfg.MaxNumInboundPeers; i++ { 639 peer := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg} 640 peers = append(peers, peer) 641 peer.Start() 642 c, err := peer.Dial(sw.NetAddress()) 643 require.NoError(t, err) 644 // spawn a reading routine to prevent connection from closing 645 go func(c net.Conn) { 646 for { 647 one := make([]byte, 1) 648 _, err := c.Read(one) 649 if err != nil { 650 return 651 } 652 } 653 }(c) 654 } 655 time.Sleep(100 * time.Millisecond) 656 assert.Equal(t, cfg.MaxNumInboundPeers, sw.Peers().Size()) 657 658 // 2. check we close new connections if we already have MaxNumInboundPeers peers 659 peer := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg} 660 peer.Start() 661 conn, err := peer.Dial(sw.NetAddress()) 662 require.NoError(t, err) 663 // check conn is closed 664 one := make([]byte, 1) 665 _ = conn.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) 666 _, err = conn.Read(one) 667 assert.Error(t, err) 668 assert.Equal(t, cfg.MaxNumInboundPeers, sw.Peers().Size()) 669 peer.Stop() 670 671 // 3. check we connect to unconditional peers despite the limit. 672 for _, peer := range unconditionalPeers { 673 c, err := peer.Dial(sw.NetAddress()) 674 require.NoError(t, err) 675 // spawn a reading routine to prevent connection from closing 676 go func(c net.Conn) { 677 for { 678 one := make([]byte, 1) 679 _, err := c.Read(one) 680 if err != nil { 681 return 682 } 683 } 684 }(c) 685 } 686 time.Sleep(10 * time.Millisecond) 687 assert.Equal(t, cfg.MaxNumInboundPeers+unconditionalPeersNum, sw.Peers().Size()) 688 689 for _, peer := range peers { 690 peer.Stop() 691 } 692 for _, peer := range unconditionalPeers { 693 peer.Stop() 694 } 695 } 696 697 type errorTransport struct { 698 acceptErr error 699 } 700 701 func (et errorTransport) NetAddress() NetAddress { 702 panic("not implemented") 703 } 704 705 func (et errorTransport) Accept(c peerConfig) (Peer, error) { 706 return nil, et.acceptErr 707 } 708 func (errorTransport) Dial(NetAddress, peerConfig) (Peer, error) { 709 panic("not implemented") 710 } 711 func (errorTransport) Cleanup(Peer) { 712 panic("not implemented") 713 } 714 715 func TestSwitchAcceptRoutineErrorCases(t *testing.T) { 716 sw := NewSwitch(cfg, errorTransport{ErrFilterTimeout{}}) 717 assert.NotPanics(t, func() { 718 err := sw.Start() 719 require.NoError(t, err) 720 err = sw.Stop() 721 require.NoError(t, err) 722 }) 723 724 sw = NewSwitch(cfg, errorTransport{ErrRejected{conn: nil, err: errors.New("filtered"), isFiltered: true}}) 725 assert.NotPanics(t, func() { 726 err := sw.Start() 727 require.NoError(t, err) 728 err = sw.Stop() 729 require.NoError(t, err) 730 }) 731 // TODO(melekes) check we remove our address from addrBook 732 733 sw = NewSwitch(cfg, errorTransport{ErrTransportClosed{}}) 734 assert.NotPanics(t, func() { 735 err := sw.Start() 736 require.NoError(t, err) 737 err = sw.Stop() 738 require.NoError(t, err) 739 }) 740 } 741 742 // mockReactor checks that InitPeer never called before RemovePeer. If that's 743 // not true, InitCalledBeforeRemoveFinished will return true. 744 type mockReactor struct { 745 *BaseReactor 746 747 // atomic 748 removePeerInProgress uint32 749 initCalledBeforeRemoveFinished uint32 750 } 751 752 func (r *mockReactor) RemovePeer(peer Peer, reason interface{}) { 753 atomic.StoreUint32(&r.removePeerInProgress, 1) 754 defer atomic.StoreUint32(&r.removePeerInProgress, 0) 755 time.Sleep(100 * time.Millisecond) 756 } 757 758 func (r *mockReactor) InitPeer(peer Peer) Peer { 759 if atomic.LoadUint32(&r.removePeerInProgress) == 1 { 760 atomic.StoreUint32(&r.initCalledBeforeRemoveFinished, 1) 761 } 762 763 return peer 764 } 765 766 func (r *mockReactor) InitCalledBeforeRemoveFinished() bool { 767 return atomic.LoadUint32(&r.initCalledBeforeRemoveFinished) == 1 768 } 769 770 // see stopAndRemovePeer 771 func TestSwitchInitPeerIsNotCalledBeforeRemovePeer(t *testing.T) { 772 // make reactor 773 reactor := &mockReactor{} 774 reactor.BaseReactor = NewBaseReactor("mockReactor", reactor) 775 776 // make switch 777 sw := MakeSwitch(cfg, 1, "testing", "123.123.123", func(i int, sw *Switch) *Switch { 778 sw.AddReactor("mock", reactor) 779 return sw 780 }) 781 err := sw.Start() 782 require.NoError(t, err) 783 t.Cleanup(func() { 784 if err := sw.Stop(); err != nil { 785 t.Error(err) 786 } 787 }) 788 789 // add peer 790 rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg} 791 rp.Start() 792 defer rp.Stop() 793 _, err = rp.Dial(sw.NetAddress()) 794 require.NoError(t, err) 795 796 // wait till the switch adds rp to the peer set, then stop the peer asynchronously 797 for { 798 time.Sleep(20 * time.Millisecond) 799 if peer := sw.Peers().Get(rp.ID()); peer != nil { 800 go sw.StopPeerForError(peer, "test") 801 break 802 } 803 } 804 805 // simulate peer reconnecting to us 806 _, err = rp.Dial(sw.NetAddress()) 807 require.NoError(t, err) 808 // wait till the switch adds rp to the peer set 809 time.Sleep(50 * time.Millisecond) 810 811 // make sure reactor.RemovePeer is finished before InitPeer is called 812 assert.False(t, reactor.InitCalledBeforeRemoveFinished()) 813 } 814 815 func BenchmarkSwitchBroadcast(b *testing.B) { 816 s1, s2 := MakeSwitchPair(b, func(i int, sw *Switch) *Switch { 817 // Make bar reactors of bar channels each 818 sw.AddReactor("foo", NewTestReactor([]*conn.ChannelDescriptor{ 819 {ID: byte(0x00), Priority: 10}, 820 {ID: byte(0x01), Priority: 10}, 821 }, false)) 822 sw.AddReactor("bar", NewTestReactor([]*conn.ChannelDescriptor{ 823 {ID: byte(0x02), Priority: 10}, 824 {ID: byte(0x03), Priority: 10}, 825 }, false)) 826 return sw 827 }) 828 829 b.Cleanup(func() { 830 if err := s1.Stop(); err != nil { 831 b.Error(err) 832 } 833 }) 834 835 b.Cleanup(func() { 836 if err := s2.Stop(); err != nil { 837 b.Error(err) 838 } 839 }) 840 841 // Allow time for goroutines to boot up 842 time.Sleep(1 * time.Second) 843 844 b.ResetTimer() 845 846 numSuccess, numFailure := 0, 0 847 848 // Send random message from foo channel to another 849 for i := 0; i < b.N; i++ { 850 chID := byte(i % 4) 851 successChan := s1.BroadcastEnvelope(Envelope{ChannelID: chID}) 852 for s := range successChan { 853 if s { 854 numSuccess++ 855 } else { 856 numFailure++ 857 } 858 } 859 } 860 861 b.Logf("success: %v, failure: %v", numSuccess, numFailure) 862 } 863 864 func TestSwitchRemovalErr(t *testing.T) { 865 866 sw1, sw2 := MakeSwitchPair(t, func(i int, sw *Switch) *Switch { 867 return initSwitchFunc(i, sw) 868 }) 869 assert.Equal(t, len(sw1.Peers().List()), 1) 870 p := sw1.Peers().List()[0] 871 872 sw2.StopPeerForError(p, fmt.Errorf("peer should error")) 873 874 assert.Equal(t, sw2.peers.Add(p).Error(), ErrPeerRemoval{}.Error()) 875 }