github.com/consideritdone/landslidecore@v0.0.0-20230718131026-a8b21c5cf8a7/p2p/switch_test.go (about) 1 package p2p 2 3 import ( 4 "bytes" 5 "errors" 6 "fmt" 7 "io/ioutil" 8 "net" 9 "net/http" 10 "net/http/httptest" 11 "regexp" 12 "strconv" 13 "sync/atomic" 14 "testing" 15 "time" 16 17 "github.com/prometheus/client_golang/prometheus/promhttp" 18 "github.com/stretchr/testify/assert" 19 "github.com/stretchr/testify/require" 20 21 "github.com/consideritdone/landslidecore/config" 22 "github.com/consideritdone/landslidecore/crypto/ed25519" 23 "github.com/consideritdone/landslidecore/libs/log" 24 tmsync "github.com/consideritdone/landslidecore/libs/sync" 25 "github.com/consideritdone/landslidecore/p2p/conn" 26 ) 27 28 var ( 29 cfg *config.P2PConfig 30 ) 31 32 func init() { 33 cfg = config.DefaultP2PConfig() 34 cfg.PexReactor = true 35 cfg.AllowDuplicateIP = true 36 } 37 38 type PeerMessage struct { 39 PeerID ID 40 Bytes []byte 41 Counter int 42 } 43 44 type TestReactor struct { 45 BaseReactor 46 47 mtx tmsync.Mutex 48 channels []*conn.ChannelDescriptor 49 logMessages bool 50 msgsCounter int 51 msgsReceived map[byte][]PeerMessage 52 } 53 54 func NewTestReactor(channels []*conn.ChannelDescriptor, logMessages bool) *TestReactor { 55 tr := &TestReactor{ 56 channels: channels, 57 logMessages: logMessages, 58 msgsReceived: make(map[byte][]PeerMessage), 59 } 60 tr.BaseReactor = *NewBaseReactor("TestReactor", tr) 61 tr.SetLogger(log.TestingLogger()) 62 return tr 63 } 64 65 func (tr *TestReactor) GetChannels() []*conn.ChannelDescriptor { 66 return tr.channels 67 } 68 69 func (tr *TestReactor) AddPeer(peer Peer) {} 70 71 func (tr *TestReactor) RemovePeer(peer Peer, reason interface{}) {} 72 73 func (tr *TestReactor) Receive(chID byte, peer Peer, msgBytes []byte) { 74 if tr.logMessages { 75 tr.mtx.Lock() 76 defer tr.mtx.Unlock() 77 // fmt.Printf("Received: %X, %X\n", chID, msgBytes) 78 tr.msgsReceived[chID] = append(tr.msgsReceived[chID], PeerMessage{peer.ID(), msgBytes, tr.msgsCounter}) 79 tr.msgsCounter++ 80 } 81 } 82 83 func (tr *TestReactor) getMsgs(chID byte) []PeerMessage { 84 tr.mtx.Lock() 85 defer tr.mtx.Unlock() 86 return tr.msgsReceived[chID] 87 } 88 89 //----------------------------------------------------------------------------- 90 91 // convenience method for creating two switches connected to each other. 92 // XXX: note this uses net.Pipe and not a proper TCP conn 93 func MakeSwitchPair(t testing.TB, initSwitch func(int, *Switch) *Switch) (*Switch, *Switch) { 94 // Create two switches that will be interconnected. 95 switches := MakeConnectedSwitches(cfg, 2, initSwitch, Connect2Switches) 96 return switches[0], switches[1] 97 } 98 99 func initSwitchFunc(i int, sw *Switch) *Switch { 100 sw.SetAddrBook(&AddrBookMock{ 101 Addrs: make(map[string]struct{}), 102 OurAddrs: make(map[string]struct{})}) 103 104 // Make two reactors of two channels each 105 sw.AddReactor("foo", NewTestReactor([]*conn.ChannelDescriptor{ 106 {ID: byte(0x00), Priority: 10}, 107 {ID: byte(0x01), Priority: 10}, 108 }, true)) 109 sw.AddReactor("bar", NewTestReactor([]*conn.ChannelDescriptor{ 110 {ID: byte(0x02), Priority: 10}, 111 {ID: byte(0x03), Priority: 10}, 112 }, true)) 113 114 return sw 115 } 116 117 func TestSwitches(t *testing.T) { 118 s1, s2 := MakeSwitchPair(t, initSwitchFunc) 119 t.Cleanup(func() { 120 if err := s1.Stop(); err != nil { 121 t.Error(err) 122 } 123 }) 124 t.Cleanup(func() { 125 if err := s2.Stop(); err != nil { 126 t.Error(err) 127 } 128 }) 129 130 if s1.Peers().Size() != 1 { 131 t.Errorf("expected exactly 1 peer in s1, got %v", s1.Peers().Size()) 132 } 133 if s2.Peers().Size() != 1 { 134 t.Errorf("expected exactly 1 peer in s2, got %v", s2.Peers().Size()) 135 } 136 137 // Lets send some messages 138 ch0Msg := []byte("channel zero") 139 ch1Msg := []byte("channel foo") 140 ch2Msg := []byte("channel bar") 141 142 s1.Broadcast(byte(0x00), ch0Msg) 143 s1.Broadcast(byte(0x01), ch1Msg) 144 s1.Broadcast(byte(0x02), ch2Msg) 145 146 assertMsgReceivedWithTimeout(t, 147 ch0Msg, 148 byte(0x00), 149 s2.Reactor("foo").(*TestReactor), 10*time.Millisecond, 5*time.Second) 150 assertMsgReceivedWithTimeout(t, 151 ch1Msg, 152 byte(0x01), 153 s2.Reactor("foo").(*TestReactor), 10*time.Millisecond, 5*time.Second) 154 assertMsgReceivedWithTimeout(t, 155 ch2Msg, 156 byte(0x02), 157 s2.Reactor("bar").(*TestReactor), 10*time.Millisecond, 5*time.Second) 158 } 159 160 func assertMsgReceivedWithTimeout( 161 t *testing.T, 162 msgBytes []byte, 163 channel byte, 164 reactor *TestReactor, 165 checkPeriod, 166 timeout time.Duration, 167 ) { 168 ticker := time.NewTicker(checkPeriod) 169 for { 170 select { 171 case <-ticker.C: 172 msgs := reactor.getMsgs(channel) 173 if len(msgs) > 0 { 174 if !bytes.Equal(msgs[0].Bytes, msgBytes) { 175 t.Fatalf("Unexpected message bytes. Wanted: %X, Got: %X", msgBytes, msgs[0].Bytes) 176 } 177 return 178 } 179 180 case <-time.After(timeout): 181 t.Fatalf("Expected to have received 1 message in channel #%v, got zero", channel) 182 } 183 } 184 } 185 186 func TestSwitchFiltersOutItself(t *testing.T) { 187 s1 := MakeSwitch(cfg, 1, "127.0.0.1", "123.123.123", initSwitchFunc) 188 189 // simulate s1 having a public IP by creating a remote peer with the same ID 190 rp := &remotePeer{PrivKey: s1.nodeKey.PrivKey, Config: cfg} 191 rp.Start() 192 193 // addr should be rejected in addPeer based on the same ID 194 err := s1.DialPeerWithAddress(rp.Addr()) 195 if assert.Error(t, err) { 196 if err, ok := err.(ErrRejected); ok { 197 if !err.IsSelf() { 198 t.Errorf("expected self to be rejected") 199 } 200 } else { 201 t.Errorf("expected ErrRejected") 202 } 203 } 204 205 assert.True(t, s1.addrBook.OurAddress(rp.Addr())) 206 assert.False(t, s1.addrBook.HasAddress(rp.Addr())) 207 208 rp.Stop() 209 210 assertNoPeersAfterTimeout(t, s1, 100*time.Millisecond) 211 } 212 213 func TestSwitchPeerFilter(t *testing.T) { 214 var ( 215 filters = []PeerFilterFunc{ 216 func(_ IPeerSet, _ Peer) error { return nil }, 217 func(_ IPeerSet, _ Peer) error { return fmt.Errorf("denied") }, 218 func(_ IPeerSet, _ Peer) error { return nil }, 219 } 220 sw = MakeSwitch( 221 cfg, 222 1, 223 "testing", 224 "123.123.123", 225 initSwitchFunc, 226 SwitchPeerFilters(filters...), 227 ) 228 ) 229 err := sw.Start() 230 require.NoError(t, err) 231 t.Cleanup(func() { 232 if err := sw.Stop(); err != nil { 233 t.Error(err) 234 } 235 }) 236 237 // simulate remote peer 238 rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg} 239 rp.Start() 240 t.Cleanup(rp.Stop) 241 242 p, err := sw.transport.Dial(*rp.Addr(), peerConfig{ 243 chDescs: sw.chDescs, 244 onPeerError: sw.StopPeerForError, 245 isPersistent: sw.IsPeerPersistent, 246 reactorsByCh: sw.reactorsByCh, 247 }) 248 if err != nil { 249 t.Fatal(err) 250 } 251 252 err = sw.addPeer(p) 253 if err, ok := err.(ErrRejected); ok { 254 if !err.IsFiltered() { 255 t.Errorf("expected peer to be filtered") 256 } 257 } else { 258 t.Errorf("expected ErrRejected") 259 } 260 } 261 262 func TestSwitchPeerFilterTimeout(t *testing.T) { 263 var ( 264 filters = []PeerFilterFunc{ 265 func(_ IPeerSet, _ Peer) error { 266 time.Sleep(10 * time.Millisecond) 267 return nil 268 }, 269 } 270 sw = MakeSwitch( 271 cfg, 272 1, 273 "testing", 274 "123.123.123", 275 initSwitchFunc, 276 SwitchFilterTimeout(5*time.Millisecond), 277 SwitchPeerFilters(filters...), 278 ) 279 ) 280 err := sw.Start() 281 require.NoError(t, err) 282 t.Cleanup(func() { 283 if err := sw.Stop(); err != nil { 284 t.Log(err) 285 } 286 }) 287 288 // simulate remote peer 289 rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg} 290 rp.Start() 291 defer rp.Stop() 292 293 p, err := sw.transport.Dial(*rp.Addr(), peerConfig{ 294 chDescs: sw.chDescs, 295 onPeerError: sw.StopPeerForError, 296 isPersistent: sw.IsPeerPersistent, 297 reactorsByCh: sw.reactorsByCh, 298 }) 299 if err != nil { 300 t.Fatal(err) 301 } 302 303 err = sw.addPeer(p) 304 if _, ok := err.(ErrFilterTimeout); !ok { 305 t.Errorf("expected ErrFilterTimeout") 306 } 307 } 308 309 func TestSwitchPeerFilterDuplicate(t *testing.T) { 310 sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc) 311 err := sw.Start() 312 require.NoError(t, err) 313 t.Cleanup(func() { 314 if err := sw.Stop(); err != nil { 315 t.Error(err) 316 } 317 }) 318 319 // simulate remote peer 320 rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg} 321 rp.Start() 322 defer rp.Stop() 323 324 p, err := sw.transport.Dial(*rp.Addr(), peerConfig{ 325 chDescs: sw.chDescs, 326 onPeerError: sw.StopPeerForError, 327 isPersistent: sw.IsPeerPersistent, 328 reactorsByCh: sw.reactorsByCh, 329 }) 330 if err != nil { 331 t.Fatal(err) 332 } 333 334 if err := sw.addPeer(p); err != nil { 335 t.Fatal(err) 336 } 337 338 err = sw.addPeer(p) 339 if errRej, ok := err.(ErrRejected); ok { 340 if !errRej.IsDuplicate() { 341 t.Errorf("expected peer to be duplicate. got %v", errRej) 342 } 343 } else { 344 t.Errorf("expected ErrRejected, got %v", err) 345 } 346 } 347 348 func assertNoPeersAfterTimeout(t *testing.T, sw *Switch, timeout time.Duration) { 349 time.Sleep(timeout) 350 if sw.Peers().Size() != 0 { 351 t.Fatalf("Expected %v to not connect to some peers, got %d", sw, sw.Peers().Size()) 352 } 353 } 354 355 func TestSwitchStopsNonPersistentPeerOnError(t *testing.T) { 356 assert, require := assert.New(t), require.New(t) 357 358 sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc) 359 err := sw.Start() 360 if err != nil { 361 t.Error(err) 362 } 363 t.Cleanup(func() { 364 if err := sw.Stop(); err != nil { 365 t.Error(err) 366 } 367 }) 368 369 // simulate remote peer 370 rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg} 371 rp.Start() 372 defer rp.Stop() 373 374 p, err := sw.transport.Dial(*rp.Addr(), peerConfig{ 375 chDescs: sw.chDescs, 376 onPeerError: sw.StopPeerForError, 377 isPersistent: sw.IsPeerPersistent, 378 reactorsByCh: sw.reactorsByCh, 379 }) 380 require.Nil(err) 381 382 err = sw.addPeer(p) 383 require.Nil(err) 384 385 require.NotNil(sw.Peers().Get(rp.ID())) 386 387 // simulate failure by closing connection 388 err = p.(*peer).CloseConn() 389 require.NoError(err) 390 391 assertNoPeersAfterTimeout(t, sw, 100*time.Millisecond) 392 assert.False(p.IsRunning()) 393 } 394 395 func TestSwitchStopPeerForError(t *testing.T) { 396 s := httptest.NewServer(promhttp.Handler()) 397 defer s.Close() 398 399 scrapeMetrics := func() string { 400 resp, err := http.Get(s.URL) 401 require.NoError(t, err) 402 defer resp.Body.Close() 403 buf, _ := ioutil.ReadAll(resp.Body) 404 return string(buf) 405 } 406 407 namespace, subsystem, name := config.TestInstrumentationConfig().Namespace, MetricsSubsystem, "peers" 408 re := regexp.MustCompile(namespace + `_` + subsystem + `_` + name + ` ([0-9\.]+)`) 409 peersMetricValue := func() float64 { 410 matches := re.FindStringSubmatch(scrapeMetrics()) 411 f, _ := strconv.ParseFloat(matches[1], 64) 412 return f 413 } 414 415 p2pMetrics := PrometheusMetrics(namespace) 416 417 // make two connected switches 418 sw1, sw2 := MakeSwitchPair(t, func(i int, sw *Switch) *Switch { 419 // set metrics on sw1 420 if i == 0 { 421 opt := WithMetrics(p2pMetrics) 422 opt(sw) 423 } 424 return initSwitchFunc(i, sw) 425 }) 426 427 assert.Equal(t, len(sw1.Peers().List()), 1) 428 assert.EqualValues(t, 1, peersMetricValue()) 429 430 // send messages to the peer from sw1 431 p := sw1.Peers().List()[0] 432 p.Send(0x1, []byte("here's a message to send")) 433 434 // stop sw2. this should cause the p to fail, 435 // which results in calling StopPeerForError internally 436 t.Cleanup(func() { 437 if err := sw2.Stop(); err != nil { 438 t.Error(err) 439 } 440 }) 441 442 // now call StopPeerForError explicitly, eg. from a reactor 443 sw1.StopPeerForError(p, fmt.Errorf("some err")) 444 445 assert.Equal(t, len(sw1.Peers().List()), 0) 446 assert.EqualValues(t, 0, peersMetricValue()) 447 } 448 449 func TestSwitchReconnectsToOutboundPersistentPeer(t *testing.T) { 450 sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc) 451 err := sw.Start() 452 require.NoError(t, err) 453 t.Cleanup(func() { 454 if err := sw.Stop(); err != nil { 455 t.Error(err) 456 } 457 }) 458 459 // 1. simulate failure by closing connection 460 rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg} 461 rp.Start() 462 defer rp.Stop() 463 464 err = sw.AddPersistentPeers([]string{rp.Addr().String()}) 465 require.NoError(t, err) 466 467 err = sw.DialPeerWithAddress(rp.Addr()) 468 require.Nil(t, err) 469 require.NotNil(t, sw.Peers().Get(rp.ID())) 470 471 p := sw.Peers().List()[0] 472 err = p.(*peer).CloseConn() 473 require.NoError(t, err) 474 475 waitUntilSwitchHasAtLeastNPeers(sw, 1) 476 assert.False(t, p.IsRunning()) // old peer instance 477 assert.Equal(t, 1, sw.Peers().Size()) // new peer instance 478 479 // 2. simulate first time dial failure 480 rp = &remotePeer{ 481 PrivKey: ed25519.GenPrivKey(), 482 Config: cfg, 483 // Use different interface to prevent duplicate IP filter, this will break 484 // beyond two peers. 485 listenAddr: "127.0.0.1:0", 486 } 487 rp.Start() 488 defer rp.Stop() 489 490 conf := config.DefaultP2PConfig() 491 conf.TestDialFail = true // will trigger a reconnect 492 err = sw.addOutboundPeerWithConfig(rp.Addr(), conf) 493 require.NotNil(t, err) 494 // DialPeerWithAddres - sw.peerConfig resets the dialer 495 waitUntilSwitchHasAtLeastNPeers(sw, 2) 496 assert.Equal(t, 2, sw.Peers().Size()) 497 } 498 499 func TestSwitchReconnectsToInboundPersistentPeer(t *testing.T) { 500 sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc) 501 err := sw.Start() 502 require.NoError(t, err) 503 t.Cleanup(func() { 504 if err := sw.Stop(); err != nil { 505 t.Error(err) 506 } 507 }) 508 509 // 1. simulate failure by closing the connection 510 rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg} 511 rp.Start() 512 defer rp.Stop() 513 514 err = sw.AddPersistentPeers([]string{rp.Addr().String()}) 515 require.NoError(t, err) 516 517 conn, err := rp.Dial(sw.NetAddress()) 518 require.NoError(t, err) 519 time.Sleep(50 * time.Millisecond) 520 require.NotNil(t, sw.Peers().Get(rp.ID())) 521 522 conn.Close() 523 524 waitUntilSwitchHasAtLeastNPeers(sw, 1) 525 assert.Equal(t, 1, sw.Peers().Size()) 526 } 527 528 func TestSwitchDialPeersAsync(t *testing.T) { 529 if testing.Short() { 530 return 531 } 532 533 sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc) 534 err := sw.Start() 535 require.NoError(t, err) 536 t.Cleanup(func() { 537 if err := sw.Stop(); err != nil { 538 t.Error(err) 539 } 540 }) 541 542 rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg} 543 rp.Start() 544 defer rp.Stop() 545 546 err = sw.DialPeersAsync([]string{rp.Addr().String()}) 547 require.NoError(t, err) 548 time.Sleep(dialRandomizerIntervalMilliseconds * time.Millisecond) 549 require.NotNil(t, sw.Peers().Get(rp.ID())) 550 } 551 552 func waitUntilSwitchHasAtLeastNPeers(sw *Switch, n int) { 553 for i := 0; i < 20; i++ { 554 time.Sleep(250 * time.Millisecond) 555 has := sw.Peers().Size() 556 if has >= n { 557 break 558 } 559 } 560 } 561 562 func TestSwitchFullConnectivity(t *testing.T) { 563 switches := MakeConnectedSwitches(cfg, 3, initSwitchFunc, Connect2Switches) 564 defer func() { 565 for _, sw := range switches { 566 sw := sw 567 t.Cleanup(func() { 568 if err := sw.Stop(); err != nil { 569 t.Error(err) 570 } 571 }) 572 } 573 }() 574 575 for i, sw := range switches { 576 if sw.Peers().Size() != 2 { 577 t.Fatalf("Expected each switch to be connected to 2 other, but %d switch only connected to %d", sw.Peers().Size(), i) 578 } 579 } 580 } 581 582 func TestSwitchAcceptRoutine(t *testing.T) { 583 cfg.MaxNumInboundPeers = 5 584 585 // Create some unconditional peers. 586 const unconditionalPeersNum = 2 587 var ( 588 unconditionalPeers = make([]*remotePeer, unconditionalPeersNum) 589 unconditionalPeerIDs = make([]string, unconditionalPeersNum) 590 ) 591 for i := 0; i < unconditionalPeersNum; i++ { 592 peer := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg} 593 peer.Start() 594 unconditionalPeers[i] = peer 595 unconditionalPeerIDs[i] = string(peer.ID()) 596 } 597 598 // make switch 599 sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc) 600 err := sw.AddUnconditionalPeerIDs(unconditionalPeerIDs) 601 require.NoError(t, err) 602 err = sw.Start() 603 require.NoError(t, err) 604 t.Cleanup(func() { 605 err := sw.Stop() 606 require.NoError(t, err) 607 }) 608 609 // 0. check there are no peers 610 assert.Equal(t, 0, sw.Peers().Size()) 611 612 // 1. check we connect up to MaxNumInboundPeers 613 peers := make([]*remotePeer, 0) 614 for i := 0; i < cfg.MaxNumInboundPeers; i++ { 615 peer := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg} 616 peers = append(peers, peer) 617 peer.Start() 618 c, err := peer.Dial(sw.NetAddress()) 619 require.NoError(t, err) 620 // spawn a reading routine to prevent connection from closing 621 go func(c net.Conn) { 622 for { 623 one := make([]byte, 1) 624 _, err := c.Read(one) 625 if err != nil { 626 return 627 } 628 } 629 }(c) 630 } 631 time.Sleep(100 * time.Millisecond) 632 assert.Equal(t, cfg.MaxNumInboundPeers, sw.Peers().Size()) 633 634 // 2. check we close new connections if we already have MaxNumInboundPeers peers 635 peer := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg} 636 peer.Start() 637 conn, err := peer.Dial(sw.NetAddress()) 638 require.NoError(t, err) 639 // check conn is closed 640 one := make([]byte, 1) 641 _ = conn.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) 642 _, err = conn.Read(one) 643 assert.Error(t, err) 644 assert.Equal(t, cfg.MaxNumInboundPeers, sw.Peers().Size()) 645 peer.Stop() 646 647 // 3. check we connect to unconditional peers despite the limit. 648 for _, peer := range unconditionalPeers { 649 c, err := peer.Dial(sw.NetAddress()) 650 require.NoError(t, err) 651 // spawn a reading routine to prevent connection from closing 652 go func(c net.Conn) { 653 for { 654 one := make([]byte, 1) 655 _, err := c.Read(one) 656 if err != nil { 657 return 658 } 659 } 660 }(c) 661 } 662 time.Sleep(10 * time.Millisecond) 663 assert.Equal(t, cfg.MaxNumInboundPeers+unconditionalPeersNum, sw.Peers().Size()) 664 665 for _, peer := range peers { 666 peer.Stop() 667 } 668 for _, peer := range unconditionalPeers { 669 peer.Stop() 670 } 671 } 672 673 type errorTransport struct { 674 acceptErr error 675 } 676 677 func (et errorTransport) NetAddress() NetAddress { 678 panic("not implemented") 679 } 680 681 func (et errorTransport) Accept(c peerConfig) (Peer, error) { 682 return nil, et.acceptErr 683 } 684 func (errorTransport) Dial(NetAddress, peerConfig) (Peer, error) { 685 panic("not implemented") 686 } 687 func (errorTransport) Cleanup(Peer) { 688 panic("not implemented") 689 } 690 691 func TestSwitchAcceptRoutineErrorCases(t *testing.T) { 692 sw := NewSwitch(cfg, errorTransport{ErrFilterTimeout{}}) 693 assert.NotPanics(t, func() { 694 err := sw.Start() 695 require.NoError(t, err) 696 err = sw.Stop() 697 require.NoError(t, err) 698 }) 699 700 sw = NewSwitch(cfg, errorTransport{ErrRejected{conn: nil, err: errors.New("filtered"), isFiltered: true}}) 701 assert.NotPanics(t, func() { 702 err := sw.Start() 703 require.NoError(t, err) 704 err = sw.Stop() 705 require.NoError(t, err) 706 }) 707 // TODO(melekes) check we remove our address from addrBook 708 709 sw = NewSwitch(cfg, errorTransport{ErrTransportClosed{}}) 710 assert.NotPanics(t, func() { 711 err := sw.Start() 712 require.NoError(t, err) 713 err = sw.Stop() 714 require.NoError(t, err) 715 }) 716 } 717 718 // mockReactor checks that InitPeer never called before RemovePeer. If that's 719 // not true, InitCalledBeforeRemoveFinished will return true. 720 type mockReactor struct { 721 *BaseReactor 722 723 // atomic 724 removePeerInProgress uint32 725 initCalledBeforeRemoveFinished uint32 726 } 727 728 func (r *mockReactor) RemovePeer(peer Peer, reason interface{}) { 729 atomic.StoreUint32(&r.removePeerInProgress, 1) 730 defer atomic.StoreUint32(&r.removePeerInProgress, 0) 731 time.Sleep(100 * time.Millisecond) 732 } 733 734 func (r *mockReactor) InitPeer(peer Peer) Peer { 735 if atomic.LoadUint32(&r.removePeerInProgress) == 1 { 736 atomic.StoreUint32(&r.initCalledBeforeRemoveFinished, 1) 737 } 738 739 return peer 740 } 741 742 func (r *mockReactor) InitCalledBeforeRemoveFinished() bool { 743 return atomic.LoadUint32(&r.initCalledBeforeRemoveFinished) == 1 744 } 745 746 // see stopAndRemovePeer 747 func TestSwitchInitPeerIsNotCalledBeforeRemovePeer(t *testing.T) { 748 // make reactor 749 reactor := &mockReactor{} 750 reactor.BaseReactor = NewBaseReactor("mockReactor", reactor) 751 752 // make switch 753 sw := MakeSwitch(cfg, 1, "testing", "123.123.123", func(i int, sw *Switch) *Switch { 754 sw.AddReactor("mock", reactor) 755 return sw 756 }) 757 err := sw.Start() 758 require.NoError(t, err) 759 t.Cleanup(func() { 760 if err := sw.Stop(); err != nil { 761 t.Error(err) 762 } 763 }) 764 765 // add peer 766 rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg} 767 rp.Start() 768 defer rp.Stop() 769 _, err = rp.Dial(sw.NetAddress()) 770 require.NoError(t, err) 771 772 // wait till the switch adds rp to the peer set, then stop the peer asynchronously 773 for { 774 time.Sleep(20 * time.Millisecond) 775 if peer := sw.Peers().Get(rp.ID()); peer != nil { 776 go sw.StopPeerForError(peer, "test") 777 break 778 } 779 } 780 781 // simulate peer reconnecting to us 782 _, err = rp.Dial(sw.NetAddress()) 783 require.NoError(t, err) 784 // wait till the switch adds rp to the peer set 785 time.Sleep(50 * time.Millisecond) 786 787 // make sure reactor.RemovePeer is finished before InitPeer is called 788 assert.False(t, reactor.InitCalledBeforeRemoveFinished()) 789 } 790 791 func BenchmarkSwitchBroadcast(b *testing.B) { 792 s1, s2 := MakeSwitchPair(b, func(i int, sw *Switch) *Switch { 793 // Make bar reactors of bar channels each 794 sw.AddReactor("foo", NewTestReactor([]*conn.ChannelDescriptor{ 795 {ID: byte(0x00), Priority: 10}, 796 {ID: byte(0x01), Priority: 10}, 797 }, false)) 798 sw.AddReactor("bar", NewTestReactor([]*conn.ChannelDescriptor{ 799 {ID: byte(0x02), Priority: 10}, 800 {ID: byte(0x03), Priority: 10}, 801 }, false)) 802 return sw 803 }) 804 805 b.Cleanup(func() { 806 if err := s1.Stop(); err != nil { 807 b.Error(err) 808 } 809 }) 810 811 b.Cleanup(func() { 812 if err := s2.Stop(); err != nil { 813 b.Error(err) 814 } 815 }) 816 817 // Allow time for goroutines to boot up 818 time.Sleep(1 * time.Second) 819 820 b.ResetTimer() 821 822 numSuccess, numFailure := 0, 0 823 824 // Send random message from foo channel to another 825 for i := 0; i < b.N; i++ { 826 chID := byte(i % 4) 827 successChan := s1.Broadcast(chID, []byte("test data")) 828 for s := range successChan { 829 if s { 830 numSuccess++ 831 } else { 832 numFailure++ 833 } 834 } 835 } 836 837 b.Logf("success: %v, failure: %v", numSuccess, numFailure) 838 }