github.com/klaytn/klaytn@v1.10.2/networks/p2p/peer_test.go (about)

     1  // Modifications Copyright 2018 The klaytn Authors
     2  // Copyright 2014 The go-ethereum Authors
     3  // This file is part of the go-ethereum library.
     4  //
     5  // The go-ethereum library is free software: you can redistribute it and/or modify
     6  // it under the terms of the GNU Lesser General Public License as published by
     7  // the Free Software Foundation, either version 3 of the License, or
     8  // (at your option) any later version.
     9  //
    10  // The go-ethereum library is distributed in the hope that it will be useful,
    11  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    12  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    13  // GNU Lesser General Public License for more details.
    14  //
    15  // You should have received a copy of the GNU Lesser General Public License
    16  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    17  //
    18  // This file is derived from p2p/peer_test.go (2018/06/04).
    19  // Modified and improved for the klaytn development.
    20  
    21  package p2p
    22  
    23  import (
    24  	"errors"
    25  	"fmt"
    26  	"math"
    27  	"math/rand"
    28  	"net"
    29  	"reflect"
    30  	"testing"
    31  	"time"
    32  
    33  	"github.com/stretchr/testify/assert"
    34  )
    35  
    36  var discard = Protocol{
    37  	Name:   "discard",
    38  	Length: 1,
    39  	Run: func(p *Peer, rw MsgReadWriter) error {
    40  		for {
    41  			msg, err := rw.ReadMsg()
    42  			if err != nil {
    43  				return err
    44  			}
    45  			fmt.Printf("discarding %d\n", msg.Code)
    46  			if err = msg.Discard(); err != nil {
    47  				return err
    48  			}
    49  		}
    50  	},
    51  }
    52  
    53  func testPeer(protos []Protocol) (func(), *conn, *Peer, <-chan error) {
    54  	fd1, fd2 := net.Pipe()
    55  	c1 := &conn{fd: fd1, transport: newTestTransport(randomID(), fd1, false)}
    56  	c2 := &conn{fd: fd2, transport: newTestTransport(randomID(), fd2, false)}
    57  	for _, p := range protos {
    58  		c1.caps = append(c1.caps, p.cap())
    59  		c2.caps = append(c2.caps, p.cap())
    60  	}
    61  
    62  	peer, _ := newPeer([]*conn{c1}, protos, defaultRWTimerConfig)
    63  	errc := make(chan error, 1)
    64  	go func() {
    65  		_, err := peer.run()
    66  		errc <- err
    67  	}()
    68  
    69  	closer := func() { c2.close(errors.New("close func called")) }
    70  	return closer, c2, peer, errc
    71  }
    72  
    73  func testPeerWithRWs(protos []Protocol, channelSize int) (func(), []*conn, *Peer, <-chan error) {
    74  	serverSideConn := make([]*conn, 0, channelSize)
    75  	peerSideConn := make([]*conn, 0, channelSize)
    76  
    77  	for i := 0; i < channelSize; i++ {
    78  		fd1, fd2 := net.Pipe()
    79  		c1 := &conn{fd: fd1, transport: newTestTransport(randomID(), fd1, true)}
    80  		c2 := &conn{fd: fd2, transport: newTestTransport(randomID(), fd2, true)}
    81  		for _, p := range protos {
    82  			c1.caps = append(c1.caps, p.cap())
    83  			c2.caps = append(c2.caps, p.cap())
    84  		}
    85  		serverSideConn = append(serverSideConn, c1)
    86  		peerSideConn = append(peerSideConn, c2)
    87  	}
    88  
    89  	peer, _ := newPeer(serverSideConn, protos, defaultRWTimerConfig)
    90  	errc := make(chan error, 1)
    91  	go func() {
    92  		_, err := peer.runWithRWs()
    93  		errc <- err
    94  	}()
    95  
    96  	closer := func() {
    97  		for _, conn := range peerSideConn {
    98  			conn.close(errors.New("close func called"))
    99  		}
   100  	}
   101  	return closer, peerSideConn, peer, errc
   102  }
   103  
   104  func TestPeerProtoReadMsg(t *testing.T) {
   105  	proto := Protocol{
   106  		Name:   "a",
   107  		Length: 5,
   108  		Run: func(peer *Peer, rw MsgReadWriter) error {
   109  			if err := ExpectMsg(rw, 2, []uint{1}); err != nil {
   110  				t.Error(err)
   111  			}
   112  			if err := ExpectMsg(rw, 3, []uint{2}); err != nil {
   113  				t.Error(err)
   114  			}
   115  			if err := ExpectMsg(rw, 4, []uint{3}); err != nil {
   116  				t.Error(err)
   117  			}
   118  			return nil
   119  		},
   120  	}
   121  
   122  	closer, rw, _, errc := testPeer([]Protocol{proto})
   123  	defer closer()
   124  
   125  	Send(rw, baseProtocolLength+2, []uint{1})
   126  	Send(rw, baseProtocolLength+3, []uint{2})
   127  	Send(rw, baseProtocolLength+4, []uint{3})
   128  
   129  	select {
   130  	case err := <-errc:
   131  		if err != errProtocolReturned {
   132  			t.Errorf("peer returned error: %v", err)
   133  		}
   134  	case <-time.After(2 * time.Second):
   135  		t.Errorf("receive timeout")
   136  	}
   137  }
   138  
   139  func TestPeerProtoEncodeMsg(t *testing.T) {
   140  	proto := Protocol{
   141  		Name:   "a",
   142  		Length: 2,
   143  		Run: func(peer *Peer, rw MsgReadWriter) error {
   144  			if err := SendItems(rw, 2); err == nil {
   145  				t.Error("expected error for out-of-range msg code, got nil")
   146  			}
   147  			if err := SendItems(rw, 1, "foo", "bar"); err != nil {
   148  				t.Errorf("write error: %v", err)
   149  			}
   150  			return nil
   151  		},
   152  	}
   153  	closer, rw, _, _ := testPeer([]Protocol{proto})
   154  	defer closer()
   155  
   156  	if err := ExpectMsg(rw, 17, []string{"foo", "bar"}); err != nil {
   157  		t.Error(err)
   158  	}
   159  }
   160  
   161  func TestPeerPing(t *testing.T) {
   162  	closer, rw, _, _ := testPeer(nil)
   163  	defer closer()
   164  	if err := SendItems(rw, pingMsg); err != nil {
   165  		t.Fatal(err)
   166  	}
   167  	if err := ExpectMsg(rw, pongMsg, nil); err != nil {
   168  		t.Error(err)
   169  	}
   170  }
   171  
   172  func TestPeerDisconnect(t *testing.T) {
   173  	testData := []DiscReason{DiscQuitting, math.MaxUint}
   174  	for _, tc := range testData {
   175  		closer, rw, _, disc := testPeer(nil)
   176  		if err := SendItems(rw, discMsg, tc); err != nil {
   177  			t.Fatal(err)
   178  		}
   179  		select {
   180  		case reason := <-disc:
   181  			assert.Equal(t, tc.Error(), reason.Error())
   182  		case <-time.After(500 * time.Millisecond):
   183  			t.Error("peer did not return")
   184  		}
   185  		closer()
   186  	}
   187  }
   188  
   189  // This test is supposed to verify that Peer can reliably handle
   190  // multiple causes of disconnection occurring at the same time.
   191  func TestPeerDisconnectRace(t *testing.T) {
   192  	maybe := func() bool { return rand.Intn(2) == 1 }
   193  
   194  	for i := 0; i < 1000; i++ {
   195  		protoclose := make(chan error)
   196  		protodisc := make(chan DiscReason)
   197  		closer, rw, p, disc := testPeer([]Protocol{
   198  			{
   199  				Name:   "closereq",
   200  				Run:    func(p *Peer, rw MsgReadWriter) error { return <-protoclose },
   201  				Length: 1,
   202  			},
   203  			{
   204  				Name:   "disconnect",
   205  				Run:    func(p *Peer, rw MsgReadWriter) error { p.Disconnect(<-protodisc); return nil },
   206  				Length: 1,
   207  			},
   208  		})
   209  
   210  		// Simulate incoming messages.
   211  		go SendItems(rw, baseProtocolLength+1)
   212  		go SendItems(rw, baseProtocolLength+2)
   213  		// Close the network connection.
   214  		go closer()
   215  		// Make protocol "closereq" return.
   216  		protoclose <- errors.New("protocol closed")
   217  		// Make protocol "disconnect" call peer.Disconnect
   218  		protodisc <- DiscAlreadyConnected
   219  		// In some cases, simulate something else calling peer.Disconnect.
   220  		if maybe() {
   221  			go p.Disconnect(DiscInvalidIdentity)
   222  		}
   223  		// In some cases, simulate remote requesting a disconnect.
   224  		if maybe() {
   225  			go SendItems(rw, discMsg, DiscQuitting)
   226  		}
   227  
   228  		select {
   229  		case <-disc:
   230  		case <-time.After(2 * time.Second):
   231  			// Peer.run should return quickly. If it doesn't the Peer
   232  			// goroutines are probably deadlocked. Call panic in order to
   233  			// show the stacks.
   234  			panic("Peer.run took to long to return.")
   235  		}
   236  	}
   237  }
   238  
   239  func TestMultiChannelPeerProtoReadMsg(t *testing.T) {
   240  	proto := Protocol{
   241  		Name:   "a",
   242  		Length: 5,
   243  		RunWithRWs: func(peer *Peer, rws []MsgReadWriter) error {
   244  			for _, rw := range rws {
   245  				if err := ExpectMsg(rw, 2, []uint{1}); err != nil {
   246  					t.Error(err)
   247  				}
   248  				if err := ExpectMsg(rw, 3, []uint{2}); err != nil {
   249  					t.Error(err)
   250  				}
   251  				if err := ExpectMsg(rw, 4, []uint{3}); err != nil {
   252  					t.Error(err)
   253  				}
   254  			}
   255  			return nil
   256  		},
   257  	}
   258  
   259  	closer, rws, _, errc := testPeerWithRWs([]Protocol{proto}, 2)
   260  	defer closer()
   261  
   262  	for _, rw := range rws {
   263  		Send(rw, baseProtocolLength+2, []uint{1})
   264  		Send(rw, baseProtocolLength+3, []uint{2})
   265  		Send(rw, baseProtocolLength+4, []uint{3})
   266  	}
   267  
   268  	select {
   269  	case err := <-errc:
   270  		if err != errProtocolReturned {
   271  			t.Errorf("peer returned error: %v", err)
   272  		}
   273  	case <-time.After(2 * time.Second):
   274  		t.Errorf("receive timeout")
   275  	}
   276  }
   277  
   278  func TestMultiChannelPeerProtoEncodeMsg(t *testing.T) {
   279  	proto := Protocol{
   280  		Name:   "a",
   281  		Length: 2,
   282  		RunWithRWs: func(peer *Peer, rws []MsgReadWriter) error {
   283  			for _, rw := range rws {
   284  				if err := SendItems(rw, 2); err == nil {
   285  					t.Error("expected error for out-of-range msg code, got nil")
   286  				}
   287  				if err := SendItems(rw, 1, "foo", "bar"); err != nil {
   288  					t.Errorf("write error: %v", err)
   289  				}
   290  			}
   291  			return nil
   292  		},
   293  	}
   294  	closer, rws, _, _ := testPeerWithRWs([]Protocol{proto}, 2)
   295  	defer closer()
   296  
   297  	for _, rw := range rws {
   298  		if err := ExpectMsg(rw, 17, []string{"foo", "bar"}); err != nil {
   299  			t.Error(err)
   300  		}
   301  	}
   302  }
   303  
   304  func TestMultiChannelPeerPing(t *testing.T) {
   305  	closer, rws, _, _ := testPeerWithRWs(nil, 2)
   306  	defer closer()
   307  
   308  	for _, rw := range rws {
   309  		if err := SendItems(rw, pingMsg); err != nil {
   310  			t.Fatal(err)
   311  		}
   312  		if err := ExpectMsg(rw, pongMsg, nil); err != nil {
   313  			t.Error(err)
   314  		}
   315  	}
   316  }
   317  
   318  func TestMultiChannelPeerDisconnect(t *testing.T) {
   319  	channelSize := 2
   320  	for i := 0; i < channelSize; i++ {
   321  		closer, rws, _, disc := testPeerWithRWs(nil, channelSize)
   322  		defer closer()
   323  
   324  		if err := SendItems(rws[i], discMsg, DiscQuitting); err != nil {
   325  			t.Fatal(err)
   326  		}
   327  
   328  		select {
   329  		case reason := <-disc:
   330  			if reason != DiscQuitting {
   331  				t.Errorf("run returned wrong reason: got %v, want %v", reason, DiscQuitting)
   332  			}
   333  		case <-time.After(500 * time.Millisecond):
   334  			t.Error("peer did not return")
   335  		}
   336  	}
   337  }
   338  
   339  // This test is supposed to verify that Peer can reliably handle
   340  // multiple causes of disconnection occurring at the same time.
   341  func TestMultiChannelPeerDisconnectRace(t *testing.T) {
   342  	maybe := func() bool { return rand.Intn(2) == 1 }
   343  	channelSize := 2
   344  
   345  	for i := 0; i < 1000; i++ {
   346  		protoclose := make(chan error)
   347  		protodisc := make(chan DiscReason)
   348  		closer, rws, p, disc := testPeerWithRWs([]Protocol{
   349  			{
   350  				Name:       "closereq",
   351  				RunWithRWs: func(p *Peer, rw []MsgReadWriter) error { return <-protoclose },
   352  				Length:     1,
   353  			},
   354  			{
   355  				Name:       "disconnect",
   356  				RunWithRWs: func(p *Peer, rw []MsgReadWriter) error { p.Disconnect(<-protodisc); return nil },
   357  				Length:     1,
   358  			},
   359  		}, channelSize)
   360  
   361  		// Simulate incoming messages.
   362  		for _, rw := range rws {
   363  			go SendItems(rw, baseProtocolLength+1)
   364  			go SendItems(rw, baseProtocolLength+2)
   365  		}
   366  		// Close the network connection.
   367  		go closer()
   368  		// Make protocol "closereq" return.
   369  		protoclose <- errors.New("protocol closed")
   370  		// Make protocol "disconnect" call peer.Disconnect
   371  		protodisc <- DiscAlreadyConnected
   372  		// In some cases, simulate something else calling peer.Disconnect.
   373  		if maybe() {
   374  			go p.Disconnect(DiscInvalidIdentity)
   375  		}
   376  		// In some cases, simulate remote requesting a disconnect.
   377  		if maybe() {
   378  			go SendItems(rws[rand.Intn(channelSize)], discMsg, DiscQuitting)
   379  		}
   380  
   381  		select {
   382  		case <-disc:
   383  		case <-time.After(2 * time.Second):
   384  			// Peer.run should return quickly. If it doesn't the Peer
   385  			// goroutines are probably deadlocked. Call panic in order to
   386  			// show the stacks.
   387  			panic("Peer.run took to long to return.")
   388  		}
   389  	}
   390  }
   391  
   392  func TestNewPeer(t *testing.T) {
   393  	name := "nodename"
   394  	caps := []Cap{{"foo", 2}, {"bar", 3}}
   395  	id := randomID()
   396  	p := NewPeer(id, name, caps)
   397  	if p.ID() != id {
   398  		t.Errorf("ID mismatch: got %v, expected %v", p.ID(), id)
   399  	}
   400  	if p.Name() != name {
   401  		t.Errorf("Name mismatch: got %v, expected %v", p.Name(), name)
   402  	}
   403  	if !reflect.DeepEqual(p.Caps(), caps) {
   404  		t.Errorf("Caps mismatch: got %v, expected %v", p.Caps(), caps)
   405  	}
   406  
   407  	p.Disconnect(DiscAlreadyConnected) // Should not hang
   408  }
   409  
   410  func TestMatchProtocols(t *testing.T) {
   411  	tests := []struct {
   412  		Remote []Cap
   413  		Local  []Protocol
   414  		Match  map[string]protoRW
   415  	}{
   416  		{
   417  			// No remote capabilities
   418  			Local: []Protocol{{Name: "a"}},
   419  		},
   420  		{
   421  			// No local protocols
   422  			Remote: []Cap{{Name: "a"}},
   423  		},
   424  		{
   425  			// No mutual protocols
   426  			Remote: []Cap{{Name: "a"}},
   427  			Local:  []Protocol{{Name: "b"}},
   428  		},
   429  		{
   430  			// Some matches, some differences
   431  			Remote: []Cap{{Name: "local"}, {Name: "match1"}, {Name: "match2"}},
   432  			Local:  []Protocol{{Name: "match1"}, {Name: "match2"}, {Name: "remote"}},
   433  			Match: map[string]protoRW{
   434  				"match1": {Protocol: Protocol{Name: "match1"}, tc: defaultRWTimerConfig},
   435  				"match2": {Protocol: Protocol{Name: "match2"}, tc: defaultRWTimerConfig},
   436  			},
   437  		},
   438  		{
   439  			// Various alphabetical ordering
   440  			Remote: []Cap{{Name: "aa"}, {Name: "ab"}, {Name: "bb"}, {Name: "ba"}},
   441  			Local:  []Protocol{{Name: "ba"}, {Name: "bb"}, {Name: "ab"}, {Name: "aa"}},
   442  			Match: map[string]protoRW{
   443  				"aa": {Protocol: Protocol{Name: "aa"}, tc: defaultRWTimerConfig},
   444  				"ab": {Protocol: Protocol{Name: "ab"}, tc: defaultRWTimerConfig},
   445  				"ba": {Protocol: Protocol{Name: "ba"}, tc: defaultRWTimerConfig},
   446  				"bb": {Protocol: Protocol{Name: "bb"}, tc: defaultRWTimerConfig},
   447  			},
   448  		},
   449  		{
   450  			// No mutual versions
   451  			Remote: []Cap{{Version: 1}},
   452  			Local:  []Protocol{{Version: 2}},
   453  		},
   454  		{
   455  			// Multiple versions, single common
   456  			Remote: []Cap{{Version: 1}, {Version: 2}},
   457  			Local:  []Protocol{{Version: 2}, {Version: 3}},
   458  			Match:  map[string]protoRW{"": {Protocol: Protocol{Version: 2}, tc: defaultRWTimerConfig}},
   459  		},
   460  		{
   461  			// Multiple versions, multiple common
   462  			Remote: []Cap{{Version: 1}, {Version: 2}, {Version: 3}, {Version: 4}},
   463  			Local:  []Protocol{{Version: 2}, {Version: 3}},
   464  			Match:  map[string]protoRW{"": {Protocol: Protocol{Version: 3}, tc: defaultRWTimerConfig}},
   465  		},
   466  		{
   467  			// Various version orderings
   468  			Remote: []Cap{{Version: 4}, {Version: 1}, {Version: 3}, {Version: 2}},
   469  			Local:  []Protocol{{Version: 2}, {Version: 3}, {Version: 1}},
   470  			Match:  map[string]protoRW{"": {Protocol: Protocol{Version: 3}, tc: defaultRWTimerConfig}},
   471  		},
   472  		{
   473  			// Versions overriding sub-protocol lengths
   474  			Remote: []Cap{{Version: 1}, {Version: 2}, {Version: 3}, {Name: "a"}},
   475  			Local:  []Protocol{{Version: 1, Length: 1}, {Version: 2, Length: 2}, {Version: 3, Length: 3}, {Name: "a"}},
   476  			Match: map[string]protoRW{
   477  				"":  {Protocol: Protocol{Version: 3}, tc: defaultRWTimerConfig},
   478  				"a": {Protocol: Protocol{Name: "a"}, offset: 3, tc: defaultRWTimerConfig},
   479  			},
   480  		},
   481  	}
   482  
   483  	for i, tt := range tests {
   484  		result := matchProtocols(tt.Local, tt.Remote, nil, defaultRWTimerConfig)
   485  		if len(result) != len(tt.Match) {
   486  			t.Errorf("test %d: negotiation mismatch: have %v, want %v", i, len(result), len(tt.Match))
   487  			continue
   488  		}
   489  		// Make sure all negotiated protocols are needed and correct
   490  		for name, proto := range result {
   491  			match, ok := tt.Match[name]
   492  			if !ok {
   493  				t.Errorf("test %d, protobuf '%s': negotiated but shouldn't have", i, name)
   494  				continue
   495  			}
   496  			if proto[ConnDefault].Name != match.Name {
   497  				t.Errorf("test %d, protobuf '%s': name mismatch: have %v, want %v", i, name, proto[ConnDefault].Name, match.Name)
   498  			}
   499  			if proto[ConnDefault].Version != match.Version {
   500  				t.Errorf("test %d, protobuf '%s': version mismatch: have %v, want %v", i, name, proto[ConnDefault].Version, match.Version)
   501  			}
   502  			if proto[ConnDefault].offset-baseProtocolLength != match.offset {
   503  				t.Errorf("test %d, protobuf '%s': offset mismatch: have %v, want %v", i, name, proto[ConnDefault].offset-baseProtocolLength, match.offset)
   504  			}
   505  		}
   506  		// Make sure no protocols missed negotiation
   507  		for name := range tt.Match {
   508  			if _, ok := result[name]; !ok {
   509  				t.Errorf("test %d, protobuf '%s': not negotiated, should have", i, name)
   510  				continue
   511  			}
   512  		}
   513  	}
   514  }