github.com/arieschain/arieschain@v0.0.0-20191023063405-37c074544356/p2p/testing/protocolsession.go (about)

     1  package testing
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"sync"
     7  	"time"
     8  
     9  	"github.com/quickchainproject/quickchain/log"
    10  	"github.com/quickchainproject/quickchain/p2p"
    11  	"github.com/quickchainproject/quickchain/p2p/discover"
    12  	"github.com/quickchainproject/quickchain/p2p/simulations/adapters"
    13  )
    14  
    15  var errTimedOut = errors.New("timed out")
    16  
    17  // ProtocolSession is a quasi simulation of a pivot node running
    18  // a service and a number of dummy peers that can send (trigger) or
    19  // receive (expect) messages
    20  type ProtocolSession struct {
    21  	Server  *p2p.Server
    22  	IDs     []discover.NodeID
    23  	adapter *adapters.SimAdapter
    24  	events  chan *p2p.PeerEvent
    25  }
    26  
    27  // Exchange is the basic units of protocol tests
    28  // the triggers and expects in the arrays are run immediately and asynchronously
    29  // thus one cannot have multiple expects for the SAME peer with DIFFERENT message types
    30  // because it's unpredictable which expect will receive which message
    31  // (with expect #1 and #2, messages might be sent #2 and #1, and both expects will complain about wrong message code)
    32  // an exchange is defined on a session
    33  type Exchange struct {
    34  	Label    string
    35  	Triggers []Trigger
    36  	Expects  []Expect
    37  	Timeout  time.Duration
    38  }
    39  
    40  // Trigger is part of the exchange, incoming message for the pivot node
    41  // sent by a peer
    42  type Trigger struct {
    43  	Msg     interface{}     // type of message to be sent
    44  	Code    uint64          // code of message is given
    45  	Peer    discover.NodeID // the peer to send the message to
    46  	Timeout time.Duration   // timeout duration for the sending
    47  }
    48  
    49  // Expect is part of an exchange, outgoing message from the pivot node
    50  // received by a peer
    51  type Expect struct {
    52  	Msg     interface{}     // type of message to expect
    53  	Code    uint64          // code of message is now given
    54  	Peer    discover.NodeID // the peer that expects the message
    55  	Timeout time.Duration   // timeout duration for receiving
    56  }
    57  
    58  // Disconnect represents a disconnect event, used and checked by TestDisconnected
    59  type Disconnect struct {
    60  	Peer  discover.NodeID // discconnected peer
    61  	Error error           // disconnect reason
    62  }
    63  
    64  // trigger sends messages from peers
    65  func (self *ProtocolSession) trigger(trig Trigger) error {
    66  	simNode, ok := self.adapter.GetNode(trig.Peer)
    67  	if !ok {
    68  		return fmt.Errorf("trigger: peer %v does not exist (1- %v)", trig.Peer, len(self.IDs))
    69  	}
    70  	mockNode, ok := simNode.Services()[0].(*mockNode)
    71  	if !ok {
    72  		return fmt.Errorf("trigger: peer %v is not a mock", trig.Peer)
    73  	}
    74  
    75  	errc := make(chan error)
    76  
    77  	go func() {
    78  		errc <- mockNode.Trigger(&trig)
    79  	}()
    80  
    81  	t := trig.Timeout
    82  	if t == time.Duration(0) {
    83  		t = 1000 * time.Millisecond
    84  	}
    85  	select {
    86  	case err := <-errc:
    87  		return err
    88  	case <-time.After(t):
    89  		return fmt.Errorf("timout expecting %v to send to peer %v", trig.Msg, trig.Peer)
    90  	}
    91  }
    92  
    93  // expect checks an expectation of a message sent out by the pivot node
    94  func (self *ProtocolSession) expect(exps []Expect) error {
    95  	// construct a map of expectations for each node
    96  	peerExpects := make(map[discover.NodeID][]Expect)
    97  	for _, exp := range exps {
    98  		if exp.Msg == nil {
    99  			return errors.New("no message to expect")
   100  		}
   101  		peerExpects[exp.Peer] = append(peerExpects[exp.Peer], exp)
   102  	}
   103  
   104  	// construct a map of mockNodes for each node
   105  	mockNodes := make(map[discover.NodeID]*mockNode)
   106  	for nodeID := range peerExpects {
   107  		simNode, ok := self.adapter.GetNode(nodeID)
   108  		if !ok {
   109  			return fmt.Errorf("trigger: peer %v does not exist (1- %v)", nodeID, len(self.IDs))
   110  		}
   111  		mockNode, ok := simNode.Services()[0].(*mockNode)
   112  		if !ok {
   113  			return fmt.Errorf("trigger: peer %v is not a mock", nodeID)
   114  		}
   115  		mockNodes[nodeID] = mockNode
   116  	}
   117  
   118  	// done chanell cancels all created goroutines when function returns
   119  	done := make(chan struct{})
   120  	defer close(done)
   121  	// errc catches the first error from
   122  	errc := make(chan error)
   123  
   124  	wg := &sync.WaitGroup{}
   125  	wg.Add(len(mockNodes))
   126  	for nodeID, mockNode := range mockNodes {
   127  		nodeID := nodeID
   128  		mockNode := mockNode
   129  		go func() {
   130  			defer wg.Done()
   131  
   132  			// Sum all Expect timeouts to give the maximum
   133  			// time for all expectations to finish.
   134  			// mockNode.Expect checks all received messages against
   135  			// a list of expected messages and timeout for each
   136  			// of them can not be checked separately.
   137  			var t time.Duration
   138  			for _, exp := range peerExpects[nodeID] {
   139  				if exp.Timeout == time.Duration(0) {
   140  					t += 2000 * time.Millisecond
   141  				} else {
   142  					t += exp.Timeout
   143  				}
   144  			}
   145  			alarm := time.NewTimer(t)
   146  			defer alarm.Stop()
   147  
   148  			// expectErrc is used to check if error returned
   149  			// from mockNode.Expect is not nil and to send it to
   150  			// errc only in that case.
   151  			// done channel will be closed when function
   152  			expectErrc := make(chan error)
   153  			go func() {
   154  				select {
   155  				case expectErrc <- mockNode.Expect(peerExpects[nodeID]...):
   156  				case <-done:
   157  				case <-alarm.C:
   158  				}
   159  			}()
   160  
   161  			select {
   162  			case err := <-expectErrc:
   163  				if err != nil {
   164  					select {
   165  					case errc <- err:
   166  					case <-done:
   167  					case <-alarm.C:
   168  						errc <- errTimedOut
   169  					}
   170  				}
   171  			case <-done:
   172  			case <-alarm.C:
   173  				errc <- errTimedOut
   174  			}
   175  
   176  		}()
   177  	}
   178  
   179  	go func() {
   180  		wg.Wait()
   181  		// close errc when all goroutines finish to return nill err from errc
   182  		close(errc)
   183  	}()
   184  
   185  	return <-errc
   186  }
   187  
   188  // TestExchanges tests a series of exchanges against the session
   189  func (self *ProtocolSession) TestExchanges(exchanges ...Exchange) error {
   190  	for i, e := range exchanges {
   191  		if err := self.testExchange(e); err != nil {
   192  			return fmt.Errorf("exchange #%d %q: %v", i, e.Label, err)
   193  		}
   194  		log.Trace(fmt.Sprintf("exchange #%d %q: run successfully", i, e.Label))
   195  	}
   196  	return nil
   197  }
   198  
   199  // testExchange tests a single Exchange.
   200  // Default timeout value is 2 seconds.
   201  func (self *ProtocolSession) testExchange(e Exchange) error {
   202  	errc := make(chan error)
   203  	done := make(chan struct{})
   204  	defer close(done)
   205  
   206  	go func() {
   207  		for _, trig := range e.Triggers {
   208  			err := self.trigger(trig)
   209  			if err != nil {
   210  				errc <- err
   211  				return
   212  			}
   213  		}
   214  
   215  		select {
   216  		case errc <- self.expect(e.Expects):
   217  		case <-done:
   218  		}
   219  	}()
   220  
   221  	// time out globally or finish when all expectations satisfied
   222  	t := e.Timeout
   223  	if t == 0 {
   224  		t = 2000 * time.Millisecond
   225  	}
   226  	alarm := time.NewTimer(t)
   227  	select {
   228  	case err := <-errc:
   229  		return err
   230  	case <-alarm.C:
   231  		return errTimedOut
   232  	}
   233  }
   234  
   235  // TestDisconnected tests the disconnections given as arguments
   236  // the disconnect structs describe what disconnect error is expected on which peer
   237  func (self *ProtocolSession) TestDisconnected(disconnects ...*Disconnect) error {
   238  	expects := make(map[discover.NodeID]error)
   239  	for _, disconnect := range disconnects {
   240  		expects[disconnect.Peer] = disconnect.Error
   241  	}
   242  
   243  	timeout := time.After(time.Second)
   244  	for len(expects) > 0 {
   245  		select {
   246  		case event := <-self.events:
   247  			if event.Type != p2p.PeerEventTypeDrop {
   248  				continue
   249  			}
   250  			expectErr, ok := expects[event.Peer]
   251  			if !ok {
   252  				continue
   253  			}
   254  
   255  			if !(expectErr == nil && event.Error == "" || expectErr != nil && expectErr.Error() == event.Error) {
   256  				return fmt.Errorf("unexpected error on peer %v. expected '%v', got '%v'", event.Peer, expectErr, event.Error)
   257  			}
   258  			delete(expects, event.Peer)
   259  		case <-timeout:
   260  			return fmt.Errorf("timed out waiting for peers to disconnect")
   261  		}
   262  	}
   263  	return nil
   264  }