github.com/consideritdone/landslidecore@v0.0.0-20230718131026-a8b21c5cf8a7/p2p/switch_test.go (about)

     1  package p2p
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"fmt"
     7  	"io/ioutil"
     8  	"net"
     9  	"net/http"
    10  	"net/http/httptest"
    11  	"regexp"
    12  	"strconv"
    13  	"sync/atomic"
    14  	"testing"
    15  	"time"
    16  
    17  	"github.com/prometheus/client_golang/prometheus/promhttp"
    18  	"github.com/stretchr/testify/assert"
    19  	"github.com/stretchr/testify/require"
    20  
    21  	"github.com/consideritdone/landslidecore/config"
    22  	"github.com/consideritdone/landslidecore/crypto/ed25519"
    23  	"github.com/consideritdone/landslidecore/libs/log"
    24  	tmsync "github.com/consideritdone/landslidecore/libs/sync"
    25  	"github.com/consideritdone/landslidecore/p2p/conn"
    26  )
    27  
    28  var (
    29  	cfg *config.P2PConfig
    30  )
    31  
    32  func init() {
    33  	cfg = config.DefaultP2PConfig()
    34  	cfg.PexReactor = true
    35  	cfg.AllowDuplicateIP = true
    36  }
    37  
    38  type PeerMessage struct {
    39  	PeerID  ID
    40  	Bytes   []byte
    41  	Counter int
    42  }
    43  
    44  type TestReactor struct {
    45  	BaseReactor
    46  
    47  	mtx          tmsync.Mutex
    48  	channels     []*conn.ChannelDescriptor
    49  	logMessages  bool
    50  	msgsCounter  int
    51  	msgsReceived map[byte][]PeerMessage
    52  }
    53  
    54  func NewTestReactor(channels []*conn.ChannelDescriptor, logMessages bool) *TestReactor {
    55  	tr := &TestReactor{
    56  		channels:     channels,
    57  		logMessages:  logMessages,
    58  		msgsReceived: make(map[byte][]PeerMessage),
    59  	}
    60  	tr.BaseReactor = *NewBaseReactor("TestReactor", tr)
    61  	tr.SetLogger(log.TestingLogger())
    62  	return tr
    63  }
    64  
    65  func (tr *TestReactor) GetChannels() []*conn.ChannelDescriptor {
    66  	return tr.channels
    67  }
    68  
    69  func (tr *TestReactor) AddPeer(peer Peer) {}
    70  
    71  func (tr *TestReactor) RemovePeer(peer Peer, reason interface{}) {}
    72  
    73  func (tr *TestReactor) Receive(chID byte, peer Peer, msgBytes []byte) {
    74  	if tr.logMessages {
    75  		tr.mtx.Lock()
    76  		defer tr.mtx.Unlock()
    77  		// fmt.Printf("Received: %X, %X\n", chID, msgBytes)
    78  		tr.msgsReceived[chID] = append(tr.msgsReceived[chID], PeerMessage{peer.ID(), msgBytes, tr.msgsCounter})
    79  		tr.msgsCounter++
    80  	}
    81  }
    82  
    83  func (tr *TestReactor) getMsgs(chID byte) []PeerMessage {
    84  	tr.mtx.Lock()
    85  	defer tr.mtx.Unlock()
    86  	return tr.msgsReceived[chID]
    87  }
    88  
    89  //-----------------------------------------------------------------------------
    90  
    91  // convenience method for creating two switches connected to each other.
    92  // XXX: note this uses net.Pipe and not a proper TCP conn
    93  func MakeSwitchPair(t testing.TB, initSwitch func(int, *Switch) *Switch) (*Switch, *Switch) {
    94  	// Create two switches that will be interconnected.
    95  	switches := MakeConnectedSwitches(cfg, 2, initSwitch, Connect2Switches)
    96  	return switches[0], switches[1]
    97  }
    98  
    99  func initSwitchFunc(i int, sw *Switch) *Switch {
   100  	sw.SetAddrBook(&AddrBookMock{
   101  		Addrs:    make(map[string]struct{}),
   102  		OurAddrs: make(map[string]struct{})})
   103  
   104  	// Make two reactors of two channels each
   105  	sw.AddReactor("foo", NewTestReactor([]*conn.ChannelDescriptor{
   106  		{ID: byte(0x00), Priority: 10},
   107  		{ID: byte(0x01), Priority: 10},
   108  	}, true))
   109  	sw.AddReactor("bar", NewTestReactor([]*conn.ChannelDescriptor{
   110  		{ID: byte(0x02), Priority: 10},
   111  		{ID: byte(0x03), Priority: 10},
   112  	}, true))
   113  
   114  	return sw
   115  }
   116  
   117  func TestSwitches(t *testing.T) {
   118  	s1, s2 := MakeSwitchPair(t, initSwitchFunc)
   119  	t.Cleanup(func() {
   120  		if err := s1.Stop(); err != nil {
   121  			t.Error(err)
   122  		}
   123  	})
   124  	t.Cleanup(func() {
   125  		if err := s2.Stop(); err != nil {
   126  			t.Error(err)
   127  		}
   128  	})
   129  
   130  	if s1.Peers().Size() != 1 {
   131  		t.Errorf("expected exactly 1 peer in s1, got %v", s1.Peers().Size())
   132  	}
   133  	if s2.Peers().Size() != 1 {
   134  		t.Errorf("expected exactly 1 peer in s2, got %v", s2.Peers().Size())
   135  	}
   136  
   137  	// Lets send some messages
   138  	ch0Msg := []byte("channel zero")
   139  	ch1Msg := []byte("channel foo")
   140  	ch2Msg := []byte("channel bar")
   141  
   142  	s1.Broadcast(byte(0x00), ch0Msg)
   143  	s1.Broadcast(byte(0x01), ch1Msg)
   144  	s1.Broadcast(byte(0x02), ch2Msg)
   145  
   146  	assertMsgReceivedWithTimeout(t,
   147  		ch0Msg,
   148  		byte(0x00),
   149  		s2.Reactor("foo").(*TestReactor), 10*time.Millisecond, 5*time.Second)
   150  	assertMsgReceivedWithTimeout(t,
   151  		ch1Msg,
   152  		byte(0x01),
   153  		s2.Reactor("foo").(*TestReactor), 10*time.Millisecond, 5*time.Second)
   154  	assertMsgReceivedWithTimeout(t,
   155  		ch2Msg,
   156  		byte(0x02),
   157  		s2.Reactor("bar").(*TestReactor), 10*time.Millisecond, 5*time.Second)
   158  }
   159  
   160  func assertMsgReceivedWithTimeout(
   161  	t *testing.T,
   162  	msgBytes []byte,
   163  	channel byte,
   164  	reactor *TestReactor,
   165  	checkPeriod,
   166  	timeout time.Duration,
   167  ) {
   168  	ticker := time.NewTicker(checkPeriod)
   169  	for {
   170  		select {
   171  		case <-ticker.C:
   172  			msgs := reactor.getMsgs(channel)
   173  			if len(msgs) > 0 {
   174  				if !bytes.Equal(msgs[0].Bytes, msgBytes) {
   175  					t.Fatalf("Unexpected message bytes. Wanted: %X, Got: %X", msgBytes, msgs[0].Bytes)
   176  				}
   177  				return
   178  			}
   179  
   180  		case <-time.After(timeout):
   181  			t.Fatalf("Expected to have received 1 message in channel #%v, got zero", channel)
   182  		}
   183  	}
   184  }
   185  
   186  func TestSwitchFiltersOutItself(t *testing.T) {
   187  	s1 := MakeSwitch(cfg, 1, "127.0.0.1", "123.123.123", initSwitchFunc)
   188  
   189  	// simulate s1 having a public IP by creating a remote peer with the same ID
   190  	rp := &remotePeer{PrivKey: s1.nodeKey.PrivKey, Config: cfg}
   191  	rp.Start()
   192  
   193  	// addr should be rejected in addPeer based on the same ID
   194  	err := s1.DialPeerWithAddress(rp.Addr())
   195  	if assert.Error(t, err) {
   196  		if err, ok := err.(ErrRejected); ok {
   197  			if !err.IsSelf() {
   198  				t.Errorf("expected self to be rejected")
   199  			}
   200  		} else {
   201  			t.Errorf("expected ErrRejected")
   202  		}
   203  	}
   204  
   205  	assert.True(t, s1.addrBook.OurAddress(rp.Addr()))
   206  	assert.False(t, s1.addrBook.HasAddress(rp.Addr()))
   207  
   208  	rp.Stop()
   209  
   210  	assertNoPeersAfterTimeout(t, s1, 100*time.Millisecond)
   211  }
   212  
   213  func TestSwitchPeerFilter(t *testing.T) {
   214  	var (
   215  		filters = []PeerFilterFunc{
   216  			func(_ IPeerSet, _ Peer) error { return nil },
   217  			func(_ IPeerSet, _ Peer) error { return fmt.Errorf("denied") },
   218  			func(_ IPeerSet, _ Peer) error { return nil },
   219  		}
   220  		sw = MakeSwitch(
   221  			cfg,
   222  			1,
   223  			"testing",
   224  			"123.123.123",
   225  			initSwitchFunc,
   226  			SwitchPeerFilters(filters...),
   227  		)
   228  	)
   229  	err := sw.Start()
   230  	require.NoError(t, err)
   231  	t.Cleanup(func() {
   232  		if err := sw.Stop(); err != nil {
   233  			t.Error(err)
   234  		}
   235  	})
   236  
   237  	// simulate remote peer
   238  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   239  	rp.Start()
   240  	t.Cleanup(rp.Stop)
   241  
   242  	p, err := sw.transport.Dial(*rp.Addr(), peerConfig{
   243  		chDescs:      sw.chDescs,
   244  		onPeerError:  sw.StopPeerForError,
   245  		isPersistent: sw.IsPeerPersistent,
   246  		reactorsByCh: sw.reactorsByCh,
   247  	})
   248  	if err != nil {
   249  		t.Fatal(err)
   250  	}
   251  
   252  	err = sw.addPeer(p)
   253  	if err, ok := err.(ErrRejected); ok {
   254  		if !err.IsFiltered() {
   255  			t.Errorf("expected peer to be filtered")
   256  		}
   257  	} else {
   258  		t.Errorf("expected ErrRejected")
   259  	}
   260  }
   261  
   262  func TestSwitchPeerFilterTimeout(t *testing.T) {
   263  	var (
   264  		filters = []PeerFilterFunc{
   265  			func(_ IPeerSet, _ Peer) error {
   266  				time.Sleep(10 * time.Millisecond)
   267  				return nil
   268  			},
   269  		}
   270  		sw = MakeSwitch(
   271  			cfg,
   272  			1,
   273  			"testing",
   274  			"123.123.123",
   275  			initSwitchFunc,
   276  			SwitchFilterTimeout(5*time.Millisecond),
   277  			SwitchPeerFilters(filters...),
   278  		)
   279  	)
   280  	err := sw.Start()
   281  	require.NoError(t, err)
   282  	t.Cleanup(func() {
   283  		if err := sw.Stop(); err != nil {
   284  			t.Log(err)
   285  		}
   286  	})
   287  
   288  	// simulate remote peer
   289  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   290  	rp.Start()
   291  	defer rp.Stop()
   292  
   293  	p, err := sw.transport.Dial(*rp.Addr(), peerConfig{
   294  		chDescs:      sw.chDescs,
   295  		onPeerError:  sw.StopPeerForError,
   296  		isPersistent: sw.IsPeerPersistent,
   297  		reactorsByCh: sw.reactorsByCh,
   298  	})
   299  	if err != nil {
   300  		t.Fatal(err)
   301  	}
   302  
   303  	err = sw.addPeer(p)
   304  	if _, ok := err.(ErrFilterTimeout); !ok {
   305  		t.Errorf("expected ErrFilterTimeout")
   306  	}
   307  }
   308  
   309  func TestSwitchPeerFilterDuplicate(t *testing.T) {
   310  	sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc)
   311  	err := sw.Start()
   312  	require.NoError(t, err)
   313  	t.Cleanup(func() {
   314  		if err := sw.Stop(); err != nil {
   315  			t.Error(err)
   316  		}
   317  	})
   318  
   319  	// simulate remote peer
   320  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   321  	rp.Start()
   322  	defer rp.Stop()
   323  
   324  	p, err := sw.transport.Dial(*rp.Addr(), peerConfig{
   325  		chDescs:      sw.chDescs,
   326  		onPeerError:  sw.StopPeerForError,
   327  		isPersistent: sw.IsPeerPersistent,
   328  		reactorsByCh: sw.reactorsByCh,
   329  	})
   330  	if err != nil {
   331  		t.Fatal(err)
   332  	}
   333  
   334  	if err := sw.addPeer(p); err != nil {
   335  		t.Fatal(err)
   336  	}
   337  
   338  	err = sw.addPeer(p)
   339  	if errRej, ok := err.(ErrRejected); ok {
   340  		if !errRej.IsDuplicate() {
   341  			t.Errorf("expected peer to be duplicate. got %v", errRej)
   342  		}
   343  	} else {
   344  		t.Errorf("expected ErrRejected, got %v", err)
   345  	}
   346  }
   347  
   348  func assertNoPeersAfterTimeout(t *testing.T, sw *Switch, timeout time.Duration) {
   349  	time.Sleep(timeout)
   350  	if sw.Peers().Size() != 0 {
   351  		t.Fatalf("Expected %v to not connect to some peers, got %d", sw, sw.Peers().Size())
   352  	}
   353  }
   354  
   355  func TestSwitchStopsNonPersistentPeerOnError(t *testing.T) {
   356  	assert, require := assert.New(t), require.New(t)
   357  
   358  	sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc)
   359  	err := sw.Start()
   360  	if err != nil {
   361  		t.Error(err)
   362  	}
   363  	t.Cleanup(func() {
   364  		if err := sw.Stop(); err != nil {
   365  			t.Error(err)
   366  		}
   367  	})
   368  
   369  	// simulate remote peer
   370  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   371  	rp.Start()
   372  	defer rp.Stop()
   373  
   374  	p, err := sw.transport.Dial(*rp.Addr(), peerConfig{
   375  		chDescs:      sw.chDescs,
   376  		onPeerError:  sw.StopPeerForError,
   377  		isPersistent: sw.IsPeerPersistent,
   378  		reactorsByCh: sw.reactorsByCh,
   379  	})
   380  	require.Nil(err)
   381  
   382  	err = sw.addPeer(p)
   383  	require.Nil(err)
   384  
   385  	require.NotNil(sw.Peers().Get(rp.ID()))
   386  
   387  	// simulate failure by closing connection
   388  	err = p.(*peer).CloseConn()
   389  	require.NoError(err)
   390  
   391  	assertNoPeersAfterTimeout(t, sw, 100*time.Millisecond)
   392  	assert.False(p.IsRunning())
   393  }
   394  
   395  func TestSwitchStopPeerForError(t *testing.T) {
   396  	s := httptest.NewServer(promhttp.Handler())
   397  	defer s.Close()
   398  
   399  	scrapeMetrics := func() string {
   400  		resp, err := http.Get(s.URL)
   401  		require.NoError(t, err)
   402  		defer resp.Body.Close()
   403  		buf, _ := ioutil.ReadAll(resp.Body)
   404  		return string(buf)
   405  	}
   406  
   407  	namespace, subsystem, name := config.TestInstrumentationConfig().Namespace, MetricsSubsystem, "peers"
   408  	re := regexp.MustCompile(namespace + `_` + subsystem + `_` + name + ` ([0-9\.]+)`)
   409  	peersMetricValue := func() float64 {
   410  		matches := re.FindStringSubmatch(scrapeMetrics())
   411  		f, _ := strconv.ParseFloat(matches[1], 64)
   412  		return f
   413  	}
   414  
   415  	p2pMetrics := PrometheusMetrics(namespace)
   416  
   417  	// make two connected switches
   418  	sw1, sw2 := MakeSwitchPair(t, func(i int, sw *Switch) *Switch {
   419  		// set metrics on sw1
   420  		if i == 0 {
   421  			opt := WithMetrics(p2pMetrics)
   422  			opt(sw)
   423  		}
   424  		return initSwitchFunc(i, sw)
   425  	})
   426  
   427  	assert.Equal(t, len(sw1.Peers().List()), 1)
   428  	assert.EqualValues(t, 1, peersMetricValue())
   429  
   430  	// send messages to the peer from sw1
   431  	p := sw1.Peers().List()[0]
   432  	p.Send(0x1, []byte("here's a message to send"))
   433  
   434  	// stop sw2. this should cause the p to fail,
   435  	// which results in calling StopPeerForError internally
   436  	t.Cleanup(func() {
   437  		if err := sw2.Stop(); err != nil {
   438  			t.Error(err)
   439  		}
   440  	})
   441  
   442  	// now call StopPeerForError explicitly, eg. from a reactor
   443  	sw1.StopPeerForError(p, fmt.Errorf("some err"))
   444  
   445  	assert.Equal(t, len(sw1.Peers().List()), 0)
   446  	assert.EqualValues(t, 0, peersMetricValue())
   447  }
   448  
   449  func TestSwitchReconnectsToOutboundPersistentPeer(t *testing.T) {
   450  	sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc)
   451  	err := sw.Start()
   452  	require.NoError(t, err)
   453  	t.Cleanup(func() {
   454  		if err := sw.Stop(); err != nil {
   455  			t.Error(err)
   456  		}
   457  	})
   458  
   459  	// 1. simulate failure by closing connection
   460  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   461  	rp.Start()
   462  	defer rp.Stop()
   463  
   464  	err = sw.AddPersistentPeers([]string{rp.Addr().String()})
   465  	require.NoError(t, err)
   466  
   467  	err = sw.DialPeerWithAddress(rp.Addr())
   468  	require.Nil(t, err)
   469  	require.NotNil(t, sw.Peers().Get(rp.ID()))
   470  
   471  	p := sw.Peers().List()[0]
   472  	err = p.(*peer).CloseConn()
   473  	require.NoError(t, err)
   474  
   475  	waitUntilSwitchHasAtLeastNPeers(sw, 1)
   476  	assert.False(t, p.IsRunning())        // old peer instance
   477  	assert.Equal(t, 1, sw.Peers().Size()) // new peer instance
   478  
   479  	// 2. simulate first time dial failure
   480  	rp = &remotePeer{
   481  		PrivKey: ed25519.GenPrivKey(),
   482  		Config:  cfg,
   483  		// Use different interface to prevent duplicate IP filter, this will break
   484  		// beyond two peers.
   485  		listenAddr: "127.0.0.1:0",
   486  	}
   487  	rp.Start()
   488  	defer rp.Stop()
   489  
   490  	conf := config.DefaultP2PConfig()
   491  	conf.TestDialFail = true // will trigger a reconnect
   492  	err = sw.addOutboundPeerWithConfig(rp.Addr(), conf)
   493  	require.NotNil(t, err)
   494  	// DialPeerWithAddres - sw.peerConfig resets the dialer
   495  	waitUntilSwitchHasAtLeastNPeers(sw, 2)
   496  	assert.Equal(t, 2, sw.Peers().Size())
   497  }
   498  
   499  func TestSwitchReconnectsToInboundPersistentPeer(t *testing.T) {
   500  	sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc)
   501  	err := sw.Start()
   502  	require.NoError(t, err)
   503  	t.Cleanup(func() {
   504  		if err := sw.Stop(); err != nil {
   505  			t.Error(err)
   506  		}
   507  	})
   508  
   509  	// 1. simulate failure by closing the connection
   510  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   511  	rp.Start()
   512  	defer rp.Stop()
   513  
   514  	err = sw.AddPersistentPeers([]string{rp.Addr().String()})
   515  	require.NoError(t, err)
   516  
   517  	conn, err := rp.Dial(sw.NetAddress())
   518  	require.NoError(t, err)
   519  	time.Sleep(50 * time.Millisecond)
   520  	require.NotNil(t, sw.Peers().Get(rp.ID()))
   521  
   522  	conn.Close()
   523  
   524  	waitUntilSwitchHasAtLeastNPeers(sw, 1)
   525  	assert.Equal(t, 1, sw.Peers().Size())
   526  }
   527  
   528  func TestSwitchDialPeersAsync(t *testing.T) {
   529  	if testing.Short() {
   530  		return
   531  	}
   532  
   533  	sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc)
   534  	err := sw.Start()
   535  	require.NoError(t, err)
   536  	t.Cleanup(func() {
   537  		if err := sw.Stop(); err != nil {
   538  			t.Error(err)
   539  		}
   540  	})
   541  
   542  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   543  	rp.Start()
   544  	defer rp.Stop()
   545  
   546  	err = sw.DialPeersAsync([]string{rp.Addr().String()})
   547  	require.NoError(t, err)
   548  	time.Sleep(dialRandomizerIntervalMilliseconds * time.Millisecond)
   549  	require.NotNil(t, sw.Peers().Get(rp.ID()))
   550  }
   551  
   552  func waitUntilSwitchHasAtLeastNPeers(sw *Switch, n int) {
   553  	for i := 0; i < 20; i++ {
   554  		time.Sleep(250 * time.Millisecond)
   555  		has := sw.Peers().Size()
   556  		if has >= n {
   557  			break
   558  		}
   559  	}
   560  }
   561  
   562  func TestSwitchFullConnectivity(t *testing.T) {
   563  	switches := MakeConnectedSwitches(cfg, 3, initSwitchFunc, Connect2Switches)
   564  	defer func() {
   565  		for _, sw := range switches {
   566  			sw := sw
   567  			t.Cleanup(func() {
   568  				if err := sw.Stop(); err != nil {
   569  					t.Error(err)
   570  				}
   571  			})
   572  		}
   573  	}()
   574  
   575  	for i, sw := range switches {
   576  		if sw.Peers().Size() != 2 {
   577  			t.Fatalf("Expected each switch to be connected to 2 other, but %d switch only connected to %d", sw.Peers().Size(), i)
   578  		}
   579  	}
   580  }
   581  
   582  func TestSwitchAcceptRoutine(t *testing.T) {
   583  	cfg.MaxNumInboundPeers = 5
   584  
   585  	// Create some unconditional peers.
   586  	const unconditionalPeersNum = 2
   587  	var (
   588  		unconditionalPeers   = make([]*remotePeer, unconditionalPeersNum)
   589  		unconditionalPeerIDs = make([]string, unconditionalPeersNum)
   590  	)
   591  	for i := 0; i < unconditionalPeersNum; i++ {
   592  		peer := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   593  		peer.Start()
   594  		unconditionalPeers[i] = peer
   595  		unconditionalPeerIDs[i] = string(peer.ID())
   596  	}
   597  
   598  	// make switch
   599  	sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc)
   600  	err := sw.AddUnconditionalPeerIDs(unconditionalPeerIDs)
   601  	require.NoError(t, err)
   602  	err = sw.Start()
   603  	require.NoError(t, err)
   604  	t.Cleanup(func() {
   605  		err := sw.Stop()
   606  		require.NoError(t, err)
   607  	})
   608  
   609  	// 0. check there are no peers
   610  	assert.Equal(t, 0, sw.Peers().Size())
   611  
   612  	// 1. check we connect up to MaxNumInboundPeers
   613  	peers := make([]*remotePeer, 0)
   614  	for i := 0; i < cfg.MaxNumInboundPeers; i++ {
   615  		peer := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   616  		peers = append(peers, peer)
   617  		peer.Start()
   618  		c, err := peer.Dial(sw.NetAddress())
   619  		require.NoError(t, err)
   620  		// spawn a reading routine to prevent connection from closing
   621  		go func(c net.Conn) {
   622  			for {
   623  				one := make([]byte, 1)
   624  				_, err := c.Read(one)
   625  				if err != nil {
   626  					return
   627  				}
   628  			}
   629  		}(c)
   630  	}
   631  	time.Sleep(100 * time.Millisecond)
   632  	assert.Equal(t, cfg.MaxNumInboundPeers, sw.Peers().Size())
   633  
   634  	// 2. check we close new connections if we already have MaxNumInboundPeers peers
   635  	peer := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   636  	peer.Start()
   637  	conn, err := peer.Dial(sw.NetAddress())
   638  	require.NoError(t, err)
   639  	// check conn is closed
   640  	one := make([]byte, 1)
   641  	_ = conn.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
   642  	_, err = conn.Read(one)
   643  	assert.Error(t, err)
   644  	assert.Equal(t, cfg.MaxNumInboundPeers, sw.Peers().Size())
   645  	peer.Stop()
   646  
   647  	// 3. check we connect to unconditional peers despite the limit.
   648  	for _, peer := range unconditionalPeers {
   649  		c, err := peer.Dial(sw.NetAddress())
   650  		require.NoError(t, err)
   651  		// spawn a reading routine to prevent connection from closing
   652  		go func(c net.Conn) {
   653  			for {
   654  				one := make([]byte, 1)
   655  				_, err := c.Read(one)
   656  				if err != nil {
   657  					return
   658  				}
   659  			}
   660  		}(c)
   661  	}
   662  	time.Sleep(10 * time.Millisecond)
   663  	assert.Equal(t, cfg.MaxNumInboundPeers+unconditionalPeersNum, sw.Peers().Size())
   664  
   665  	for _, peer := range peers {
   666  		peer.Stop()
   667  	}
   668  	for _, peer := range unconditionalPeers {
   669  		peer.Stop()
   670  	}
   671  }
   672  
   673  type errorTransport struct {
   674  	acceptErr error
   675  }
   676  
   677  func (et errorTransport) NetAddress() NetAddress {
   678  	panic("not implemented")
   679  }
   680  
   681  func (et errorTransport) Accept(c peerConfig) (Peer, error) {
   682  	return nil, et.acceptErr
   683  }
   684  func (errorTransport) Dial(NetAddress, peerConfig) (Peer, error) {
   685  	panic("not implemented")
   686  }
   687  func (errorTransport) Cleanup(Peer) {
   688  	panic("not implemented")
   689  }
   690  
   691  func TestSwitchAcceptRoutineErrorCases(t *testing.T) {
   692  	sw := NewSwitch(cfg, errorTransport{ErrFilterTimeout{}})
   693  	assert.NotPanics(t, func() {
   694  		err := sw.Start()
   695  		require.NoError(t, err)
   696  		err = sw.Stop()
   697  		require.NoError(t, err)
   698  	})
   699  
   700  	sw = NewSwitch(cfg, errorTransport{ErrRejected{conn: nil, err: errors.New("filtered"), isFiltered: true}})
   701  	assert.NotPanics(t, func() {
   702  		err := sw.Start()
   703  		require.NoError(t, err)
   704  		err = sw.Stop()
   705  		require.NoError(t, err)
   706  	})
   707  	// TODO(melekes) check we remove our address from addrBook
   708  
   709  	sw = NewSwitch(cfg, errorTransport{ErrTransportClosed{}})
   710  	assert.NotPanics(t, func() {
   711  		err := sw.Start()
   712  		require.NoError(t, err)
   713  		err = sw.Stop()
   714  		require.NoError(t, err)
   715  	})
   716  }
   717  
   718  // mockReactor checks that InitPeer never called before RemovePeer. If that's
   719  // not true, InitCalledBeforeRemoveFinished will return true.
   720  type mockReactor struct {
   721  	*BaseReactor
   722  
   723  	// atomic
   724  	removePeerInProgress           uint32
   725  	initCalledBeforeRemoveFinished uint32
   726  }
   727  
   728  func (r *mockReactor) RemovePeer(peer Peer, reason interface{}) {
   729  	atomic.StoreUint32(&r.removePeerInProgress, 1)
   730  	defer atomic.StoreUint32(&r.removePeerInProgress, 0)
   731  	time.Sleep(100 * time.Millisecond)
   732  }
   733  
   734  func (r *mockReactor) InitPeer(peer Peer) Peer {
   735  	if atomic.LoadUint32(&r.removePeerInProgress) == 1 {
   736  		atomic.StoreUint32(&r.initCalledBeforeRemoveFinished, 1)
   737  	}
   738  
   739  	return peer
   740  }
   741  
   742  func (r *mockReactor) InitCalledBeforeRemoveFinished() bool {
   743  	return atomic.LoadUint32(&r.initCalledBeforeRemoveFinished) == 1
   744  }
   745  
   746  // see stopAndRemovePeer
   747  func TestSwitchInitPeerIsNotCalledBeforeRemovePeer(t *testing.T) {
   748  	// make reactor
   749  	reactor := &mockReactor{}
   750  	reactor.BaseReactor = NewBaseReactor("mockReactor", reactor)
   751  
   752  	// make switch
   753  	sw := MakeSwitch(cfg, 1, "testing", "123.123.123", func(i int, sw *Switch) *Switch {
   754  		sw.AddReactor("mock", reactor)
   755  		return sw
   756  	})
   757  	err := sw.Start()
   758  	require.NoError(t, err)
   759  	t.Cleanup(func() {
   760  		if err := sw.Stop(); err != nil {
   761  			t.Error(err)
   762  		}
   763  	})
   764  
   765  	// add peer
   766  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   767  	rp.Start()
   768  	defer rp.Stop()
   769  	_, err = rp.Dial(sw.NetAddress())
   770  	require.NoError(t, err)
   771  
   772  	// wait till the switch adds rp to the peer set, then stop the peer asynchronously
   773  	for {
   774  		time.Sleep(20 * time.Millisecond)
   775  		if peer := sw.Peers().Get(rp.ID()); peer != nil {
   776  			go sw.StopPeerForError(peer, "test")
   777  			break
   778  		}
   779  	}
   780  
   781  	// simulate peer reconnecting to us
   782  	_, err = rp.Dial(sw.NetAddress())
   783  	require.NoError(t, err)
   784  	// wait till the switch adds rp to the peer set
   785  	time.Sleep(50 * time.Millisecond)
   786  
   787  	// make sure reactor.RemovePeer is finished before InitPeer is called
   788  	assert.False(t, reactor.InitCalledBeforeRemoveFinished())
   789  }
   790  
   791  func BenchmarkSwitchBroadcast(b *testing.B) {
   792  	s1, s2 := MakeSwitchPair(b, func(i int, sw *Switch) *Switch {
   793  		// Make bar reactors of bar channels each
   794  		sw.AddReactor("foo", NewTestReactor([]*conn.ChannelDescriptor{
   795  			{ID: byte(0x00), Priority: 10},
   796  			{ID: byte(0x01), Priority: 10},
   797  		}, false))
   798  		sw.AddReactor("bar", NewTestReactor([]*conn.ChannelDescriptor{
   799  			{ID: byte(0x02), Priority: 10},
   800  			{ID: byte(0x03), Priority: 10},
   801  		}, false))
   802  		return sw
   803  	})
   804  
   805  	b.Cleanup(func() {
   806  		if err := s1.Stop(); err != nil {
   807  			b.Error(err)
   808  		}
   809  	})
   810  
   811  	b.Cleanup(func() {
   812  		if err := s2.Stop(); err != nil {
   813  			b.Error(err)
   814  		}
   815  	})
   816  
   817  	// Allow time for goroutines to boot up
   818  	time.Sleep(1 * time.Second)
   819  
   820  	b.ResetTimer()
   821  
   822  	numSuccess, numFailure := 0, 0
   823  
   824  	// Send random message from foo channel to another
   825  	for i := 0; i < b.N; i++ {
   826  		chID := byte(i % 4)
   827  		successChan := s1.Broadcast(chID, []byte("test data"))
   828  		for s := range successChan {
   829  			if s {
   830  				numSuccess++
   831  			} else {
   832  				numFailure++
   833  			}
   834  		}
   835  	}
   836  
   837  	b.Logf("success: %v, failure: %v", numSuccess, numFailure)
   838  }