github.com/badrootd/nibiru-cometbft@v0.37.5-0.20240307173500-2a75559eee9b/p2p/switch_test.go (about)

     1  package p2p
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"net"
     9  	"net/http"
    10  	"net/http/httptest"
    11  	"regexp"
    12  	"strconv"
    13  	"sync/atomic"
    14  	"testing"
    15  	"time"
    16  
    17  	"github.com/cosmos/gogoproto/proto"
    18  	"github.com/prometheus/client_golang/prometheus/promhttp"
    19  	"github.com/stretchr/testify/assert"
    20  	"github.com/stretchr/testify/require"
    21  
    22  	"github.com/badrootd/nibiru-cometbft/config"
    23  	"github.com/badrootd/nibiru-cometbft/crypto/ed25519"
    24  	"github.com/badrootd/nibiru-cometbft/libs/log"
    25  	cmtsync "github.com/badrootd/nibiru-cometbft/libs/sync"
    26  	"github.com/badrootd/nibiru-cometbft/p2p/conn"
    27  	p2pproto "github.com/badrootd/nibiru-cometbft/proto/tendermint/p2p"
    28  )
    29  
    30  var (
    31  	cfg *config.P2PConfig
    32  )
    33  
    34  func init() {
    35  	cfg = config.DefaultP2PConfig()
    36  	cfg.PexReactor = true
    37  	cfg.AllowDuplicateIP = true
    38  }
    39  
    40  type PeerMessage struct {
    41  	Contents proto.Message
    42  	Counter  int
    43  }
    44  
    45  type TestReactor struct {
    46  	BaseReactor
    47  
    48  	mtx          cmtsync.Mutex
    49  	channels     []*conn.ChannelDescriptor
    50  	logMessages  bool
    51  	msgsCounter  int
    52  	msgsReceived map[byte][]PeerMessage
    53  }
    54  
    55  func NewTestReactor(channels []*conn.ChannelDescriptor, logMessages bool) *TestReactor {
    56  	tr := &TestReactor{
    57  		channels:     channels,
    58  		logMessages:  logMessages,
    59  		msgsReceived: make(map[byte][]PeerMessage),
    60  	}
    61  	tr.BaseReactor = *NewBaseReactor("TestReactor", tr)
    62  	tr.SetLogger(log.TestingLogger())
    63  	return tr
    64  }
    65  
    66  func (tr *TestReactor) GetChannels() []*conn.ChannelDescriptor {
    67  	return tr.channels
    68  }
    69  
    70  func (tr *TestReactor) AddPeer(peer Peer) {}
    71  
    72  func (tr *TestReactor) RemovePeer(peer Peer, reason interface{}) {}
    73  
    74  func (tr *TestReactor) ReceiveEnvelope(e Envelope) {
    75  	if tr.logMessages {
    76  		tr.mtx.Lock()
    77  		defer tr.mtx.Unlock()
    78  		// fmt.Printf("Received: %X, %X\n", e.ChannelID, e.Message)
    79  		tr.msgsReceived[e.ChannelID] = append(tr.msgsReceived[e.ChannelID], PeerMessage{Contents: e.Message, Counter: tr.msgsCounter})
    80  		tr.msgsCounter++
    81  	}
    82  }
    83  
    84  func (tr *TestReactor) getMsgs(chID byte) []PeerMessage {
    85  	tr.mtx.Lock()
    86  	defer tr.mtx.Unlock()
    87  	return tr.msgsReceived[chID]
    88  }
    89  
    90  //-----------------------------------------------------------------------------
    91  
    92  // convenience method for creating two switches connected to each other.
    93  // XXX: note this uses net.Pipe and not a proper TCP conn
    94  func MakeSwitchPair(t testing.TB, initSwitch func(int, *Switch) *Switch) (*Switch, *Switch) {
    95  	// Create two switches that will be interconnected.
    96  	switches := MakeConnectedSwitches(cfg, 2, initSwitch, Connect2Switches)
    97  	return switches[0], switches[1]
    98  }
    99  
   100  func initSwitchFunc(i int, sw *Switch) *Switch {
   101  	sw.SetAddrBook(&AddrBookMock{
   102  		Addrs:    make(map[string]struct{}),
   103  		OurAddrs: make(map[string]struct{})})
   104  
   105  	// Make two reactors of two channels each
   106  	sw.AddReactor("foo", NewTestReactor([]*conn.ChannelDescriptor{
   107  		{ID: byte(0x00), Priority: 10, MessageType: &p2pproto.Message{}},
   108  		{ID: byte(0x01), Priority: 10, MessageType: &p2pproto.Message{}},
   109  	}, true))
   110  	sw.AddReactor("bar", NewTestReactor([]*conn.ChannelDescriptor{
   111  		{ID: byte(0x02), Priority: 10, MessageType: &p2pproto.Message{}},
   112  		{ID: byte(0x03), Priority: 10, MessageType: &p2pproto.Message{}},
   113  	}, true))
   114  
   115  	return sw
   116  }
   117  
   118  func TestSwitches(t *testing.T) {
   119  	s1, s2 := MakeSwitchPair(t, initSwitchFunc)
   120  	t.Cleanup(func() {
   121  		if err := s1.Stop(); err != nil {
   122  			t.Error(err)
   123  		}
   124  	})
   125  	t.Cleanup(func() {
   126  		if err := s2.Stop(); err != nil {
   127  			t.Error(err)
   128  		}
   129  	})
   130  
   131  	if s1.Peers().Size() != 1 {
   132  		t.Errorf("expected exactly 1 peer in s1, got %v", s1.Peers().Size())
   133  	}
   134  	if s2.Peers().Size() != 1 {
   135  		t.Errorf("expected exactly 1 peer in s2, got %v", s2.Peers().Size())
   136  	}
   137  
   138  	// Lets send some messages
   139  	ch0Msg := &p2pproto.PexAddrs{
   140  		Addrs: []p2pproto.NetAddress{
   141  			{
   142  				ID: "1",
   143  			},
   144  		},
   145  	}
   146  	ch1Msg := &p2pproto.PexAddrs{
   147  		Addrs: []p2pproto.NetAddress{
   148  			{
   149  				ID: "1",
   150  			},
   151  		},
   152  	}
   153  	ch2Msg := &p2pproto.PexAddrs{
   154  		Addrs: []p2pproto.NetAddress{
   155  			{
   156  				ID: "2",
   157  			},
   158  		},
   159  	}
   160  	s1.BroadcastEnvelope(Envelope{ChannelID: byte(0x00), Message: ch0Msg})
   161  	s1.BroadcastEnvelope(Envelope{ChannelID: byte(0x01), Message: ch1Msg})
   162  	s1.BroadcastEnvelope(Envelope{ChannelID: byte(0x02), Message: ch2Msg})
   163  	assertMsgReceivedWithTimeout(t,
   164  		ch0Msg,
   165  		byte(0x00),
   166  		s2.Reactor("foo").(*TestReactor), 200*time.Millisecond, 5*time.Second)
   167  	assertMsgReceivedWithTimeout(t,
   168  		ch1Msg,
   169  		byte(0x01),
   170  		s2.Reactor("foo").(*TestReactor), 200*time.Millisecond, 5*time.Second)
   171  	assertMsgReceivedWithTimeout(t,
   172  		ch2Msg,
   173  		byte(0x02),
   174  		s2.Reactor("bar").(*TestReactor), 200*time.Millisecond, 5*time.Second)
   175  }
   176  
   177  func assertMsgReceivedWithTimeout(
   178  	t *testing.T,
   179  	msg proto.Message,
   180  	channel byte,
   181  	reactor *TestReactor,
   182  	checkPeriod,
   183  	timeout time.Duration,
   184  ) {
   185  	ticker := time.NewTicker(checkPeriod)
   186  	for {
   187  		select {
   188  		case <-ticker.C:
   189  			msgs := reactor.getMsgs(channel)
   190  			expectedBytes, err := proto.Marshal(msgs[0].Contents)
   191  			require.NoError(t, err)
   192  			gotBytes, err := proto.Marshal(msg)
   193  			require.NoError(t, err)
   194  			if len(msgs) > 0 {
   195  				if !bytes.Equal(expectedBytes, gotBytes) {
   196  					t.Fatalf("Unexpected message bytes. Wanted: %X, Got: %X", msg, msgs[0].Counter)
   197  				}
   198  				return
   199  			}
   200  
   201  		case <-time.After(timeout):
   202  			t.Fatalf("Expected to have received 1 message in channel #%v, got zero", channel)
   203  		}
   204  	}
   205  }
   206  
   207  func TestSwitchFiltersOutItself(t *testing.T) {
   208  	s1 := MakeSwitch(cfg, 1, "127.0.0.1", "123.123.123", initSwitchFunc)
   209  
   210  	// simulate s1 having a public IP by creating a remote peer with the same ID
   211  	rp := &remotePeer{PrivKey: s1.nodeKey.PrivKey, Config: cfg}
   212  	rp.Start()
   213  
   214  	// addr should be rejected in addPeer based on the same ID
   215  	err := s1.DialPeerWithAddress(rp.Addr())
   216  	if assert.Error(t, err) {
   217  		if err, ok := err.(ErrRejected); ok {
   218  			if !err.IsSelf() {
   219  				t.Errorf("expected self to be rejected")
   220  			}
   221  		} else {
   222  			t.Errorf("expected ErrRejected")
   223  		}
   224  	}
   225  
   226  	assert.True(t, s1.addrBook.OurAddress(rp.Addr()))
   227  	assert.False(t, s1.addrBook.HasAddress(rp.Addr()))
   228  
   229  	rp.Stop()
   230  
   231  	assertNoPeersAfterTimeout(t, s1, 100*time.Millisecond)
   232  }
   233  
   234  func TestSwitchPeerFilter(t *testing.T) {
   235  	var (
   236  		filters = []PeerFilterFunc{
   237  			func(_ IPeerSet, _ Peer) error { return nil },
   238  			func(_ IPeerSet, _ Peer) error { return fmt.Errorf("denied") },
   239  			func(_ IPeerSet, _ Peer) error { return nil },
   240  		}
   241  		sw = MakeSwitch(
   242  			cfg,
   243  			1,
   244  			"testing",
   245  			"123.123.123",
   246  			initSwitchFunc,
   247  			SwitchPeerFilters(filters...),
   248  		)
   249  	)
   250  	err := sw.Start()
   251  	require.NoError(t, err)
   252  	t.Cleanup(func() {
   253  		if err := sw.Stop(); err != nil {
   254  			t.Error(err)
   255  		}
   256  	})
   257  
   258  	// simulate remote peer
   259  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   260  	rp.Start()
   261  	t.Cleanup(rp.Stop)
   262  
   263  	p, err := sw.transport.Dial(*rp.Addr(), peerConfig{
   264  		chDescs:      sw.chDescs,
   265  		onPeerError:  sw.StopPeerForError,
   266  		isPersistent: sw.IsPeerPersistent,
   267  		reactorsByCh: sw.reactorsByCh,
   268  	})
   269  	if err != nil {
   270  		t.Fatal(err)
   271  	}
   272  
   273  	err = sw.addPeer(p)
   274  	if err, ok := err.(ErrRejected); ok {
   275  		if !err.IsFiltered() {
   276  			t.Errorf("expected peer to be filtered")
   277  		}
   278  	} else {
   279  		t.Errorf("expected ErrRejected")
   280  	}
   281  }
   282  
   283  func TestSwitchPeerFilterTimeout(t *testing.T) {
   284  	var (
   285  		filters = []PeerFilterFunc{
   286  			func(_ IPeerSet, _ Peer) error {
   287  				time.Sleep(10 * time.Millisecond)
   288  				return nil
   289  			},
   290  		}
   291  		sw = MakeSwitch(
   292  			cfg,
   293  			1,
   294  			"testing",
   295  			"123.123.123",
   296  			initSwitchFunc,
   297  			SwitchFilterTimeout(5*time.Millisecond),
   298  			SwitchPeerFilters(filters...),
   299  		)
   300  	)
   301  	err := sw.Start()
   302  	require.NoError(t, err)
   303  	t.Cleanup(func() {
   304  		if err := sw.Stop(); err != nil {
   305  			t.Log(err)
   306  		}
   307  	})
   308  
   309  	// simulate remote peer
   310  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   311  	rp.Start()
   312  	defer rp.Stop()
   313  
   314  	p, err := sw.transport.Dial(*rp.Addr(), peerConfig{
   315  		chDescs:      sw.chDescs,
   316  		onPeerError:  sw.StopPeerForError,
   317  		isPersistent: sw.IsPeerPersistent,
   318  		reactorsByCh: sw.reactorsByCh,
   319  	})
   320  	if err != nil {
   321  		t.Fatal(err)
   322  	}
   323  
   324  	err = sw.addPeer(p)
   325  	if _, ok := err.(ErrFilterTimeout); !ok {
   326  		t.Errorf("expected ErrFilterTimeout")
   327  	}
   328  }
   329  
   330  func TestSwitchPeerFilterDuplicate(t *testing.T) {
   331  	sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc)
   332  	err := sw.Start()
   333  	require.NoError(t, err)
   334  	t.Cleanup(func() {
   335  		if err := sw.Stop(); err != nil {
   336  			t.Error(err)
   337  		}
   338  	})
   339  
   340  	// simulate remote peer
   341  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   342  	rp.Start()
   343  	defer rp.Stop()
   344  
   345  	p, err := sw.transport.Dial(*rp.Addr(), peerConfig{
   346  		chDescs:      sw.chDescs,
   347  		onPeerError:  sw.StopPeerForError,
   348  		isPersistent: sw.IsPeerPersistent,
   349  		reactorsByCh: sw.reactorsByCh,
   350  	})
   351  	if err != nil {
   352  		t.Fatal(err)
   353  	}
   354  
   355  	if err := sw.addPeer(p); err != nil {
   356  		t.Fatal(err)
   357  	}
   358  
   359  	err = sw.addPeer(p)
   360  	if errRej, ok := err.(ErrRejected); ok {
   361  		if !errRej.IsDuplicate() {
   362  			t.Errorf("expected peer to be duplicate. got %v", errRej)
   363  		}
   364  	} else {
   365  		t.Errorf("expected ErrRejected, got %v", err)
   366  	}
   367  }
   368  
   369  func assertNoPeersAfterTimeout(t *testing.T, sw *Switch, timeout time.Duration) {
   370  	time.Sleep(timeout)
   371  	if sw.Peers().Size() != 0 {
   372  		t.Fatalf("Expected %v to not connect to some peers, got %d", sw, sw.Peers().Size())
   373  	}
   374  }
   375  
   376  func TestSwitchStopsNonPersistentPeerOnError(t *testing.T) {
   377  	assert, require := assert.New(t), require.New(t)
   378  
   379  	sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc)
   380  	err := sw.Start()
   381  	if err != nil {
   382  		t.Error(err)
   383  	}
   384  	t.Cleanup(func() {
   385  		if err := sw.Stop(); err != nil {
   386  			t.Error(err)
   387  		}
   388  	})
   389  
   390  	// simulate remote peer
   391  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   392  	rp.Start()
   393  	defer rp.Stop()
   394  
   395  	p, err := sw.transport.Dial(*rp.Addr(), peerConfig{
   396  		chDescs:      sw.chDescs,
   397  		onPeerError:  sw.StopPeerForError,
   398  		isPersistent: sw.IsPeerPersistent,
   399  		reactorsByCh: sw.reactorsByCh,
   400  	})
   401  	require.Nil(err)
   402  
   403  	err = sw.addPeer(p)
   404  	require.Nil(err)
   405  
   406  	require.NotNil(sw.Peers().Get(rp.ID()))
   407  
   408  	// simulate failure by closing connection
   409  	err = p.(*peer).CloseConn()
   410  	require.NoError(err)
   411  
   412  	assertNoPeersAfterTimeout(t, sw, 100*time.Millisecond)
   413  	assert.False(p.IsRunning())
   414  }
   415  
   416  func TestSwitchStopPeerForError(t *testing.T) {
   417  	s := httptest.NewServer(promhttp.Handler())
   418  	defer s.Close()
   419  
   420  	scrapeMetrics := func() string {
   421  		resp, err := http.Get(s.URL)
   422  		require.NoError(t, err)
   423  		defer resp.Body.Close()
   424  		buf, _ := io.ReadAll(resp.Body)
   425  		return string(buf)
   426  	}
   427  
   428  	namespace, subsystem, name := config.TestInstrumentationConfig().Namespace, MetricsSubsystem, "peers"
   429  	re := regexp.MustCompile(namespace + `_` + subsystem + `_` + name + ` ([0-9\.]+)`)
   430  	peersMetricValue := func() float64 {
   431  		matches := re.FindStringSubmatch(scrapeMetrics())
   432  		f, _ := strconv.ParseFloat(matches[1], 64)
   433  		return f
   434  	}
   435  
   436  	p2pMetrics := PrometheusMetrics(namespace)
   437  
   438  	// make two connected switches
   439  	sw1, sw2 := MakeSwitchPair(t, func(i int, sw *Switch) *Switch {
   440  		// set metrics on sw1
   441  		if i == 0 {
   442  			opt := WithMetrics(p2pMetrics)
   443  			opt(sw)
   444  		}
   445  		return initSwitchFunc(i, sw)
   446  	})
   447  
   448  	assert.Equal(t, len(sw1.Peers().List()), 1)
   449  	assert.EqualValues(t, 1, peersMetricValue())
   450  
   451  	// send messages to the peer from sw1
   452  	p := sw1.Peers().List()[0]
   453  	p.SendEnvelope(Envelope{
   454  		ChannelID: 0x1,
   455  		Message:   &p2pproto.Message{},
   456  	})
   457  
   458  	// stop sw2. this should cause the p to fail,
   459  	// which results in calling StopPeerForError internally
   460  	t.Cleanup(func() {
   461  		if err := sw2.Stop(); err != nil {
   462  			t.Error(err)
   463  		}
   464  	})
   465  
   466  	// now call StopPeerForError explicitly, eg. from a reactor
   467  	sw1.StopPeerForError(p, fmt.Errorf("some err"))
   468  
   469  	assert.Equal(t, len(sw1.Peers().List()), 0)
   470  	assert.EqualValues(t, 0, peersMetricValue())
   471  }
   472  
   473  func TestSwitchReconnectsToOutboundPersistentPeer(t *testing.T) {
   474  	sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc)
   475  	err := sw.Start()
   476  	require.NoError(t, err)
   477  	t.Cleanup(func() {
   478  		if err := sw.Stop(); err != nil {
   479  			t.Error(err)
   480  		}
   481  	})
   482  
   483  	// 1. simulate failure by closing connection
   484  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   485  	rp.Start()
   486  	defer rp.Stop()
   487  
   488  	err = sw.AddPersistentPeers([]string{rp.Addr().String()})
   489  	require.NoError(t, err)
   490  
   491  	err = sw.DialPeerWithAddress(rp.Addr())
   492  	require.Nil(t, err)
   493  	require.NotNil(t, sw.Peers().Get(rp.ID()))
   494  
   495  	p := sw.Peers().List()[0]
   496  	err = p.(*peer).CloseConn()
   497  	require.NoError(t, err)
   498  
   499  	waitUntilSwitchHasAtLeastNPeers(sw, 1)
   500  	assert.False(t, p.IsRunning())        // old peer instance
   501  	assert.Equal(t, 1, sw.Peers().Size()) // new peer instance
   502  
   503  	// 2. simulate first time dial failure
   504  	rp = &remotePeer{
   505  		PrivKey: ed25519.GenPrivKey(),
   506  		Config:  cfg,
   507  		// Use different interface to prevent duplicate IP filter, this will break
   508  		// beyond two peers.
   509  		listenAddr: "127.0.0.1:0",
   510  	}
   511  	rp.Start()
   512  	defer rp.Stop()
   513  
   514  	conf := config.DefaultP2PConfig()
   515  	conf.TestDialFail = true // will trigger a reconnect
   516  	err = sw.addOutboundPeerWithConfig(rp.Addr(), conf)
   517  	require.NotNil(t, err)
   518  	// DialPeerWithAddres - sw.peerConfig resets the dialer
   519  	waitUntilSwitchHasAtLeastNPeers(sw, 2)
   520  	assert.Equal(t, 2, sw.Peers().Size())
   521  }
   522  
   523  func TestSwitchReconnectsToInboundPersistentPeer(t *testing.T) {
   524  	sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc)
   525  	err := sw.Start()
   526  	require.NoError(t, err)
   527  	t.Cleanup(func() {
   528  		if err := sw.Stop(); err != nil {
   529  			t.Error(err)
   530  		}
   531  	})
   532  
   533  	// 1. simulate failure by closing the connection
   534  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   535  	rp.Start()
   536  	defer rp.Stop()
   537  
   538  	err = sw.AddPersistentPeers([]string{rp.Addr().String()})
   539  	require.NoError(t, err)
   540  
   541  	conn, err := rp.Dial(sw.NetAddress())
   542  	require.NoError(t, err)
   543  	time.Sleep(50 * time.Millisecond)
   544  	require.NotNil(t, sw.Peers().Get(rp.ID()))
   545  
   546  	conn.Close()
   547  
   548  	waitUntilSwitchHasAtLeastNPeers(sw, 1)
   549  	assert.Equal(t, 1, sw.Peers().Size())
   550  }
   551  
   552  func TestSwitchDialPeersAsync(t *testing.T) {
   553  	if testing.Short() {
   554  		return
   555  	}
   556  
   557  	sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc)
   558  	err := sw.Start()
   559  	require.NoError(t, err)
   560  	t.Cleanup(func() {
   561  		if err := sw.Stop(); err != nil {
   562  			t.Error(err)
   563  		}
   564  	})
   565  
   566  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   567  	rp.Start()
   568  	defer rp.Stop()
   569  
   570  	err = sw.DialPeersAsync([]string{rp.Addr().String()})
   571  	require.NoError(t, err)
   572  	time.Sleep(dialRandomizerIntervalMilliseconds * time.Millisecond)
   573  	require.NotNil(t, sw.Peers().Get(rp.ID()))
   574  }
   575  
   576  func waitUntilSwitchHasAtLeastNPeers(sw *Switch, n int) {
   577  	for i := 0; i < 20; i++ {
   578  		time.Sleep(250 * time.Millisecond)
   579  		has := sw.Peers().Size()
   580  		if has >= n {
   581  			break
   582  		}
   583  	}
   584  }
   585  
   586  func TestSwitchFullConnectivity(t *testing.T) {
   587  	switches := MakeConnectedSwitches(cfg, 3, initSwitchFunc, Connect2Switches)
   588  	defer func() {
   589  		for _, sw := range switches {
   590  			sw := sw
   591  			t.Cleanup(func() {
   592  				if err := sw.Stop(); err != nil {
   593  					t.Error(err)
   594  				}
   595  			})
   596  		}
   597  	}()
   598  
   599  	for i, sw := range switches {
   600  		if sw.Peers().Size() != 2 {
   601  			t.Fatalf("Expected each switch to be connected to 2 other, but %d switch only connected to %d", sw.Peers().Size(), i)
   602  		}
   603  	}
   604  }
   605  
   606  func TestSwitchAcceptRoutine(t *testing.T) {
   607  	cfg.MaxNumInboundPeers = 5
   608  
   609  	// Create some unconditional peers.
   610  	const unconditionalPeersNum = 2
   611  	var (
   612  		unconditionalPeers   = make([]*remotePeer, unconditionalPeersNum)
   613  		unconditionalPeerIDs = make([]string, unconditionalPeersNum)
   614  	)
   615  	for i := 0; i < unconditionalPeersNum; i++ {
   616  		peer := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   617  		peer.Start()
   618  		unconditionalPeers[i] = peer
   619  		unconditionalPeerIDs[i] = string(peer.ID())
   620  	}
   621  
   622  	// make switch
   623  	sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc)
   624  	err := sw.AddUnconditionalPeerIDs(unconditionalPeerIDs)
   625  	require.NoError(t, err)
   626  	err = sw.Start()
   627  	require.NoError(t, err)
   628  	t.Cleanup(func() {
   629  		err := sw.Stop()
   630  		require.NoError(t, err)
   631  	})
   632  
   633  	// 0. check there are no peers
   634  	assert.Equal(t, 0, sw.Peers().Size())
   635  
   636  	// 1. check we connect up to MaxNumInboundPeers
   637  	peers := make([]*remotePeer, 0)
   638  	for i := 0; i < cfg.MaxNumInboundPeers; i++ {
   639  		peer := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   640  		peers = append(peers, peer)
   641  		peer.Start()
   642  		c, err := peer.Dial(sw.NetAddress())
   643  		require.NoError(t, err)
   644  		// spawn a reading routine to prevent connection from closing
   645  		go func(c net.Conn) {
   646  			for {
   647  				one := make([]byte, 1)
   648  				_, err := c.Read(one)
   649  				if err != nil {
   650  					return
   651  				}
   652  			}
   653  		}(c)
   654  	}
   655  	time.Sleep(100 * time.Millisecond)
   656  	assert.Equal(t, cfg.MaxNumInboundPeers, sw.Peers().Size())
   657  
   658  	// 2. check we close new connections if we already have MaxNumInboundPeers peers
   659  	peer := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   660  	peer.Start()
   661  	conn, err := peer.Dial(sw.NetAddress())
   662  	require.NoError(t, err)
   663  	// check conn is closed
   664  	one := make([]byte, 1)
   665  	_ = conn.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
   666  	_, err = conn.Read(one)
   667  	assert.Error(t, err)
   668  	assert.Equal(t, cfg.MaxNumInboundPeers, sw.Peers().Size())
   669  	peer.Stop()
   670  
   671  	// 3. check we connect to unconditional peers despite the limit.
   672  	for _, peer := range unconditionalPeers {
   673  		c, err := peer.Dial(sw.NetAddress())
   674  		require.NoError(t, err)
   675  		// spawn a reading routine to prevent connection from closing
   676  		go func(c net.Conn) {
   677  			for {
   678  				one := make([]byte, 1)
   679  				_, err := c.Read(one)
   680  				if err != nil {
   681  					return
   682  				}
   683  			}
   684  		}(c)
   685  	}
   686  	time.Sleep(10 * time.Millisecond)
   687  	assert.Equal(t, cfg.MaxNumInboundPeers+unconditionalPeersNum, sw.Peers().Size())
   688  
   689  	for _, peer := range peers {
   690  		peer.Stop()
   691  	}
   692  	for _, peer := range unconditionalPeers {
   693  		peer.Stop()
   694  	}
   695  }
   696  
   697  type errorTransport struct {
   698  	acceptErr error
   699  }
   700  
   701  func (et errorTransport) NetAddress() NetAddress {
   702  	panic("not implemented")
   703  }
   704  
   705  func (et errorTransport) Accept(c peerConfig) (Peer, error) {
   706  	return nil, et.acceptErr
   707  }
   708  func (errorTransport) Dial(NetAddress, peerConfig) (Peer, error) {
   709  	panic("not implemented")
   710  }
   711  func (errorTransport) Cleanup(Peer) {
   712  	panic("not implemented")
   713  }
   714  
   715  func TestSwitchAcceptRoutineErrorCases(t *testing.T) {
   716  	sw := NewSwitch(cfg, errorTransport{ErrFilterTimeout{}})
   717  	assert.NotPanics(t, func() {
   718  		err := sw.Start()
   719  		require.NoError(t, err)
   720  		err = sw.Stop()
   721  		require.NoError(t, err)
   722  	})
   723  
   724  	sw = NewSwitch(cfg, errorTransport{ErrRejected{conn: nil, err: errors.New("filtered"), isFiltered: true}})
   725  	assert.NotPanics(t, func() {
   726  		err := sw.Start()
   727  		require.NoError(t, err)
   728  		err = sw.Stop()
   729  		require.NoError(t, err)
   730  	})
   731  	// TODO(melekes) check we remove our address from addrBook
   732  
   733  	sw = NewSwitch(cfg, errorTransport{ErrTransportClosed{}})
   734  	assert.NotPanics(t, func() {
   735  		err := sw.Start()
   736  		require.NoError(t, err)
   737  		err = sw.Stop()
   738  		require.NoError(t, err)
   739  	})
   740  }
   741  
   742  // mockReactor checks that InitPeer never called before RemovePeer. If that's
   743  // not true, InitCalledBeforeRemoveFinished will return true.
   744  type mockReactor struct {
   745  	*BaseReactor
   746  
   747  	// atomic
   748  	removePeerInProgress           uint32
   749  	initCalledBeforeRemoveFinished uint32
   750  }
   751  
   752  func (r *mockReactor) RemovePeer(peer Peer, reason interface{}) {
   753  	atomic.StoreUint32(&r.removePeerInProgress, 1)
   754  	defer atomic.StoreUint32(&r.removePeerInProgress, 0)
   755  	time.Sleep(100 * time.Millisecond)
   756  }
   757  
   758  func (r *mockReactor) InitPeer(peer Peer) Peer {
   759  	if atomic.LoadUint32(&r.removePeerInProgress) == 1 {
   760  		atomic.StoreUint32(&r.initCalledBeforeRemoveFinished, 1)
   761  	}
   762  
   763  	return peer
   764  }
   765  
   766  func (r *mockReactor) InitCalledBeforeRemoveFinished() bool {
   767  	return atomic.LoadUint32(&r.initCalledBeforeRemoveFinished) == 1
   768  }
   769  
   770  // see stopAndRemovePeer
   771  func TestSwitchInitPeerIsNotCalledBeforeRemovePeer(t *testing.T) {
   772  	// make reactor
   773  	reactor := &mockReactor{}
   774  	reactor.BaseReactor = NewBaseReactor("mockReactor", reactor)
   775  
   776  	// make switch
   777  	sw := MakeSwitch(cfg, 1, "testing", "123.123.123", func(i int, sw *Switch) *Switch {
   778  		sw.AddReactor("mock", reactor)
   779  		return sw
   780  	})
   781  	err := sw.Start()
   782  	require.NoError(t, err)
   783  	t.Cleanup(func() {
   784  		if err := sw.Stop(); err != nil {
   785  			t.Error(err)
   786  		}
   787  	})
   788  
   789  	// add peer
   790  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   791  	rp.Start()
   792  	defer rp.Stop()
   793  	_, err = rp.Dial(sw.NetAddress())
   794  	require.NoError(t, err)
   795  
   796  	// wait till the switch adds rp to the peer set, then stop the peer asynchronously
   797  	for {
   798  		time.Sleep(20 * time.Millisecond)
   799  		if peer := sw.Peers().Get(rp.ID()); peer != nil {
   800  			go sw.StopPeerForError(peer, "test")
   801  			break
   802  		}
   803  	}
   804  
   805  	// simulate peer reconnecting to us
   806  	_, err = rp.Dial(sw.NetAddress())
   807  	require.NoError(t, err)
   808  	// wait till the switch adds rp to the peer set
   809  	time.Sleep(50 * time.Millisecond)
   810  
   811  	// make sure reactor.RemovePeer is finished before InitPeer is called
   812  	assert.False(t, reactor.InitCalledBeforeRemoveFinished())
   813  }
   814  
   815  func BenchmarkSwitchBroadcast(b *testing.B) {
   816  	s1, s2 := MakeSwitchPair(b, func(i int, sw *Switch) *Switch {
   817  		// Make bar reactors of bar channels each
   818  		sw.AddReactor("foo", NewTestReactor([]*conn.ChannelDescriptor{
   819  			{ID: byte(0x00), Priority: 10},
   820  			{ID: byte(0x01), Priority: 10},
   821  		}, false))
   822  		sw.AddReactor("bar", NewTestReactor([]*conn.ChannelDescriptor{
   823  			{ID: byte(0x02), Priority: 10},
   824  			{ID: byte(0x03), Priority: 10},
   825  		}, false))
   826  		return sw
   827  	})
   828  
   829  	b.Cleanup(func() {
   830  		if err := s1.Stop(); err != nil {
   831  			b.Error(err)
   832  		}
   833  	})
   834  
   835  	b.Cleanup(func() {
   836  		if err := s2.Stop(); err != nil {
   837  			b.Error(err)
   838  		}
   839  	})
   840  
   841  	// Allow time for goroutines to boot up
   842  	time.Sleep(1 * time.Second)
   843  
   844  	b.ResetTimer()
   845  
   846  	numSuccess, numFailure := 0, 0
   847  
   848  	// Send random message from foo channel to another
   849  	for i := 0; i < b.N; i++ {
   850  		chID := byte(i % 4)
   851  		successChan := s1.BroadcastEnvelope(Envelope{ChannelID: chID})
   852  		for s := range successChan {
   853  			if s {
   854  				numSuccess++
   855  			} else {
   856  				numFailure++
   857  			}
   858  		}
   859  	}
   860  
   861  	b.Logf("success: %v, failure: %v", numSuccess, numFailure)
   862  }
   863  
   864  func TestSwitchRemovalErr(t *testing.T) {
   865  
   866  	sw1, sw2 := MakeSwitchPair(t, func(i int, sw *Switch) *Switch {
   867  		return initSwitchFunc(i, sw)
   868  	})
   869  	assert.Equal(t, len(sw1.Peers().List()), 1)
   870  	p := sw1.Peers().List()[0]
   871  
   872  	sw2.StopPeerForError(p, fmt.Errorf("peer should error"))
   873  
   874  	assert.Equal(t, sw2.peers.Add(p).Error(), ErrPeerRemoval{}.Error())
   875  }