github.com/daeglee/go-ethereum@v0.0.0-20190504220456-cad3e8d18e9b/swarm/network/protocol_test.go (about)

     1  // Copyright 2016 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package network
    18  
    19  import (
    20  	"crypto/ecdsa"
    21  	"flag"
    22  	"fmt"
    23  	"os"
    24  	"testing"
    25  	"time"
    26  
    27  	"github.com/ethereum/go-ethereum/crypto"
    28  	"github.com/ethereum/go-ethereum/log"
    29  	"github.com/ethereum/go-ethereum/p2p"
    30  	"github.com/ethereum/go-ethereum/p2p/enode"
    31  	"github.com/ethereum/go-ethereum/p2p/enr"
    32  	"github.com/ethereum/go-ethereum/p2p/protocols"
    33  	p2ptest "github.com/ethereum/go-ethereum/p2p/testing"
    34  )
    35  
    36  const (
    37  	TestProtocolVersion   = 8
    38  	TestProtocolNetworkID = 3
    39  )
    40  
    41  var (
    42  	loglevel = flag.Int("loglevel", 2, "verbosity of logs")
    43  )
    44  
    45  func init() {
    46  	flag.Parse()
    47  	log.Root().SetHandler(log.LvlFilterHandler(log.Lvl(*loglevel), log.StreamHandler(os.Stderr, log.TerminalFormat(true))))
    48  }
    49  
    50  func HandshakeMsgExchange(lhs, rhs *HandshakeMsg, id enode.ID) []p2ptest.Exchange {
    51  	return []p2ptest.Exchange{
    52  		{
    53  			Expects: []p2ptest.Expect{
    54  				{
    55  					Code: 0,
    56  					Msg:  lhs,
    57  					Peer: id,
    58  				},
    59  			},
    60  		},
    61  		{
    62  			Triggers: []p2ptest.Trigger{
    63  				{
    64  					Code: 0,
    65  					Msg:  rhs,
    66  					Peer: id,
    67  				},
    68  			},
    69  		},
    70  	}
    71  }
    72  
    73  func newBzzBaseTester(t *testing.T, n int, prvkey *ecdsa.PrivateKey, spec *protocols.Spec, run func(*BzzPeer) error) (*bzzTester, error) {
    74  	cs := make(map[string]chan bool)
    75  
    76  	srv := func(p *BzzPeer) error {
    77  		defer func() {
    78  			if cs[p.ID().String()] != nil {
    79  				close(cs[p.ID().String()])
    80  			}
    81  		}()
    82  		return run(p)
    83  	}
    84  
    85  	protocol := func(p *p2p.Peer, rw p2p.MsgReadWriter) error {
    86  		return srv(&BzzPeer{Peer: protocols.NewPeer(p, rw, spec), BzzAddr: NewAddr(p.Node())})
    87  	}
    88  
    89  	s := p2ptest.NewProtocolTester(prvkey, n, protocol)
    90  	var record enr.Record
    91  	bzzKey := PrivateKeyToBzzKey(prvkey)
    92  	record.Set(NewENRAddrEntry(bzzKey))
    93  	err := enode.SignV4(&record, prvkey)
    94  	if err != nil {
    95  		return nil, fmt.Errorf("unable to generate ENR: %v", err)
    96  	}
    97  	nod, err := enode.New(enode.V4ID{}, &record)
    98  	if err != nil {
    99  		return nil, fmt.Errorf("unable to create enode: %v", err)
   100  	}
   101  	addr := getENRBzzAddr(nod)
   102  
   103  	for _, node := range s.Nodes {
   104  		log.Warn("node", "node", node)
   105  		cs[node.ID().String()] = make(chan bool)
   106  	}
   107  
   108  	return &bzzTester{
   109  		addr:           addr,
   110  		ProtocolTester: s,
   111  		cs:             cs,
   112  	}, nil
   113  }
   114  
   115  type bzzTester struct {
   116  	*p2ptest.ProtocolTester
   117  	addr *BzzAddr
   118  	cs   map[string]chan bool
   119  	bzz  *Bzz
   120  }
   121  
   122  func newBzz(addr *BzzAddr, lightNode bool) *Bzz {
   123  	config := &BzzConfig{
   124  		OverlayAddr:  addr.Over(),
   125  		UnderlayAddr: addr.Under(),
   126  		HiveParams:   NewHiveParams(),
   127  		NetworkID:    DefaultNetworkID,
   128  		LightNode:    lightNode,
   129  	}
   130  	kad := NewKademlia(addr.OAddr, NewKadParams())
   131  	bzz := NewBzz(config, kad, nil, nil, nil)
   132  	return bzz
   133  }
   134  
   135  func newBzzHandshakeTester(n int, prvkey *ecdsa.PrivateKey, lightNode bool) (*bzzTester, error) {
   136  
   137  	var record enr.Record
   138  	bzzkey := PrivateKeyToBzzKey(prvkey)
   139  	record.Set(NewENRAddrEntry(bzzkey))
   140  	record.Set(ENRLightNodeEntry(lightNode))
   141  	err := enode.SignV4(&record, prvkey)
   142  	if err != nil {
   143  		return nil, err
   144  	}
   145  	nod, err := enode.New(enode.V4ID{}, &record)
   146  	addr := getENRBzzAddr(nod)
   147  
   148  	bzz := newBzz(addr, lightNode)
   149  
   150  	pt := p2ptest.NewProtocolTester(prvkey, n, bzz.runBzz)
   151  
   152  	return &bzzTester{
   153  		addr:           addr,
   154  		ProtocolTester: pt,
   155  		bzz:            bzz,
   156  	}, nil
   157  }
   158  
   159  // should test handshakes in one exchange? parallelisation
   160  func (s *bzzTester) testHandshake(lhs, rhs *HandshakeMsg, disconnects ...*p2ptest.Disconnect) error {
   161  	if err := s.TestExchanges(HandshakeMsgExchange(lhs, rhs, rhs.Addr.ID())...); err != nil {
   162  		return err
   163  	}
   164  
   165  	if len(disconnects) > 0 {
   166  		return s.TestDisconnected(disconnects...)
   167  	}
   168  
   169  	// If we don't expect disconnect, ensure peers remain connected
   170  	err := s.TestDisconnected(&p2ptest.Disconnect{
   171  		Peer:  s.Nodes[0].ID(),
   172  		Error: nil,
   173  	})
   174  
   175  	if err == nil {
   176  		return fmt.Errorf("Unexpected peer disconnect")
   177  	}
   178  
   179  	if err.Error() != "timed out waiting for peers to disconnect" {
   180  		return err
   181  	}
   182  
   183  	return nil
   184  }
   185  
   186  func correctBzzHandshake(addr *BzzAddr, lightNode bool) *HandshakeMsg {
   187  	return &HandshakeMsg{
   188  		Version:   TestProtocolVersion,
   189  		NetworkID: TestProtocolNetworkID,
   190  		Addr:      addr,
   191  		LightNode: lightNode,
   192  	}
   193  }
   194  
   195  func TestBzzHandshakeNetworkIDMismatch(t *testing.T) {
   196  	lightNode := false
   197  	prvkey, err := crypto.GenerateKey()
   198  	if err != nil {
   199  		t.Fatal(err)
   200  	}
   201  	s, err := newBzzHandshakeTester(1, prvkey, lightNode)
   202  	if err != nil {
   203  		t.Fatal(err)
   204  	}
   205  	node := s.Nodes[0]
   206  
   207  	err = s.testHandshake(
   208  		correctBzzHandshake(s.addr, lightNode),
   209  		&HandshakeMsg{Version: TestProtocolVersion, NetworkID: 321, Addr: NewAddr(node)},
   210  		&p2ptest.Disconnect{Peer: node.ID(), Error: fmt.Errorf("Handshake error: Message handler error: (msg code 0): network id mismatch 321 (!= 3)")},
   211  	)
   212  
   213  	if err != nil {
   214  		t.Fatal(err)
   215  	}
   216  }
   217  
   218  func TestBzzHandshakeVersionMismatch(t *testing.T) {
   219  	lightNode := false
   220  	prvkey, err := crypto.GenerateKey()
   221  	if err != nil {
   222  		t.Fatal(err)
   223  	}
   224  	s, err := newBzzHandshakeTester(1, prvkey, lightNode)
   225  	if err != nil {
   226  		t.Fatal(err)
   227  	}
   228  	node := s.Nodes[0]
   229  
   230  	err = s.testHandshake(
   231  		correctBzzHandshake(s.addr, lightNode),
   232  		&HandshakeMsg{Version: 0, NetworkID: TestProtocolNetworkID, Addr: NewAddr(node)},
   233  		&p2ptest.Disconnect{Peer: node.ID(), Error: fmt.Errorf("Handshake error: Message handler error: (msg code 0): version mismatch 0 (!= %d)", TestProtocolVersion)},
   234  	)
   235  
   236  	if err != nil {
   237  		t.Fatal(err)
   238  	}
   239  }
   240  
   241  func TestBzzHandshakeSuccess(t *testing.T) {
   242  	lightNode := false
   243  	prvkey, err := crypto.GenerateKey()
   244  	if err != nil {
   245  		t.Fatal(err)
   246  	}
   247  	s, err := newBzzHandshakeTester(1, prvkey, lightNode)
   248  	if err != nil {
   249  		t.Fatal(err)
   250  	}
   251  	node := s.Nodes[0]
   252  
   253  	err = s.testHandshake(
   254  		correctBzzHandshake(s.addr, lightNode),
   255  		&HandshakeMsg{Version: TestProtocolVersion, NetworkID: TestProtocolNetworkID, Addr: NewAddr(node)},
   256  	)
   257  
   258  	if err != nil {
   259  		t.Fatal(err)
   260  	}
   261  }
   262  
   263  func TestBzzHandshakeLightNode(t *testing.T) {
   264  	var lightNodeTests = []struct {
   265  		name      string
   266  		lightNode bool
   267  	}{
   268  		{"on", true},
   269  		{"off", false},
   270  	}
   271  
   272  	for _, test := range lightNodeTests {
   273  		t.Run(test.name, func(t *testing.T) {
   274  			prvkey, err := crypto.GenerateKey()
   275  			if err != nil {
   276  				t.Fatal(err)
   277  			}
   278  			pt, err := newBzzHandshakeTester(1, prvkey, false)
   279  			if err != nil {
   280  				t.Fatal(err)
   281  			}
   282  
   283  			node := pt.Nodes[0]
   284  			addr := NewAddr(node)
   285  
   286  			err = pt.testHandshake(
   287  				correctBzzHandshake(pt.addr, false),
   288  				&HandshakeMsg{Version: TestProtocolVersion, NetworkID: TestProtocolNetworkID, Addr: addr, LightNode: test.lightNode},
   289  			)
   290  
   291  			if err != nil {
   292  				t.Fatal(err)
   293  			}
   294  
   295  			select {
   296  
   297  			case <-pt.bzz.handshakes[node.ID()].done:
   298  				if pt.bzz.handshakes[node.ID()].LightNode != test.lightNode {
   299  					t.Fatalf("peer LightNode flag is %v, should be %v", pt.bzz.handshakes[node.ID()].LightNode, test.lightNode)
   300  				}
   301  			case <-time.After(10 * time.Second):
   302  				t.Fatal("test timeout")
   303  			}
   304  		})
   305  	}
   306  }