github.com/Finschia/ostracon@v1.1.5/p2p/switch_test.go (about)

     1  package p2p
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"net"
     9  	"net/http/httptest"
    10  	"regexp"
    11  	"strconv"
    12  	"sync/atomic"
    13  	"testing"
    14  	"time"
    15  
    16  	"github.com/gogo/protobuf/proto"
    17  	"github.com/prometheus/client_golang/prometheus/promhttp"
    18  	"github.com/stretchr/testify/assert"
    19  	"github.com/stretchr/testify/require"
    20  
    21  	p2pproto "github.com/tendermint/tendermint/proto/tendermint/p2p"
    22  
    23  	"github.com/Finschia/ostracon/config"
    24  	"github.com/Finschia/ostracon/crypto/ed25519"
    25  	"github.com/Finschia/ostracon/libs/log"
    26  	tmnet "github.com/Finschia/ostracon/libs/net"
    27  	tmsync "github.com/Finschia/ostracon/libs/sync"
    28  	"github.com/Finschia/ostracon/p2p/conn"
    29  )
    30  
    31  var cfg *config.P2PConfig
    32  
    33  func init() {
    34  	cfg = config.DefaultP2PConfig()
    35  	cfg.PexReactor = true
    36  	cfg.AllowDuplicateIP = true
    37  }
    38  
    39  type PeerMessage struct {
    40  	Contents proto.Message
    41  	Counter  int
    42  }
    43  
    44  type TestReactor struct {
    45  	BaseReactor
    46  
    47  	mtx          tmsync.Mutex
    48  	channels     []*conn.ChannelDescriptor
    49  	logMessages  bool
    50  	msgsCounter  int
    51  	msgsReceived map[byte][]PeerMessage
    52  }
    53  
    54  func NewTestReactor(channels []*conn.ChannelDescriptor, async bool, recvBufSize int, logMessages bool) *TestReactor {
    55  	tr := &TestReactor{
    56  		channels:     channels,
    57  		logMessages:  logMessages,
    58  		msgsReceived: make(map[byte][]PeerMessage),
    59  	}
    60  	tr.BaseReactor = *NewBaseReactor("TestReactor", tr, async, recvBufSize)
    61  	tr.SetLogger(log.TestingLogger())
    62  	return tr
    63  }
    64  
    65  func (tr *TestReactor) GetChannels() []*conn.ChannelDescriptor {
    66  	return tr.channels
    67  }
    68  
    69  func (tr *TestReactor) AddPeer(peer Peer) {}
    70  
    71  func (tr *TestReactor) RemovePeer(peer Peer, reason interface{}) {}
    72  
    73  func (tr *TestReactor) ReceiveEnvelope(e Envelope) {
    74  	if tr.logMessages {
    75  		tr.mtx.Lock()
    76  		defer tr.mtx.Unlock()
    77  		// fmt.Printf("Received: %X, %X\n", e.ChannelID, e.Message)
    78  		tr.msgsReceived[e.ChannelID] = append(tr.msgsReceived[e.ChannelID], PeerMessage{Contents: e.Message, Counter: tr.msgsCounter})
    79  		tr.msgsCounter++
    80  	}
    81  }
    82  
    83  func (tr *TestReactor) Receive(chID byte, peer Peer, msgBytes []byte) {
    84  	msg := &p2pproto.Message{}
    85  	err := proto.Unmarshal(msgBytes, msg)
    86  	if err != nil {
    87  		panic(err)
    88  	}
    89  	um, err := msg.Unwrap()
    90  	if err != nil {
    91  		panic(err)
    92  	}
    93  
    94  	tr.ReceiveEnvelope(Envelope{
    95  		ChannelID: chID,
    96  		Src:       peer,
    97  		Message:   um,
    98  	})
    99  }
   100  
   101  func (tr *TestReactor) getMsgs(chID byte) []PeerMessage {
   102  	tr.mtx.Lock()
   103  	defer tr.mtx.Unlock()
   104  	return tr.msgsReceived[chID]
   105  }
   106  
   107  //-----------------------------------------------------------------------------
   108  
   109  // convenience method for creating two switches connected to each other.
   110  // XXX: note this uses net.Pipe and not a proper TCP conn
   111  func MakeSwitchPair(t testing.TB, initSwitch func(int, *Switch, *config.P2PConfig) *Switch) (*Switch, *Switch) {
   112  	// Create two switches that will be interconnected.
   113  	switches := MakeConnectedSwitches(cfg, 2, initSwitch, Connect2Switches)
   114  	return switches[0], switches[1]
   115  }
   116  
   117  func initSwitchFunc(i int, sw *Switch, config *config.P2PConfig) *Switch {
   118  	sw.SetAddrBook(&AddrBookMock{
   119  		Addrs:    make(map[string]struct{}),
   120  		OurAddrs: make(map[string]struct{}),
   121  	})
   122  
   123  	// Make two reactors of two channels each
   124  	sw.AddReactor("foo", NewTestReactor([]*conn.ChannelDescriptor{
   125  		{ID: byte(0x00), Priority: 10, MessageType: &p2pproto.Message{}},
   126  		{ID: byte(0x01), Priority: 10, MessageType: &p2pproto.Message{}},
   127  	}, config.RecvAsync, 1000, true))
   128  	sw.AddReactor("bar", NewTestReactor([]*conn.ChannelDescriptor{
   129  		{ID: byte(0x02), Priority: 10, MessageType: &p2pproto.Message{}},
   130  		{ID: byte(0x03), Priority: 10, MessageType: &p2pproto.Message{}},
   131  	}, config.RecvAsync, 1000, true))
   132  
   133  	return sw
   134  }
   135  
   136  func TestSwitches(t *testing.T) {
   137  	s1, s2 := MakeSwitchPair(t, initSwitchFunc)
   138  	t.Cleanup(func() {
   139  		if err := s1.Stop(); err != nil {
   140  			t.Error(err)
   141  		}
   142  	})
   143  	t.Cleanup(func() {
   144  		if err := s2.Stop(); err != nil {
   145  			t.Error(err)
   146  		}
   147  	})
   148  
   149  	if s1.Peers().Size() != 1 {
   150  		t.Errorf("expected exactly 1 peer in s1, got %v", s1.Peers().Size())
   151  	}
   152  	if s2.Peers().Size() != 1 {
   153  		t.Errorf("expected exactly 1 peer in s2, got %v", s2.Peers().Size())
   154  	}
   155  
   156  	// Lets send some messages
   157  	ch0Msg := &p2pproto.PexAddrs{
   158  		Addrs: []p2pproto.NetAddress{
   159  			{
   160  				ID: "1",
   161  			},
   162  		},
   163  	}
   164  	ch1Msg := &p2pproto.PexAddrs{
   165  		Addrs: []p2pproto.NetAddress{
   166  			{
   167  				ID: "1",
   168  			},
   169  		},
   170  	}
   171  	ch2Msg := &p2pproto.PexAddrs{
   172  		Addrs: []p2pproto.NetAddress{
   173  			{
   174  				ID: "2",
   175  			},
   176  		},
   177  	}
   178  	s1.BroadcastEnvelope(Envelope{ChannelID: byte(0x00), Message: ch0Msg})
   179  	s1.BroadcastEnvelope(Envelope{ChannelID: byte(0x01), Message: ch1Msg})
   180  	s1.BroadcastEnvelope(Envelope{ChannelID: byte(0x02), Message: ch2Msg})
   181  	assertMsgReceivedWithTimeout(t,
   182  		ch0Msg,
   183  		byte(0x00),
   184  		s2.Reactor("foo").(*TestReactor), 200*time.Millisecond, 5*time.Second)
   185  	assertMsgReceivedWithTimeout(t,
   186  		ch1Msg,
   187  		byte(0x01),
   188  		s2.Reactor("foo").(*TestReactor), 200*time.Millisecond, 5*time.Second)
   189  	assertMsgReceivedWithTimeout(t,
   190  		ch2Msg,
   191  		byte(0x02),
   192  		s2.Reactor("bar").(*TestReactor), 200*time.Millisecond, 5*time.Second)
   193  }
   194  
   195  func assertMsgReceivedWithTimeout(
   196  	t *testing.T,
   197  	msg proto.Message,
   198  	channel byte,
   199  	reactor *TestReactor,
   200  	checkPeriod,
   201  	timeout time.Duration,
   202  ) {
   203  	ticker := time.NewTicker(checkPeriod)
   204  	for {
   205  		select {
   206  		case <-ticker.C:
   207  			msgs := reactor.getMsgs(channel)
   208  			expectedBytes, err := proto.Marshal(msgs[0].Contents)
   209  			require.NoError(t, err)
   210  			gotBytes, err := proto.Marshal(msg)
   211  			require.NoError(t, err)
   212  			if len(msgs) > 0 {
   213  				if !bytes.Equal(expectedBytes, gotBytes) {
   214  					t.Fatalf("Unexpected message bytes. Wanted: %X, Got: %X", msg, msgs[0].Counter)
   215  				}
   216  				return
   217  			}
   218  
   219  		case <-time.After(timeout):
   220  			t.Fatalf("Expected to have received 1 message in channel #%v, got zero", channel)
   221  		}
   222  	}
   223  }
   224  
   225  func TestSwitchFiltersOutItself(t *testing.T) {
   226  	s1 := MakeSwitch(cfg, 1, "127.0.0.1", "123.123.123", initSwitchFunc)
   227  
   228  	// simulate s1 having a public IP by creating a remote peer with the same ID
   229  	rp := &remotePeer{PrivKey: s1.nodeKey.PrivKey, Config: cfg}
   230  	rp.Start()
   231  
   232  	// addr should be rejected in addPeer based on the same ID
   233  	err := s1.DialPeerWithAddress(rp.Addr())
   234  	if assert.Error(t, err) {
   235  		if err, ok := err.(ErrRejected); ok {
   236  			if !err.IsSelf() {
   237  				t.Errorf("expected self to be rejected")
   238  			}
   239  		} else {
   240  			t.Errorf("expected ErrRejected")
   241  		}
   242  	}
   243  
   244  	assert.True(t, s1.addrBook.OurAddress(rp.Addr()))
   245  	assert.False(t, s1.addrBook.HasAddress(rp.Addr()))
   246  
   247  	rp.Stop()
   248  
   249  	assertNoPeersAfterTimeout(t, s1, 100*time.Millisecond)
   250  }
   251  
   252  func TestSwitchPeerFilter(t *testing.T) {
   253  	var (
   254  		filters = []PeerFilterFunc{
   255  			func(_ IPeerSet, _ Peer) error { return nil },
   256  			func(_ IPeerSet, _ Peer) error { return fmt.Errorf("denied") },
   257  			func(_ IPeerSet, _ Peer) error { return nil },
   258  		}
   259  		sw = MakeSwitch(
   260  			cfg,
   261  			1,
   262  			"testing",
   263  			"123.123.123",
   264  			initSwitchFunc,
   265  			SwitchPeerFilters(filters...),
   266  		)
   267  	)
   268  	err := sw.Start()
   269  	require.NoError(t, err)
   270  	t.Cleanup(func() {
   271  		if err := sw.Stop(); err != nil {
   272  			t.Error(err)
   273  		}
   274  	})
   275  
   276  	// simulate remote peer
   277  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   278  	rp.Start()
   279  	t.Cleanup(rp.Stop)
   280  
   281  	p, err := sw.transport.Dial(*rp.Addr(), peerConfig{
   282  		chDescs:      sw.chDescs,
   283  		onPeerError:  sw.StopPeerForError,
   284  		isPersistent: sw.IsPeerPersistent,
   285  		reactorsByCh: sw.reactorsByCh,
   286  	})
   287  	if err != nil {
   288  		t.Fatal(err)
   289  	}
   290  
   291  	err = sw.addPeer(p)
   292  	if err, ok := err.(ErrRejected); ok {
   293  		if !err.IsFiltered() {
   294  			t.Errorf("expected peer to be filtered")
   295  		}
   296  	} else {
   297  		t.Errorf("expected ErrRejected")
   298  	}
   299  }
   300  
   301  func TestSwitchPeerFilterTimeout(t *testing.T) {
   302  	var (
   303  		filters = []PeerFilterFunc{
   304  			func(_ IPeerSet, _ Peer) error {
   305  				time.Sleep(10 * time.Millisecond)
   306  				return nil
   307  			},
   308  		}
   309  		sw = MakeSwitch(
   310  			cfg,
   311  			1,
   312  			"testing",
   313  			"123.123.123",
   314  			initSwitchFunc,
   315  			SwitchFilterTimeout(5*time.Millisecond),
   316  			SwitchPeerFilters(filters...),
   317  		)
   318  	)
   319  	err := sw.Start()
   320  	require.NoError(t, err)
   321  	t.Cleanup(func() {
   322  		if err := sw.Stop(); err != nil {
   323  			t.Log(err)
   324  		}
   325  	})
   326  
   327  	// simulate remote peer
   328  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   329  	rp.Start()
   330  	defer rp.Stop()
   331  
   332  	p, err := sw.transport.Dial(*rp.Addr(), peerConfig{
   333  		chDescs:      sw.chDescs,
   334  		onPeerError:  sw.StopPeerForError,
   335  		isPersistent: sw.IsPeerPersistent,
   336  		reactorsByCh: sw.reactorsByCh,
   337  	})
   338  	if err != nil {
   339  		t.Fatal(err)
   340  	}
   341  
   342  	err = sw.addPeer(p)
   343  	if _, ok := err.(ErrFilterTimeout); !ok {
   344  		t.Errorf("expected ErrFilterTimeout")
   345  	}
   346  }
   347  
   348  func TestSwitchPeerFilterDuplicate(t *testing.T) {
   349  	sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc)
   350  	err := sw.Start()
   351  	require.NoError(t, err)
   352  	t.Cleanup(func() {
   353  		if err := sw.Stop(); err != nil {
   354  			t.Error(err)
   355  		}
   356  	})
   357  
   358  	// simulate remote peer
   359  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   360  	rp.Start()
   361  	defer rp.Stop()
   362  
   363  	p, err := sw.transport.Dial(*rp.Addr(), peerConfig{
   364  		chDescs:      sw.chDescs,
   365  		onPeerError:  sw.StopPeerForError,
   366  		isPersistent: sw.IsPeerPersistent,
   367  		reactorsByCh: sw.reactorsByCh,
   368  	})
   369  	if err != nil {
   370  		t.Fatal(err)
   371  	}
   372  
   373  	if err := sw.addPeer(p); err != nil {
   374  		t.Fatal(err)
   375  	}
   376  
   377  	err = sw.addPeer(p)
   378  	if errRej, ok := err.(ErrRejected); ok {
   379  		if !errRej.IsDuplicate() {
   380  			t.Errorf("expected peer to be duplicate. got %v", errRej)
   381  		}
   382  	} else {
   383  		t.Errorf("expected ErrRejected, got %v", err)
   384  	}
   385  }
   386  
   387  func assertNoPeersAfterTimeout(t *testing.T, sw *Switch, timeout time.Duration) {
   388  	time.Sleep(timeout)
   389  	if sw.Peers().Size() != 0 {
   390  		t.Fatalf("Expected %v to not connect to some peers, got %d", sw, sw.Peers().Size())
   391  	}
   392  }
   393  
   394  func TestSwitchStopsNonPersistentPeerOnError(t *testing.T) {
   395  	assert, require := assert.New(t), require.New(t)
   396  
   397  	sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc)
   398  	err := sw.Start()
   399  	if err != nil {
   400  		t.Error(err)
   401  	}
   402  	t.Cleanup(func() {
   403  		if err := sw.Stop(); err != nil {
   404  			t.Error(err)
   405  		}
   406  	})
   407  
   408  	// simulate remote peer
   409  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   410  	rp.Start()
   411  	defer rp.Stop()
   412  
   413  	p, err := sw.transport.Dial(*rp.Addr(), peerConfig{
   414  		chDescs:      sw.chDescs,
   415  		onPeerError:  sw.StopPeerForError,
   416  		isPersistent: sw.IsPeerPersistent,
   417  		reactorsByCh: sw.reactorsByCh,
   418  	})
   419  	require.Nil(err)
   420  
   421  	err = sw.addPeer(p)
   422  	require.Nil(err)
   423  
   424  	require.NotNil(sw.Peers().Get(rp.ID()))
   425  
   426  	// simulate failure by closing connection
   427  	err = p.(*peer).CloseConn()
   428  	require.NoError(err)
   429  
   430  	assertNoPeersAfterTimeout(t, sw, 100*time.Millisecond)
   431  	assert.False(p.IsRunning())
   432  }
   433  
   434  func TestSwitchStopPeerForError(t *testing.T) {
   435  	s := httptest.NewServer(promhttp.Handler())
   436  	defer s.Close()
   437  
   438  	scrapeMetrics := func() string {
   439  		resp, err := tmnet.HttpGet(s.URL, 60*time.Second)
   440  		require.NoError(t, err)
   441  		defer resp.Body.Close()
   442  		buf, _ := io.ReadAll(resp.Body)
   443  		return string(buf)
   444  	}
   445  
   446  	namespace, subsystem, name := config.TestInstrumentationConfig().Namespace, MetricsSubsystem, "peers"
   447  	re := regexp.MustCompile(namespace + `_` + subsystem + `_` + name + ` ([0-9\.]+)`)
   448  	peersMetricValue := func() float64 {
   449  		matches := re.FindStringSubmatch(scrapeMetrics())
   450  		f, _ := strconv.ParseFloat(matches[1], 64)
   451  		return f
   452  	}
   453  
   454  	p2pMetrics := PrometheusMetrics(namespace)
   455  
   456  	// make two connected switches
   457  	sw1, sw2 := MakeSwitchPair(t, func(i int, sw *Switch, config *config.P2PConfig) *Switch {
   458  		// set metrics on sw1
   459  		if i == 0 {
   460  			opt := WithMetrics(p2pMetrics)
   461  			opt(sw)
   462  		}
   463  		return initSwitchFunc(i, sw, config)
   464  	})
   465  
   466  	assert.Equal(t, len(sw1.Peers().List()), 1)
   467  	assert.EqualValues(t, 1, peersMetricValue())
   468  
   469  	// send messages to the peer from sw1
   470  	p := sw1.Peers().List()[0]
   471  	SendEnvelopeShim(p, Envelope{
   472  		ChannelID: 0x1,
   473  		Message:   &p2pproto.Message{},
   474  	}, sw1.Logger)
   475  
   476  	// stop sw2. this should cause the p to fail,
   477  	// which results in calling StopPeerForError internally
   478  	t.Cleanup(func() {
   479  		if err := sw2.Stop(); err != nil {
   480  			t.Error(err)
   481  		}
   482  	})
   483  
   484  	// now call StopPeerForError explicitly, eg. from a reactor
   485  	sw1.StopPeerForError(p, fmt.Errorf("some err"))
   486  
   487  	assert.Equal(t, len(sw1.Peers().List()), 0)
   488  	assert.EqualValues(t, 0, peersMetricValue())
   489  }
   490  
   491  func TestSwitchReconnectsToOutboundPersistentPeer(t *testing.T) {
   492  	sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc)
   493  	err := sw.Start()
   494  	require.NoError(t, err)
   495  	t.Cleanup(func() {
   496  		if err := sw.Stop(); err != nil {
   497  			t.Error(err)
   498  		}
   499  	})
   500  
   501  	// 1. simulate failure by closing connection
   502  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   503  	rp.Start()
   504  	defer rp.Stop()
   505  
   506  	err = sw.AddPersistentPeers([]string{rp.Addr().String()})
   507  	require.NoError(t, err)
   508  
   509  	err = sw.DialPeerWithAddress(rp.Addr())
   510  	require.Nil(t, err)
   511  	require.NotNil(t, sw.Peers().Get(rp.ID()))
   512  
   513  	p := sw.Peers().List()[0]
   514  	err = p.(*peer).CloseConn()
   515  	require.NoError(t, err)
   516  
   517  	waitUntilSwitchHasAtLeastNPeers(sw, 1)
   518  	assert.False(t, p.IsRunning())        // old peer instance
   519  	assert.Equal(t, 1, sw.Peers().Size()) // new peer instance
   520  
   521  	// 2. simulate first time dial failure
   522  	rp = &remotePeer{
   523  		PrivKey: ed25519.GenPrivKey(),
   524  		Config:  cfg,
   525  		// Use different interface to prevent duplicate IP filter, this will break
   526  		// beyond two peers.
   527  		listenAddr: "127.0.0.1:0",
   528  	}
   529  	rp.Start()
   530  	defer rp.Stop()
   531  
   532  	conf := config.DefaultP2PConfig()
   533  	conf.TestDialFail = true // will trigger a reconnect
   534  	err = sw.addOutboundPeerWithConfig(rp.Addr(), conf)
   535  	require.NotNil(t, err)
   536  	// DialPeerWithAddres - sw.peerConfig resets the dialer
   537  	waitUntilSwitchHasAtLeastNPeers(sw, 2)
   538  	assert.Equal(t, 2, sw.Peers().Size())
   539  }
   540  
   541  func TestSwitchReconnectsToInboundPersistentPeer(t *testing.T) {
   542  	sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc)
   543  	err := sw.Start()
   544  	require.NoError(t, err)
   545  	t.Cleanup(func() {
   546  		if err := sw.Stop(); err != nil {
   547  			t.Error(err)
   548  		}
   549  	})
   550  
   551  	// 1. simulate failure by closing the connection
   552  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   553  	rp.Start()
   554  	defer rp.Stop()
   555  
   556  	err = sw.AddPersistentPeers([]string{rp.Addr().String()})
   557  	require.NoError(t, err)
   558  
   559  	conn, err := rp.Dial(sw.NetAddress())
   560  	require.NoError(t, err)
   561  	time.Sleep(50 * time.Millisecond)
   562  	require.NotNil(t, sw.Peers().Get(rp.ID()))
   563  
   564  	conn.Close()
   565  
   566  	waitUntilSwitchHasAtLeastNPeers(sw, 1)
   567  	assert.Equal(t, 1, sw.Peers().Size())
   568  }
   569  
   570  func TestSwitchDialPeersAsync(t *testing.T) {
   571  	if testing.Short() {
   572  		return
   573  	}
   574  
   575  	sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc)
   576  	err := sw.Start()
   577  	require.NoError(t, err)
   578  	t.Cleanup(func() {
   579  		if err := sw.Stop(); err != nil {
   580  			t.Error(err)
   581  		}
   582  	})
   583  
   584  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   585  	rp.Start()
   586  	defer rp.Stop()
   587  
   588  	err = sw.DialPeersAsync([]string{rp.Addr().String()})
   589  	require.NoError(t, err)
   590  	time.Sleep(dialRandomizerIntervalMilliseconds * time.Millisecond)
   591  	require.NotNil(t, sw.Peers().Get(rp.ID()))
   592  }
   593  
   594  func waitUntilSwitchHasAtLeastNPeers(sw *Switch, n int) {
   595  	for i := 0; i < 20; i++ {
   596  		time.Sleep(250 * time.Millisecond)
   597  		has := sw.Peers().Size()
   598  		if has >= n {
   599  			break
   600  		}
   601  	}
   602  }
   603  
   604  func TestSwitchFullConnectivity(t *testing.T) {
   605  	switches := MakeConnectedSwitches(cfg, 3, initSwitchFunc, Connect2Switches)
   606  	defer func() {
   607  		for _, sw := range switches {
   608  			sw := sw
   609  			t.Cleanup(func() {
   610  				if err := sw.Stop(); err != nil {
   611  					t.Error(err)
   612  				}
   613  			})
   614  		}
   615  	}()
   616  
   617  	for i, sw := range switches {
   618  		if sw.Peers().Size() != 2 {
   619  			t.Fatalf("Expected each switch to be connected to 2 other, but %d switch only connected to %d", sw.Peers().Size(), i)
   620  		}
   621  	}
   622  }
   623  
   624  func TestSwitchAcceptRoutine(t *testing.T) {
   625  	cfg.MaxNumInboundPeers = 5
   626  
   627  	// Create some unconditional peers.
   628  	const unconditionalPeersNum = 2
   629  	var (
   630  		unconditionalPeers   = make([]*remotePeer, unconditionalPeersNum)
   631  		unconditionalPeerIDs = make([]string, unconditionalPeersNum)
   632  	)
   633  	for i := 0; i < unconditionalPeersNum; i++ {
   634  		peer := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   635  		peer.Start()
   636  		unconditionalPeers[i] = peer
   637  		unconditionalPeerIDs[i] = string(peer.ID())
   638  	}
   639  
   640  	// make switch
   641  	sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc)
   642  	err := sw.AddUnconditionalPeerIDs(unconditionalPeerIDs)
   643  	require.NoError(t, err)
   644  	err = sw.Start()
   645  	require.NoError(t, err)
   646  	t.Cleanup(func() {
   647  		err := sw.Stop()
   648  		require.NoError(t, err)
   649  	})
   650  
   651  	// 0. check there are no peers
   652  	assert.Equal(t, 0, sw.Peers().Size())
   653  
   654  	// 1. check we connect up to MaxNumInboundPeers
   655  	peers := make([]*remotePeer, 0)
   656  	for i := 0; i < cfg.MaxNumInboundPeers; i++ {
   657  		peer := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   658  		peers = append(peers, peer)
   659  		peer.Start()
   660  		c, err := peer.Dial(sw.NetAddress())
   661  		require.NoError(t, err)
   662  		// spawn a reading routine to prevent connection from closing
   663  		go func(c net.Conn) {
   664  			for {
   665  				one := make([]byte, 1)
   666  				_, err := c.Read(one)
   667  				if err != nil {
   668  					return
   669  				}
   670  			}
   671  		}(c)
   672  	}
   673  	time.Sleep(100 * time.Millisecond)
   674  	assert.Equal(t, cfg.MaxNumInboundPeers, sw.Peers().Size())
   675  
   676  	// 2. check we close new connections if we already have MaxNumInboundPeers peers
   677  	peer := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   678  	peer.Start()
   679  	conn, err := peer.Dial(sw.NetAddress())
   680  	require.NoError(t, err)
   681  	// check conn is closed
   682  	one := make([]byte, 1)
   683  	_ = conn.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
   684  	_, err = conn.Read(one)
   685  	assert.Error(t, err)
   686  	assert.Equal(t, cfg.MaxNumInboundPeers, sw.Peers().Size())
   687  	peer.Stop()
   688  
   689  	// 3. check we connect to unconditional peers despite the limit.
   690  	for _, peer := range unconditionalPeers {
   691  		c, err := peer.Dial(sw.NetAddress())
   692  		require.NoError(t, err)
   693  		// spawn a reading routine to prevent connection from closing
   694  		go func(c net.Conn) {
   695  			for {
   696  				one := make([]byte, 1)
   697  				_, err := c.Read(one)
   698  				if err != nil {
   699  					return
   700  				}
   701  			}
   702  		}(c)
   703  	}
   704  	time.Sleep(10 * time.Millisecond)
   705  	assert.Equal(t, cfg.MaxNumInboundPeers+unconditionalPeersNum, sw.Peers().Size())
   706  
   707  	for _, peer := range peers {
   708  		peer.Stop()
   709  	}
   710  	for _, peer := range unconditionalPeers {
   711  		peer.Stop()
   712  	}
   713  }
   714  
   715  type errorTransport struct {
   716  	acceptErr error
   717  }
   718  
   719  func (et errorTransport) NetAddress() NetAddress {
   720  	panic("not implemented")
   721  }
   722  
   723  func (et errorTransport) Accept(c peerConfig) (Peer, error) {
   724  	return nil, et.acceptErr
   725  }
   726  
   727  func (errorTransport) Dial(NetAddress, peerConfig) (Peer, error) {
   728  	panic("not implemented")
   729  }
   730  
   731  func (errorTransport) Cleanup(Peer) {
   732  	panic("not implemented")
   733  }
   734  
   735  func TestSwitchAcceptRoutineErrorCases(t *testing.T) {
   736  	sw := NewSwitch(cfg, errorTransport{ErrFilterTimeout{}})
   737  	assert.NotPanics(t, func() {
   738  		err := sw.Start()
   739  		require.NoError(t, err)
   740  		err = sw.Stop()
   741  		require.NoError(t, err)
   742  	})
   743  
   744  	sw = NewSwitch(cfg, errorTransport{ErrRejected{conn: nil, err: errors.New("filtered"), isFiltered: true}})
   745  	assert.NotPanics(t, func() {
   746  		err := sw.Start()
   747  		require.NoError(t, err)
   748  		err = sw.Stop()
   749  		require.NoError(t, err)
   750  	})
   751  	// TODO(melekes) check we remove our address from addrBook
   752  
   753  	sw = NewSwitch(cfg, errorTransport{ErrTransportClosed{}})
   754  	assert.NotPanics(t, func() {
   755  		err := sw.Start()
   756  		require.NoError(t, err)
   757  		err = sw.Stop()
   758  		require.NoError(t, err)
   759  	})
   760  }
   761  
   762  // mockReactor checks that InitPeer never called before RemovePeer. If that's
   763  // not true, InitCalledBeforeRemoveFinished will return true.
   764  type mockReactor struct {
   765  	*BaseReactor
   766  
   767  	// atomic
   768  	removePeerInProgress           uint32
   769  	initCalledBeforeRemoveFinished uint32
   770  }
   771  
   772  func (r *mockReactor) RemovePeer(peer Peer, reason interface{}) {
   773  	atomic.StoreUint32(&r.removePeerInProgress, 1)
   774  	defer atomic.StoreUint32(&r.removePeerInProgress, 0)
   775  	time.Sleep(100 * time.Millisecond)
   776  }
   777  
   778  func (r *mockReactor) InitPeer(peer Peer) Peer {
   779  	if atomic.LoadUint32(&r.removePeerInProgress) == 1 {
   780  		atomic.StoreUint32(&r.initCalledBeforeRemoveFinished, 1)
   781  	}
   782  
   783  	return peer
   784  }
   785  
   786  func (r *mockReactor) InitCalledBeforeRemoveFinished() bool {
   787  	return atomic.LoadUint32(&r.initCalledBeforeRemoveFinished) == 1
   788  }
   789  
   790  // see stopAndRemovePeer
   791  func TestSwitchInitPeerIsNotCalledBeforeRemovePeer(t *testing.T) {
   792  	// make reactor
   793  	reactor := &mockReactor{}
   794  	reactor.BaseReactor = NewBaseReactor("mockReactor", reactor, true, 1000)
   795  
   796  	// make switch
   797  	sw := MakeSwitch(cfg, 1, "testing", "123.123.123", func(i int, sw *Switch, config *config.P2PConfig) *Switch {
   798  		sw.AddReactor("mock", reactor)
   799  		return sw
   800  	}) // mock reactor uses 1000 chan buffer
   801  	err := sw.Start()
   802  	require.NoError(t, err)
   803  	t.Cleanup(func() {
   804  		if err := sw.Stop(); err != nil {
   805  			t.Error(err)
   806  		}
   807  	})
   808  
   809  	// add peer
   810  	rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg}
   811  	rp.Start()
   812  	defer rp.Stop()
   813  	_, err = rp.Dial(sw.NetAddress())
   814  	require.NoError(t, err)
   815  
   816  	// wait till the switch adds rp to the peer set, then stop the peer asynchronously
   817  	for {
   818  		time.Sleep(20 * time.Millisecond)
   819  		if peer := sw.Peers().Get(rp.ID()); peer != nil {
   820  			go sw.StopPeerForError(peer, "test")
   821  			break
   822  		}
   823  	}
   824  
   825  	// simulate peer reconnecting to us
   826  	_, err = rp.Dial(sw.NetAddress())
   827  	require.NoError(t, err)
   828  	// wait till the switch adds rp to the peer set
   829  	time.Sleep(50 * time.Millisecond)
   830  
   831  	// make sure reactor.RemovePeer is finished before InitPeer is called
   832  	assert.False(t, reactor.InitCalledBeforeRemoveFinished())
   833  }
   834  
   835  func BenchmarkSwitchBroadcast(b *testing.B) {
   836  	s1, s2 := MakeSwitchPair(b, func(i int, sw *Switch, config *config.P2PConfig) *Switch {
   837  		// Make bar reactors of bar channels each
   838  		sw.AddReactor("foo", NewTestReactor([]*conn.ChannelDescriptor{
   839  			{ID: byte(0x00), Priority: 10},
   840  			{ID: byte(0x01), Priority: 10},
   841  		}, config.RecvAsync, 1000, false))
   842  		sw.AddReactor("bar", NewTestReactor([]*conn.ChannelDescriptor{
   843  			{ID: byte(0x02), Priority: 10},
   844  			{ID: byte(0x03), Priority: 10},
   845  		}, config.RecvAsync, 1000, false))
   846  		return sw
   847  	})
   848  
   849  	b.Cleanup(func() {
   850  		if err := s1.Stop(); err != nil {
   851  			b.Error(err)
   852  		}
   853  	})
   854  
   855  	b.Cleanup(func() {
   856  		if err := s2.Stop(); err != nil {
   857  			b.Error(err)
   858  		}
   859  	})
   860  
   861  	// Allow time for goroutines to boot up
   862  	time.Sleep(1 * time.Second)
   863  
   864  	b.ResetTimer()
   865  
   866  	numSuccess, numFailure := 0, 0
   867  
   868  	// Send random message from foo channel to another
   869  	for i := 0; i < b.N; i++ {
   870  		chID := byte(i % 4)
   871  		successChan := s1.BroadcastEnvelope(Envelope{ChannelID: chID})
   872  		for s := range successChan {
   873  			if s {
   874  				numSuccess++
   875  			} else {
   876  				numFailure++
   877  			}
   878  		}
   879  	}
   880  
   881  	b.Logf("success: %v, failure: %v", numSuccess, numFailure)
   882  }
   883  
   884  func TestSwitchRemovalErr(t *testing.T) {
   885  	sw1, sw2 := MakeSwitchPair(t, func(i int, sw *Switch, config *config.P2PConfig) *Switch {
   886  		return initSwitchFunc(i, sw, config)
   887  	})
   888  	assert.Equal(t, len(sw1.Peers().List()), 1)
   889  	p := sw1.Peers().List()[0]
   890  
   891  	sw2.StopPeerForError(p, fmt.Errorf("peer should error"))
   892  
   893  	assert.Equal(t, sw2.peers.Add(p).Error(), ErrPeerRemoval{}.Error())
   894  }
   895  
   896  type addrBookMock struct {
   897  	addrs    map[string]struct{}
   898  	ourAddrs map[string]struct{}
   899  }
   900  
   901  var _ AddrBook = (*addrBookMock)(nil)
   902  
   903  func (book *addrBookMock) AddAddress(addr *NetAddress, src *NetAddress) error {
   904  	book.addrs[addr.String()] = struct{}{}
   905  	return nil
   906  }
   907  func (book *addrBookMock) AddPrivateIDs(strings []string) {
   908  	panic("implement me")
   909  }
   910  func (book *addrBookMock) AddOurAddress(addr *NetAddress) { book.ourAddrs[addr.String()] = struct{}{} }
   911  func (book *addrBookMock) OurAddress(addr *NetAddress) bool {
   912  	_, ok := book.ourAddrs[addr.String()]
   913  	return ok
   914  }
   915  func (book *addrBookMock) MarkGood(ID) {}
   916  func (book *addrBookMock) RemoveAddress(addr *NetAddress) {
   917  	delete(book.addrs, addr.String())
   918  }
   919  func (book *addrBookMock) HasAddress(addr *NetAddress) bool {
   920  	_, ok := book.addrs[addr.String()]
   921  	return ok
   922  }
   923  func (book *addrBookMock) Save() {}
   924  
   925  type NormalReactor struct {
   926  	BaseReactor
   927  	channels []*conn.ChannelDescriptor
   928  	msgChan  chan []byte
   929  }
   930  
   931  func NewNormalReactor(channels []*conn.ChannelDescriptor, async bool, recvBufSize int) *NormalReactor {
   932  	nr := &NormalReactor{
   933  		channels: channels,
   934  	}
   935  	nr.BaseReactor = *NewBaseReactor("NormalReactor", nr, async, recvBufSize)
   936  	nr.msgChan = make(chan []byte)
   937  	nr.SetLogger(log.TestingLogger())
   938  	return nr
   939  }
   940  
   941  func (nr *NormalReactor) GetChannels() []*conn.ChannelDescriptor {
   942  	return nr.channels
   943  }
   944  
   945  func (nr *NormalReactor) AddPeer(peer Peer) {}
   946  
   947  func (nr *NormalReactor) RemovePeer(peer Peer, reason interface{}) {}
   948  
   949  func (nr *NormalReactor) ReceiveEnvelope(e Envelope) {
   950  	nr.msgChan <- []byte{0}
   951  }
   952  
   953  func (nr *NormalReactor) Receive(chID byte, peer Peer, msgBytes []byte) {
   954  	nr.msgChan <- msgBytes
   955  }
   956  
   957  type BlockedReactor struct {
   958  	BaseReactor
   959  	channels []*conn.ChannelDescriptor
   960  	waitChan chan int
   961  }
   962  
   963  func NewBlockedReactor(channels []*conn.ChannelDescriptor, async bool, recvBufSize int) *BlockedReactor {
   964  	br := &BlockedReactor{
   965  		channels: channels,
   966  	}
   967  	br.BaseReactor = *NewBaseReactor("BlockedReactor", br, async, recvBufSize)
   968  	br.waitChan = make(chan int, 1)
   969  	br.SetLogger(log.TestingLogger())
   970  	return br
   971  }
   972  
   973  func (br *BlockedReactor) GetChannels() []*conn.ChannelDescriptor {
   974  	return br.channels
   975  }
   976  
   977  func (br *BlockedReactor) AddPeer(peer Peer) {}
   978  
   979  func (br *BlockedReactor) RemovePeer(peer Peer, reason interface{}) {}
   980  
   981  func (br *BlockedReactor) ReceiveEnvelope(e Envelope) {
   982  	<-br.waitChan
   983  }
   984  
   985  func (br *BlockedReactor) Receive(chID byte, peer Peer, msgBytes []byte) {
   986  	<-br.waitChan
   987  }
   988  
   989  const (
   990  	reactorNameNormal  = "normal"
   991  	reactorNameBlocked = "blocked"
   992  )
   993  
   994  func TestSyncReactor(t *testing.T) {
   995  	cfg.RecvAsync = false
   996  	s1, s2 := MakeSwitchPair(t, getInitSwitchFunc(0))
   997  	defer func() {
   998  		err := s1.Stop()
   999  		if err != nil {
  1000  			t.Log(err)
  1001  		}
  1002  	}()
  1003  	defer func() {
  1004  		err := s2.Stop()
  1005  		if err != nil {
  1006  			t.Log(err)
  1007  		}
  1008  	}()
  1009  
  1010  	normalReactor := s2.Reactor(reactorNameNormal).(*NormalReactor)
  1011  	blockedReactor := s2.Reactor(reactorNameBlocked).(*BlockedReactor)
  1012  	// the message for blocked reactor is first
  1013  	s1.BroadcastEnvelope(Envelope{
  1014  		ChannelID: blockedReactor.channels[0].ID,
  1015  		Message:   &p2pproto.PexRequest{},
  1016  	})
  1017  	// to make order among messages
  1018  	time.Sleep(time.Millisecond * 200)
  1019  	// and then second message is for normal reactor
  1020  	s1.BroadcastEnvelope(Envelope{
  1021  		ChannelID: normalReactor.channels[0].ID,
  1022  		Message:   &p2pproto.PexRequest{},
  1023  	})
  1024  
  1025  	select {
  1026  	case <-normalReactor.msgChan:
  1027  		assert.Fail(t, "blocked reactor is not blocked")
  1028  	case <-time.After(time.Second * 1):
  1029  		assert.True(t, true, "blocked reactor is blocked: OK")
  1030  	}
  1031  
  1032  	blockedReactor.waitChan <- 1 // release blocked reactor
  1033  	msg := <-normalReactor.msgChan
  1034  	assert.True(t, bytes.Equal(msg, []byte{0}))
  1035  }
  1036  
  1037  func TestAsyncReactor(t *testing.T) {
  1038  	cfg.RecvAsync = true
  1039  	s1, s2 := MakeSwitchPair(t, getInitSwitchFunc(1))
  1040  	defer func() {
  1041  		err := s1.Stop()
  1042  		if err != nil {
  1043  			t.Log(err)
  1044  		}
  1045  	}()
  1046  	defer func() {
  1047  		err := s2.Stop()
  1048  		if err != nil {
  1049  			t.Log(err)
  1050  		}
  1051  	}()
  1052  
  1053  	normalReactor := s2.Reactor(reactorNameNormal).(*NormalReactor)
  1054  	blockedReactor := s2.Reactor(reactorNameBlocked).(*BlockedReactor)
  1055  	// the message for blocked reactor is first
  1056  	s1.BroadcastEnvelope(Envelope{
  1057  		ChannelID: blockedReactor.channels[0].ID,
  1058  		Message:   &p2pproto.PexRequest{},
  1059  	})
  1060  	// to make order among messages
  1061  	time.Sleep(time.Millisecond * 200)
  1062  	// and then second message is for normal reactor
  1063  	s1.BroadcastEnvelope(Envelope{
  1064  		ChannelID: normalReactor.channels[0].ID,
  1065  		Message:   &p2pproto.PexRequest{},
  1066  	})
  1067  
  1068  	select {
  1069  	case msg := <-normalReactor.msgChan:
  1070  		assert.True(t, bytes.Equal(msg, []byte{0}))
  1071  	case <-time.After(time.Second * 1):
  1072  		assert.Fail(t, "blocked reactor is blocked")
  1073  	}
  1074  }
  1075  
  1076  func getInitSwitchFunc(bufSize int) func(int, *Switch, *config.P2PConfig) *Switch {
  1077  	return func(i int, sw *Switch, config *config.P2PConfig) *Switch {
  1078  		sw.SetAddrBook(&addrBookMock{
  1079  			addrs:    make(map[string]struct{}),
  1080  			ourAddrs: make(map[string]struct{})})
  1081  
  1082  		// Make two reactors of two channels each
  1083  		sw.AddReactor(reactorNameNormal, NewNormalReactor([]*conn.ChannelDescriptor{
  1084  			{ID: byte(0x00), Priority: 10, MessageType: &p2pproto.Message{}},
  1085  		}, config.RecvAsync, bufSize))
  1086  		sw.AddReactor(reactorNameBlocked, NewBlockedReactor([]*conn.ChannelDescriptor{
  1087  			{ID: byte(0x01), Priority: 10, MessageType: &p2pproto.Message{}},
  1088  		}, config.RecvAsync, bufSize))
  1089  
  1090  		return sw
  1091  	}
  1092  }