github.com/daeglee/go-ethereum@v0.0.0-20190504220456-cad3e8d18e9b/swarm/pss/prox_test.go (about)

     1  package pss
     2  
     3  import (
     4  	"context"
     5  	"crypto/ecdsa"
     6  	"encoding/binary"
     7  	"errors"
     8  	"fmt"
     9  	"strconv"
    10  	"strings"
    11  	"sync"
    12  	"testing"
    13  	"time"
    14  
    15  	"github.com/ethereum/go-ethereum/common"
    16  	"github.com/ethereum/go-ethereum/common/hexutil"
    17  	"github.com/ethereum/go-ethereum/log"
    18  	"github.com/ethereum/go-ethereum/node"
    19  	"github.com/ethereum/go-ethereum/p2p"
    20  	"github.com/ethereum/go-ethereum/p2p/enode"
    21  	"github.com/ethereum/go-ethereum/p2p/simulations/adapters"
    22  	"github.com/ethereum/go-ethereum/rpc"
    23  	"github.com/ethereum/go-ethereum/swarm/network"
    24  	"github.com/ethereum/go-ethereum/swarm/network/simulation"
    25  	"github.com/ethereum/go-ethereum/swarm/pot"
    26  	"github.com/ethereum/go-ethereum/swarm/state"
    27  )
    28  
    29  // needed to make the enode id of the receiving node available to the handler for triggers
    30  type handlerContextFunc func(*testData, *adapters.NodeConfig) *handler
    31  
    32  // struct to notify reception of messages to simulation driver
    33  // TODO To make code cleaner:
    34  // - consider a separate pss unwrap to message event in sim framework (this will make eventual message propagation analysis with pss easier/possible in the future)
    35  // - consider also test api calls to inspect handling results of messages
    36  type handlerNotification struct {
    37  	id     enode.ID
    38  	serial uint64
    39  }
    40  
    41  type testData struct {
    42  	mu               sync.Mutex
    43  	sim              *simulation.Simulation
    44  	handlerDone      bool // set to true on termination of the simulation run
    45  	requiredMessages int
    46  	allowedMessages  int
    47  	messageCount     int
    48  	kademlias        map[enode.ID]*network.Kademlia
    49  	nodeAddrs        map[enode.ID][]byte      // make predictable overlay addresses from the generated random enode ids
    50  	recipients       map[int][]enode.ID       // for logging output only
    51  	allowed          map[int][]enode.ID       // allowed recipients
    52  	expectedMsgs     map[enode.ID][]uint64    // message serials we expect respective nodes to receive
    53  	allowedMsgs      map[enode.ID][]uint64    // message serials we expect respective nodes to receive
    54  	senders          map[int]enode.ID         // originating nodes of the messages (intention is to choose as far as possible from the receiving neighborhood)
    55  	handlerC         chan handlerNotification // passes message from pss message handler to simulation driver
    56  	doneC            chan struct{}            // terminates the handler channel listener
    57  	errC             chan error               // error to pass to main sim thread
    58  	msgC             chan handlerNotification // message receipt notification to main sim thread
    59  	msgs             [][]byte                 // recipient addresses of messages
    60  }
    61  
    62  var (
    63  	pof   = pot.DefaultPof(256) // generate messages and index them
    64  	topic = BytesToTopic([]byte{0xf3, 0x9e, 0x06, 0x82})
    65  )
    66  
    67  func (d *testData) getMsgCount() int {
    68  	d.mu.Lock()
    69  	defer d.mu.Unlock()
    70  	return d.messageCount
    71  }
    72  
    73  func (d *testData) incrementMsgCount() int {
    74  	d.mu.Lock()
    75  	defer d.mu.Unlock()
    76  	d.messageCount++
    77  	return d.messageCount
    78  }
    79  
    80  func (d *testData) isDone() bool {
    81  	d.mu.Lock()
    82  	defer d.mu.Unlock()
    83  	return d.handlerDone
    84  }
    85  
    86  func (d *testData) setDone() {
    87  	d.mu.Lock()
    88  	defer d.mu.Unlock()
    89  	d.handlerDone = true
    90  }
    91  
    92  func getCmdParams(t *testing.T) (int, int, time.Duration) {
    93  	args := strings.Split(t.Name(), "/")
    94  	msgCount, err := strconv.ParseInt(args[2], 10, 16)
    95  	if err != nil {
    96  		t.Fatal(err)
    97  	}
    98  	nodeCount, err := strconv.ParseInt(args[1], 10, 16)
    99  	if err != nil {
   100  		t.Fatal(err)
   101  	}
   102  	timeoutStr := fmt.Sprintf("%ss", args[3])
   103  	timeoutDur, err := time.ParseDuration(timeoutStr)
   104  	if err != nil {
   105  		t.Fatal(err)
   106  	}
   107  	return int(msgCount), int(nodeCount), timeoutDur
   108  }
   109  
   110  func newTestData() *testData {
   111  	return &testData{
   112  		kademlias:    make(map[enode.ID]*network.Kademlia),
   113  		nodeAddrs:    make(map[enode.ID][]byte),
   114  		recipients:   make(map[int][]enode.ID),
   115  		allowed:      make(map[int][]enode.ID),
   116  		expectedMsgs: make(map[enode.ID][]uint64),
   117  		allowedMsgs:  make(map[enode.ID][]uint64),
   118  		senders:      make(map[int]enode.ID),
   119  		handlerC:     make(chan handlerNotification),
   120  		doneC:        make(chan struct{}),
   121  		errC:         make(chan error),
   122  		msgC:         make(chan handlerNotification),
   123  	}
   124  }
   125  
   126  func (d *testData) getKademlia(nodeId *enode.ID) (*network.Kademlia, error) {
   127  	kadif, ok := d.sim.NodeItem(*nodeId, simulation.BucketKeyKademlia)
   128  	if !ok {
   129  		return nil, fmt.Errorf("no kademlia entry for %v", nodeId)
   130  	}
   131  	kad, ok := kadif.(*network.Kademlia)
   132  	if !ok {
   133  		return nil, fmt.Errorf("invalid kademlia entry for %v", nodeId)
   134  	}
   135  	return kad, nil
   136  }
   137  
   138  func (d *testData) init(msgCount int) error {
   139  	log.Debug("TestProxNetwork start")
   140  
   141  	for _, nodeId := range d.sim.NodeIDs() {
   142  		kad, err := d.getKademlia(&nodeId)
   143  		if err != nil {
   144  			return err
   145  		}
   146  		d.nodeAddrs[nodeId] = kad.BaseAddr()
   147  	}
   148  
   149  	for i := 0; i < int(msgCount); i++ {
   150  		msgAddr := pot.RandomAddress() // we choose message addresses randomly
   151  		d.msgs = append(d.msgs, msgAddr.Bytes())
   152  		smallestPo := 256
   153  		var targets []enode.ID
   154  		var closestPO int
   155  
   156  		// loop through all nodes and find the required and allowed recipients of each message
   157  		// (for more information, please see the comment to the main test function)
   158  		for _, nod := range d.sim.Net.GetNodes() {
   159  			po, _ := pof(d.msgs[i], d.nodeAddrs[nod.ID()], 0)
   160  			depth := d.kademlias[nod.ID()].NeighbourhoodDepth()
   161  
   162  			// only nodes with closest IDs (wrt the msg address) will be required recipients
   163  			if po > closestPO {
   164  				closestPO = po
   165  				targets = nil
   166  				targets = append(targets, nod.ID())
   167  			} else if po == closestPO {
   168  				targets = append(targets, nod.ID())
   169  			}
   170  
   171  			if po >= depth {
   172  				d.allowedMessages++
   173  				d.allowed[i] = append(d.allowed[i], nod.ID())
   174  				d.allowedMsgs[nod.ID()] = append(d.allowedMsgs[nod.ID()], uint64(i))
   175  			}
   176  
   177  			// a node with the smallest PO (wrt msg) will be the sender,
   178  			// in order to increase the distance the msg must travel
   179  			if po < smallestPo {
   180  				smallestPo = po
   181  				d.senders[i] = nod.ID()
   182  			}
   183  		}
   184  
   185  		d.requiredMessages += len(targets)
   186  		for _, id := range targets {
   187  			d.recipients[i] = append(d.recipients[i], id)
   188  			d.expectedMsgs[id] = append(d.expectedMsgs[id], uint64(i))
   189  		}
   190  
   191  		log.Debug("nn for msg", "targets", len(d.recipients[i]), "msgidx", i, "msg", common.Bytes2Hex(msgAddr[:8]), "sender", d.senders[i], "senderpo", smallestPo)
   192  	}
   193  	log.Debug("msgs to receive", "count", d.requiredMessages)
   194  	return nil
   195  }
   196  
   197  // Here we test specific functionality of the pss, setting the prox property of
   198  // the handler. The tests generate a number of messages with random addresses.
   199  // Then, for each message it calculates which nodes have the msg address
   200  // within its nearest neighborhood depth, and stores those nodes as possible
   201  // recipients. Those nodes that are the closest to the message address (nodes
   202  // belonging to the deepest PO wrt the msg address) are stored as required
   203  // recipients. The difference between allowed and required recipients results
   204  // from the fact that the nearest neighbours are not necessarily reciprocal.
   205  // Upon sending the messages, the test verifies that the respective message is
   206  // passed to the message handlers of these required recipients. The test fails
   207  // if a message is handled by recipient which is not listed among the allowed
   208  // recipients of this particular message. It also fails after timeout, if not
   209  // all the required recipients have received their respective messages.
   210  //
   211  // For example, if proximity order of certain msg address is 4, and node X
   212  // has PO=5 wrt the message address, and nodes Y and Z have PO=6, then:
   213  // nodes Y and Z will be considered required recipients of the msg,
   214  // whereas nodes X, Y and Z will be allowed recipients.
   215  func TestProxNetwork(t *testing.T) {
   216  	t.Run("16/16/15", testProxNetwork)
   217  }
   218  
   219  // params in run name: nodes/msgs
   220  func TestProxNetworkLong(t *testing.T) {
   221  	if !*longrunning {
   222  		t.Skip("run with --longrunning flag to run extensive network tests")
   223  	}
   224  	t.Run("8/100/30", testProxNetwork)
   225  	t.Run("16/100/30", testProxNetwork)
   226  	t.Run("32/100/60", testProxNetwork)
   227  	t.Run("64/100/60", testProxNetwork)
   228  	t.Run("128/100/120", testProxNetwork)
   229  }
   230  
   231  func testProxNetwork(t *testing.T) {
   232  	tstdata := newTestData()
   233  	msgCount, nodeCount, timeout := getCmdParams(t)
   234  	handlerContextFuncs := make(map[Topic]handlerContextFunc)
   235  	handlerContextFuncs[topic] = nodeMsgHandler
   236  	services := newProxServices(tstdata, true, handlerContextFuncs, tstdata.kademlias)
   237  	tstdata.sim = simulation.New(services)
   238  	defer tstdata.sim.Close()
   239  	ctx, cancel := context.WithTimeout(context.Background(), timeout)
   240  	defer cancel()
   241  	filename := fmt.Sprintf("testdata/snapshot_%d.json", nodeCount)
   242  	err := tstdata.sim.UploadSnapshot(ctx, filename)
   243  	if err != nil {
   244  		t.Fatal(err)
   245  	}
   246  	err = tstdata.init(msgCount) // initialize the test data
   247  	if err != nil {
   248  		t.Fatal(err)
   249  	}
   250  	wrapper := func(c context.Context, _ *simulation.Simulation) error {
   251  		return testRoutine(tstdata, c)
   252  	}
   253  	result := tstdata.sim.Run(ctx, wrapper) // call the main test function
   254  	if result.Error != nil {
   255  		// context deadline exceeded
   256  		// however, it might just mean that not all possible messages are received
   257  		// now we must check if all required messages are received
   258  		cnt := tstdata.getMsgCount()
   259  		log.Debug("TestProxNetwork finished", "rcv", cnt)
   260  		if cnt < tstdata.requiredMessages {
   261  			t.Fatal(result.Error)
   262  		}
   263  	}
   264  	t.Logf("completed %d", result.Duration)
   265  }
   266  
   267  func (tstdata *testData) sendAllMsgs() {
   268  	for i, msg := range tstdata.msgs {
   269  		log.Debug("sending msg", "idx", i, "from", tstdata.senders[i])
   270  		nodeClient, err := tstdata.sim.Net.GetNode(tstdata.senders[i]).Client()
   271  		if err != nil {
   272  			tstdata.errC <- err
   273  		}
   274  		var uvarByte [8]byte
   275  		binary.PutUvarint(uvarByte[:], uint64(i))
   276  		nodeClient.Call(nil, "pss_sendRaw", hexutil.Encode(msg), hexutil.Encode(topic[:]), hexutil.Encode(uvarByte[:]))
   277  	}
   278  	log.Debug("all messages sent")
   279  }
   280  
   281  // testRoutine is the main test function, called by Simulation.Run()
   282  func testRoutine(tstdata *testData, ctx context.Context) error {
   283  	go handlerChannelListener(tstdata, ctx)
   284  	go tstdata.sendAllMsgs()
   285  	received := 0
   286  
   287  	// collect incoming messages and terminate with corresponding status when message handler listener ends
   288  	for {
   289  		select {
   290  		case err := <-tstdata.errC:
   291  			return err
   292  		case hn := <-tstdata.msgC:
   293  			received++
   294  			log.Debug("msg received", "msgs_received", received, "total_expected", tstdata.requiredMessages, "id", hn.id, "serial", hn.serial)
   295  			if received == tstdata.allowedMessages {
   296  				close(tstdata.doneC)
   297  				return nil
   298  			}
   299  		}
   300  	}
   301  	return nil
   302  }
   303  
   304  func handlerChannelListener(tstdata *testData, ctx context.Context) {
   305  	for {
   306  		select {
   307  		case <-tstdata.doneC: // graceful exit
   308  			tstdata.setDone()
   309  			tstdata.errC <- nil
   310  			return
   311  
   312  		case <-ctx.Done(): // timeout or cancel
   313  			tstdata.setDone()
   314  			tstdata.errC <- ctx.Err()
   315  			return
   316  
   317  		// incoming message from pss message handler
   318  		case handlerNotification := <-tstdata.handlerC:
   319  			// check if recipient has already received all its messages and notify to fail the test if so
   320  			aMsgs := tstdata.allowedMsgs[handlerNotification.id]
   321  			if len(aMsgs) == 0 {
   322  				tstdata.setDone()
   323  				tstdata.errC <- fmt.Errorf("too many messages received by recipient %x", handlerNotification.id)
   324  				return
   325  			}
   326  
   327  			// check if message serial is in expected messages for this recipient and notify to fail the test if not
   328  			idx := -1
   329  			for i, msg := range aMsgs {
   330  				if handlerNotification.serial == msg {
   331  					idx = i
   332  					break
   333  				}
   334  			}
   335  			if idx == -1 {
   336  				tstdata.setDone()
   337  				tstdata.errC <- fmt.Errorf("message %d received by wrong recipient %v", handlerNotification.serial, handlerNotification.id)
   338  				return
   339  			}
   340  
   341  			// message is ok, so remove that message serial from the recipient expectation array and notify the main sim thread
   342  			aMsgs[idx] = aMsgs[len(aMsgs)-1]
   343  			aMsgs = aMsgs[:len(aMsgs)-1]
   344  			tstdata.msgC <- handlerNotification
   345  		}
   346  	}
   347  }
   348  
   349  func nodeMsgHandler(tstdata *testData, config *adapters.NodeConfig) *handler {
   350  	return &handler{
   351  		f: func(msg []byte, p *p2p.Peer, asymmetric bool, keyid string) error {
   352  			cnt := tstdata.incrementMsgCount()
   353  			log.Debug("nodeMsgHandler rcv", "cnt", cnt)
   354  
   355  			// using simple serial in message body, makes it easy to keep track of who's getting what
   356  			serial, c := binary.Uvarint(msg)
   357  			if c <= 0 {
   358  				log.Crit(fmt.Sprintf("corrupt message received by %x (uvarint parse returned %d)", config.ID, c))
   359  			}
   360  
   361  			if tstdata.isDone() {
   362  				return errors.New("handlers aborted") // terminate if simulation is over
   363  			}
   364  
   365  			// pass message context to the listener in the simulation
   366  			tstdata.handlerC <- handlerNotification{
   367  				id:     config.ID,
   368  				serial: serial,
   369  			}
   370  			return nil
   371  		},
   372  		caps: &handlerCaps{
   373  			raw:  true, // we use raw messages for simplicity
   374  			prox: true,
   375  		},
   376  	}
   377  }
   378  
   379  // an adaptation of the same services setup as in pss_test.go
   380  // replaces pss_test.go when those tests are rewritten to the new swarm/network/simulation package
   381  func newProxServices(tstdata *testData, allowRaw bool, handlerContextFuncs map[Topic]handlerContextFunc, kademlias map[enode.ID]*network.Kademlia) map[string]simulation.ServiceFunc {
   382  	stateStore := state.NewInmemoryStore()
   383  	kademlia := func(id enode.ID, bzzkey []byte) *network.Kademlia {
   384  		if k, ok := kademlias[id]; ok {
   385  			return k
   386  		}
   387  		params := network.NewKadParams()
   388  		params.MaxBinSize = 3
   389  		params.MinBinSize = 1
   390  		params.MaxRetries = 1000
   391  		params.RetryExponent = 2
   392  		params.RetryInterval = 1000000
   393  		kademlias[id] = network.NewKademlia(bzzkey, params)
   394  		return kademlias[id]
   395  	}
   396  	return map[string]simulation.ServiceFunc{
   397  		"bzz": func(ctx *adapters.ServiceContext, b *sync.Map) (node.Service, func(), error) {
   398  			var err error
   399  			var bzzPrivateKey *ecdsa.PrivateKey
   400  			// normally translation of enode id to swarm address is concealed by the network package
   401  			// however, we need to keep track of it in the test driver as well.
   402  			// if the translation in the network package changes, that can cause these tests to unpredictably fail
   403  			// therefore we keep a local copy of the translation here
   404  			addr := network.NewAddr(ctx.Config.Node())
   405  			bzzPrivateKey, err = simulation.BzzPrivateKeyFromConfig(ctx.Config)
   406  			if err != nil {
   407  				return nil, nil, err
   408  			}
   409  			addr.OAddr = network.PrivateKeyToBzzKey(bzzPrivateKey)
   410  			b.Store(simulation.BucketKeyBzzPrivateKey, bzzPrivateKey)
   411  			hp := network.NewHiveParams()
   412  			hp.Discovery = false
   413  			config := &network.BzzConfig{
   414  				OverlayAddr:  addr.Over(),
   415  				UnderlayAddr: addr.Under(),
   416  				HiveParams:   hp,
   417  			}
   418  			return network.NewBzz(config, kademlia(ctx.Config.ID, addr.OAddr), stateStore, nil, nil), nil, nil
   419  		},
   420  		"pss": func(ctx *adapters.ServiceContext, b *sync.Map) (node.Service, func(), error) {
   421  			// execadapter does not exec init()
   422  			initTest()
   423  
   424  			// create keys in whisper and set up the pss object
   425  			ctxlocal, cancel := context.WithTimeout(context.Background(), time.Second*3)
   426  			defer cancel()
   427  			keys, err := wapi.NewKeyPair(ctxlocal)
   428  			privkey, err := w.GetPrivateKey(keys)
   429  			pssp := NewPssParams().WithPrivateKey(privkey)
   430  			pssp.AllowRaw = allowRaw
   431  			bzzPrivateKey, err := simulation.BzzPrivateKeyFromConfig(ctx.Config)
   432  			if err != nil {
   433  				return nil, nil, err
   434  			}
   435  			bzzKey := network.PrivateKeyToBzzKey(bzzPrivateKey)
   436  			pskad := kademlia(ctx.Config.ID, bzzKey)
   437  			ps, err := NewPss(pskad, pssp)
   438  			if err != nil {
   439  				return nil, nil, err
   440  			}
   441  
   442  			// register the handlers we've been passed
   443  			var deregisters []func()
   444  			for tpc, hndlrFunc := range handlerContextFuncs {
   445  				deregisters = append(deregisters, ps.Register(&tpc, hndlrFunc(tstdata, ctx.Config)))
   446  			}
   447  
   448  			// if handshake mode is set, add the controller
   449  			// TODO: This should be hooked to the handshake test file
   450  			if useHandshake {
   451  				SetHandshakeController(ps, NewHandshakeParams())
   452  			}
   453  
   454  			// we expose some api calls for cheating
   455  			ps.addAPI(rpc.API{
   456  				Namespace: "psstest",
   457  				Version:   "0.3",
   458  				Service:   NewAPITest(ps),
   459  				Public:    false,
   460  			})
   461  
   462  			b.Store(simulation.BucketKeyKademlia, pskad)
   463  
   464  			// return Pss and cleanups
   465  			return ps, func() {
   466  				// run the handler deregister functions in reverse order
   467  				for i := len(deregisters); i > 0; i-- {
   468  					deregisters[i-1]()
   469  				}
   470  			}, nil
   471  		},
   472  	}
   473  }