github.com/evdatsion/aphelion-dpos-bft@v0.32.1/p2p/switch_test.go (about)

     1  package p2p
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"io/ioutil"
     9  	"net"
    10  	"net/http"
    11  	"net/http/httptest"
    12  	"regexp"
    13  	"strconv"
    14  	"sync"
    15  	"sync/atomic"
    16  	"testing"
    17  	"time"
    18  
    19  	stdprometheus "github.com/prometheus/client_golang/prometheus"
    20  	"github.com/stretchr/testify/assert"
    21  	"github.com/stretchr/testify/require"
    22  
    23  	"github.com/evdatsion/aphelion-dpos-bft/config"
    24  	"github.com/evdatsion/aphelion-dpos-bft/crypto/ed25519"
    25  	"github.com/evdatsion/aphelion-dpos-bft/libs/log"
    26  	"github.com/evdatsion/aphelion-dpos-bft/p2p/conn"
    27  )
    28  
    29  var (
    30  	cfg *config.P2PConfig
    31  )
    32  
    33  func init() {
    34  	cfg = config.DefaultP2PConfig()
    35  	cfg.PexReactor = true
    36  	cfg.AllowDuplicateIP = true
    37  }
    38  
    39  type PeerMessage struct {
    40  	PeerID  ID
    41  	Bytes   []byte
    42  	Counter int
    43  }
    44  
    45  type TestReactor struct {
    46  	BaseReactor
    47  
    48  	mtx          sync.Mutex
    49  	channels     []*conn.ChannelDescriptor
    50  	logMessages  bool
    51  	msgsCounter  int
    52  	msgsReceived map[byte][]PeerMessage
    53  }
    54  
    55  func NewTestReactor(channels []*conn.ChannelDescriptor, logMessages bool) *TestReactor {
    56  	tr := &TestReactor{
    57  		channels:     channels,
    58  		logMessages:  logMessages,
    59  		msgsReceived: make(map[byte][]PeerMessage),
    60  	}
    61  	tr.BaseReactor = *NewBaseReactor("TestReactor", tr)
    62  	tr.SetLogger(log.TestingLogger())
    63  	return tr
    64  }
    65  
    66  func (tr *TestReactor) GetChannels() []*conn.ChannelDescriptor {
    67  	return tr.channels
    68  }
    69  
    70  func (tr *TestReactor) AddPeer(peer Peer) {}
    71  
    72  func (tr *TestReactor) RemovePeer(peer Peer, reason interface{}) {}
    73  
    74  func (tr *TestReactor) Receive(chID byte, peer Peer, msgBytes []byte) {
    75  	if tr.logMessages {
    76  		tr.mtx.Lock()
    77  		defer tr.mtx.Unlock()
    78  		//fmt.Printf("Received: %X, %X\n", chID, msgBytes)
    79  		tr.msgsReceived[chID] = append(tr.msgsReceived[chID], PeerMessage{peer.ID(), msgBytes, tr.msgsCounter})
    80  		tr.msgsCounter++
    81  	}
    82  }
    83  
    84  func (tr *TestReactor) getMsgs(chID byte) []PeerMessage {
    85  	tr.mtx.Lock()
    86  	defer tr.mtx.Unlock()
    87  	return tr.msgsReceived[chID]
    88  }
    89  
    90  //-----------------------------------------------------------------------------
    91  
    92  // convenience method for creating two switches connected to each other.
    93  // XXX: note this uses net.Pipe and not a proper TCP conn
    94  func MakeSwitchPair(t testing.TB, initSwitch func(int, *Switch) *Switch) (*Switch, *Switch) {
    95  	// Create two switches that will be interconnected.
    96  	switches := MakeConnectedSwitches(cfg, 2, initSwitch, Connect2Switches)
    97  	return switches[0], switches[1]
    98  }
    99  
   100  func initSwitchFunc(i int, sw *Switch) *Switch {
   101  	sw.SetAddrBook(&addrBookMock{
   102  		addrs:    make(map[string]struct{}),
   103  		ourAddrs: make(map[string]struct{})})
   104  
   105  	// Make two reactors of two channels each
   106  	sw.AddReactor("foo", NewTestReactor([]*conn.ChannelDescriptor{
   107  		{ID: byte(0x00), Priority: 10},
   108  		{ID: byte(0x01), Priority: 10},
   109  	}, true))
   110  	sw.AddReactor("bar", NewTestReactor([]*conn.ChannelDescriptor{
   111  		{ID: byte(0x02), Priority: 10},
   112  		{ID: byte(0x03), Priority: 10},
   113  	}, true))
   114  
   115  	return sw
   116  }
   117  
   118  func TestSwitches(t *testing.T) {
   119  	s1, s2 := MakeSwitchPair(t, initSwitchFunc)
   120  	defer s1.Stop()
   121  	defer s2.Stop()
   122  
   123  	if s1.Peers().Size() != 1 {
   124  		t.Errorf("Expected exactly 1 peer in s1, got %v", s1.Peers().Size())
   125  	}
   126  	if s2.Peers().Size() != 1 {
   127  		t.Errorf("Expected exactly 1 peer in s2, got %v", s2.Peers().Size())
   128  	}
   129  
   130  	// Lets send some messages
   131  	ch0Msg := []byte("channel zero")
   132  	ch1Msg := []byte("channel foo")
   133  	ch2Msg := []byte("channel bar")
   134  
   135  	s1.Broadcast(byte(0x00), ch0Msg)
   136  	s1.Broadcast(byte(0x01), ch1Msg)
   137  	s1.Broadcast(byte(0x02), ch2Msg)
   138  
   139  	assertMsgReceivedWithTimeout(t, ch0Msg, byte(0x00), s2.Reactor("foo").(*TestReactor), 10*time.Millisecond, 5*time.Second)
   140  	assertMsgReceivedWithTimeout(t, ch1Msg, byte(0x01), s2.Reactor("foo").(*TestReactor), 10*time.Millisecond, 5*time.Second)
   141  	assertMsgReceivedWithTimeout(t, ch2Msg, byte(0x02), s2.Reactor("bar").(*TestReactor), 10*time.Millisecond, 5*time.Second)
   142  }
   143  
   144  func assertMsgReceivedWithTimeout(t *testing.T, msgBytes []byte, channel byte, reactor *TestReactor, checkPeriod, timeout time.Duration) {
   145  	ticker := time.NewTicker(checkPeriod)
   146  	for {
   147  		select {
   148  		case <-ticker.C:
   149  			msgs := reactor.getMsgs(channel)
   150  			if len(msgs) > 0 {
   151  				if !bytes.Equal(msgs[0].Bytes, msgBytes) {
   152  					t.Fatalf("Unexpected message bytes. Wanted: %X, Got: %X", msgBytes, msgs[0].Bytes)
   153  				}
   154  				return
   155  			}
   156  
   157  		case <-time.After(timeout):
   158  			t.Fatalf("Expected to have received 1 message in channel #%v, got zero", channel)
   159  		}
   160  	}
   161  }
   162  
   163  func TestSwitchFiltersOutItself(t *testing.T) {
   164  	s1 := MakeSwitch(cfg, 1, "127.0.0.1", "123.123.123", initSwitchFunc)
   165  
   166  	// simulate s1 having a public IP by creating a remote peer with the same ID
   167  	rp := &remotePeer{PrivKey: s1.nodeKey.PrivKey, Config: cfg}
   168  	rp.Start()
   169  
   170  	// addr should be rejected in addPeer based on the same ID
   171  	err := s1.DialPeerWithAddress(rp.Addr())
   172  	if assert.Error(t, err) {
   173  		if err, ok := err.(ErrRejected); ok {
   174  			if !err.IsSelf() {
   175  				t.Errorf("expected self to be rejected")
   176  			}
   177  		} else {
   178  			t.Errorf("expected ErrRejected")
   179  		}
   180  	}
   181  
   182  	assert.True(t, s1.addrBook.OurAddress(rp.Addr()))
   183  	assert.False(t, s1.addrBook.HasAddress(rp.Addr()))
   184  
   185  	rp.Stop()
   186  
   187  	assertNoPeersAfterTimeout(t, s1, 100*time.Millisecond)
   188  }
   189  
   190  func TestSwitchPeerFilter(t *testing.T) {
   191  	var (
   192  		filters = []PeerFilterFunc{
   193  			func(_ IPeerSet, _ Peer) error { return nil },
   194  			func(_ IPeerSet, _ Peer) error { return fmt.Errorf("denied!") },
   195  			func(_ IPeerSet, _ Peer) error { return nil },
   196  		}
   197  		sw = MakeSwitch(
   198  			cfg,
   199  			1,
   200  			"testing",
   201  			"123.123.123",
   202  			initSwitchFunc,
   203  			SwitchPeerFilters(filters...),
   204  		)
   205  	)
   206  	defer sw.Stop()
   207  
   208  	// simulate remote peer
   209  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   210  	rp.Start()
   211  	defer rp.Stop()
   212  
   213  	p, err := sw.transport.Dial(*rp.Addr(), peerConfig{
   214  		chDescs:      sw.chDescs,
   215  		onPeerError:  sw.StopPeerForError,
   216  		isPersistent: sw.isPeerPersistentFn(),
   217  		reactorsByCh: sw.reactorsByCh,
   218  	})
   219  	if err != nil {
   220  		t.Fatal(err)
   221  	}
   222  
   223  	err = sw.addPeer(p)
   224  	if err, ok := err.(ErrRejected); ok {
   225  		if !err.IsFiltered() {
   226  			t.Errorf("expected peer to be filtered")
   227  		}
   228  	} else {
   229  		t.Errorf("expected ErrRejected")
   230  	}
   231  }
   232  
   233  func TestSwitchPeerFilterTimeout(t *testing.T) {
   234  	var (
   235  		filters = []PeerFilterFunc{
   236  			func(_ IPeerSet, _ Peer) error {
   237  				time.Sleep(10 * time.Millisecond)
   238  				return nil
   239  			},
   240  		}
   241  		sw = MakeSwitch(
   242  			cfg,
   243  			1,
   244  			"testing",
   245  			"123.123.123",
   246  			initSwitchFunc,
   247  			SwitchFilterTimeout(5*time.Millisecond),
   248  			SwitchPeerFilters(filters...),
   249  		)
   250  	)
   251  	defer sw.Stop()
   252  
   253  	// simulate remote peer
   254  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   255  	rp.Start()
   256  	defer rp.Stop()
   257  
   258  	p, err := sw.transport.Dial(*rp.Addr(), peerConfig{
   259  		chDescs:      sw.chDescs,
   260  		onPeerError:  sw.StopPeerForError,
   261  		isPersistent: sw.isPeerPersistentFn(),
   262  		reactorsByCh: sw.reactorsByCh,
   263  	})
   264  	if err != nil {
   265  		t.Fatal(err)
   266  	}
   267  
   268  	err = sw.addPeer(p)
   269  	if _, ok := err.(ErrFilterTimeout); !ok {
   270  		t.Errorf("expected ErrFilterTimeout")
   271  	}
   272  }
   273  
   274  func TestSwitchPeerFilterDuplicate(t *testing.T) {
   275  	sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc)
   276  	sw.Start()
   277  	defer sw.Stop()
   278  
   279  	// simulate remote peer
   280  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   281  	rp.Start()
   282  	defer rp.Stop()
   283  
   284  	p, err := sw.transport.Dial(*rp.Addr(), peerConfig{
   285  		chDescs:      sw.chDescs,
   286  		onPeerError:  sw.StopPeerForError,
   287  		isPersistent: sw.isPeerPersistentFn(),
   288  		reactorsByCh: sw.reactorsByCh,
   289  	})
   290  	if err != nil {
   291  		t.Fatal(err)
   292  	}
   293  
   294  	if err := sw.addPeer(p); err != nil {
   295  		t.Fatal(err)
   296  	}
   297  
   298  	err = sw.addPeer(p)
   299  	if errRej, ok := err.(ErrRejected); ok {
   300  		if !errRej.IsDuplicate() {
   301  			t.Errorf("expected peer to be duplicate. got %v", errRej)
   302  		}
   303  	} else {
   304  		t.Errorf("expected ErrRejected, got %v", err)
   305  	}
   306  }
   307  
   308  func assertNoPeersAfterTimeout(t *testing.T, sw *Switch, timeout time.Duration) {
   309  	time.Sleep(timeout)
   310  	if sw.Peers().Size() != 0 {
   311  		t.Fatalf("Expected %v to not connect to some peers, got %d", sw, sw.Peers().Size())
   312  	}
   313  }
   314  
   315  func TestSwitchStopsNonPersistentPeerOnError(t *testing.T) {
   316  	assert, require := assert.New(t), require.New(t)
   317  
   318  	sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc)
   319  	err := sw.Start()
   320  	if err != nil {
   321  		t.Error(err)
   322  	}
   323  	defer sw.Stop()
   324  
   325  	// simulate remote peer
   326  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   327  	rp.Start()
   328  	defer rp.Stop()
   329  
   330  	p, err := sw.transport.Dial(*rp.Addr(), peerConfig{
   331  		chDescs:      sw.chDescs,
   332  		onPeerError:  sw.StopPeerForError,
   333  		isPersistent: sw.isPeerPersistentFn(),
   334  		reactorsByCh: sw.reactorsByCh,
   335  	})
   336  	require.Nil(err)
   337  
   338  	err = sw.addPeer(p)
   339  	require.Nil(err)
   340  
   341  	require.NotNil(sw.Peers().Get(rp.ID()))
   342  
   343  	// simulate failure by closing connection
   344  	p.(*peer).CloseConn()
   345  
   346  	assertNoPeersAfterTimeout(t, sw, 100*time.Millisecond)
   347  	assert.False(p.IsRunning())
   348  }
   349  
   350  func TestSwitchStopPeerForError(t *testing.T) {
   351  	s := httptest.NewServer(stdprometheus.UninstrumentedHandler())
   352  	defer s.Close()
   353  
   354  	scrapeMetrics := func() string {
   355  		resp, _ := http.Get(s.URL)
   356  		buf, _ := ioutil.ReadAll(resp.Body)
   357  		return string(buf)
   358  	}
   359  
   360  	namespace, subsystem, name := config.TestInstrumentationConfig().Namespace, MetricsSubsystem, "peers"
   361  	re := regexp.MustCompile(namespace + `_` + subsystem + `_` + name + ` ([0-9\.]+)`)
   362  	peersMetricValue := func() float64 {
   363  		matches := re.FindStringSubmatch(scrapeMetrics())
   364  		f, _ := strconv.ParseFloat(matches[1], 64)
   365  		return f
   366  	}
   367  
   368  	p2pMetrics := PrometheusMetrics(namespace)
   369  
   370  	// make two connected switches
   371  	sw1, sw2 := MakeSwitchPair(t, func(i int, sw *Switch) *Switch {
   372  		// set metrics on sw1
   373  		if i == 0 {
   374  			opt := WithMetrics(p2pMetrics)
   375  			opt(sw)
   376  		}
   377  		return initSwitchFunc(i, sw)
   378  	})
   379  
   380  	assert.Equal(t, len(sw1.Peers().List()), 1)
   381  	assert.EqualValues(t, 1, peersMetricValue())
   382  
   383  	// send messages to the peer from sw1
   384  	p := sw1.Peers().List()[0]
   385  	p.Send(0x1, []byte("here's a message to send"))
   386  
   387  	// stop sw2. this should cause the p to fail,
   388  	// which results in calling StopPeerForError internally
   389  	sw2.Stop()
   390  
   391  	// now call StopPeerForError explicitly, eg. from a reactor
   392  	sw1.StopPeerForError(p, fmt.Errorf("some err"))
   393  
   394  	assert.Equal(t, len(sw1.Peers().List()), 0)
   395  	assert.EqualValues(t, 0, peersMetricValue())
   396  }
   397  
   398  func TestSwitchReconnectsToOutboundPersistentPeer(t *testing.T) {
   399  	sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc)
   400  	err := sw.Start()
   401  	require.NoError(t, err)
   402  	defer sw.Stop()
   403  
   404  	// 1. simulate failure by closing connection
   405  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   406  	rp.Start()
   407  	defer rp.Stop()
   408  
   409  	err = sw.AddPersistentPeers([]string{rp.Addr().String()})
   410  	require.NoError(t, err)
   411  
   412  	err = sw.DialPeerWithAddress(rp.Addr())
   413  	require.Nil(t, err)
   414  	require.NotNil(t, sw.Peers().Get(rp.ID()))
   415  
   416  	p := sw.Peers().List()[0]
   417  	p.(*peer).CloseConn()
   418  
   419  	waitUntilSwitchHasAtLeastNPeers(sw, 1)
   420  	assert.False(t, p.IsRunning())        // old peer instance
   421  	assert.Equal(t, 1, sw.Peers().Size()) // new peer instance
   422  
   423  	// 2. simulate first time dial failure
   424  	rp = &remotePeer{
   425  		PrivKey: ed25519.GenPrivKey(),
   426  		Config:  cfg,
   427  		// Use different interface to prevent duplicate IP filter, this will break
   428  		// beyond two peers.
   429  		listenAddr: "127.0.0.1:0",
   430  	}
   431  	rp.Start()
   432  	defer rp.Stop()
   433  
   434  	conf := config.DefaultP2PConfig()
   435  	conf.TestDialFail = true // will trigger a reconnect
   436  	err = sw.addOutboundPeerWithConfig(rp.Addr(), conf)
   437  	require.NotNil(t, err)
   438  	// DialPeerWithAddres - sw.peerConfig resets the dialer
   439  	waitUntilSwitchHasAtLeastNPeers(sw, 2)
   440  	assert.Equal(t, 2, sw.Peers().Size())
   441  }
   442  
   443  func TestSwitchReconnectsToInboundPersistentPeer(t *testing.T) {
   444  	sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc)
   445  	err := sw.Start()
   446  	require.NoError(t, err)
   447  	defer sw.Stop()
   448  
   449  	// 1. simulate failure by closing the connection
   450  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   451  	rp.Start()
   452  	defer rp.Stop()
   453  
   454  	err = sw.AddPersistentPeers([]string{rp.Addr().String()})
   455  	require.NoError(t, err)
   456  
   457  	conn, err := rp.Dial(sw.NetAddress())
   458  	require.NoError(t, err)
   459  	time.Sleep(50 * time.Millisecond)
   460  	require.NotNil(t, sw.Peers().Get(rp.ID()))
   461  
   462  	conn.Close()
   463  
   464  	waitUntilSwitchHasAtLeastNPeers(sw, 1)
   465  	assert.Equal(t, 1, sw.Peers().Size())
   466  }
   467  
   468  func TestSwitchDialPeersAsync(t *testing.T) {
   469  	if testing.Short() {
   470  		return
   471  	}
   472  
   473  	sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc)
   474  	err := sw.Start()
   475  	require.NoError(t, err)
   476  	defer sw.Stop()
   477  
   478  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   479  	rp.Start()
   480  	defer rp.Stop()
   481  
   482  	err = sw.DialPeersAsync([]string{rp.Addr().String()})
   483  	require.NoError(t, err)
   484  	time.Sleep(dialRandomizerIntervalMilliseconds * time.Millisecond)
   485  	require.NotNil(t, sw.Peers().Get(rp.ID()))
   486  }
   487  
   488  func waitUntilSwitchHasAtLeastNPeers(sw *Switch, n int) {
   489  	for i := 0; i < 20; i++ {
   490  		time.Sleep(250 * time.Millisecond)
   491  		has := sw.Peers().Size()
   492  		if has >= n {
   493  			break
   494  		}
   495  	}
   496  }
   497  
   498  func TestSwitchFullConnectivity(t *testing.T) {
   499  	switches := MakeConnectedSwitches(cfg, 3, initSwitchFunc, Connect2Switches)
   500  	defer func() {
   501  		for _, sw := range switches {
   502  			sw.Stop()
   503  		}
   504  	}()
   505  
   506  	for i, sw := range switches {
   507  		if sw.Peers().Size() != 2 {
   508  			t.Fatalf("Expected each switch to be connected to 2 other, but %d switch only connected to %d", sw.Peers().Size(), i)
   509  		}
   510  	}
   511  }
   512  
   513  func TestSwitchAcceptRoutine(t *testing.T) {
   514  	cfg.MaxNumInboundPeers = 5
   515  
   516  	// make switch
   517  	sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc)
   518  	err := sw.Start()
   519  	require.NoError(t, err)
   520  	defer sw.Stop()
   521  
   522  	remotePeers := make([]*remotePeer, 0)
   523  	assert.Equal(t, 0, sw.Peers().Size())
   524  
   525  	// 1. check we connect up to MaxNumInboundPeers
   526  	for i := 0; i < cfg.MaxNumInboundPeers; i++ {
   527  		rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   528  		remotePeers = append(remotePeers, rp)
   529  		rp.Start()
   530  		c, err := rp.Dial(sw.NetAddress())
   531  		require.NoError(t, err)
   532  		// spawn a reading routine to prevent connection from closing
   533  		go func(c net.Conn) {
   534  			for {
   535  				one := make([]byte, 1)
   536  				_, err := c.Read(one)
   537  				if err != nil {
   538  					return
   539  				}
   540  			}
   541  		}(c)
   542  	}
   543  	time.Sleep(10 * time.Millisecond)
   544  	assert.Equal(t, cfg.MaxNumInboundPeers, sw.Peers().Size())
   545  
   546  	// 2. check we close new connections if we already have MaxNumInboundPeers peers
   547  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   548  	rp.Start()
   549  	conn, err := rp.Dial(sw.NetAddress())
   550  	require.NoError(t, err)
   551  	// check conn is closed
   552  	one := make([]byte, 1)
   553  	conn.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
   554  	_, err = conn.Read(one)
   555  	assert.Equal(t, io.EOF, err)
   556  	assert.Equal(t, cfg.MaxNumInboundPeers, sw.Peers().Size())
   557  	rp.Stop()
   558  
   559  	// stop remote peers
   560  	for _, rp := range remotePeers {
   561  		rp.Stop()
   562  	}
   563  }
   564  
   565  type errorTransport struct {
   566  	acceptErr error
   567  }
   568  
   569  func (et errorTransport) NetAddress() NetAddress {
   570  	panic("not implemented")
   571  }
   572  
   573  func (et errorTransport) Accept(c peerConfig) (Peer, error) {
   574  	return nil, et.acceptErr
   575  }
   576  func (errorTransport) Dial(NetAddress, peerConfig) (Peer, error) {
   577  	panic("not implemented")
   578  }
   579  func (errorTransport) Cleanup(Peer) {
   580  	panic("not implemented")
   581  }
   582  
   583  func TestSwitchAcceptRoutineErrorCases(t *testing.T) {
   584  	sw := NewSwitch(cfg, errorTransport{ErrFilterTimeout{}})
   585  	assert.NotPanics(t, func() {
   586  		err := sw.Start()
   587  		assert.NoError(t, err)
   588  		sw.Stop()
   589  	})
   590  
   591  	sw = NewSwitch(cfg, errorTransport{ErrRejected{conn: nil, err: errors.New("filtered"), isFiltered: true}})
   592  	assert.NotPanics(t, func() {
   593  		err := sw.Start()
   594  		assert.NoError(t, err)
   595  		sw.Stop()
   596  	})
   597  	// TODO(melekes) check we remove our address from addrBook
   598  
   599  	sw = NewSwitch(cfg, errorTransport{ErrTransportClosed{}})
   600  	assert.NotPanics(t, func() {
   601  		err := sw.Start()
   602  		assert.NoError(t, err)
   603  		sw.Stop()
   604  	})
   605  }
   606  
   607  // mockReactor checks that InitPeer never called before RemovePeer. If that's
   608  // not true, InitCalledBeforeRemoveFinished will return true.
   609  type mockReactor struct {
   610  	*BaseReactor
   611  
   612  	// atomic
   613  	removePeerInProgress           uint32
   614  	initCalledBeforeRemoveFinished uint32
   615  }
   616  
   617  func (r *mockReactor) RemovePeer(peer Peer, reason interface{}) {
   618  	atomic.StoreUint32(&r.removePeerInProgress, 1)
   619  	defer atomic.StoreUint32(&r.removePeerInProgress, 0)
   620  	time.Sleep(100 * time.Millisecond)
   621  }
   622  
   623  func (r *mockReactor) InitPeer(peer Peer) Peer {
   624  	if atomic.LoadUint32(&r.removePeerInProgress) == 1 {
   625  		atomic.StoreUint32(&r.initCalledBeforeRemoveFinished, 1)
   626  	}
   627  
   628  	return peer
   629  }
   630  
   631  func (r *mockReactor) InitCalledBeforeRemoveFinished() bool {
   632  	return atomic.LoadUint32(&r.initCalledBeforeRemoveFinished) == 1
   633  }
   634  
   635  // see stopAndRemovePeer
   636  func TestSwitchInitPeerIsNotCalledBeforeRemovePeer(t *testing.T) {
   637  	// make reactor
   638  	reactor := &mockReactor{}
   639  	reactor.BaseReactor = NewBaseReactor("mockReactor", reactor)
   640  
   641  	// make switch
   642  	sw := MakeSwitch(cfg, 1, "testing", "123.123.123", func(i int, sw *Switch) *Switch {
   643  		sw.AddReactor("mock", reactor)
   644  		return sw
   645  	})
   646  	err := sw.Start()
   647  	require.NoError(t, err)
   648  	defer sw.Stop()
   649  
   650  	// add peer
   651  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   652  	rp.Start()
   653  	defer rp.Stop()
   654  	_, err = rp.Dial(sw.NetAddress())
   655  	require.NoError(t, err)
   656  	// wait till the switch adds rp to the peer set
   657  	time.Sleep(50 * time.Millisecond)
   658  
   659  	// stop peer asynchronously
   660  	go sw.StopPeerForError(sw.Peers().Get(rp.ID()), "test")
   661  
   662  	// simulate peer reconnecting to us
   663  	_, err = rp.Dial(sw.NetAddress())
   664  	require.NoError(t, err)
   665  	// wait till the switch adds rp to the peer set
   666  	time.Sleep(50 * time.Millisecond)
   667  
   668  	// make sure reactor.RemovePeer is finished before InitPeer is called
   669  	assert.False(t, reactor.InitCalledBeforeRemoveFinished())
   670  }
   671  
   672  func BenchmarkSwitchBroadcast(b *testing.B) {
   673  	s1, s2 := MakeSwitchPair(b, func(i int, sw *Switch) *Switch {
   674  		// Make bar reactors of bar channels each
   675  		sw.AddReactor("foo", NewTestReactor([]*conn.ChannelDescriptor{
   676  			{ID: byte(0x00), Priority: 10},
   677  			{ID: byte(0x01), Priority: 10},
   678  		}, false))
   679  		sw.AddReactor("bar", NewTestReactor([]*conn.ChannelDescriptor{
   680  			{ID: byte(0x02), Priority: 10},
   681  			{ID: byte(0x03), Priority: 10},
   682  		}, false))
   683  		return sw
   684  	})
   685  	defer s1.Stop()
   686  	defer s2.Stop()
   687  
   688  	// Allow time for goroutines to boot up
   689  	time.Sleep(1 * time.Second)
   690  
   691  	b.ResetTimer()
   692  
   693  	numSuccess, numFailure := 0, 0
   694  
   695  	// Send random message from foo channel to another
   696  	for i := 0; i < b.N; i++ {
   697  		chID := byte(i % 4)
   698  		successChan := s1.Broadcast(chID, []byte("test data"))
   699  		for s := range successChan {
   700  			if s {
   701  				numSuccess++
   702  			} else {
   703  				numFailure++
   704  			}
   705  		}
   706  	}
   707  
   708  	b.Logf("success: %v, failure: %v", numSuccess, numFailure)
   709  }
   710  
   711  type addrBookMock struct {
   712  	addrs    map[string]struct{}
   713  	ourAddrs map[string]struct{}
   714  }
   715  
   716  var _ AddrBook = (*addrBookMock)(nil)
   717  
   718  func (book *addrBookMock) AddAddress(addr *NetAddress, src *NetAddress) error {
   719  	book.addrs[addr.String()] = struct{}{}
   720  	return nil
   721  }
   722  func (book *addrBookMock) AddOurAddress(addr *NetAddress) { book.ourAddrs[addr.String()] = struct{}{} }
   723  func (book *addrBookMock) OurAddress(addr *NetAddress) bool {
   724  	_, ok := book.ourAddrs[addr.String()]
   725  	return ok
   726  }
   727  func (book *addrBookMock) MarkGood(ID) {}
   728  func (book *addrBookMock) HasAddress(addr *NetAddress) bool {
   729  	_, ok := book.addrs[addr.String()]
   730  	return ok
   731  }
   732  func (book *addrBookMock) RemoveAddress(addr *NetAddress) {
   733  	delete(book.addrs, addr.String())
   734  }
   735  func (book *addrBookMock) Save() {}