github.com/arieschain/arieschain@v0.0.0-20191023063405-37c074544356/p2p/testing/protocoltester.go (about)

     1  /*
     2  the p2p/testing package provides a unit test scheme to check simple
     3  protocol message exchanges with one pivot node and a number of dummy peers
     4  The pivot test node runs a node.Service, the dummy peers run a mock node
     5  that can be used to send and receive messages
     6  */
     7  
     8  package testing
     9  
    10  import (
    11  	"bytes"
    12  	"fmt"
    13  	"io"
    14  	"io/ioutil"
    15  	"strings"
    16  	"sync"
    17  	"testing"
    18  
    19  	"github.com/quickchainproject/quickchain/log"
    20  	"github.com/quickchainproject/quickchain/node"
    21  	"github.com/quickchainproject/quickchain/p2p"
    22  	"github.com/quickchainproject/quickchain/p2p/discover"
    23  	"github.com/quickchainproject/quickchain/p2p/simulations"
    24  	"github.com/quickchainproject/quickchain/p2p/simulations/adapters"
    25  	"github.com/quickchainproject/quickchain/rlp"
    26  	"github.com/quickchainproject/quickchain/rpc"
    27  )
    28  
    29  // ProtocolTester is the tester environment used for unit testing protocol
    30  // message exchanges. It uses p2p/simulations framework
    31  type ProtocolTester struct {
    32  	*ProtocolSession
    33  	network *simulations.Network
    34  }
    35  
    36  // NewProtocolTester constructs a new ProtocolTester
    37  // it takes as argument the pivot node id, the number of dummy peers and the
    38  // protocol run function called on a peer connection by the p2p server
    39  func NewProtocolTester(t *testing.T, id discover.NodeID, n int, run func(*p2p.Peer, p2p.MsgReadWriter) error) *ProtocolTester {
    40  	services := adapters.Services{
    41  		"test": func(ctx *adapters.ServiceContext) (node.Service, error) {
    42  			return &testNode{run}, nil
    43  		},
    44  		"mock": func(ctx *adapters.ServiceContext) (node.Service, error) {
    45  			return newMockNode(), nil
    46  		},
    47  	}
    48  	adapter := adapters.NewSimAdapter(services)
    49  	net := simulations.NewNetwork(adapter, &simulations.NetworkConfig{})
    50  	if _, err := net.NewNodeWithConfig(&adapters.NodeConfig{
    51  		ID:              id,
    52  		EnableMsgEvents: true,
    53  		Services:        []string{"test"},
    54  	}); err != nil {
    55  		panic(err.Error())
    56  	}
    57  	if err := net.Start(id); err != nil {
    58  		panic(err.Error())
    59  	}
    60  
    61  	node := net.GetNode(id).Node.(*adapters.SimNode)
    62  	peers := make([]*adapters.NodeConfig, n)
    63  	peerIDs := make([]discover.NodeID, n)
    64  	for i := 0; i < n; i++ {
    65  		peers[i] = adapters.RandomNodeConfig()
    66  		peers[i].Services = []string{"mock"}
    67  		peerIDs[i] = peers[i].ID
    68  	}
    69  	events := make(chan *p2p.PeerEvent, 1000)
    70  	node.SubscribeEvents(events)
    71  	ps := &ProtocolSession{
    72  		Server:  node.Server(),
    73  		IDs:     peerIDs,
    74  		adapter: adapter,
    75  		events:  events,
    76  	}
    77  	self := &ProtocolTester{
    78  		ProtocolSession: ps,
    79  		network:         net,
    80  	}
    81  
    82  	self.Connect(id, peers...)
    83  
    84  	return self
    85  }
    86  
    87  // Stop stops the p2p server
    88  func (self *ProtocolTester) Stop() error {
    89  	self.Server.Stop()
    90  	return nil
    91  }
    92  
    93  // Connect brings up the remote peer node and connects it using the
    94  // p2p/simulations network connection with the in memory network adapter
    95  func (self *ProtocolTester) Connect(selfID discover.NodeID, peers ...*adapters.NodeConfig) {
    96  	for _, peer := range peers {
    97  		log.Trace(fmt.Sprintf("start node %v", peer.ID))
    98  		if _, err := self.network.NewNodeWithConfig(peer); err != nil {
    99  			panic(fmt.Sprintf("error starting peer %v: %v", peer.ID, err))
   100  		}
   101  		if err := self.network.Start(peer.ID); err != nil {
   102  			panic(fmt.Sprintf("error starting peer %v: %v", peer.ID, err))
   103  		}
   104  		log.Trace(fmt.Sprintf("connect to %v", peer.ID))
   105  		if err := self.network.Connect(selfID, peer.ID); err != nil {
   106  			panic(fmt.Sprintf("error connecting to peer %v: %v", peer.ID, err))
   107  		}
   108  	}
   109  
   110  }
   111  
   112  // testNode wraps a protocol run function and implements the node.Service
   113  // interface
   114  type testNode struct {
   115  	run func(*p2p.Peer, p2p.MsgReadWriter) error
   116  }
   117  
   118  func (t *testNode) Protocols() []p2p.Protocol {
   119  	return []p2p.Protocol{{
   120  		Length: 100,
   121  		Run:    t.run,
   122  	}}
   123  }
   124  
   125  func (t *testNode) APIs() []rpc.API {
   126  	return nil
   127  }
   128  
   129  func (t *testNode) Start(server *p2p.Server) error {
   130  	return nil
   131  }
   132  
   133  func (t *testNode) Stop() error {
   134  	return nil
   135  }
   136  
   137  // mockNode is a testNode which doesn't actually run a protocol, instead
   138  // exposing channels so that tests can manually trigger and expect certain
   139  // messages
   140  type mockNode struct {
   141  	testNode
   142  
   143  	trigger  chan *Trigger
   144  	expect   chan []Expect
   145  	err      chan error
   146  	stop     chan struct{}
   147  	stopOnce sync.Once
   148  }
   149  
   150  func newMockNode() *mockNode {
   151  	mock := &mockNode{
   152  		trigger: make(chan *Trigger),
   153  		expect:  make(chan []Expect),
   154  		err:     make(chan error),
   155  		stop:    make(chan struct{}),
   156  	}
   157  	mock.testNode.run = mock.Run
   158  	return mock
   159  }
   160  
   161  // Run is a protocol run function which just loops waiting for tests to
   162  // instruct it to either trigger or expect a message from the peer
   163  func (m *mockNode) Run(peer *p2p.Peer, rw p2p.MsgReadWriter) error {
   164  	for {
   165  		select {
   166  		case trig := <-m.trigger:
   167  			m.err <- p2p.Send(rw, trig.Code, trig.Msg)
   168  		case exps := <-m.expect:
   169  			m.err <- expectMsgs(rw, exps)
   170  		case <-m.stop:
   171  			return nil
   172  		}
   173  	}
   174  }
   175  
   176  func (m *mockNode) Trigger(trig *Trigger) error {
   177  	m.trigger <- trig
   178  	return <-m.err
   179  }
   180  
   181  func (m *mockNode) Expect(exp ...Expect) error {
   182  	m.expect <- exp
   183  	return <-m.err
   184  }
   185  
   186  func (m *mockNode) Stop() error {
   187  	m.stopOnce.Do(func() { close(m.stop) })
   188  	return nil
   189  }
   190  
   191  func expectMsgs(rw p2p.MsgReadWriter, exps []Expect) error {
   192  	matched := make([]bool, len(exps))
   193  	for {
   194  		msg, err := rw.ReadMsg()
   195  		if err != nil {
   196  			if err == io.EOF {
   197  				break
   198  			}
   199  			return err
   200  		}
   201  		actualContent, err := ioutil.ReadAll(msg.Payload)
   202  		if err != nil {
   203  			return err
   204  		}
   205  		var found bool
   206  		for i, exp := range exps {
   207  			if exp.Code == msg.Code && bytes.Equal(actualContent, mustEncodeMsg(exp.Msg)) {
   208  				if matched[i] {
   209  					return fmt.Errorf("message #%d received two times", i)
   210  				}
   211  				matched[i] = true
   212  				found = true
   213  				break
   214  			}
   215  		}
   216  		if !found {
   217  			expected := make([]string, 0)
   218  			for i, exp := range exps {
   219  				if matched[i] {
   220  					continue
   221  				}
   222  				expected = append(expected, fmt.Sprintf("code %d payload %x", exp.Code, mustEncodeMsg(exp.Msg)))
   223  			}
   224  			return fmt.Errorf("unexpected message code %d payload %x, expected %s", msg.Code, actualContent, strings.Join(expected, " or "))
   225  		}
   226  		done := true
   227  		for _, m := range matched {
   228  			if !m {
   229  				done = false
   230  				break
   231  			}
   232  		}
   233  		if done {
   234  			return nil
   235  		}
   236  	}
   237  	for i, m := range matched {
   238  		if !m {
   239  			return fmt.Errorf("expected message #%d not received", i)
   240  		}
   241  	}
   242  	return nil
   243  }
   244  
   245  // mustEncodeMsg uses rlp to encode a message.
   246  // In case of error it panics.
   247  func mustEncodeMsg(msg interface{}) []byte {
   248  	contentEnc, err := rlp.EncodeToBytes(msg)
   249  	if err != nil {
   250  		panic("content encode error: " + err.Error())
   251  	}
   252  	return contentEnc
   253  }