github.com/vipernet-xyz/tm@v0.34.24/p2p/switch_test.go (about)

     1  package p2p
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"net"
     9  	"net/http"
    10  	"net/http/httptest"
    11  	"regexp"
    12  	"strconv"
    13  	"sync/atomic"
    14  	"testing"
    15  	"time"
    16  
    17  	"github.com/gogo/protobuf/proto"
    18  	"github.com/prometheus/client_golang/prometheus/promhttp"
    19  	"github.com/stretchr/testify/assert"
    20  	"github.com/stretchr/testify/require"
    21  
    22  	"github.com/vipernet-xyz/tm/config"
    23  	"github.com/vipernet-xyz/tm/crypto/ed25519"
    24  	"github.com/vipernet-xyz/tm/libs/log"
    25  	tmsync "github.com/vipernet-xyz/tm/libs/sync"
    26  	"github.com/vipernet-xyz/tm/p2p/conn"
    27  	p2pproto "github.com/vipernet-xyz/tm/proto/tendermint/p2p"
    28  )
    29  
    30  var cfg *config.P2PConfig
    31  
    32  func init() {
    33  	cfg = config.DefaultP2PConfig()
    34  	cfg.PexReactor = true
    35  	cfg.AllowDuplicateIP = true
    36  }
    37  
    38  type PeerMessage struct {
    39  	Contents proto.Message
    40  	Counter  int
    41  }
    42  
    43  type TestReactor struct {
    44  	BaseReactor
    45  
    46  	mtx          tmsync.Mutex
    47  	channels     []*conn.ChannelDescriptor
    48  	logMessages  bool
    49  	msgsCounter  int
    50  	msgsReceived map[byte][]PeerMessage
    51  }
    52  
    53  func NewTestReactor(channels []*conn.ChannelDescriptor, logMessages bool) *TestReactor {
    54  	tr := &TestReactor{
    55  		channels:     channels,
    56  		logMessages:  logMessages,
    57  		msgsReceived: make(map[byte][]PeerMessage),
    58  	}
    59  	tr.BaseReactor = *NewBaseReactor("TestReactor", tr)
    60  	tr.SetLogger(log.TestingLogger())
    61  	return tr
    62  }
    63  
    64  func (tr *TestReactor) GetChannels() []*conn.ChannelDescriptor {
    65  	return tr.channels
    66  }
    67  
    68  func (tr *TestReactor) AddPeer(peer Peer) {}
    69  
    70  func (tr *TestReactor) RemovePeer(peer Peer, reason interface{}) {}
    71  
    72  func (tr *TestReactor) ReceiveEnvelope(e Envelope) {
    73  	if tr.logMessages {
    74  		tr.mtx.Lock()
    75  		defer tr.mtx.Unlock()
    76  		// fmt.Printf("Received: %X, %X\n", e.ChannelID, e.Message)
    77  		tr.msgsReceived[e.ChannelID] = append(tr.msgsReceived[e.ChannelID], PeerMessage{Contents: e.Message, Counter: tr.msgsCounter})
    78  		tr.msgsCounter++
    79  	}
    80  }
    81  
    82  func (tr *TestReactor) Receive(chID byte, peer Peer, msgBytes []byte) {
    83  	msg := &p2pproto.Message{}
    84  	err := proto.Unmarshal(msgBytes, msg)
    85  	if err != nil {
    86  		panic(err)
    87  	}
    88  	um, err := msg.Unwrap()
    89  	if err != nil {
    90  		panic(err)
    91  	}
    92  
    93  	tr.ReceiveEnvelope(Envelope{
    94  		ChannelID: chID,
    95  		Src:       peer,
    96  		Message:   um,
    97  	})
    98  }
    99  
   100  func (tr *TestReactor) getMsgs(chID byte) []PeerMessage {
   101  	tr.mtx.Lock()
   102  	defer tr.mtx.Unlock()
   103  	return tr.msgsReceived[chID]
   104  }
   105  
   106  //-----------------------------------------------------------------------------
   107  
   108  // convenience method for creating two switches connected to each other.
   109  // XXX: note this uses net.Pipe and not a proper TCP conn
   110  func MakeSwitchPair(t testing.TB, initSwitch func(int, *Switch) *Switch) (*Switch, *Switch) {
   111  	// Create two switches that will be interconnected.
   112  	switches := MakeConnectedSwitches(cfg, 2, initSwitch, Connect2Switches)
   113  	return switches[0], switches[1]
   114  }
   115  
   116  func initSwitchFunc(i int, sw *Switch) *Switch {
   117  	sw.SetAddrBook(&AddrBookMock{
   118  		Addrs:    make(map[string]struct{}),
   119  		OurAddrs: make(map[string]struct{}),
   120  	})
   121  
   122  	// Make two reactors of two channels each
   123  	sw.AddReactor("foo", NewTestReactor([]*conn.ChannelDescriptor{
   124  		{ID: byte(0x00), Priority: 10, MessageType: &p2pproto.Message{}},
   125  		{ID: byte(0x01), Priority: 10, MessageType: &p2pproto.Message{}},
   126  	}, true))
   127  	sw.AddReactor("bar", NewTestReactor([]*conn.ChannelDescriptor{
   128  		{ID: byte(0x02), Priority: 10, MessageType: &p2pproto.Message{}},
   129  		{ID: byte(0x03), Priority: 10, MessageType: &p2pproto.Message{}},
   130  	}, true))
   131  
   132  	return sw
   133  }
   134  
   135  func TestSwitches(t *testing.T) {
   136  	s1, s2 := MakeSwitchPair(t, initSwitchFunc)
   137  	t.Cleanup(func() {
   138  		if err := s1.Stop(); err != nil {
   139  			t.Error(err)
   140  		}
   141  	})
   142  	t.Cleanup(func() {
   143  		if err := s2.Stop(); err != nil {
   144  			t.Error(err)
   145  		}
   146  	})
   147  
   148  	if s1.Peers().Size() != 1 {
   149  		t.Errorf("expected exactly 1 peer in s1, got %v", s1.Peers().Size())
   150  	}
   151  	if s2.Peers().Size() != 1 {
   152  		t.Errorf("expected exactly 1 peer in s2, got %v", s2.Peers().Size())
   153  	}
   154  
   155  	// Lets send some messages
   156  	ch0Msg := &p2pproto.PexAddrs{
   157  		Addrs: []p2pproto.NetAddress{
   158  			{
   159  				ID: "1",
   160  			},
   161  		},
   162  	}
   163  	ch1Msg := &p2pproto.PexAddrs{
   164  		Addrs: []p2pproto.NetAddress{
   165  			{
   166  				ID: "1",
   167  			},
   168  		},
   169  	}
   170  	ch2Msg := &p2pproto.PexAddrs{
   171  		Addrs: []p2pproto.NetAddress{
   172  			{
   173  				ID: "2",
   174  			},
   175  		},
   176  	}
   177  	s1.BroadcastEnvelope(Envelope{ChannelID: byte(0x00), Message: ch0Msg})
   178  	s1.BroadcastEnvelope(Envelope{ChannelID: byte(0x01), Message: ch1Msg})
   179  	s1.BroadcastEnvelope(Envelope{ChannelID: byte(0x02), Message: ch2Msg})
   180  	assertMsgReceivedWithTimeout(t,
   181  		ch0Msg,
   182  		byte(0x00),
   183  		s2.Reactor("foo").(*TestReactor), 200*time.Millisecond, 5*time.Second)
   184  	assertMsgReceivedWithTimeout(t,
   185  		ch1Msg,
   186  		byte(0x01),
   187  		s2.Reactor("foo").(*TestReactor), 200*time.Millisecond, 5*time.Second)
   188  	assertMsgReceivedWithTimeout(t,
   189  		ch2Msg,
   190  		byte(0x02),
   191  		s2.Reactor("bar").(*TestReactor), 200*time.Millisecond, 5*time.Second)
   192  }
   193  
   194  func assertMsgReceivedWithTimeout(
   195  	t *testing.T,
   196  	msg proto.Message,
   197  	channel byte,
   198  	reactor *TestReactor,
   199  	checkPeriod,
   200  	timeout time.Duration,
   201  ) {
   202  	ticker := time.NewTicker(checkPeriod)
   203  	for {
   204  		select {
   205  		case <-ticker.C:
   206  			msgs := reactor.getMsgs(channel)
   207  			expectedBytes, err := proto.Marshal(msgs[0].Contents)
   208  			require.NoError(t, err)
   209  			gotBytes, err := proto.Marshal(msg)
   210  			require.NoError(t, err)
   211  			if len(msgs) > 0 {
   212  				if !bytes.Equal(expectedBytes, gotBytes) {
   213  					t.Fatalf("Unexpected message bytes. Wanted: %X, Got: %X", msg, msgs[0].Counter)
   214  				}
   215  				return
   216  			}
   217  
   218  		case <-time.After(timeout):
   219  			t.Fatalf("Expected to have received 1 message in channel #%v, got zero", channel)
   220  		}
   221  	}
   222  }
   223  
   224  func TestSwitchFiltersOutItself(t *testing.T) {
   225  	s1 := MakeSwitch(cfg, 1, "127.0.0.1", "123.123.123", initSwitchFunc)
   226  
   227  	// simulate s1 having a public IP by creating a remote peer with the same ID
   228  	rp := &remotePeer{PrivKey: s1.nodeKey.PrivKey, Config: cfg}
   229  	rp.Start()
   230  
   231  	// addr should be rejected in addPeer based on the same ID
   232  	err := s1.DialPeerWithAddress(rp.Addr())
   233  	if assert.Error(t, err) {
   234  		if err, ok := err.(ErrRejected); ok {
   235  			if !err.IsSelf() {
   236  				t.Errorf("expected self to be rejected")
   237  			}
   238  		} else {
   239  			t.Errorf("expected ErrRejected")
   240  		}
   241  	}
   242  
   243  	assert.True(t, s1.addrBook.OurAddress(rp.Addr()))
   244  	assert.False(t, s1.addrBook.HasAddress(rp.Addr()))
   245  
   246  	rp.Stop()
   247  
   248  	assertNoPeersAfterTimeout(t, s1, 100*time.Millisecond)
   249  }
   250  
   251  func TestSwitchPeerFilter(t *testing.T) {
   252  	var (
   253  		filters = []PeerFilterFunc{
   254  			func(_ IPeerSet, _ Peer) error { return nil },
   255  			func(_ IPeerSet, _ Peer) error { return fmt.Errorf("denied") },
   256  			func(_ IPeerSet, _ Peer) error { return nil },
   257  		}
   258  		sw = MakeSwitch(
   259  			cfg,
   260  			1,
   261  			"testing",
   262  			"123.123.123",
   263  			initSwitchFunc,
   264  			SwitchPeerFilters(filters...),
   265  		)
   266  	)
   267  	err := sw.Start()
   268  	require.NoError(t, err)
   269  	t.Cleanup(func() {
   270  		if err := sw.Stop(); err != nil {
   271  			t.Error(err)
   272  		}
   273  	})
   274  
   275  	// simulate remote peer
   276  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   277  	rp.Start()
   278  	t.Cleanup(rp.Stop)
   279  
   280  	p, err := sw.transport.Dial(*rp.Addr(), peerConfig{
   281  		chDescs:      sw.chDescs,
   282  		onPeerError:  sw.StopPeerForError,
   283  		isPersistent: sw.IsPeerPersistent,
   284  		reactorsByCh: sw.reactorsByCh,
   285  	})
   286  	if err != nil {
   287  		t.Fatal(err)
   288  	}
   289  
   290  	err = sw.addPeer(p)
   291  	if err, ok := err.(ErrRejected); ok {
   292  		if !err.IsFiltered() {
   293  			t.Errorf("expected peer to be filtered")
   294  		}
   295  	} else {
   296  		t.Errorf("expected ErrRejected")
   297  	}
   298  }
   299  
   300  func TestSwitchPeerFilterTimeout(t *testing.T) {
   301  	var (
   302  		filters = []PeerFilterFunc{
   303  			func(_ IPeerSet, _ Peer) error {
   304  				time.Sleep(10 * time.Millisecond)
   305  				return nil
   306  			},
   307  		}
   308  		sw = MakeSwitch(
   309  			cfg,
   310  			1,
   311  			"testing",
   312  			"123.123.123",
   313  			initSwitchFunc,
   314  			SwitchFilterTimeout(5*time.Millisecond),
   315  			SwitchPeerFilters(filters...),
   316  		)
   317  	)
   318  	err := sw.Start()
   319  	require.NoError(t, err)
   320  	t.Cleanup(func() {
   321  		if err := sw.Stop(); err != nil {
   322  			t.Log(err)
   323  		}
   324  	})
   325  
   326  	// simulate remote peer
   327  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   328  	rp.Start()
   329  	defer rp.Stop()
   330  
   331  	p, err := sw.transport.Dial(*rp.Addr(), peerConfig{
   332  		chDescs:      sw.chDescs,
   333  		onPeerError:  sw.StopPeerForError,
   334  		isPersistent: sw.IsPeerPersistent,
   335  		reactorsByCh: sw.reactorsByCh,
   336  	})
   337  	if err != nil {
   338  		t.Fatal(err)
   339  	}
   340  
   341  	err = sw.addPeer(p)
   342  	if _, ok := err.(ErrFilterTimeout); !ok {
   343  		t.Errorf("expected ErrFilterTimeout")
   344  	}
   345  }
   346  
   347  func TestSwitchPeerFilterDuplicate(t *testing.T) {
   348  	sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc)
   349  	err := sw.Start()
   350  	require.NoError(t, err)
   351  	t.Cleanup(func() {
   352  		if err := sw.Stop(); err != nil {
   353  			t.Error(err)
   354  		}
   355  	})
   356  
   357  	// simulate remote peer
   358  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   359  	rp.Start()
   360  	defer rp.Stop()
   361  
   362  	p, err := sw.transport.Dial(*rp.Addr(), peerConfig{
   363  		chDescs:      sw.chDescs,
   364  		onPeerError:  sw.StopPeerForError,
   365  		isPersistent: sw.IsPeerPersistent,
   366  		reactorsByCh: sw.reactorsByCh,
   367  	})
   368  	if err != nil {
   369  		t.Fatal(err)
   370  	}
   371  
   372  	if err := sw.addPeer(p); err != nil {
   373  		t.Fatal(err)
   374  	}
   375  
   376  	err = sw.addPeer(p)
   377  	if errRej, ok := err.(ErrRejected); ok {
   378  		if !errRej.IsDuplicate() {
   379  			t.Errorf("expected peer to be duplicate. got %v", errRej)
   380  		}
   381  	} else {
   382  		t.Errorf("expected ErrRejected, got %v", err)
   383  	}
   384  }
   385  
   386  func assertNoPeersAfterTimeout(t *testing.T, sw *Switch, timeout time.Duration) {
   387  	time.Sleep(timeout)
   388  	if sw.Peers().Size() != 0 {
   389  		t.Fatalf("Expected %v to not connect to some peers, got %d", sw, sw.Peers().Size())
   390  	}
   391  }
   392  
   393  func TestSwitchStopsNonPersistentPeerOnError(t *testing.T) {
   394  	assert, require := assert.New(t), require.New(t)
   395  
   396  	sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc)
   397  	err := sw.Start()
   398  	if err != nil {
   399  		t.Error(err)
   400  	}
   401  	t.Cleanup(func() {
   402  		if err := sw.Stop(); err != nil {
   403  			t.Error(err)
   404  		}
   405  	})
   406  
   407  	// simulate remote peer
   408  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   409  	rp.Start()
   410  	defer rp.Stop()
   411  
   412  	p, err := sw.transport.Dial(*rp.Addr(), peerConfig{
   413  		chDescs:      sw.chDescs,
   414  		onPeerError:  sw.StopPeerForError,
   415  		isPersistent: sw.IsPeerPersistent,
   416  		reactorsByCh: sw.reactorsByCh,
   417  	})
   418  	require.Nil(err)
   419  
   420  	err = sw.addPeer(p)
   421  	require.Nil(err)
   422  
   423  	require.NotNil(sw.Peers().Get(rp.ID()))
   424  
   425  	// simulate failure by closing connection
   426  	err = p.(*peer).CloseConn()
   427  	require.NoError(err)
   428  
   429  	assertNoPeersAfterTimeout(t, sw, 100*time.Millisecond)
   430  	assert.False(p.IsRunning())
   431  }
   432  
   433  func TestSwitchStopPeerForError(t *testing.T) {
   434  	s := httptest.NewServer(promhttp.Handler())
   435  	defer s.Close()
   436  
   437  	scrapeMetrics := func() string {
   438  		resp, err := http.Get(s.URL)
   439  		require.NoError(t, err)
   440  		defer resp.Body.Close()
   441  		buf, _ := io.ReadAll(resp.Body)
   442  		return string(buf)
   443  	}
   444  
   445  	namespace, subsystem, name := config.TestInstrumentationConfig().Namespace, MetricsSubsystem, "peers"
   446  	re := regexp.MustCompile(namespace + `_` + subsystem + `_` + name + ` ([0-9\.]+)`)
   447  	peersMetricValue := func() float64 {
   448  		matches := re.FindStringSubmatch(scrapeMetrics())
   449  		f, _ := strconv.ParseFloat(matches[1], 64)
   450  		return f
   451  	}
   452  
   453  	p2pMetrics := PrometheusMetrics(namespace)
   454  
   455  	// make two connected switches
   456  	sw1, sw2 := MakeSwitchPair(t, func(i int, sw *Switch) *Switch {
   457  		// set metrics on sw1
   458  		if i == 0 {
   459  			opt := WithMetrics(p2pMetrics)
   460  			opt(sw)
   461  		}
   462  		return initSwitchFunc(i, sw)
   463  	})
   464  
   465  	assert.Equal(t, len(sw1.Peers().List()), 1)
   466  	assert.EqualValues(t, 1, peersMetricValue())
   467  
   468  	// send messages to the peer from sw1
   469  	p := sw1.Peers().List()[0]
   470  	SendEnvelopeShim(p, Envelope{
   471  		ChannelID: 0x1,
   472  		Message:   &p2pproto.Message{},
   473  	}, sw1.Logger)
   474  
   475  	// stop sw2. this should cause the p to fail,
   476  	// which results in calling StopPeerForError internally
   477  	t.Cleanup(func() {
   478  		if err := sw2.Stop(); err != nil {
   479  			t.Error(err)
   480  		}
   481  	})
   482  
   483  	// now call StopPeerForError explicitly, eg. from a reactor
   484  	sw1.StopPeerForError(p, fmt.Errorf("some err"))
   485  
   486  	assert.Equal(t, len(sw1.Peers().List()), 0)
   487  	assert.EqualValues(t, 0, peersMetricValue())
   488  }
   489  
   490  func TestSwitchReconnectsToOutboundPersistentPeer(t *testing.T) {
   491  	sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc)
   492  	err := sw.Start()
   493  	require.NoError(t, err)
   494  	t.Cleanup(func() {
   495  		if err := sw.Stop(); err != nil {
   496  			t.Error(err)
   497  		}
   498  	})
   499  
   500  	// 1. simulate failure by closing connection
   501  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   502  	rp.Start()
   503  	defer rp.Stop()
   504  
   505  	err = sw.AddPersistentPeers([]string{rp.Addr().String()})
   506  	require.NoError(t, err)
   507  
   508  	err = sw.DialPeerWithAddress(rp.Addr())
   509  	require.Nil(t, err)
   510  	require.NotNil(t, sw.Peers().Get(rp.ID()))
   511  
   512  	p := sw.Peers().List()[0]
   513  	err = p.(*peer).CloseConn()
   514  	require.NoError(t, err)
   515  
   516  	waitUntilSwitchHasAtLeastNPeers(sw, 1)
   517  	assert.False(t, p.IsRunning())        // old peer instance
   518  	assert.Equal(t, 1, sw.Peers().Size()) // new peer instance
   519  
   520  	// 2. simulate first time dial failure
   521  	rp = &remotePeer{
   522  		PrivKey: ed25519.GenPrivKey(),
   523  		Config:  cfg,
   524  		// Use different interface to prevent duplicate IP filter, this will break
   525  		// beyond two peers.
   526  		listenAddr: "127.0.0.1:0",
   527  	}
   528  	rp.Start()
   529  	defer rp.Stop()
   530  
   531  	conf := config.DefaultP2PConfig()
   532  	conf.TestDialFail = true // will trigger a reconnect
   533  	err = sw.addOutboundPeerWithConfig(rp.Addr(), conf)
   534  	require.NotNil(t, err)
   535  	// DialPeerWithAddres - sw.peerConfig resets the dialer
   536  	waitUntilSwitchHasAtLeastNPeers(sw, 2)
   537  	assert.Equal(t, 2, sw.Peers().Size())
   538  }
   539  
   540  func TestSwitchReconnectsToInboundPersistentPeer(t *testing.T) {
   541  	sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc)
   542  	err := sw.Start()
   543  	require.NoError(t, err)
   544  	t.Cleanup(func() {
   545  		if err := sw.Stop(); err != nil {
   546  			t.Error(err)
   547  		}
   548  	})
   549  
   550  	// 1. simulate failure by closing the connection
   551  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   552  	rp.Start()
   553  	defer rp.Stop()
   554  
   555  	err = sw.AddPersistentPeers([]string{rp.Addr().String()})
   556  	require.NoError(t, err)
   557  
   558  	conn, err := rp.Dial(sw.NetAddress())
   559  	require.NoError(t, err)
   560  	time.Sleep(50 * time.Millisecond)
   561  	require.NotNil(t, sw.Peers().Get(rp.ID()))
   562  
   563  	conn.Close()
   564  
   565  	waitUntilSwitchHasAtLeastNPeers(sw, 1)
   566  	assert.Equal(t, 1, sw.Peers().Size())
   567  }
   568  
   569  func TestSwitchDialPeersAsync(t *testing.T) {
   570  	if testing.Short() {
   571  		return
   572  	}
   573  
   574  	sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc)
   575  	err := sw.Start()
   576  	require.NoError(t, err)
   577  	t.Cleanup(func() {
   578  		if err := sw.Stop(); err != nil {
   579  			t.Error(err)
   580  		}
   581  	})
   582  
   583  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   584  	rp.Start()
   585  	defer rp.Stop()
   586  
   587  	err = sw.DialPeersAsync([]string{rp.Addr().String()})
   588  	require.NoError(t, err)
   589  	time.Sleep(dialRandomizerIntervalMilliseconds * time.Millisecond)
   590  	require.NotNil(t, sw.Peers().Get(rp.ID()))
   591  }
   592  
   593  func waitUntilSwitchHasAtLeastNPeers(sw *Switch, n int) {
   594  	for i := 0; i < 20; i++ {
   595  		time.Sleep(250 * time.Millisecond)
   596  		has := sw.Peers().Size()
   597  		if has >= n {
   598  			break
   599  		}
   600  	}
   601  }
   602  
   603  func TestSwitchFullConnectivity(t *testing.T) {
   604  	switches := MakeConnectedSwitches(cfg, 3, initSwitchFunc, Connect2Switches)
   605  	defer func() {
   606  		for _, sw := range switches {
   607  			sw := sw
   608  			t.Cleanup(func() {
   609  				if err := sw.Stop(); err != nil {
   610  					t.Error(err)
   611  				}
   612  			})
   613  		}
   614  	}()
   615  
   616  	for i, sw := range switches {
   617  		if sw.Peers().Size() != 2 {
   618  			t.Fatalf("Expected each switch to be connected to 2 other, but %d switch only connected to %d", sw.Peers().Size(), i)
   619  		}
   620  	}
   621  }
   622  
   623  func TestSwitchAcceptRoutine(t *testing.T) {
   624  	cfg.MaxNumInboundPeers = 5
   625  
   626  	// Create some unconditional peers.
   627  	const unconditionalPeersNum = 2
   628  	var (
   629  		unconditionalPeers   = make([]*remotePeer, unconditionalPeersNum)
   630  		unconditionalPeerIDs = make([]string, unconditionalPeersNum)
   631  	)
   632  	for i := 0; i < unconditionalPeersNum; i++ {
   633  		peer := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   634  		peer.Start()
   635  		unconditionalPeers[i] = peer
   636  		unconditionalPeerIDs[i] = string(peer.ID())
   637  	}
   638  
   639  	// make switch
   640  	sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc)
   641  	err := sw.AddUnconditionalPeerIDs(unconditionalPeerIDs)
   642  	require.NoError(t, err)
   643  	err = sw.Start()
   644  	require.NoError(t, err)
   645  	t.Cleanup(func() {
   646  		err := sw.Stop()
   647  		require.NoError(t, err)
   648  	})
   649  
   650  	// 0. check there are no peers
   651  	assert.Equal(t, 0, sw.Peers().Size())
   652  
   653  	// 1. check we connect up to MaxNumInboundPeers
   654  	peers := make([]*remotePeer, 0)
   655  	for i := 0; i < cfg.MaxNumInboundPeers; i++ {
   656  		peer := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   657  		peers = append(peers, peer)
   658  		peer.Start()
   659  		c, err := peer.Dial(sw.NetAddress())
   660  		require.NoError(t, err)
   661  		// spawn a reading routine to prevent connection from closing
   662  		go func(c net.Conn) {
   663  			for {
   664  				one := make([]byte, 1)
   665  				_, err := c.Read(one)
   666  				if err != nil {
   667  					return
   668  				}
   669  			}
   670  		}(c)
   671  	}
   672  	time.Sleep(100 * time.Millisecond)
   673  	assert.Equal(t, cfg.MaxNumInboundPeers, sw.Peers().Size())
   674  
   675  	// 2. check we close new connections if we already have MaxNumInboundPeers peers
   676  	peer := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   677  	peer.Start()
   678  	conn, err := peer.Dial(sw.NetAddress())
   679  	require.NoError(t, err)
   680  	// check conn is closed
   681  	one := make([]byte, 1)
   682  	_ = conn.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
   683  	_, err = conn.Read(one)
   684  	assert.Error(t, err)
   685  	assert.Equal(t, cfg.MaxNumInboundPeers, sw.Peers().Size())
   686  	peer.Stop()
   687  
   688  	// 3. check we connect to unconditional peers despite the limit.
   689  	for _, peer := range unconditionalPeers {
   690  		c, err := peer.Dial(sw.NetAddress())
   691  		require.NoError(t, err)
   692  		// spawn a reading routine to prevent connection from closing
   693  		go func(c net.Conn) {
   694  			for {
   695  				one := make([]byte, 1)
   696  				_, err := c.Read(one)
   697  				if err != nil {
   698  					return
   699  				}
   700  			}
   701  		}(c)
   702  	}
   703  	time.Sleep(10 * time.Millisecond)
   704  	assert.Equal(t, cfg.MaxNumInboundPeers+unconditionalPeersNum, sw.Peers().Size())
   705  
   706  	for _, peer := range peers {
   707  		peer.Stop()
   708  	}
   709  	for _, peer := range unconditionalPeers {
   710  		peer.Stop()
   711  	}
   712  }
   713  
   714  type errorTransport struct {
   715  	acceptErr error
   716  }
   717  
   718  func (et errorTransport) NetAddress() NetAddress {
   719  	panic("not implemented")
   720  }
   721  
   722  func (et errorTransport) Accept(c peerConfig) (Peer, error) {
   723  	return nil, et.acceptErr
   724  }
   725  
   726  func (errorTransport) Dial(NetAddress, peerConfig) (Peer, error) {
   727  	panic("not implemented")
   728  }
   729  
   730  func (errorTransport) Cleanup(Peer) {
   731  	panic("not implemented")
   732  }
   733  
   734  func TestSwitchAcceptRoutineErrorCases(t *testing.T) {
   735  	sw := NewSwitch(cfg, errorTransport{ErrFilterTimeout{}})
   736  	assert.NotPanics(t, func() {
   737  		err := sw.Start()
   738  		require.NoError(t, err)
   739  		err = sw.Stop()
   740  		require.NoError(t, err)
   741  	})
   742  
   743  	sw = NewSwitch(cfg, errorTransport{ErrRejected{conn: nil, err: errors.New("filtered"), isFiltered: true}})
   744  	assert.NotPanics(t, func() {
   745  		err := sw.Start()
   746  		require.NoError(t, err)
   747  		err = sw.Stop()
   748  		require.NoError(t, err)
   749  	})
   750  	// TODO(melekes) check we remove our address from addrBook
   751  
   752  	sw = NewSwitch(cfg, errorTransport{ErrTransportClosed{}})
   753  	assert.NotPanics(t, func() {
   754  		err := sw.Start()
   755  		require.NoError(t, err)
   756  		err = sw.Stop()
   757  		require.NoError(t, err)
   758  	})
   759  }
   760  
   761  // mockReactor checks that InitPeer never called before RemovePeer. If that's
   762  // not true, InitCalledBeforeRemoveFinished will return true.
   763  type mockReactor struct {
   764  	*BaseReactor
   765  
   766  	// atomic
   767  	removePeerInProgress           uint32
   768  	initCalledBeforeRemoveFinished uint32
   769  }
   770  
   771  func (r *mockReactor) RemovePeer(peer Peer, reason interface{}) {
   772  	atomic.StoreUint32(&r.removePeerInProgress, 1)
   773  	defer atomic.StoreUint32(&r.removePeerInProgress, 0)
   774  	time.Sleep(100 * time.Millisecond)
   775  }
   776  
   777  func (r *mockReactor) InitPeer(peer Peer) Peer {
   778  	if atomic.LoadUint32(&r.removePeerInProgress) == 1 {
   779  		atomic.StoreUint32(&r.initCalledBeforeRemoveFinished, 1)
   780  	}
   781  
   782  	return peer
   783  }
   784  
   785  func (r *mockReactor) InitCalledBeforeRemoveFinished() bool {
   786  	return atomic.LoadUint32(&r.initCalledBeforeRemoveFinished) == 1
   787  }
   788  
   789  // see stopAndRemovePeer
   790  func TestSwitchInitPeerIsNotCalledBeforeRemovePeer(t *testing.T) {
   791  	// make reactor
   792  	reactor := &mockReactor{}
   793  	reactor.BaseReactor = NewBaseReactor("mockReactor", reactor)
   794  
   795  	// make switch
   796  	sw := MakeSwitch(cfg, 1, "testing", "123.123.123", func(i int, sw *Switch) *Switch {
   797  		sw.AddReactor("mock", reactor)
   798  		return sw
   799  	})
   800  	err := sw.Start()
   801  	require.NoError(t, err)
   802  	t.Cleanup(func() {
   803  		if err := sw.Stop(); err != nil {
   804  			t.Error(err)
   805  		}
   806  	})
   807  
   808  	// add peer
   809  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   810  	rp.Start()
   811  	defer rp.Stop()
   812  	_, err = rp.Dial(sw.NetAddress())
   813  	require.NoError(t, err)
   814  
   815  	// wait till the switch adds rp to the peer set, then stop the peer asynchronously
   816  	for {
   817  		time.Sleep(20 * time.Millisecond)
   818  		if peer := sw.Peers().Get(rp.ID()); peer != nil {
   819  			go sw.StopPeerForError(peer, "test")
   820  			break
   821  		}
   822  	}
   823  
   824  	// simulate peer reconnecting to us
   825  	_, err = rp.Dial(sw.NetAddress())
   826  	require.NoError(t, err)
   827  	// wait till the switch adds rp to the peer set
   828  	time.Sleep(50 * time.Millisecond)
   829  
   830  	// make sure reactor.RemovePeer is finished before InitPeer is called
   831  	assert.False(t, reactor.InitCalledBeforeRemoveFinished())
   832  }
   833  
   834  func BenchmarkSwitchBroadcast(b *testing.B) {
   835  	s1, s2 := MakeSwitchPair(b, func(i int, sw *Switch) *Switch {
   836  		// Make bar reactors of bar channels each
   837  		sw.AddReactor("foo", NewTestReactor([]*conn.ChannelDescriptor{
   838  			{ID: byte(0x00), Priority: 10},
   839  			{ID: byte(0x01), Priority: 10},
   840  		}, false))
   841  		sw.AddReactor("bar", NewTestReactor([]*conn.ChannelDescriptor{
   842  			{ID: byte(0x02), Priority: 10},
   843  			{ID: byte(0x03), Priority: 10},
   844  		}, false))
   845  		return sw
   846  	})
   847  
   848  	b.Cleanup(func() {
   849  		if err := s1.Stop(); err != nil {
   850  			b.Error(err)
   851  		}
   852  	})
   853  
   854  	b.Cleanup(func() {
   855  		if err := s2.Stop(); err != nil {
   856  			b.Error(err)
   857  		}
   858  	})
   859  
   860  	// Allow time for goroutines to boot up
   861  	time.Sleep(1 * time.Second)
   862  
   863  	b.ResetTimer()
   864  
   865  	numSuccess, numFailure := 0, 0
   866  
   867  	// Send random message from foo channel to another
   868  	for i := 0; i < b.N; i++ {
   869  		chID := byte(i % 4)
   870  		successChan := s1.BroadcastEnvelope(Envelope{ChannelID: chID})
   871  		for s := range successChan {
   872  			if s {
   873  				numSuccess++
   874  			} else {
   875  				numFailure++
   876  			}
   877  		}
   878  	}
   879  
   880  	b.Logf("success: %v, failure: %v", numSuccess, numFailure)
   881  }
   882  
   883  func TestSwitchRemovalErr(t *testing.T) {
   884  	sw1, sw2 := MakeSwitchPair(t, func(i int, sw *Switch) *Switch {
   885  		return initSwitchFunc(i, sw)
   886  	})
   887  	assert.Equal(t, len(sw1.Peers().List()), 1)
   888  	p := sw1.Peers().List()[0]
   889  
   890  	sw2.StopPeerForError(p, fmt.Errorf("peer should error"))
   891  
   892  	assert.Equal(t, sw2.peers.Add(p).Error(), ErrPeerRemoval{}.Error())
   893  }