github.com/nitinawathare/ethereumassignment3@v0.0.0-20211021213010-f07344c2b868/go-ethereum/swarm/network/protocol_test.go (about)

     1  // Copyright 2016 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package network
    18  
    19  import (
    20  	"crypto/ecdsa"
    21  	"flag"
    22  	"fmt"
    23  	"os"
    24  	"sync"
    25  	"testing"
    26  	"time"
    27  
    28  	"github.com/ethereum/go-ethereum/crypto"
    29  	"github.com/ethereum/go-ethereum/log"
    30  	"github.com/ethereum/go-ethereum/p2p"
    31  	"github.com/ethereum/go-ethereum/p2p/enode"
    32  	"github.com/ethereum/go-ethereum/p2p/enr"
    33  	"github.com/ethereum/go-ethereum/p2p/protocols"
    34  	p2ptest "github.com/ethereum/go-ethereum/p2p/testing"
    35  	"github.com/ethereum/go-ethereum/swarm/pot"
    36  )
    37  
    38  const (
    39  	TestProtocolVersion = 8
    40  )
    41  
    42  var TestProtocolNetworkID = DefaultTestNetworkID
    43  
    44  var (
    45  	loglevel = flag.Int("loglevel", 2, "verbosity of logs")
    46  )
    47  
    48  func init() {
    49  	flag.Parse()
    50  	log.Root().SetHandler(log.LvlFilterHandler(log.Lvl(*loglevel), log.StreamHandler(os.Stderr, log.TerminalFormat(true))))
    51  }
    52  
    53  func HandshakeMsgExchange(lhs, rhs *HandshakeMsg, id enode.ID) []p2ptest.Exchange {
    54  	return []p2ptest.Exchange{
    55  		{
    56  			Expects: []p2ptest.Expect{
    57  				{
    58  					Code: 0,
    59  					Msg:  lhs,
    60  					Peer: id,
    61  				},
    62  			},
    63  		},
    64  		{
    65  			Triggers: []p2ptest.Trigger{
    66  				{
    67  					Code: 0,
    68  					Msg:  rhs,
    69  					Peer: id,
    70  				},
    71  			},
    72  		},
    73  	}
    74  }
    75  
    76  func newBzzBaseTester(n int, prvkey *ecdsa.PrivateKey, spec *protocols.Spec, run func(*BzzPeer) error) (*bzzTester, error) {
    77  	var addrs [][]byte
    78  	for i := 0; i < n; i++ {
    79  		addr := pot.RandomAddress()
    80  		addrs = append(addrs, addr[:])
    81  	}
    82  	pt, _, err := newBzzBaseTesterWithAddrs(prvkey, addrs, spec, run)
    83  	return pt, err
    84  }
    85  
    86  func newBzzBaseTesterWithAddrs(prvkey *ecdsa.PrivateKey, addrs [][]byte, spec *protocols.Spec, run func(*BzzPeer) error) (*bzzTester, [][]byte, error) {
    87  	n := len(addrs)
    88  	cs := make(map[enode.ID]chan bool)
    89  
    90  	srv := func(p *BzzPeer) error {
    91  		defer func() {
    92  			if cs[p.ID()] != nil {
    93  				close(cs[p.ID()])
    94  			}
    95  		}()
    96  		return run(p)
    97  	}
    98  	mu := &sync.Mutex{}
    99  	nodeToAddr := make(map[enode.ID][]byte)
   100  	protocol := func(p *p2p.Peer, rw p2p.MsgReadWriter) error {
   101  		mu.Lock()
   102  		defer mu.Unlock()
   103  		nodeToAddr[p.ID()] = addrs[0]
   104  		bzzAddr := &BzzAddr{addrs[0], []byte(p.Node().String())}
   105  		addrs = addrs[1:]
   106  		return srv(&BzzPeer{Peer: protocols.NewPeer(p, rw, spec), BzzAddr: bzzAddr})
   107  	}
   108  
   109  	s := p2ptest.NewProtocolTester(prvkey, n, protocol)
   110  	var record enr.Record
   111  	bzzKey := PrivateKeyToBzzKey(prvkey)
   112  	record.Set(NewENRAddrEntry(bzzKey))
   113  	err := enode.SignV4(&record, prvkey)
   114  	if err != nil {
   115  		return nil, nil, fmt.Errorf("unable to generate ENR: %v", err)
   116  	}
   117  	nod, err := enode.New(enode.V4ID{}, &record)
   118  	if err != nil {
   119  		return nil, nil, fmt.Errorf("unable to create enode: %v", err)
   120  	}
   121  	addr := getENRBzzAddr(nod)
   122  
   123  	for _, node := range s.Nodes {
   124  		log.Warn("node", "node", node)
   125  		cs[node.ID()] = make(chan bool)
   126  	}
   127  
   128  	var nodeAddrs [][]byte
   129  	pt := &bzzTester{
   130  		addr:           addr,
   131  		ProtocolTester: s,
   132  		cs:             cs,
   133  	}
   134  	for _, n := range pt.Nodes {
   135  		nodeAddrs = append(nodeAddrs, nodeToAddr[n.ID()])
   136  	}
   137  
   138  	return pt, nodeAddrs, nil
   139  }
   140  
   141  type bzzTester struct {
   142  	*p2ptest.ProtocolTester
   143  	addr *BzzAddr
   144  	cs   map[enode.ID]chan bool
   145  	bzz  *Bzz
   146  }
   147  
   148  func newBzz(addr *BzzAddr, lightNode bool) *Bzz {
   149  	config := &BzzConfig{
   150  		OverlayAddr:  addr.Over(),
   151  		UnderlayAddr: addr.Under(),
   152  		HiveParams:   NewHiveParams(),
   153  		NetworkID:    DefaultTestNetworkID,
   154  		LightNode:    lightNode,
   155  	}
   156  	kad := NewKademlia(addr.OAddr, NewKadParams())
   157  	bzz := NewBzz(config, kad, nil, nil, nil)
   158  	return bzz
   159  }
   160  
   161  func newBzzHandshakeTester(n int, prvkey *ecdsa.PrivateKey, lightNode bool) (*bzzTester, error) {
   162  
   163  	var record enr.Record
   164  	bzzkey := PrivateKeyToBzzKey(prvkey)
   165  	record.Set(NewENRAddrEntry(bzzkey))
   166  	record.Set(ENRLightNodeEntry(lightNode))
   167  	err := enode.SignV4(&record, prvkey)
   168  	if err != nil {
   169  		return nil, err
   170  	}
   171  	nod, err := enode.New(enode.V4ID{}, &record)
   172  	addr := getENRBzzAddr(nod)
   173  
   174  	bzz := newBzz(addr, lightNode)
   175  
   176  	pt := p2ptest.NewProtocolTester(prvkey, n, bzz.runBzz)
   177  
   178  	return &bzzTester{
   179  		addr:           addr,
   180  		ProtocolTester: pt,
   181  		bzz:            bzz,
   182  	}, nil
   183  }
   184  
   185  // should test handshakes in one exchange? parallelisation
   186  func (s *bzzTester) testHandshake(lhs, rhs *HandshakeMsg, disconnects ...*p2ptest.Disconnect) error {
   187  	if err := s.TestExchanges(HandshakeMsgExchange(lhs, rhs, rhs.Addr.ID())...); err != nil {
   188  		return err
   189  	}
   190  
   191  	if len(disconnects) > 0 {
   192  		return s.TestDisconnected(disconnects...)
   193  	}
   194  
   195  	// If we don't expect disconnect, ensure peers remain connected
   196  	err := s.TestDisconnected(&p2ptest.Disconnect{
   197  		Peer:  s.Nodes[0].ID(),
   198  		Error: nil,
   199  	})
   200  
   201  	if err == nil {
   202  		return fmt.Errorf("Unexpected peer disconnect")
   203  	}
   204  
   205  	if err.Error() != "timed out waiting for peers to disconnect" {
   206  		return err
   207  	}
   208  
   209  	return nil
   210  }
   211  
   212  func correctBzzHandshake(addr *BzzAddr, lightNode bool) *HandshakeMsg {
   213  	return &HandshakeMsg{
   214  		Version:   TestProtocolVersion,
   215  		NetworkID: TestProtocolNetworkID,
   216  		Addr:      addr,
   217  		LightNode: lightNode,
   218  	}
   219  }
   220  
   221  func TestBzzHandshakeNetworkIDMismatch(t *testing.T) {
   222  	lightNode := false
   223  	prvkey, err := crypto.GenerateKey()
   224  	if err != nil {
   225  		t.Fatal(err)
   226  	}
   227  	s, err := newBzzHandshakeTester(1, prvkey, lightNode)
   228  	if err != nil {
   229  		t.Fatal(err)
   230  	}
   231  	node := s.Nodes[0]
   232  
   233  	err = s.testHandshake(
   234  		correctBzzHandshake(s.addr, lightNode),
   235  		&HandshakeMsg{Version: TestProtocolVersion, NetworkID: 321, Addr: NewAddr(node)},
   236  		&p2ptest.Disconnect{Peer: node.ID(), Error: fmt.Errorf("Handshake error: Message handler error: (msg code 0): network id mismatch 321 (!= %v)", TestProtocolNetworkID)},
   237  	)
   238  
   239  	if err != nil {
   240  		t.Fatal(err)
   241  	}
   242  }
   243  
   244  func TestBzzHandshakeVersionMismatch(t *testing.T) {
   245  	lightNode := false
   246  	prvkey, err := crypto.GenerateKey()
   247  	if err != nil {
   248  		t.Fatal(err)
   249  	}
   250  	s, err := newBzzHandshakeTester(1, prvkey, lightNode)
   251  	if err != nil {
   252  		t.Fatal(err)
   253  	}
   254  	node := s.Nodes[0]
   255  
   256  	err = s.testHandshake(
   257  		correctBzzHandshake(s.addr, lightNode),
   258  		&HandshakeMsg{Version: 0, NetworkID: TestProtocolNetworkID, Addr: NewAddr(node)},
   259  		&p2ptest.Disconnect{Peer: node.ID(), Error: fmt.Errorf("Handshake error: Message handler error: (msg code 0): version mismatch 0 (!= %d)", TestProtocolVersion)},
   260  	)
   261  
   262  	if err != nil {
   263  		t.Fatal(err)
   264  	}
   265  }
   266  
   267  func TestBzzHandshakeSuccess(t *testing.T) {
   268  	lightNode := false
   269  	prvkey, err := crypto.GenerateKey()
   270  	if err != nil {
   271  		t.Fatal(err)
   272  	}
   273  	s, err := newBzzHandshakeTester(1, prvkey, lightNode)
   274  	if err != nil {
   275  		t.Fatal(err)
   276  	}
   277  	node := s.Nodes[0]
   278  
   279  	err = s.testHandshake(
   280  		correctBzzHandshake(s.addr, lightNode),
   281  		&HandshakeMsg{Version: TestProtocolVersion, NetworkID: TestProtocolNetworkID, Addr: NewAddr(node)},
   282  	)
   283  
   284  	if err != nil {
   285  		t.Fatal(err)
   286  	}
   287  }
   288  
   289  func TestBzzHandshakeLightNode(t *testing.T) {
   290  	var lightNodeTests = []struct {
   291  		name      string
   292  		lightNode bool
   293  	}{
   294  		{"on", true},
   295  		{"off", false},
   296  	}
   297  
   298  	for _, test := range lightNodeTests {
   299  		t.Run(test.name, func(t *testing.T) {
   300  			prvkey, err := crypto.GenerateKey()
   301  			if err != nil {
   302  				t.Fatal(err)
   303  			}
   304  			pt, err := newBzzHandshakeTester(1, prvkey, false)
   305  			if err != nil {
   306  				t.Fatal(err)
   307  			}
   308  
   309  			node := pt.Nodes[0]
   310  			addr := NewAddr(node)
   311  
   312  			err = pt.testHandshake(
   313  				correctBzzHandshake(pt.addr, false),
   314  				&HandshakeMsg{Version: TestProtocolVersion, NetworkID: TestProtocolNetworkID, Addr: addr, LightNode: test.lightNode},
   315  			)
   316  
   317  			if err != nil {
   318  				t.Fatal(err)
   319  			}
   320  
   321  			select {
   322  
   323  			case <-pt.bzz.handshakes[node.ID()].done:
   324  				if pt.bzz.handshakes[node.ID()].LightNode != test.lightNode {
   325  					t.Fatalf("peer LightNode flag is %v, should be %v", pt.bzz.handshakes[node.ID()].LightNode, test.lightNode)
   326  				}
   327  			case <-time.After(10 * time.Second):
   328  				t.Fatal("test timeout")
   329  			}
   330  		})
   331  	}
   332  }