go.dedis.ch/onet/v3@v3.2.11-0.20210930124529-e36530bca7ef/node_test.go (about)

     1  package onet
     2  
     3  import (
     4  	"testing"
     5  	"time"
     6  
     7  	"github.com/google/uuid"
     8  	"github.com/stretchr/testify/require"
     9  	"go.dedis.ch/onet/v3/log"
    10  	"go.dedis.ch/onet/v3/network"
    11  )
    12  
    13  const (
    14  	ProtocolChannelsName = "ProtocolChannels"
    15  	ProtocolHandlersName = "ProtocolHandlers"
    16  	ProtocolBlockingName = "ProtocolBlocking"
    17  )
    18  
    19  func init() {
    20  	GlobalProtocolRegister(ProtocolHandlersName, NewProtocolHandlers)
    21  	GlobalProtocolRegister("ProtocolBlocking", NewProtocolBlocking)
    22  	GlobalProtocolRegister(ProtocolChannelsName, NewProtocolChannels)
    23  	GlobalProtocolRegister(testProto, NewProtocolTest)
    24  	Incoming = make(chan struct {
    25  		*TreeNode
    26  		NodeTestMsg
    27  	}, 1)
    28  	network.RegisterMessage(&Closing{})
    29  }
    30  
    31  func TestNodeChannelCreateSlice(t *testing.T) {
    32  	local := NewLocalTest(tSuite)
    33  	_, _, tree := local.GenTree(2, true)
    34  	defer local.CloseAll()
    35  
    36  	p, err := local.CreateProtocol(ProtocolChannelsName, tree)
    37  	if err != nil {
    38  		t.Fatal("Couldn't create new node:", err)
    39  	}
    40  
    41  	var c chan []struct {
    42  		*TreeNode
    43  		NodeTestMsg
    44  	}
    45  	tni := p.(*ProtocolChannels).TreeNodeInstance
    46  	err = tni.RegisterChannel(&c)
    47  	if err != nil {
    48  		t.Fatal("Couldn't register channel:", err)
    49  	}
    50  	tni.Done()
    51  }
    52  
    53  func TestNodeChannelCreate(t *testing.T) {
    54  	local := NewLocalTest(tSuite)
    55  	_, _, tree := local.GenTree(2, true)
    56  	defer local.CloseAll()
    57  
    58  	p, err := local.CreateProtocol(ProtocolChannelsName, tree)
    59  	if err != nil {
    60  		t.Fatal("Couldn't create new node:", err)
    61  	}
    62  	var c chan struct {
    63  		*TreeNode
    64  		NodeTestMsg
    65  	}
    66  	tni := p.(*ProtocolChannels).TreeNodeInstance
    67  	err = tni.RegisterChannel(&c)
    68  	if err != nil {
    69  		t.Fatal("Couldn't register channel:", err)
    70  	}
    71  	err = tni.dispatchChannel([]*ProtocolMsg{{
    72  		Msg:     NodeTestMsg{3},
    73  		MsgType: network.RegisterMessage(NodeTestMsg{}),
    74  		From: &Token{
    75  			TreeID:     tree.ID,
    76  			TreeNodeID: tree.Root.ID,
    77  		}},
    78  	})
    79  	if err != nil {
    80  		t.Fatal("Couldn't dispatch to channel:", err)
    81  	}
    82  	msg := <-c
    83  	if msg.I != 3 {
    84  		t.Fatal("Message should contain '3'")
    85  	}
    86  	tni.Done()
    87  }
    88  
    89  func TestNodeChannel(t *testing.T) {
    90  	local := NewLocalTest(tSuite)
    91  	_, _, tree := local.GenTree(2, true)
    92  	defer local.CloseAll()
    93  
    94  	p, err := local.CreateProtocol(ProtocolChannelsName, tree)
    95  	if err != nil {
    96  		t.Fatal("Couldn't create new node:", err)
    97  	}
    98  	c := make(chan struct {
    99  		*TreeNode
   100  		NodeTestMsg
   101  	}, 1)
   102  	tni := p.(*ProtocolChannels).TreeNodeInstance
   103  	err = tni.RegisterChannel(c)
   104  	if err != nil {
   105  		t.Fatal("Couldn't register channel:", err)
   106  	}
   107  	err = tni.dispatchChannel([]*ProtocolMsg{{
   108  		Msg:     NodeTestMsg{3},
   109  		MsgType: network.RegisterMessage(NodeTestMsg{}),
   110  		From: &Token{
   111  			TreeID:     tree.ID,
   112  			TreeNodeID: tree.Root.ID,
   113  		}},
   114  	})
   115  	if err != nil {
   116  		t.Fatal("Couldn't dispatch to channel:", err)
   117  	}
   118  	msg := <-c
   119  	if msg.I != 3 {
   120  		t.Fatal("Message should contain '3'")
   121  	}
   122  	tni.Done()
   123  }
   124  
   125  // Test instantiation of Node
   126  func TestNodeNew(t *testing.T) {
   127  	local := NewLocalTest(tSuite)
   128  	defer local.CloseAll()
   129  
   130  	hosts, _, tree := local.GenTree(2, true)
   131  	h1 := hosts[0]
   132  	// Try directly StartNewNode
   133  	proto, err := h1.StartProtocol(testProto, tree)
   134  	if err != nil {
   135  		t.Fatal("Could not start new protocol", err)
   136  	}
   137  	p := proto.(*ProtocolTest)
   138  	m := <-p.DispMsg
   139  	if m != "Dispatch" {
   140  		t.Fatal("Dispatch() not called - msg is:", m)
   141  	}
   142  	m = <-p.StartMsg
   143  	if m != "Start" {
   144  		t.Fatal("Start() not called - msg is:", m)
   145  	}
   146  }
   147  
   148  func TestTreeNodeProtocolHandlers(t *testing.T) {
   149  	local := NewLocalTest(tSuite)
   150  	_, _, tree := local.GenTree(3, true)
   151  	defer local.CloseAll()
   152  	log.Lvl2("Sending to children")
   153  	IncomingHandlers = make(chan *TreeNodeInstance, 2)
   154  	p, err := local.CreateProtocol(ProtocolHandlersName, tree)
   155  	if err != nil {
   156  		t.Fatal(err)
   157  	}
   158  	go p.Start()
   159  	log.Lvl2("Waiting for response from child 1/2")
   160  	child1 := <-IncomingHandlers
   161  	defer child1.Done()
   162  	log.Lvl2("Waiting for response from child 2/2")
   163  	child2 := <-IncomingHandlers
   164  	defer child2.Done()
   165  
   166  	if child1.ServerIdentity().ID.Equal(child2.ServerIdentity().ID) {
   167  		t.Fatal("Both entities should be different")
   168  	}
   169  
   170  	log.Lvl2("Sending to parent")
   171  
   172  	tni := p.(*ProtocolHandlers).TreeNodeInstance
   173  	require.Nil(t, child1.SendTo(tni.TreeNode(), &NodeTestAggMsg{}))
   174  	if len(IncomingHandlers) > 0 {
   175  		t.Fatal("This should not trigger yet")
   176  	}
   177  	require.Nil(t, child2.SendTo(tni.TreeNode(), &NodeTestAggMsg{}))
   178  	final := <-IncomingHandlers
   179  	if !final.ServerIdentity().ID.Equal(tni.ServerIdentity().ID) {
   180  		t.Fatal("This should be the same ID")
   181  	}
   182  }
   183  
   184  func TestTreeNodeMsgAggregation(t *testing.T) {
   185  	local := NewLocalTest(tSuite)
   186  	_, _, tree := local.GenTree(3, true)
   187  	defer local.CloseAll()
   188  	root, err := local.StartProtocol(ProtocolChannelsName, tree)
   189  	if err != nil {
   190  		t.Fatal("Couldn't create new node:", err)
   191  	}
   192  	proto := root.(*ProtocolChannels)
   193  	// Wait for both children to be up
   194  	<-Incoming
   195  	<-Incoming
   196  	log.Lvl2("Both children are up")
   197  	child1 := local.GetTreeNodeInstances(tree.Root.Children[0].ServerIdentity.ID)[0]
   198  	child2 := local.GetTreeNodeInstances(tree.Root.Children[1].ServerIdentity.ID)[0]
   199  
   200  	err = local.sendTreeNode(ProtocolChannelsName, child1, proto.TreeNodeInstance, &NodeTestAggMsg{3})
   201  	if err != nil {
   202  		t.Fatal(err)
   203  	}
   204  	if len(proto.IncomingAgg) > 0 {
   205  		t.Fatal("Messages should NOT be there")
   206  	}
   207  	err = local.sendTreeNode(ProtocolChannelsName, child2, proto.TreeNodeInstance, &NodeTestAggMsg{4})
   208  	if err != nil {
   209  		t.Fatal(err)
   210  	}
   211  
   212  	msgs := <-proto.IncomingAgg
   213  	if msgs[0].I != 3 {
   214  		t.Fatal("First message should be 3")
   215  	}
   216  	if msgs[1].I != 4 {
   217  		t.Fatal("Second message should be 4")
   218  	}
   219  
   220  	proto.Closing(ClosingMsg{})
   221  }
   222  
   223  func TestTreeNodeFlags(t *testing.T) {
   224  	testType := network.MessageTypeID(uuid.Nil)
   225  	local := NewLocalTest(tSuite)
   226  	_, _, tree := local.GenTree(3, true)
   227  	defer local.CloseAll()
   228  	p, err := local.CreateProtocol(ProtocolChannelsName, tree)
   229  	if err != nil {
   230  		t.Fatal("Couldn't create node.")
   231  	}
   232  	pc := p.(*ProtocolChannels)
   233  	tni := pc.TreeNodeInstance
   234  	if tni.hasFlag(testType, AggregateMessages) {
   235  		t.Fatal("Should NOT have AggregateMessages-flag")
   236  	}
   237  	tni.setFlag(testType, AggregateMessages)
   238  	if !tni.hasFlag(testType, AggregateMessages) {
   239  		t.Fatal("Should HAVE AggregateMessages-flag cleared")
   240  	}
   241  	tni.clearFlag(testType, AggregateMessages)
   242  	if tni.hasFlag(testType, AggregateMessages) {
   243  		t.Fatal("Should NOT have AggregateMessages-flag")
   244  	}
   245  	pc.Closing(ClosingMsg{})
   246  }
   247  
   248  // Protocol/service Channels test code:
   249  type NodeTestMsg struct {
   250  	I int64
   251  }
   252  
   253  var Incoming chan struct {
   254  	*TreeNode
   255  	NodeTestMsg
   256  }
   257  
   258  type NodeTestAggMsg struct {
   259  	I int64
   260  }
   261  
   262  type ProtocolChannels struct {
   263  	*TreeNodeInstance
   264  	IncomingAgg chan []struct {
   265  		*TreeNode
   266  		NodeTestAggMsg
   267  	}
   268  }
   269  
   270  type Closing struct{}
   271  
   272  type ClosingMsg struct {
   273  	*TreeNode
   274  	Closing
   275  }
   276  
   277  func NewProtocolChannels(n *TreeNodeInstance) (ProtocolInstance, error) {
   278  	p := &ProtocolChannels{
   279  		TreeNodeInstance: n,
   280  	}
   281  	p.RegisterChannel(Incoming)
   282  	p.RegisterChannel(&p.IncomingAgg)
   283  	log.ErrFatal(p.RegisterHandler(p.Closing))
   284  	return p, nil
   285  }
   286  
   287  func (p *ProtocolChannels) Closing(msg ClosingMsg) error {
   288  	log.ErrFatal(p.SendToChildren(&Closing{}))
   289  	time.Sleep(100 * time.Millisecond)
   290  	p.Done()
   291  	return nil
   292  }
   293  
   294  func (p *ProtocolChannels) Start() error {
   295  	for _, c := range p.Children() {
   296  		err := p.SendTo(c, &NodeTestMsg{12})
   297  		if err != nil {
   298  			return err
   299  		}
   300  	}
   301  	return nil
   302  }
   303  
   304  type ProtocolHandlers struct {
   305  	*TreeNodeInstance
   306  }
   307  
   308  var IncomingHandlers chan *TreeNodeInstance
   309  
   310  func NewProtocolHandlers(n *TreeNodeInstance) (ProtocolInstance, error) {
   311  	p := &ProtocolHandlers{
   312  		TreeNodeInstance: n,
   313  	}
   314  	if err := p.RegisterHandlers(p.HandleMessageOne,
   315  		p.HandleMessageAggregate); err != nil {
   316  		return nil, err
   317  	}
   318  	return p, nil
   319  }
   320  
   321  func (p *ProtocolHandlers) Start() error {
   322  	for _, c := range p.Children() {
   323  		err := p.SendTo(c, &NodeTestMsg{12})
   324  		if err != nil {
   325  			log.Error("Error sending to ", c.Name(), ":", err)
   326  		}
   327  	}
   328  	return nil
   329  }
   330  
   331  func (p *ProtocolHandlers) HandleMessageOne(msg struct {
   332  	*TreeNode
   333  	NodeTestMsg
   334  }) error {
   335  	IncomingHandlers <- p.TreeNodeInstance
   336  	return nil
   337  }
   338  
   339  func (p *ProtocolHandlers) HandleMessageAggregate(msg []struct {
   340  	*TreeNode
   341  	NodeTestAggMsg
   342  }) error {
   343  	log.Lvl3("Received message")
   344  	IncomingHandlers <- p.TreeNodeInstance
   345  	p.Done()
   346  	return nil
   347  }
   348  
   349  func (p *ProtocolHandlers) Dispatch() error {
   350  	return nil
   351  }
   352  
   353  func TestNodeBlocking(t *testing.T) {
   354  	l := NewLocalTest(tSuite)
   355  	_, _, tree := l.GenTree(2, true)
   356  	defer l.CloseAll()
   357  
   358  	n1, err := l.StartProtocol("ProtocolBlocking", tree)
   359  	if err != nil {
   360  		t.Fatal("Couldn't start protocol")
   361  	}
   362  	n2, err := l.StartProtocol("ProtocolBlocking", tree)
   363  	if err != nil {
   364  		t.Fatal("Couldn't start protocol")
   365  	}
   366  
   367  	p1 := n1.(*BlockingProtocol)
   368  	p2 := n2.(*BlockingProtocol)
   369  	tn1 := p1.TreeNodeInstance
   370  	tn2 := p2.TreeNodeInstance
   371  	go func() {
   372  		// Send two messages to n1, which blocks the old interface
   373  		err := l.sendTreeNode("", tn2, tn1, &NodeTestMsg{})
   374  		require.NoError(t, err, "Couldn't send message")
   375  		err = l.sendTreeNode("", tn2, tn1, &NodeTestMsg{})
   376  		require.NoError(t, err, "Couldn't send message")
   377  		// Now send a message to n2, but in the old interface this
   378  		// blocks.
   379  		err = l.sendTreeNode("", tn1, tn2, &NodeTestMsg{})
   380  		require.NoError(t, err, "Couldn't send message")
   381  	}()
   382  	// Release p2
   383  	p2.stopBlockChan <- true
   384  	<-p2.doneChan
   385  	log.Lvl2("Node 2 done")
   386  	p1.stopBlockChan <- true
   387  	<-p1.doneChan
   388  
   389  	p1.Done()
   390  	p2.Done()
   391  }
   392  
   393  // BlockingProtocol is a protocol that will block until it receives a "continue"
   394  // signal on the continue channel. It is used for testing the asynchronous
   395  // & non blocking handling of the messages in
   396  type BlockingProtocol struct {
   397  	*TreeNodeInstance
   398  	// the protocol will signal on this channel that it is done
   399  	doneChan chan bool
   400  	// stopBLockChan is used to signal the protocol to stop blocking the
   401  	// incoming messages on the Incoming chan
   402  	stopBlockChan chan bool
   403  	Incoming      chan struct {
   404  		*TreeNode
   405  		NodeTestMsg
   406  	}
   407  }
   408  
   409  func NewProtocolBlocking(node *TreeNodeInstance) (ProtocolInstance, error) {
   410  	bp := &BlockingProtocol{
   411  		TreeNodeInstance: node,
   412  		doneChan:         make(chan bool),
   413  		stopBlockChan:    make(chan bool),
   414  	}
   415  
   416  	log.ErrFatal(node.RegisterChannel(&bp.Incoming))
   417  	return bp, nil
   418  }
   419  
   420  func (bp *BlockingProtocol) Start() error {
   421  	return nil
   422  }
   423  
   424  func (bp *BlockingProtocol) Dispatch() error {
   425  	// first wait on stopBlockChan
   426  	<-bp.stopBlockChan
   427  	log.Lvl2("BlockingProtocol: will continue")
   428  	// Then wait on the actual message
   429  	<-bp.Incoming
   430  	log.Lvl2("BlockingProtocol: received message => signal Done")
   431  	// then signal that you are done
   432  	bp.doneChan <- true
   433  	return nil
   434  }