github.com/KYVENetwork/cometbft/v38@v38.0.3/p2p/switch_test.go (about)

     1  package p2p
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"net"
     9  	"net/http"
    10  	"net/http/httptest"
    11  	"regexp"
    12  	"strconv"
    13  	"sync/atomic"
    14  	"testing"
    15  	"time"
    16  
    17  	"github.com/cosmos/gogoproto/proto"
    18  	"github.com/prometheus/client_golang/prometheus/promhttp"
    19  	"github.com/stretchr/testify/assert"
    20  	"github.com/stretchr/testify/require"
    21  
    22  	"github.com/KYVENetwork/cometbft/v38/config"
    23  	"github.com/KYVENetwork/cometbft/v38/crypto/ed25519"
    24  	"github.com/KYVENetwork/cometbft/v38/libs/log"
    25  	cmtsync "github.com/KYVENetwork/cometbft/v38/libs/sync"
    26  	"github.com/KYVENetwork/cometbft/v38/p2p/conn"
    27  	p2pproto "github.com/KYVENetwork/cometbft/v38/proto/cometbft/v38/p2p"
    28  )
    29  
    30  var cfg *config.P2PConfig
    31  
    32  func init() {
    33  	cfg = config.DefaultP2PConfig()
    34  	cfg.PexReactor = true
    35  	cfg.AllowDuplicateIP = true
    36  }
    37  
    38  type PeerMessage struct {
    39  	Contents proto.Message
    40  	Counter  int
    41  }
    42  
    43  type TestReactor struct {
    44  	BaseReactor
    45  
    46  	mtx          cmtsync.Mutex
    47  	channels     []*conn.ChannelDescriptor
    48  	logMessages  bool
    49  	msgsCounter  int
    50  	msgsReceived map[byte][]PeerMessage
    51  }
    52  
    53  func NewTestReactor(channels []*conn.ChannelDescriptor, logMessages bool) *TestReactor {
    54  	tr := &TestReactor{
    55  		channels:     channels,
    56  		logMessages:  logMessages,
    57  		msgsReceived: make(map[byte][]PeerMessage),
    58  	}
    59  	tr.BaseReactor = *NewBaseReactor("TestReactor", tr)
    60  	tr.SetLogger(log.TestingLogger())
    61  	return tr
    62  }
    63  
    64  func (tr *TestReactor) GetChannels() []*conn.ChannelDescriptor {
    65  	return tr.channels
    66  }
    67  
    68  func (tr *TestReactor) AddPeer(Peer) {}
    69  
    70  func (tr *TestReactor) RemovePeer(Peer, interface{}) {}
    71  
    72  func (tr *TestReactor) Receive(e Envelope) {
    73  	if tr.logMessages {
    74  		tr.mtx.Lock()
    75  		defer tr.mtx.Unlock()
    76  		fmt.Printf("Received: %X, %X\n", e.ChannelID, e.Message)
    77  		tr.msgsReceived[e.ChannelID] = append(tr.msgsReceived[e.ChannelID], PeerMessage{Contents: e.Message, Counter: tr.msgsCounter})
    78  		tr.msgsCounter++
    79  	}
    80  }
    81  
    82  func (tr *TestReactor) getMsgs(chID byte) []PeerMessage {
    83  	tr.mtx.Lock()
    84  	defer tr.mtx.Unlock()
    85  	return tr.msgsReceived[chID]
    86  }
    87  
    88  //-----------------------------------------------------------------------------
    89  
    90  // convenience method for creating two switches connected to each other.
    91  // XXX: note this uses net.Pipe and not a proper TCP conn
    92  func MakeSwitchPair(initSwitch func(int, *Switch) *Switch) (*Switch, *Switch) {
    93  	// Create two switches that will be interconnected.
    94  	switches := MakeConnectedSwitches(cfg, 2, initSwitch, Connect2Switches)
    95  	return switches[0], switches[1]
    96  }
    97  
    98  func initSwitchFunc(_ int, sw *Switch) *Switch {
    99  	sw.SetAddrBook(&AddrBookMock{
   100  		Addrs:    make(map[string]struct{}),
   101  		OurAddrs: make(map[string]struct{}),
   102  	})
   103  
   104  	// Make two reactors of two channels each
   105  	sw.AddReactor("foo", NewTestReactor([]*conn.ChannelDescriptor{
   106  		{ID: byte(0x00), Priority: 10, MessageType: &p2pproto.Message{}},
   107  		{ID: byte(0x01), Priority: 10, MessageType: &p2pproto.Message{}},
   108  	}, true))
   109  	sw.AddReactor("bar", NewTestReactor([]*conn.ChannelDescriptor{
   110  		{ID: byte(0x02), Priority: 10, MessageType: &p2pproto.Message{}},
   111  		{ID: byte(0x03), Priority: 10, MessageType: &p2pproto.Message{}},
   112  	}, true))
   113  
   114  	return sw
   115  }
   116  
   117  func TestSwitches(t *testing.T) {
   118  	s1, s2 := MakeSwitchPair(initSwitchFunc)
   119  	t.Cleanup(func() {
   120  		if err := s1.Stop(); err != nil {
   121  			t.Error(err)
   122  		}
   123  	})
   124  	t.Cleanup(func() {
   125  		if err := s2.Stop(); err != nil {
   126  			t.Error(err)
   127  		}
   128  	})
   129  
   130  	if s1.Peers().Size() != 1 {
   131  		t.Errorf("expected exactly 1 peer in s1, got %v", s1.Peers().Size())
   132  	}
   133  	if s2.Peers().Size() != 1 {
   134  		t.Errorf("expected exactly 1 peer in s2, got %v", s2.Peers().Size())
   135  	}
   136  
   137  	// Lets send some messages
   138  	ch0Msg := &p2pproto.PexAddrs{
   139  		Addrs: []p2pproto.NetAddress{
   140  			{
   141  				ID: "1",
   142  			},
   143  		},
   144  	}
   145  	ch1Msg := &p2pproto.PexAddrs{
   146  		Addrs: []p2pproto.NetAddress{
   147  			{
   148  				ID: "1",
   149  			},
   150  		},
   151  	}
   152  	ch2Msg := &p2pproto.PexAddrs{
   153  		Addrs: []p2pproto.NetAddress{
   154  			{
   155  				ID: "2",
   156  			},
   157  		},
   158  	}
   159  	s1.Broadcast(Envelope{ChannelID: byte(0x00), Message: ch0Msg})
   160  	s1.Broadcast(Envelope{ChannelID: byte(0x01), Message: ch1Msg})
   161  	s1.Broadcast(Envelope{ChannelID: byte(0x02), Message: ch2Msg})
   162  	assertMsgReceivedWithTimeout(t,
   163  		ch0Msg,
   164  		byte(0x00),
   165  		s2.Reactor("foo").(*TestReactor), 200*time.Millisecond, 5*time.Second)
   166  	assertMsgReceivedWithTimeout(t,
   167  		ch1Msg,
   168  		byte(0x01),
   169  		s2.Reactor("foo").(*TestReactor), 200*time.Millisecond, 5*time.Second)
   170  	assertMsgReceivedWithTimeout(t,
   171  		ch2Msg,
   172  		byte(0x02),
   173  		s2.Reactor("bar").(*TestReactor), 200*time.Millisecond, 5*time.Second)
   174  }
   175  
   176  func assertMsgReceivedWithTimeout(
   177  	t *testing.T,
   178  	msg proto.Message,
   179  	channel byte,
   180  	reactor *TestReactor,
   181  	checkPeriod,
   182  	timeout time.Duration,
   183  ) {
   184  	ticker := time.NewTicker(checkPeriod)
   185  	for {
   186  		select {
   187  		case <-ticker.C:
   188  			msgs := reactor.getMsgs(channel)
   189  			expectedBytes, err := proto.Marshal(msgs[0].Contents)
   190  			require.NoError(t, err)
   191  			gotBytes, err := proto.Marshal(msg)
   192  			require.NoError(t, err)
   193  			if len(msgs) > 0 {
   194  				if !bytes.Equal(expectedBytes, gotBytes) {
   195  					t.Fatalf("Unexpected message bytes. Wanted: %X, Got: %X", msg, msgs[0].Counter)
   196  				}
   197  				return
   198  			}
   199  
   200  		case <-time.After(timeout):
   201  			t.Fatalf("Expected to have received 1 message in channel #%v, got zero", channel)
   202  		}
   203  	}
   204  }
   205  
   206  func TestSwitchFiltersOutItself(t *testing.T) {
   207  	s1 := MakeSwitch(cfg, 1, initSwitchFunc)
   208  
   209  	// simulate s1 having a public IP by creating a remote peer with the same ID
   210  	rp := &remotePeer{PrivKey: s1.nodeKey.PrivKey, Config: cfg}
   211  	rp.Start()
   212  
   213  	// addr should be rejected in addPeer based on the same ID
   214  	err := s1.DialPeerWithAddress(rp.Addr())
   215  	if assert.Error(t, err) {
   216  		if err, ok := err.(ErrRejected); ok {
   217  			if !err.IsSelf() {
   218  				t.Errorf("expected self to be rejected")
   219  			}
   220  		} else {
   221  			t.Errorf("expected ErrRejected")
   222  		}
   223  	}
   224  
   225  	assert.True(t, s1.addrBook.OurAddress(rp.Addr()))
   226  	assert.False(t, s1.addrBook.HasAddress(rp.Addr()))
   227  
   228  	rp.Stop()
   229  
   230  	assertNoPeersAfterTimeout(t, s1, 100*time.Millisecond)
   231  }
   232  
   233  func TestSwitchPeerFilter(t *testing.T) {
   234  	var (
   235  		filters = []PeerFilterFunc{
   236  			func(_ IPeerSet, _ Peer) error { return nil },
   237  			func(_ IPeerSet, _ Peer) error { return fmt.Errorf("denied") },
   238  			func(_ IPeerSet, _ Peer) error { return nil },
   239  		}
   240  		sw = MakeSwitch(
   241  			cfg,
   242  			1,
   243  			initSwitchFunc,
   244  			SwitchPeerFilters(filters...),
   245  		)
   246  	)
   247  	err := sw.Start()
   248  	require.NoError(t, err)
   249  	t.Cleanup(func() {
   250  		if err := sw.Stop(); err != nil {
   251  			t.Error(err)
   252  		}
   253  	})
   254  
   255  	// simulate remote peer
   256  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   257  	rp.Start()
   258  	t.Cleanup(rp.Stop)
   259  
   260  	p, err := sw.transport.Dial(*rp.Addr(), peerConfig{
   261  		chDescs:      sw.chDescs,
   262  		onPeerError:  sw.StopPeerForError,
   263  		isPersistent: sw.IsPeerPersistent,
   264  		reactorsByCh: sw.reactorsByCh,
   265  	})
   266  	if err != nil {
   267  		t.Fatal(err)
   268  	}
   269  
   270  	err = sw.addPeer(p)
   271  	if err, ok := err.(ErrRejected); ok {
   272  		if !err.IsFiltered() {
   273  			t.Errorf("expected peer to be filtered")
   274  		}
   275  	} else {
   276  		t.Errorf("expected ErrRejected")
   277  	}
   278  }
   279  
   280  func TestSwitchPeerFilterTimeout(t *testing.T) {
   281  	var (
   282  		filters = []PeerFilterFunc{
   283  			func(_ IPeerSet, _ Peer) error {
   284  				time.Sleep(10 * time.Millisecond)
   285  				return nil
   286  			},
   287  		}
   288  		sw = MakeSwitch(
   289  			cfg,
   290  			1,
   291  			initSwitchFunc,
   292  			SwitchFilterTimeout(5*time.Millisecond),
   293  			SwitchPeerFilters(filters...),
   294  		)
   295  	)
   296  	err := sw.Start()
   297  	require.NoError(t, err)
   298  	t.Cleanup(func() {
   299  		if err := sw.Stop(); err != nil {
   300  			t.Log(err)
   301  		}
   302  	})
   303  
   304  	// simulate remote peer
   305  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   306  	rp.Start()
   307  	defer rp.Stop()
   308  
   309  	p, err := sw.transport.Dial(*rp.Addr(), peerConfig{
   310  		chDescs:      sw.chDescs,
   311  		onPeerError:  sw.StopPeerForError,
   312  		isPersistent: sw.IsPeerPersistent,
   313  		reactorsByCh: sw.reactorsByCh,
   314  	})
   315  	if err != nil {
   316  		t.Fatal(err)
   317  	}
   318  
   319  	err = sw.addPeer(p)
   320  	if _, ok := err.(ErrFilterTimeout); !ok {
   321  		t.Errorf("expected ErrFilterTimeout")
   322  	}
   323  }
   324  
   325  func TestSwitchPeerFilterDuplicate(t *testing.T) {
   326  	sw := MakeSwitch(cfg, 1, initSwitchFunc)
   327  	err := sw.Start()
   328  	require.NoError(t, err)
   329  	t.Cleanup(func() {
   330  		if err := sw.Stop(); err != nil {
   331  			t.Error(err)
   332  		}
   333  	})
   334  
   335  	// simulate remote peer
   336  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   337  	rp.Start()
   338  	defer rp.Stop()
   339  
   340  	p, err := sw.transport.Dial(*rp.Addr(), peerConfig{
   341  		chDescs:      sw.chDescs,
   342  		onPeerError:  sw.StopPeerForError,
   343  		isPersistent: sw.IsPeerPersistent,
   344  		reactorsByCh: sw.reactorsByCh,
   345  	})
   346  	if err != nil {
   347  		t.Fatal(err)
   348  	}
   349  
   350  	if err := sw.addPeer(p); err != nil {
   351  		t.Fatal(err)
   352  	}
   353  
   354  	err = sw.addPeer(p)
   355  	if errRej, ok := err.(ErrRejected); ok {
   356  		if !errRej.IsDuplicate() {
   357  			t.Errorf("expected peer to be duplicate. got %v", errRej)
   358  		}
   359  	} else {
   360  		t.Errorf("expected ErrRejected, got %v", err)
   361  	}
   362  }
   363  
   364  func assertNoPeersAfterTimeout(t *testing.T, sw *Switch, timeout time.Duration) {
   365  	time.Sleep(timeout)
   366  	if sw.Peers().Size() != 0 {
   367  		t.Fatalf("Expected %v to not connect to some peers, got %d", sw, sw.Peers().Size())
   368  	}
   369  }
   370  
   371  func TestSwitchStopsNonPersistentPeerOnError(t *testing.T) {
   372  	assert, require := assert.New(t), require.New(t)
   373  
   374  	sw := MakeSwitch(cfg, 1, initSwitchFunc)
   375  	err := sw.Start()
   376  	if err != nil {
   377  		t.Error(err)
   378  	}
   379  	t.Cleanup(func() {
   380  		if err := sw.Stop(); err != nil {
   381  			t.Error(err)
   382  		}
   383  	})
   384  
   385  	// simulate remote peer
   386  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   387  	rp.Start()
   388  	defer rp.Stop()
   389  
   390  	p, err := sw.transport.Dial(*rp.Addr(), peerConfig{
   391  		chDescs:      sw.chDescs,
   392  		onPeerError:  sw.StopPeerForError,
   393  		isPersistent: sw.IsPeerPersistent,
   394  		reactorsByCh: sw.reactorsByCh,
   395  	})
   396  	require.Nil(err)
   397  
   398  	err = sw.addPeer(p)
   399  	require.Nil(err)
   400  
   401  	require.NotNil(sw.Peers().Get(rp.ID()))
   402  
   403  	// simulate failure by closing connection
   404  	err = p.(*peer).CloseConn()
   405  	require.NoError(err)
   406  
   407  	assertNoPeersAfterTimeout(t, sw, 100*time.Millisecond)
   408  	assert.False(p.IsRunning())
   409  }
   410  
   411  func TestSwitchStopPeerForError(t *testing.T) {
   412  	s := httptest.NewServer(promhttp.Handler())
   413  	defer s.Close()
   414  
   415  	scrapeMetrics := func() string {
   416  		resp, err := http.Get(s.URL)
   417  		require.NoError(t, err)
   418  		defer resp.Body.Close()
   419  		buf, _ := io.ReadAll(resp.Body)
   420  		return string(buf)
   421  	}
   422  
   423  	namespace, subsystem, name := config.TestInstrumentationConfig().Namespace, MetricsSubsystem, "peers"
   424  	re := regexp.MustCompile(namespace + `_` + subsystem + `_` + name + ` ([0-9\.]+)`)
   425  	peersMetricValue := func() float64 {
   426  		matches := re.FindStringSubmatch(scrapeMetrics())
   427  		f, _ := strconv.ParseFloat(matches[1], 64)
   428  		return f
   429  	}
   430  
   431  	p2pMetrics := PrometheusMetrics(namespace)
   432  
   433  	// make two connected switches
   434  	sw1, sw2 := MakeSwitchPair(func(i int, sw *Switch) *Switch {
   435  		// set metrics on sw1
   436  		if i == 0 {
   437  			opt := WithMetrics(p2pMetrics)
   438  			opt(sw)
   439  		}
   440  		return initSwitchFunc(i, sw)
   441  	})
   442  
   443  	assert.Equal(t, len(sw1.Peers().List()), 1)
   444  	assert.EqualValues(t, 1, peersMetricValue())
   445  
   446  	// send messages to the peer from sw1
   447  	p := sw1.Peers().List()[0]
   448  	p.Send(Envelope{
   449  		ChannelID: 0x1,
   450  		Message:   &p2pproto.Message{},
   451  	})
   452  
   453  	// stop sw2. this should cause the p to fail,
   454  	// which results in calling StopPeerForError internally
   455  	t.Cleanup(func() {
   456  		if err := sw2.Stop(); err != nil {
   457  			t.Error(err)
   458  		}
   459  	})
   460  
   461  	// now call StopPeerForError explicitly, eg. from a reactor
   462  	sw1.StopPeerForError(p, fmt.Errorf("some err"))
   463  
   464  	assert.Equal(t, len(sw1.Peers().List()), 0)
   465  	assert.EqualValues(t, 0, peersMetricValue())
   466  }
   467  
   468  func TestSwitchReconnectsToOutboundPersistentPeer(t *testing.T) {
   469  	sw := MakeSwitch(cfg, 1, initSwitchFunc)
   470  	err := sw.Start()
   471  	require.NoError(t, err)
   472  	t.Cleanup(func() {
   473  		if err := sw.Stop(); err != nil {
   474  			t.Error(err)
   475  		}
   476  	})
   477  
   478  	// 1. simulate failure by closing connection
   479  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   480  	rp.Start()
   481  	defer rp.Stop()
   482  
   483  	err = sw.AddPersistentPeers([]string{rp.Addr().String()})
   484  	require.NoError(t, err)
   485  
   486  	err = sw.DialPeerWithAddress(rp.Addr())
   487  	require.Nil(t, err)
   488  	require.NotNil(t, sw.Peers().Get(rp.ID()))
   489  
   490  	p := sw.Peers().List()[0]
   491  	err = p.(*peer).CloseConn()
   492  	require.NoError(t, err)
   493  
   494  	waitUntilSwitchHasAtLeastNPeers(sw, 1)
   495  	assert.False(t, p.IsRunning())        // old peer instance
   496  	assert.Equal(t, 1, sw.Peers().Size()) // new peer instance
   497  
   498  	// 2. simulate first time dial failure
   499  	rp = &remotePeer{
   500  		PrivKey: ed25519.GenPrivKey(),
   501  		Config:  cfg,
   502  		// Use different interface to prevent duplicate IP filter, this will break
   503  		// beyond two peers.
   504  		listenAddr: "127.0.0.1:0",
   505  	}
   506  	rp.Start()
   507  	defer rp.Stop()
   508  
   509  	conf := config.DefaultP2PConfig()
   510  	conf.TestDialFail = true // will trigger a reconnect
   511  	err = sw.addOutboundPeerWithConfig(rp.Addr(), conf)
   512  	require.NotNil(t, err)
   513  	// DialPeerWithAddres - sw.peerConfig resets the dialer
   514  	waitUntilSwitchHasAtLeastNPeers(sw, 2)
   515  	assert.Equal(t, 2, sw.Peers().Size())
   516  }
   517  
   518  func TestSwitchReconnectsToInboundPersistentPeer(t *testing.T) {
   519  	sw := MakeSwitch(cfg, 1, initSwitchFunc)
   520  	err := sw.Start()
   521  	require.NoError(t, err)
   522  	t.Cleanup(func() {
   523  		if err := sw.Stop(); err != nil {
   524  			t.Error(err)
   525  		}
   526  	})
   527  
   528  	// 1. simulate failure by closing the connection
   529  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   530  	rp.Start()
   531  	defer rp.Stop()
   532  
   533  	err = sw.AddPersistentPeers([]string{rp.Addr().String()})
   534  	require.NoError(t, err)
   535  
   536  	conn, err := rp.Dial(sw.NetAddress())
   537  	require.NoError(t, err)
   538  	time.Sleep(50 * time.Millisecond)
   539  	require.NotNil(t, sw.Peers().Get(rp.ID()))
   540  
   541  	conn.Close()
   542  
   543  	waitUntilSwitchHasAtLeastNPeers(sw, 1)
   544  	assert.Equal(t, 1, sw.Peers().Size())
   545  }
   546  
   547  func TestSwitchDialPeersAsync(t *testing.T) {
   548  	if testing.Short() {
   549  		return
   550  	}
   551  
   552  	sw := MakeSwitch(cfg, 1, initSwitchFunc)
   553  	err := sw.Start()
   554  	require.NoError(t, err)
   555  	t.Cleanup(func() {
   556  		if err := sw.Stop(); err != nil {
   557  			t.Error(err)
   558  		}
   559  	})
   560  
   561  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   562  	rp.Start()
   563  	defer rp.Stop()
   564  
   565  	err = sw.DialPeersAsync([]string{rp.Addr().String()})
   566  	require.NoError(t, err)
   567  	time.Sleep(dialRandomizerIntervalMilliseconds * time.Millisecond)
   568  	require.NotNil(t, sw.Peers().Get(rp.ID()))
   569  }
   570  
   571  func waitUntilSwitchHasAtLeastNPeers(sw *Switch, n int) {
   572  	for i := 0; i < 20; i++ {
   573  		time.Sleep(250 * time.Millisecond)
   574  		has := sw.Peers().Size()
   575  		if has >= n {
   576  			break
   577  		}
   578  	}
   579  }
   580  
   581  func TestSwitchFullConnectivity(t *testing.T) {
   582  	switches := MakeConnectedSwitches(cfg, 3, initSwitchFunc, Connect2Switches)
   583  	defer func() {
   584  		for _, sw := range switches {
   585  			sw := sw
   586  			t.Cleanup(func() {
   587  				if err := sw.Stop(); err != nil {
   588  					t.Error(err)
   589  				}
   590  			})
   591  		}
   592  	}()
   593  
   594  	for i, sw := range switches {
   595  		if sw.Peers().Size() != 2 {
   596  			t.Fatalf("Expected each switch to be connected to 2 other, but %d switch only connected to %d", sw.Peers().Size(), i)
   597  		}
   598  	}
   599  }
   600  
   601  func TestSwitchAcceptRoutine(t *testing.T) {
   602  	cfg.MaxNumInboundPeers = 5
   603  
   604  	// Create some unconditional peers.
   605  	const unconditionalPeersNum = 2
   606  	var (
   607  		unconditionalPeers   = make([]*remotePeer, unconditionalPeersNum)
   608  		unconditionalPeerIDs = make([]string, unconditionalPeersNum)
   609  	)
   610  	for i := 0; i < unconditionalPeersNum; i++ {
   611  		peer := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   612  		peer.Start()
   613  		unconditionalPeers[i] = peer
   614  		unconditionalPeerIDs[i] = string(peer.ID())
   615  	}
   616  
   617  	// make switch
   618  	sw := MakeSwitch(cfg, 1, initSwitchFunc)
   619  	err := sw.AddUnconditionalPeerIDs(unconditionalPeerIDs)
   620  	require.NoError(t, err)
   621  	err = sw.Start()
   622  	require.NoError(t, err)
   623  	t.Cleanup(func() {
   624  		err := sw.Stop()
   625  		require.NoError(t, err)
   626  	})
   627  
   628  	// 0. check there are no peers
   629  	assert.Equal(t, 0, sw.Peers().Size())
   630  
   631  	// 1. check we connect up to MaxNumInboundPeers
   632  	peers := make([]*remotePeer, 0)
   633  	for i := 0; i < cfg.MaxNumInboundPeers; i++ {
   634  		peer := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   635  		peers = append(peers, peer)
   636  		peer.Start()
   637  		c, err := peer.Dial(sw.NetAddress())
   638  		require.NoError(t, err)
   639  		// spawn a reading routine to prevent connection from closing
   640  		go func(c net.Conn) {
   641  			for {
   642  				one := make([]byte, 1)
   643  				_, err := c.Read(one)
   644  				if err != nil {
   645  					return
   646  				}
   647  			}
   648  		}(c)
   649  	}
   650  	time.Sleep(100 * time.Millisecond)
   651  	assert.Equal(t, cfg.MaxNumInboundPeers, sw.Peers().Size())
   652  
   653  	// 2. check we close new connections if we already have MaxNumInboundPeers peers
   654  	peer := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   655  	peer.Start()
   656  	conn, err := peer.Dial(sw.NetAddress())
   657  	require.NoError(t, err)
   658  	// check conn is closed
   659  	one := make([]byte, 1)
   660  	_ = conn.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
   661  	_, err = conn.Read(one)
   662  	assert.Error(t, err)
   663  	assert.Equal(t, cfg.MaxNumInboundPeers, sw.Peers().Size())
   664  	peer.Stop()
   665  
   666  	// 3. check we connect to unconditional peers despite the limit.
   667  	for _, peer := range unconditionalPeers {
   668  		c, err := peer.Dial(sw.NetAddress())
   669  		require.NoError(t, err)
   670  		// spawn a reading routine to prevent connection from closing
   671  		go func(c net.Conn) {
   672  			for {
   673  				one := make([]byte, 1)
   674  				_, err := c.Read(one)
   675  				if err != nil {
   676  					return
   677  				}
   678  			}
   679  		}(c)
   680  	}
   681  	time.Sleep(10 * time.Millisecond)
   682  	assert.Equal(t, cfg.MaxNumInboundPeers+unconditionalPeersNum, sw.Peers().Size())
   683  
   684  	for _, peer := range peers {
   685  		peer.Stop()
   686  	}
   687  	for _, peer := range unconditionalPeers {
   688  		peer.Stop()
   689  	}
   690  }
   691  
   692  type errorTransport struct {
   693  	acceptErr error
   694  }
   695  
   696  func (et errorTransport) NetAddress() NetAddress {
   697  	panic("not implemented")
   698  }
   699  
   700  func (et errorTransport) Accept(peerConfig) (Peer, error) {
   701  	return nil, et.acceptErr
   702  }
   703  
   704  func (errorTransport) Dial(NetAddress, peerConfig) (Peer, error) {
   705  	panic("not implemented")
   706  }
   707  
   708  func (errorTransport) Cleanup(Peer) {
   709  	panic("not implemented")
   710  }
   711  
   712  func TestSwitchAcceptRoutineErrorCases(t *testing.T) {
   713  	sw := NewSwitch(cfg, errorTransport{ErrFilterTimeout{}})
   714  	assert.NotPanics(t, func() {
   715  		err := sw.Start()
   716  		require.NoError(t, err)
   717  		err = sw.Stop()
   718  		require.NoError(t, err)
   719  	})
   720  
   721  	sw = NewSwitch(cfg, errorTransport{ErrRejected{conn: nil, err: errors.New("filtered"), isFiltered: true}})
   722  	assert.NotPanics(t, func() {
   723  		err := sw.Start()
   724  		require.NoError(t, err)
   725  		err = sw.Stop()
   726  		require.NoError(t, err)
   727  	})
   728  	// TODO(melekes) check we remove our address from addrBook
   729  
   730  	sw = NewSwitch(cfg, errorTransport{ErrTransportClosed{}})
   731  	assert.NotPanics(t, func() {
   732  		err := sw.Start()
   733  		require.NoError(t, err)
   734  		err = sw.Stop()
   735  		require.NoError(t, err)
   736  	})
   737  }
   738  
   739  // mockReactor checks that InitPeer never called before RemovePeer. If that's
   740  // not true, InitCalledBeforeRemoveFinished will return true.
   741  type mockReactor struct {
   742  	*BaseReactor
   743  
   744  	// atomic
   745  	removePeerInProgress           uint32
   746  	initCalledBeforeRemoveFinished uint32
   747  }
   748  
   749  func (r *mockReactor) RemovePeer(Peer, interface{}) {
   750  	atomic.StoreUint32(&r.removePeerInProgress, 1)
   751  	defer atomic.StoreUint32(&r.removePeerInProgress, 0)
   752  	time.Sleep(100 * time.Millisecond)
   753  }
   754  
   755  func (r *mockReactor) InitPeer(peer Peer) Peer {
   756  	if atomic.LoadUint32(&r.removePeerInProgress) == 1 {
   757  		atomic.StoreUint32(&r.initCalledBeforeRemoveFinished, 1)
   758  	}
   759  
   760  	return peer
   761  }
   762  
   763  func (r *mockReactor) InitCalledBeforeRemoveFinished() bool {
   764  	return atomic.LoadUint32(&r.initCalledBeforeRemoveFinished) == 1
   765  }
   766  
   767  // see stopAndRemovePeer
   768  func TestSwitchInitPeerIsNotCalledBeforeRemovePeer(t *testing.T) {
   769  	// make reactor
   770  	reactor := &mockReactor{}
   771  	reactor.BaseReactor = NewBaseReactor("mockReactor", reactor)
   772  
   773  	// make switch
   774  	sw := MakeSwitch(cfg, 1, func(i int, sw *Switch) *Switch {
   775  		sw.AddReactor("mock", reactor)
   776  		return sw
   777  	})
   778  	err := sw.Start()
   779  	require.NoError(t, err)
   780  	t.Cleanup(func() {
   781  		if err := sw.Stop(); err != nil {
   782  			t.Error(err)
   783  		}
   784  	})
   785  
   786  	// add peer
   787  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   788  	rp.Start()
   789  	defer rp.Stop()
   790  	_, err = rp.Dial(sw.NetAddress())
   791  	require.NoError(t, err)
   792  
   793  	// wait till the switch adds rp to the peer set, then stop the peer asynchronously
   794  	for {
   795  		time.Sleep(20 * time.Millisecond)
   796  		if peer := sw.Peers().Get(rp.ID()); peer != nil {
   797  			go sw.StopPeerForError(peer, "test")
   798  			break
   799  		}
   800  	}
   801  
   802  	// simulate peer reconnecting to us
   803  	_, err = rp.Dial(sw.NetAddress())
   804  	require.NoError(t, err)
   805  	// wait till the switch adds rp to the peer set
   806  	time.Sleep(50 * time.Millisecond)
   807  
   808  	// make sure reactor.RemovePeer is finished before InitPeer is called
   809  	assert.False(t, reactor.InitCalledBeforeRemoveFinished())
   810  }
   811  
   812  func BenchmarkSwitchBroadcast(b *testing.B) {
   813  	s1, s2 := MakeSwitchPair(func(i int, sw *Switch) *Switch {
   814  		// Make bar reactors of bar channels each
   815  		sw.AddReactor("foo", NewTestReactor([]*conn.ChannelDescriptor{
   816  			{ID: byte(0x00), Priority: 10},
   817  			{ID: byte(0x01), Priority: 10},
   818  		}, false))
   819  		sw.AddReactor("bar", NewTestReactor([]*conn.ChannelDescriptor{
   820  			{ID: byte(0x02), Priority: 10},
   821  			{ID: byte(0x03), Priority: 10},
   822  		}, false))
   823  		return sw
   824  	})
   825  
   826  	b.Cleanup(func() {
   827  		if err := s1.Stop(); err != nil {
   828  			b.Error(err)
   829  		}
   830  	})
   831  
   832  	b.Cleanup(func() {
   833  		if err := s2.Stop(); err != nil {
   834  			b.Error(err)
   835  		}
   836  	})
   837  
   838  	// Allow time for goroutines to boot up
   839  	time.Sleep(1 * time.Second)
   840  
   841  	b.ResetTimer()
   842  
   843  	numSuccess, numFailure := 0, 0
   844  
   845  	// Send random message from foo channel to another
   846  	for i := 0; i < b.N; i++ {
   847  		chID := byte(i % 4)
   848  		successChan := s1.Broadcast(Envelope{ChannelID: chID})
   849  		for s := range successChan {
   850  			if s {
   851  				numSuccess++
   852  			} else {
   853  				numFailure++
   854  			}
   855  		}
   856  	}
   857  
   858  	b.Logf("success: %v, failure: %v", numSuccess, numFailure)
   859  }
   860  
   861  func TestSwitchRemovalErr(t *testing.T) {
   862  	sw1, sw2 := MakeSwitchPair(func(i int, sw *Switch) *Switch {
   863  		return initSwitchFunc(i, sw)
   864  	})
   865  	assert.Equal(t, len(sw1.Peers().List()), 1)
   866  	p := sw1.Peers().List()[0]
   867  
   868  	sw2.StopPeerForError(p, fmt.Errorf("peer should error"))
   869  
   870  	assert.Equal(t, sw2.peers.Add(p).Error(), ErrPeerRemoval{}.Error())
   871  }