github.com/oskarth/go-ethereum@v1.6.8-0.20191013093314-dac24a9d3494/swarm/network/protocol_test.go (about)

     1  // Copyright 2016 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package network
    18  
    19  import (
    20  	"flag"
    21  	"fmt"
    22  	"os"
    23  	"sync"
    24  	"testing"
    25  
    26  	"github.com/ethereum/go-ethereum/log"
    27  	"github.com/ethereum/go-ethereum/p2p"
    28  	"github.com/ethereum/go-ethereum/p2p/enode"
    29  	"github.com/ethereum/go-ethereum/p2p/protocols"
    30  	p2ptest "github.com/ethereum/go-ethereum/p2p/testing"
    31  )
    32  
    33  const (
    34  	TestProtocolVersion   = 7
    35  	TestProtocolNetworkID = 3
    36  )
    37  
    38  var (
    39  	loglevel = flag.Int("loglevel", 2, "verbosity of logs")
    40  )
    41  
    42  func init() {
    43  	flag.Parse()
    44  	log.Root().SetHandler(log.LvlFilterHandler(log.Lvl(*loglevel), log.StreamHandler(os.Stderr, log.TerminalFormat(true))))
    45  }
    46  
    47  type testStore struct {
    48  	sync.Mutex
    49  
    50  	values map[string][]byte
    51  }
    52  
    53  func (t *testStore) Load(key string) ([]byte, error) {
    54  	t.Lock()
    55  	defer t.Unlock()
    56  	v, ok := t.values[key]
    57  	if !ok {
    58  		return nil, fmt.Errorf("key not found: %s", key)
    59  	}
    60  	return v, nil
    61  }
    62  
    63  func (t *testStore) Save(key string, v []byte) error {
    64  	t.Lock()
    65  	defer t.Unlock()
    66  	t.values[key] = v
    67  	return nil
    68  }
    69  
    70  func HandshakeMsgExchange(lhs, rhs *HandshakeMsg, id enode.ID) []p2ptest.Exchange {
    71  
    72  	return []p2ptest.Exchange{
    73  		{
    74  			Expects: []p2ptest.Expect{
    75  				{
    76  					Code: 0,
    77  					Msg:  lhs,
    78  					Peer: id,
    79  				},
    80  			},
    81  		},
    82  		{
    83  			Triggers: []p2ptest.Trigger{
    84  				{
    85  					Code: 0,
    86  					Msg:  rhs,
    87  					Peer: id,
    88  				},
    89  			},
    90  		},
    91  	}
    92  }
    93  
    94  func newBzzBaseTester(t *testing.T, n int, addr *BzzAddr, spec *protocols.Spec, run func(*BzzPeer) error) *bzzTester {
    95  	cs := make(map[string]chan bool)
    96  
    97  	srv := func(p *BzzPeer) error {
    98  		defer func() {
    99  			if cs[p.ID().String()] != nil {
   100  				close(cs[p.ID().String()])
   101  			}
   102  		}()
   103  		return run(p)
   104  	}
   105  
   106  	protocol := func(p *p2p.Peer, rw p2p.MsgReadWriter) error {
   107  		return srv(&BzzPeer{Peer: protocols.NewPeer(p, rw, spec), BzzAddr: NewAddr(p.Node())})
   108  	}
   109  
   110  	s := p2ptest.NewProtocolTester(t, addr.ID(), n, protocol)
   111  
   112  	for _, node := range s.Nodes {
   113  		cs[node.ID().String()] = make(chan bool)
   114  	}
   115  
   116  	return &bzzTester{
   117  		addr:           addr,
   118  		ProtocolTester: s,
   119  		cs:             cs,
   120  	}
   121  }
   122  
   123  type bzzTester struct {
   124  	*p2ptest.ProtocolTester
   125  	addr *BzzAddr
   126  	cs   map[string]chan bool
   127  	bzz  *Bzz
   128  }
   129  
   130  func newBzz(addr *BzzAddr, lightNode bool) *Bzz {
   131  	config := &BzzConfig{
   132  		OverlayAddr:  addr.Over(),
   133  		UnderlayAddr: addr.Under(),
   134  		HiveParams:   NewHiveParams(),
   135  		NetworkID:    DefaultNetworkID,
   136  		LightNode:    lightNode,
   137  	}
   138  	kad := NewKademlia(addr.OAddr, NewKadParams())
   139  	bzz := NewBzz(config, kad, nil, nil, nil)
   140  	return bzz
   141  }
   142  
   143  func newBzzHandshakeTester(t *testing.T, n int, addr *BzzAddr, lightNode bool) *bzzTester {
   144  	bzz := newBzz(addr, lightNode)
   145  	pt := p2ptest.NewProtocolTester(t, addr.ID(), n, bzz.runBzz)
   146  
   147  	return &bzzTester{
   148  		addr:           addr,
   149  		ProtocolTester: pt,
   150  		bzz:            bzz,
   151  	}
   152  }
   153  
   154  // should test handshakes in one exchange? parallelisation
   155  func (s *bzzTester) testHandshake(lhs, rhs *HandshakeMsg, disconnects ...*p2ptest.Disconnect) error {
   156  	var peers []enode.ID
   157  	id := rhs.Addr.ID()
   158  	if len(disconnects) > 0 {
   159  		for _, d := range disconnects {
   160  			peers = append(peers, d.Peer)
   161  		}
   162  	} else {
   163  		peers = []enode.ID{id}
   164  	}
   165  
   166  	if err := s.TestExchanges(HandshakeMsgExchange(lhs, rhs, id)...); err != nil {
   167  		return err
   168  	}
   169  
   170  	if len(disconnects) > 0 {
   171  		return s.TestDisconnected(disconnects...)
   172  	}
   173  
   174  	// If we don't expect disconnect, ensure peers remain connected
   175  	err := s.TestDisconnected(&p2ptest.Disconnect{
   176  		Peer:  s.Nodes[0].ID(),
   177  		Error: nil,
   178  	})
   179  
   180  	if err == nil {
   181  		return fmt.Errorf("Unexpected peer disconnect")
   182  	}
   183  
   184  	if err.Error() != "timed out waiting for peers to disconnect" {
   185  		return err
   186  	}
   187  
   188  	return nil
   189  }
   190  
   191  func correctBzzHandshake(addr *BzzAddr, lightNode bool) *HandshakeMsg {
   192  	return &HandshakeMsg{
   193  		Version:   TestProtocolVersion,
   194  		NetworkID: TestProtocolNetworkID,
   195  		Addr:      addr,
   196  		LightNode: lightNode,
   197  	}
   198  }
   199  
   200  func TestBzzHandshakeNetworkIDMismatch(t *testing.T) {
   201  	lightNode := false
   202  	addr := RandomAddr()
   203  	s := newBzzHandshakeTester(t, 1, addr, lightNode)
   204  	node := s.Nodes[0]
   205  
   206  	err := s.testHandshake(
   207  		correctBzzHandshake(addr, lightNode),
   208  		&HandshakeMsg{Version: TestProtocolVersion, NetworkID: 321, Addr: NewAddr(node)},
   209  		&p2ptest.Disconnect{Peer: node.ID(), Error: fmt.Errorf("Handshake error: Message handler error: (msg code 0): network id mismatch 321 (!= 3)")},
   210  	)
   211  
   212  	if err != nil {
   213  		t.Fatal(err)
   214  	}
   215  }
   216  
   217  func TestBzzHandshakeVersionMismatch(t *testing.T) {
   218  	lightNode := false
   219  	addr := RandomAddr()
   220  	s := newBzzHandshakeTester(t, 1, addr, lightNode)
   221  	node := s.Nodes[0]
   222  
   223  	err := s.testHandshake(
   224  		correctBzzHandshake(addr, lightNode),
   225  		&HandshakeMsg{Version: 0, NetworkID: TestProtocolNetworkID, Addr: NewAddr(node)},
   226  		&p2ptest.Disconnect{Peer: node.ID(), Error: fmt.Errorf("Handshake error: Message handler error: (msg code 0): version mismatch 0 (!= %d)", TestProtocolVersion)},
   227  	)
   228  
   229  	if err != nil {
   230  		t.Fatal(err)
   231  	}
   232  }
   233  
   234  func TestBzzHandshakeSuccess(t *testing.T) {
   235  	lightNode := false
   236  	addr := RandomAddr()
   237  	s := newBzzHandshakeTester(t, 1, addr, lightNode)
   238  	node := s.Nodes[0]
   239  
   240  	err := s.testHandshake(
   241  		correctBzzHandshake(addr, lightNode),
   242  		&HandshakeMsg{Version: TestProtocolVersion, NetworkID: TestProtocolNetworkID, Addr: NewAddr(node)},
   243  	)
   244  
   245  	if err != nil {
   246  		t.Fatal(err)
   247  	}
   248  }
   249  
   250  func TestBzzHandshakeLightNode(t *testing.T) {
   251  	var lightNodeTests = []struct {
   252  		name      string
   253  		lightNode bool
   254  	}{
   255  		{"on", true},
   256  		{"off", false},
   257  	}
   258  
   259  	for _, test := range lightNodeTests {
   260  		t.Run(test.name, func(t *testing.T) {
   261  			randomAddr := RandomAddr()
   262  			pt := newBzzHandshakeTester(t, 1, randomAddr, false)
   263  			node := pt.Nodes[0]
   264  			addr := NewAddr(node)
   265  
   266  			err := pt.testHandshake(
   267  				correctBzzHandshake(randomAddr, false),
   268  				&HandshakeMsg{Version: TestProtocolVersion, NetworkID: TestProtocolNetworkID, Addr: addr, LightNode: test.lightNode},
   269  			)
   270  
   271  			if err != nil {
   272  				t.Fatal(err)
   273  			}
   274  
   275  			if pt.bzz.handshakes[node.ID()].LightNode != test.lightNode {
   276  				t.Fatalf("peer LightNode flag is %v, should be %v", pt.bzz.handshakes[node.ID()].LightNode, test.lightNode)
   277  			}
   278  		})
   279  	}
   280  }