go.dedis.ch/onet/v3@v3.2.11-0.20210930124529-e36530bca7ef/treenode.go (about)

     1  package onet
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  	"sync"
     7  
     8  	"go.dedis.ch/kyber/v3"
     9  	"go.dedis.ch/onet/v3/log"
    10  	"go.dedis.ch/onet/v3/network"
    11  	"golang.org/x/xerrors"
    12  )
    13  
    14  // TreeNodeInstance represents a protocol-instance in a given TreeNode. It embeds an
    15  // Overlay where all the tree-structures are stored.
    16  type TreeNodeInstance struct {
    17  	overlay *Overlay
    18  	token   *Token
    19  	// cache for the TreeNode this Node is representing
    20  	treeNode *TreeNode
    21  	// cached list of all TreeNodes
    22  	treeNodeList []*TreeNode
    23  	// mutex to synchronise creation of treeNodeList
    24  	mtx sync.Mutex
    25  
    26  	// channels holds all channels available for the different message-types
    27  	channels map[network.MessageTypeID]interface{}
    28  	// registered handler-functions for that protocol
    29  	handlers map[network.MessageTypeID]interface{}
    30  	// flags for messages - only one channel/handler possible
    31  	messageTypeFlags map[network.MessageTypeID]uint32
    32  	// The protocolInstance belonging to that node
    33  	instance ProtocolInstance
    34  	// aggregate messages in order to dispatch them at once in the protocol
    35  	// instance
    36  	msgQueue map[network.MessageTypeID][]*ProtocolMsg
    37  	// done callback
    38  	onDoneCallback func() bool
    39  	// queue holding msgs
    40  	msgDispatchQueue []*ProtocolMsg
    41  	// locking for msgqueue
    42  	msgDispatchQueueMutex sync.Mutex
    43  	// kicking off new message
    44  	msgDispatchQueueWait chan bool
    45  	// whether this node is closing
    46  	closing bool
    47  
    48  	protoIO MessageProxy
    49  
    50  	// config is to be passed down in the first message of what the protocol is
    51  	// sending if it is non nil. Set with `tni.SetConfig()`.
    52  	config    *GenericConfig
    53  	sentTo    map[TreeNodeID]bool
    54  	configMut sync.Mutex
    55  
    56  	// used for the CounterIO interface
    57  	tx safeAdder
    58  	rx safeAdder
    59  }
    60  
    61  type safeAdder struct {
    62  	sync.RWMutex
    63  	x uint64
    64  }
    65  
    66  func (a *safeAdder) add(x uint64) {
    67  	a.Lock()
    68  	a.x += x
    69  	a.Unlock()
    70  }
    71  
    72  func (a *safeAdder) get() (x uint64) {
    73  	a.RLock()
    74  	x = a.x
    75  	a.RUnlock()
    76  	return
    77  }
    78  
    79  const (
    80  	// AggregateMessages (if set) tells to aggregate messages from all children
    81  	// before sending to the (parent) Node
    82  	AggregateMessages = 1
    83  
    84  	// DefaultChannelLength is the default number of messages that can wait
    85  	// in a channel.
    86  	DefaultChannelLength = 100
    87  )
    88  
    89  // MsgHandler is called upon reception of a certain message-type
    90  type MsgHandler func([]*interface{})
    91  
    92  // NewNode creates a new node
    93  func newTreeNodeInstance(o *Overlay, tok *Token, tn *TreeNode, io MessageProxy) *TreeNodeInstance {
    94  	n := &TreeNodeInstance{overlay: o,
    95  		token:                tok,
    96  		channels:             make(map[network.MessageTypeID]interface{}),
    97  		handlers:             make(map[network.MessageTypeID]interface{}),
    98  		messageTypeFlags:     make(map[network.MessageTypeID]uint32),
    99  		msgQueue:             make(map[network.MessageTypeID][]*ProtocolMsg),
   100  		treeNode:             tn,
   101  		msgDispatchQueue:     make([]*ProtocolMsg, 0, 1),
   102  		msgDispatchQueueWait: make(chan bool, 1),
   103  		protoIO:              io,
   104  		sentTo:               make(map[TreeNodeID]bool),
   105  	}
   106  	go n.dispatchMsgReader()
   107  	return n
   108  }
   109  
   110  // TreeNode gets the treeNode of this node. If there is no TreeNode for the
   111  // Token of this node, the function will return nil
   112  func (n *TreeNodeInstance) TreeNode() *TreeNode {
   113  	return n.treeNode
   114  }
   115  
   116  // ServerIdentity returns our entity
   117  func (n *TreeNodeInstance) ServerIdentity() *network.ServerIdentity {
   118  	return n.treeNode.ServerIdentity
   119  }
   120  
   121  // Parent returns the parent-TreeNode of ourselves
   122  func (n *TreeNodeInstance) Parent() *TreeNode {
   123  	return n.treeNode.Parent
   124  }
   125  
   126  // Children returns the children of ourselves
   127  func (n *TreeNodeInstance) Children() []*TreeNode {
   128  	return n.treeNode.Children
   129  }
   130  
   131  // Root returns the root-node of that tree
   132  func (n *TreeNodeInstance) Root() *TreeNode {
   133  	t := n.Tree()
   134  	if t != nil {
   135  		return t.Root
   136  	}
   137  	return nil
   138  }
   139  
   140  // IsRoot returns whether whether we are at the top of the tree
   141  func (n *TreeNodeInstance) IsRoot() bool {
   142  	return n.treeNode.Parent == nil
   143  }
   144  
   145  // IsLeaf returns whether whether we are at the bottom of the tree
   146  func (n *TreeNodeInstance) IsLeaf() bool {
   147  	return len(n.treeNode.Children) == 0
   148  }
   149  
   150  // SendTo sends to a given node
   151  func (n *TreeNodeInstance) SendTo(to *TreeNode, msg interface{}) error {
   152  	if to == nil {
   153  		return xerrors.New("Sent to a nil TreeNode")
   154  	}
   155  	n.msgDispatchQueueMutex.Lock()
   156  	if n.closing {
   157  		n.msgDispatchQueueMutex.Unlock()
   158  		return xerrors.New("is closing")
   159  	}
   160  	n.msgDispatchQueueMutex.Unlock()
   161  	var c *GenericConfig
   162  	// only sends the config once
   163  	n.configMut.Lock()
   164  	if !n.sentTo[to.ID] {
   165  		c = n.config
   166  		n.sentTo[to.ID] = true
   167  	}
   168  	n.configMut.Unlock()
   169  
   170  	sentLen, err := n.overlay.SendToTreeNode(n.token, to, msg, n.protoIO, c)
   171  	n.tx.add(sentLen)
   172  	if err != nil {
   173  		return xerrors.Errorf("sending: %v", err)
   174  	}
   175  	return nil
   176  }
   177  
   178  // Tree returns the tree of that node. Because the storage keeps the tree around
   179  // until the protocol is done, this will never return a nil value. It will panic
   180  // if the tree is nil.
   181  func (n *TreeNodeInstance) Tree() *Tree {
   182  	tree := n.overlay.treeStorage.Get(n.token.TreeID)
   183  	if tree == nil {
   184  		panic("tree should never be nil when called during a protocol; " +
   185  			"it might be that Tree() has been called after Done() which " +
   186  			"is wrong or the tree has not correctly been passed.")
   187  	}
   188  
   189  	return tree
   190  }
   191  
   192  // Roster returns the entity-list
   193  func (n *TreeNodeInstance) Roster() *Roster {
   194  	return n.Tree().Roster
   195  }
   196  
   197  // Suite can be used to get the kyber.Suite associated with the service. It can
   198  // be either the default suite or the one registered with the service.
   199  func (n *TreeNodeInstance) Suite() network.Suite {
   200  	suite := ServiceFactory.SuiteByID(n.token.ServiceID)
   201  	if suite != nil {
   202  		return suite
   203  	}
   204  
   205  	return n.overlay.suite()
   206  }
   207  
   208  // RegisterChannel is a compatibility-method for RegisterChannelLength
   209  // and setting up a channel with length 100.
   210  func (n *TreeNodeInstance) RegisterChannel(c interface{}) error {
   211  	err := n.RegisterChannelLength(c, DefaultChannelLength)
   212  	if err != nil {
   213  		return xerrors.Errorf("registering channel length: %v", err)
   214  	}
   215  	return nil
   216  }
   217  
   218  // RegisterChannelLength takes a channel with a struct that contains two
   219  // elements: a TreeNode and a message. The second argument is the length of
   220  // the channel. It will send every message that are the
   221  // same type to this channel.
   222  // This function handles also
   223  // - registration of the message-type
   224  // - aggregation or not of messages: if you give a channel of slices, the
   225  //   messages will be aggregated, else they will come one-by-one
   226  func (n *TreeNodeInstance) RegisterChannelLength(c interface{}, length int) error {
   227  	flags := uint32(0)
   228  	cr := reflect.TypeOf(c)
   229  	if cr.Kind() == reflect.Ptr {
   230  		val := reflect.ValueOf(c).Elem()
   231  		val.Set(reflect.MakeChan(val.Type(), length))
   232  		return n.RegisterChannel(reflect.Indirect(val).Interface())
   233  	} else if reflect.ValueOf(c).IsNil() {
   234  		return xerrors.New("Can not Register a (value) channel not initialized")
   235  	}
   236  	// Check we have the correct channel-type
   237  	if cr.Kind() != reflect.Chan {
   238  		return xerrors.New("Input is not channel")
   239  	}
   240  	if cr.Elem().Kind() == reflect.Slice {
   241  		flags += AggregateMessages
   242  		cr = cr.Elem()
   243  	}
   244  	if cr.Elem().Kind() != reflect.Struct {
   245  		return xerrors.New("Input is not channel of structure")
   246  	}
   247  	if cr.Elem().NumField() != 2 {
   248  		return xerrors.New("Input is not channel of structure with 2 elements")
   249  	}
   250  	if cr.Elem().Field(0).Type != reflect.TypeOf(&TreeNode{}) {
   251  		return xerrors.New("Input-channel doesn't have TreeNode as element")
   252  	}
   253  	// Automatic registration of the message to the network library.
   254  	m := reflect.New(cr.Elem().Field(1).Type)
   255  	typ := network.RegisterMessage(m.Interface())
   256  	n.channels[typ] = c
   257  	//typ := network.RTypeToUUID(cr.Elem().Field(1).Type) n.channels[typ] = c
   258  	n.messageTypeFlags[typ] = flags
   259  	log.Lvl4("Registered channel", typ, "with flags", flags)
   260  	return nil
   261  }
   262  
   263  // RegisterChannels registers a list of given channels by calling RegisterChannel above
   264  func (n *TreeNodeInstance) RegisterChannels(channels ...interface{}) error {
   265  	for _, ch := range channels {
   266  		if err := n.RegisterChannel(ch); err != nil {
   267  			return xerrors.Errorf("Error, could not register channel %T: %s",
   268  				ch, err.Error())
   269  		}
   270  	}
   271  	return nil
   272  }
   273  
   274  // RegisterChannelsLength is a convenience function to register a vararg of
   275  // channels with a given length.
   276  func (n *TreeNodeInstance) RegisterChannelsLength(length int, channels ...interface{}) error {
   277  	for _, ch := range channels {
   278  		if err := n.RegisterChannelLength(ch, length); err != nil {
   279  			return xerrors.Errorf("Error, could not register channel %T: %s",
   280  				ch, err.Error())
   281  		}
   282  	}
   283  	return nil
   284  }
   285  
   286  // RegisterHandler takes a function which takes a struct as argument that contains two
   287  // elements: a TreeNode and a message. It will send every message that are the
   288  // same type to this channel.
   289  //
   290  // This function also handles:
   291  //     - registration of the message-type
   292  //     - aggregation or not of messages: if you give a channel of slices, the
   293  //       messages will be aggregated, otherwise they will come one by one
   294  func (n *TreeNodeInstance) RegisterHandler(c interface{}) error {
   295  	flags := uint32(0)
   296  	cr := reflect.TypeOf(c)
   297  	// Check we have the correct channel-type
   298  	if cr.Kind() != reflect.Func {
   299  		return xerrors.New("Input is not function")
   300  	}
   301  	if cr.NumOut() != 1 {
   302  		return xerrors.New("Need exactly one return argument of type error")
   303  	}
   304  	if cr.Out(0) != reflect.TypeOf((*error)(nil)).Elem() {
   305  		return xerrors.New("return-type of message-handler needs to be error")
   306  	}
   307  	ci := cr.In(0)
   308  	if ci.Kind() == reflect.Slice {
   309  		flags += AggregateMessages
   310  		ci = ci.Elem()
   311  	}
   312  	if ci.Kind() != reflect.Struct {
   313  		return xerrors.New("Input is not a structure")
   314  	}
   315  	if ci.NumField() != 2 {
   316  		return xerrors.New("Input is not a structure with 2 elements")
   317  	}
   318  	if ci.Field(0).Type != reflect.TypeOf(&TreeNode{}) {
   319  		return xerrors.New("Input-handler doesn't have TreeNode as element")
   320  	}
   321  	// Automatic registration of the message to the network library.
   322  	ptr := reflect.New(ci.Field(1).Type)
   323  	typ := network.RegisterMessage(ptr.Interface())
   324  	n.handlers[typ] = c
   325  	n.messageTypeFlags[typ] = flags
   326  	log.Lvl3("Registered handler", typ, "with flags", flags)
   327  	return nil
   328  }
   329  
   330  // RegisterHandlers registers a list of given handlers by calling RegisterHandler above
   331  func (n *TreeNodeInstance) RegisterHandlers(handlers ...interface{}) error {
   332  	for _, h := range handlers {
   333  		if err := n.RegisterHandler(h); err != nil {
   334  			return xerrors.Errorf("Error, could not register handler %T: %s",
   335  				h, err.Error())
   336  		}
   337  	}
   338  	return nil
   339  }
   340  
   341  // ProtocolInstance returns the instance of the running protocol
   342  func (n *TreeNodeInstance) ProtocolInstance() ProtocolInstance {
   343  	return n.instance
   344  }
   345  
   346  // Dispatch - the standard dispatching function is empty
   347  func (n *TreeNodeInstance) Dispatch() error {
   348  	return nil
   349  }
   350  
   351  // Shutdown - standard Shutdown implementation. Define your own
   352  // in your protocol (if necessary)
   353  func (n *TreeNodeInstance) Shutdown() error {
   354  	return nil
   355  }
   356  
   357  // closeDispatch shuts down the go-routine and calls the protocolInstance-shutdown
   358  func (n *TreeNodeInstance) closeDispatch() error {
   359  	defer func() {
   360  		if r := recover(); r != nil {
   361  			log.Errorf("Recovered panic while closing protocol: %v", r)
   362  			log.Error(log.Stack())
   363  		}
   364  	}()
   365  	log.Lvl3("Closing node", n.Info())
   366  	n.msgDispatchQueueMutex.Lock()
   367  	n.closing = true
   368  	close(n.msgDispatchQueueWait)
   369  	n.msgDispatchQueueMutex.Unlock()
   370  	log.Lvl3("Closed node", n.Info())
   371  	pni := n.ProtocolInstance()
   372  	if pni == nil {
   373  		return xerrors.New("Can't shutdown empty ProtocolInstance")
   374  	}
   375  	err := pni.Shutdown()
   376  	if err != nil {
   377  		return xerrors.Errorf("shutdown: %v", err)
   378  	}
   379  	return nil
   380  }
   381  
   382  // ProtocolName will return the string representing that protocol
   383  func (n *TreeNodeInstance) ProtocolName() string {
   384  	return n.overlay.server.protocols.ProtocolIDToName(n.token.ProtoID)
   385  }
   386  
   387  func (n *TreeNodeInstance) dispatchHandler(msgSlice []*ProtocolMsg) error {
   388  	mt := msgSlice[0].MsgType
   389  	to := reflect.TypeOf(n.handlers[mt]).In(0)
   390  	f := reflect.ValueOf(n.handlers[mt])
   391  	var errV reflect.Value
   392  	if n.hasFlag(mt, AggregateMessages) {
   393  		msgs := reflect.MakeSlice(to, len(msgSlice), len(msgSlice))
   394  		for i, msg := range msgSlice {
   395  			m, err := n.createValueAndVerify(to.Elem(), msg)
   396  			if err != nil {
   397  				return xerrors.Errorf("processing message: %v", err)
   398  			}
   399  			msgs.Index(i).Set(m)
   400  		}
   401  		log.Lvl4("Dispatching aggregation to", n.ServerIdentity().Address)
   402  		errV = f.Call([]reflect.Value{msgs})[0]
   403  	} else {
   404  		for _, msg := range msgSlice {
   405  			if errV.IsValid() && !errV.IsNil() {
   406  				// Before overwriting an error, print it out
   407  				log.Errorf("%s: error while dispatching message %s: %s",
   408  					n.Name(), reflect.TypeOf(msg.Msg),
   409  					errV.Interface().(error))
   410  			}
   411  			log.Lvl4("Dispatching", msg, "to", n.ServerIdentity().Address)
   412  			m, err := n.createValueAndVerify(to, msg)
   413  			if err != nil {
   414  				return xerrors.Errorf("processing message: %v", err)
   415  			}
   416  			errV = f.Call([]reflect.Value{m})[0]
   417  		}
   418  	}
   419  	log.Lvlf4("%s Done with handler for %s", n.Name(), f.Type())
   420  	if !errV.IsNil() {
   421  		return xerrors.Errorf("handler: %v", errV.Interface())
   422  	}
   423  	return nil
   424  }
   425  
   426  func (n *TreeNodeInstance) createValueAndVerify(t reflect.Type, msg *ProtocolMsg) (reflect.Value, error) {
   427  	m := reflect.Indirect(reflect.New(t))
   428  	tr := n.Tree()
   429  	if t != nil {
   430  		tn := tr.Search(msg.From.TreeNodeID)
   431  		if tn != nil {
   432  			m.Field(0).Set(reflect.ValueOf(tn))
   433  			m.Field(1).Set(reflect.Indirect(reflect.ValueOf(msg.Msg)))
   434  		}
   435  		// Check whether the sender treenode actually is the same as the node who sent it.
   436  		// We can trust msg.ServerIdentity, because it is written in Router.handleConn and
   437  		// is not writable by the sending node.
   438  		if msg.ServerIdentity != nil && tn != nil && !tn.ServerIdentity.Equal(msg.ServerIdentity) {
   439  			return m, xerrors.Errorf("ServerIdentity in the tree node referenced by the message (%v) does not match the ServerIdentity of the message originator (%v)",
   440  				tn.ServerIdentity, msg.ServerIdentity)
   441  		}
   442  	}
   443  	return m, nil
   444  }
   445  
   446  // dispatchChannel takes a message and sends it to a channel
   447  func (n *TreeNodeInstance) dispatchChannel(msgSlice []*ProtocolMsg) error {
   448  	mt := msgSlice[0].MsgType
   449  	defer func() {
   450  		// In rare occasions we write to a closed channel which throws a panic.
   451  		// Catch it here so we can find out better why this happens.
   452  		if r := recover(); r != nil {
   453  			log.Errorf("Couldn't dispatch protocol-message %s in %s: %v",
   454  				mt, n.Info(), r)
   455  			log.Error(log.Stack())
   456  		}
   457  	}()
   458  	to := reflect.TypeOf(n.channels[mt])
   459  	if n.hasFlag(mt, AggregateMessages) {
   460  		log.Lvl4("Received aggregated message of type:", mt)
   461  		to = to.Elem()
   462  		out := reflect.MakeSlice(to, len(msgSlice), len(msgSlice))
   463  		for i, msg := range msgSlice {
   464  			log.Lvl4("Dispatching aggregated to", to)
   465  			m, err := n.createValueAndVerify(to.Elem(), msg)
   466  			if err != nil {
   467  				return xerrors.Errorf("processing message: %v", err)
   468  			}
   469  			log.Lvl4("Adding msg", m, "to", n.ServerIdentity().Address)
   470  			out.Index(i).Set(m)
   471  		}
   472  		reflect.ValueOf(n.channels[mt]).Send(out)
   473  	} else {
   474  		for _, msg := range msgSlice {
   475  			out := reflect.ValueOf(n.channels[mt])
   476  			m, err := n.createValueAndVerify(to.Elem(), msg)
   477  			if err != nil {
   478  				return xerrors.Errorf("processing message: %v", err)
   479  			}
   480  			log.Lvl4(n.Name(), "Dispatching msg type", mt, " to", to, " :", m.Field(1).Interface())
   481  			if out.Len() < out.Cap() {
   482  				n.msgDispatchQueueMutex.Lock()
   483  				closing := n.closing
   484  				n.msgDispatchQueueMutex.Unlock()
   485  				if !closing {
   486  					out.Send(m)
   487  				}
   488  			} else {
   489  				return xerrors.Errorf("channel too small for msg %s in %s: "+
   490  					"please use RegisterChannelLength()",
   491  					mt, n.ProtocolName())
   492  			}
   493  		}
   494  	}
   495  	return nil
   496  }
   497  
   498  // ProcessProtocolMsg takes a message and puts it into a queue for later processing.
   499  // This allows a protocol to have a backlog of messages.
   500  func (n *TreeNodeInstance) ProcessProtocolMsg(msg *ProtocolMsg) {
   501  	log.Lvl4(n.Info(), "Received message")
   502  	n.msgDispatchQueueMutex.Lock()
   503  	defer n.msgDispatchQueueMutex.Unlock()
   504  	if n.closing {
   505  		log.Lvl3("Received message for closed protocol")
   506  		return
   507  	}
   508  	n.msgDispatchQueue = append(n.msgDispatchQueue, msg)
   509  	n.notifyDispatch()
   510  }
   511  
   512  func (n *TreeNodeInstance) notifyDispatch() {
   513  	select {
   514  	case n.msgDispatchQueueWait <- true:
   515  		return
   516  	default:
   517  		// Channel write would block: already been notified.
   518  		// So, nothing to do here.
   519  	}
   520  }
   521  
   522  func (n *TreeNodeInstance) dispatchMsgReader() {
   523  	log.TraceID(n.token.RoundID[:])
   524  	defer log.Lvl3("done tracing")
   525  	log.Lvl3("Starting node", n.Info())
   526  	for {
   527  		n.msgDispatchQueueMutex.Lock()
   528  		if n.closing {
   529  			log.Lvl3("Closing reader")
   530  			n.msgDispatchQueueMutex.Unlock()
   531  			return
   532  		}
   533  		if len(n.msgDispatchQueue) > 0 {
   534  			log.Lvl4(n.Info(), "Read message and dispatching it",
   535  				len(n.msgDispatchQueue))
   536  			msg := n.msgDispatchQueue[0]
   537  			n.msgDispatchQueue = n.msgDispatchQueue[1:]
   538  			n.msgDispatchQueueMutex.Unlock()
   539  			err := n.dispatchMsgToProtocol(msg)
   540  			if err != nil {
   541  				log.Errorf("%s: error while dispatching message %s: %s",
   542  					n.Name(), reflect.TypeOf(msg.Msg), err)
   543  			}
   544  		} else {
   545  			n.msgDispatchQueueMutex.Unlock()
   546  			log.Lvl4(n.Info(), "Waiting for message")
   547  			// Allow for closing of the channel
   548  			select {
   549  			case <-n.msgDispatchQueueWait:
   550  			}
   551  		}
   552  	}
   553  }
   554  
   555  // dispatchMsgToProtocol will dispatch this onet.Data to the right instance
   556  func (n *TreeNodeInstance) dispatchMsgToProtocol(onetMsg *ProtocolMsg) error {
   557  	log.Lvl3("Dispatching", onetMsg.MsgType)
   558  
   559  	n.rx.add(uint64(onetMsg.Size))
   560  
   561  	// if message comes from parent, dispatch directly
   562  	// if messages come from children we must aggregate them
   563  	// if we still need to wait for additional messages, we return
   564  	msgType, msgs, done := n.aggregate(onetMsg)
   565  	if !done {
   566  		log.Lvl3(n.Name(), "Not done aggregating children msgs")
   567  		return nil
   568  	}
   569  	log.Lvlf5("%s->%s: Message is: %+v", onetMsg.From, n.Name(), onetMsg.Msg)
   570  
   571  	var err error
   572  	switch {
   573  	case n.channels[msgType] != nil:
   574  		log.Lvl4(n.Name(), "Dispatching to channel")
   575  		err = n.dispatchChannel(msgs)
   576  	case n.handlers[msgType] != nil:
   577  		log.Lvl4("Dispatching to handler", n.ServerIdentity().Address)
   578  		err = n.dispatchHandler(msgs)
   579  	default:
   580  		return xerrors.Errorf("message-type not handled by the protocol: %s", reflect.TypeOf(onetMsg.Msg))
   581  	}
   582  	if err != nil {
   583  		return xerrors.Errorf("dispatch: %v", err)
   584  	}
   585  	return nil
   586  }
   587  
   588  // setFlag makes sure a given flag is set
   589  func (n *TreeNodeInstance) setFlag(mt network.MessageTypeID, f uint32) {
   590  	n.messageTypeFlags[mt] |= f
   591  }
   592  
   593  // clearFlag makes sure a given flag is removed
   594  func (n *TreeNodeInstance) clearFlag(mt network.MessageTypeID, f uint32) {
   595  	n.messageTypeFlags[mt] &^= f
   596  }
   597  
   598  // hasFlag returns true if the given flag is set
   599  func (n *TreeNodeInstance) hasFlag(mt network.MessageTypeID, f uint32) bool {
   600  	return n.messageTypeFlags[mt]&f != 0
   601  }
   602  
   603  // aggregate store the message for a protocol instance such that a protocol
   604  // instances will get all its children messages at once.
   605  // node is the node the host is representing in this Tree, and onetMsg is the
   606  // message being analyzed.
   607  func (n *TreeNodeInstance) aggregate(onetMsg *ProtocolMsg) (network.MessageTypeID, []*ProtocolMsg, bool) {
   608  	mt := onetMsg.MsgType
   609  	fromParent := !n.IsRoot() && onetMsg.From.TreeNodeID.Equal(n.Parent().ID)
   610  	if fromParent || !n.hasFlag(mt, AggregateMessages) {
   611  		return mt, []*ProtocolMsg{onetMsg}, true
   612  	}
   613  	// store the msg according to its type
   614  	if _, ok := n.msgQueue[mt]; !ok {
   615  		n.msgQueue[mt] = make([]*ProtocolMsg, 0)
   616  	}
   617  	msgs := append(n.msgQueue[mt], onetMsg)
   618  	n.msgQueue[mt] = msgs
   619  	log.Lvl4(n.ServerIdentity().Address, "received", len(msgs), "of", len(n.Children()), "messages")
   620  
   621  	// do we have everything yet or no
   622  	// get the node this host is in this tree
   623  	// OK we have all the children messages
   624  	if len(msgs) == len(n.Children()) {
   625  		// erase
   626  		delete(n.msgQueue, mt)
   627  		return mt, msgs, true
   628  	}
   629  	// no we still have to wait!
   630  	return mt, nil, false
   631  }
   632  
   633  // startProtocol calls the Start() on the underlying protocol which in turn will
   634  // initiate the first message to its children
   635  func (n *TreeNodeInstance) startProtocol() error {
   636  	err := n.instance.Start()
   637  	if err != nil {
   638  		return xerrors.Errorf("starting protocol: %v", err)
   639  	}
   640  	return nil
   641  }
   642  
   643  // Done calls onDoneCallback if available and only finishes when the return-
   644  // value is true.
   645  func (n *TreeNodeInstance) Done() {
   646  	if n.onDoneCallback != nil {
   647  		ok := n.onDoneCallback()
   648  		if !ok {
   649  			return
   650  		}
   651  	}
   652  	log.Lvl3(n.Info(), "has finished. Deleting its resources")
   653  	n.overlay.nodeDone(n.token)
   654  }
   655  
   656  // OnDoneCallback should be called if we want to control the Done() of the node.
   657  // It is used by protocols that uses others protocols inside and that want to
   658  // control when the final Done() should be called.
   659  // the function should return true if the real Done() has to be called otherwise
   660  // false.
   661  func (n *TreeNodeInstance) OnDoneCallback(fn func() bool) {
   662  	n.onDoneCallback = fn
   663  }
   664  
   665  // Private returns the private key of the service entity
   666  func (n *TreeNodeInstance) Private() kyber.Scalar {
   667  	serviceName := ServiceFactory.Name(n.token.ServiceID)
   668  
   669  	return n.Host().ServerIdentity.ServicePrivate(serviceName)
   670  }
   671  
   672  // Public returns the public key of the service, either the specific
   673  // or the default if not available
   674  func (n *TreeNodeInstance) Public() kyber.Point {
   675  	serviceName := ServiceFactory.Name(n.token.ServiceID)
   676  
   677  	return n.Host().ServerIdentity.ServicePublic(serviceName)
   678  }
   679  
   680  // Aggregate returns the sum of all public key of the roster for this TreeNodeInstance, either the specific
   681  // or the default if one or more of the nodes don't have the service-public key available.
   682  func (n *TreeNodeInstance) Aggregate() kyber.Point {
   683  	serviceName := ServiceFactory.Name(n.token.ServiceID)
   684  
   685  	agg, err := n.Roster().ServiceAggregate(serviceName)
   686  	if err != nil {
   687  		return n.Roster().Aggregate
   688  	}
   689  	return agg
   690  }
   691  
   692  // Publics makes a list of public keys for the service
   693  // associated with the instance
   694  func (n *TreeNodeInstance) Publics() []kyber.Point {
   695  	serviceName := ServiceFactory.Name(n.token.ServiceID)
   696  
   697  	return n.Roster().ServicePublics(serviceName)
   698  }
   699  
   700  // NodePublic returns the public key associated with the node's service
   701  // stored in the given server identity
   702  func (n *TreeNodeInstance) NodePublic(si *network.ServerIdentity) kyber.Point {
   703  	serviceName := ServiceFactory.Name(n.token.ServiceID)
   704  
   705  	return si.ServicePublic(serviceName)
   706  }
   707  
   708  // CloseHost closes the underlying onet.Host (which closes the overlay
   709  // and sends Shutdown to all protocol instances)
   710  // NOTE: It is to be used VERY carefully and is likely to disappear in the next
   711  // releases.
   712  func (n *TreeNodeInstance) CloseHost() error {
   713  	n.Host().callTestClose()
   714  	err := n.Host().Close()
   715  	if err != nil {
   716  		return xerrors.Errorf("closing host: %v", err)
   717  	}
   718  	return nil
   719  }
   720  
   721  // Name returns a human readable name of this Node (IP address).
   722  func (n *TreeNodeInstance) Name() string {
   723  	return n.ServerIdentity().Address.String()
   724  }
   725  
   726  // Info returns a human readable representation name of this Node
   727  // (IP address and TokenID).
   728  func (n *TreeNodeInstance) Info() string {
   729  	tid := n.TokenID()
   730  	name := protocols.ProtocolIDToName(n.token.ProtoID)
   731  	if name == "" {
   732  		name = n.overlay.server.protocols.ProtocolIDToName(n.token.ProtoID)
   733  	}
   734  	return fmt.Sprintf("%s (%s): %s", n.ServerIdentity().Address, tid.String(), name)
   735  }
   736  
   737  // TokenID returns the TokenID of the given node (to uniquely identify it)
   738  func (n *TreeNodeInstance) TokenID() TokenID {
   739  	return n.token.ID()
   740  }
   741  
   742  // Token returns a CLONE of the underlying onet.Token struct.
   743  // Useful for unit testing.
   744  func (n *TreeNodeInstance) Token() *Token {
   745  	return n.token.Clone()
   746  }
   747  
   748  // List returns the list of TreeNodes cached in the node (creating it if necessary)
   749  func (n *TreeNodeInstance) List() []*TreeNode {
   750  	n.mtx.Lock()
   751  	t := n.Tree()
   752  	if t != nil && n.treeNodeList == nil {
   753  		n.treeNodeList = t.List()
   754  	}
   755  	n.mtx.Unlock()
   756  	return n.treeNodeList
   757  }
   758  
   759  // Index returns the index of the node in the Roster
   760  func (n *TreeNodeInstance) Index() int {
   761  	return n.TreeNode().RosterIndex
   762  }
   763  
   764  // Broadcast sends a given message from the calling node directly to all other TreeNodes
   765  func (n *TreeNodeInstance) Broadcast(msg interface{}) []error {
   766  	var errs []error
   767  	for _, node := range n.List() {
   768  		if !node.Equal(n.TreeNode()) {
   769  			if err := n.SendTo(node, msg); err != nil {
   770  				errs = append(errs, xerrors.Errorf("sending: %v", err))
   771  			}
   772  		}
   773  	}
   774  	return errs
   775  }
   776  
   777  // Multicast ... XXX: should probably have a parallel more robust version like "SendToChildrenInParallel"
   778  func (n *TreeNodeInstance) Multicast(msg interface{}, nodes ...*TreeNode) []error {
   779  	var errs []error
   780  	for _, node := range nodes {
   781  		if err := n.SendTo(node, msg); err != nil {
   782  			errs = append(errs, xerrors.Errorf("sending: %v", err))
   783  		}
   784  	}
   785  	return errs
   786  }
   787  
   788  // SendToParent sends a given message to the parent of the calling node (unless it is the root)
   789  func (n *TreeNodeInstance) SendToParent(msg interface{}) error {
   790  	if n.IsRoot() {
   791  		return nil
   792  	}
   793  	err := n.SendTo(n.Parent(), msg)
   794  	if err != nil {
   795  		return xerrors.Errorf("sending: %v", err)
   796  	}
   797  	return nil
   798  }
   799  
   800  // SendToChildren sends a given message to all children of the calling node.
   801  // It stops sending if sending to one of the children fails. In that case it
   802  // returns an error. If the underlying node is a leaf node this function does
   803  // nothing.
   804  func (n *TreeNodeInstance) SendToChildren(msg interface{}) error {
   805  	if n.IsLeaf() {
   806  		return nil
   807  	}
   808  	for _, node := range n.Children() {
   809  		if err := n.SendTo(node, msg); err != nil {
   810  			return xerrors.Errorf("sending: %v", err)
   811  		}
   812  	}
   813  	return nil
   814  }
   815  
   816  // SendToChildrenInParallel sends a given message to all children of the calling
   817  // node. It has the following differences to node.SendToChildren:
   818  // The actual sending happens in a go routine (in parallel).
   819  // It continues sending to the other nodes if sending to one of the children
   820  // fails. In that case it will collect all errors in a slice.
   821  // If the underlying node is a leaf node this function does
   822  // nothing.
   823  func (n *TreeNodeInstance) SendToChildrenInParallel(msg interface{}) []error {
   824  	log.TraceID(n.token.RoundID[:])
   825  	if n.IsLeaf() {
   826  		return nil
   827  	}
   828  	children := n.Children()
   829  	var errs []error
   830  	eMut := sync.Mutex{}
   831  	wg := sync.WaitGroup{}
   832  	for _, node := range children {
   833  		name := node.Name()
   834  		wg.Add(1)
   835  		go func(n2 *TreeNode) {
   836  			log.TraceID(n.token.RoundID[:])
   837  			if err := n.SendTo(n2, msg); err != nil {
   838  				eMut.Lock()
   839  				errs = append(errs, xerrors.Errorf("%s: %v", name, err))
   840  				eMut.Unlock()
   841  			}
   842  			wg.Done()
   843  		}(node)
   844  	}
   845  	wg.Wait()
   846  	return errs
   847  }
   848  
   849  // CreateProtocol instantiates a new protocol of name "name" and
   850  // returns it with any error that might have happened during the creation. If
   851  // the TreeNodeInstance calling this is attached to a service, the new protocol
   852  // will also be attached to this same service. Else the new protocol will only
   853  // be handled by onet.
   854  func (n *TreeNodeInstance) CreateProtocol(name string, t *Tree) (ProtocolInstance, error) {
   855  	pi, err := n.overlay.CreateProtocol(name, t, n.Token().ServiceID)
   856  	if err != nil {
   857  		return nil, xerrors.Errorf("creating protocol: %v", err)
   858  	}
   859  	return pi, nil
   860  }
   861  
   862  // Host returns the underlying Host of this node.
   863  //
   864  // WARNING: you should not play with that feature unless you know what you are
   865  // doing. This feature is meant to access the low level parts of the API. For
   866  // example it is used to add a new tree config / new entity list to the Server.
   867  func (n *TreeNodeInstance) Host() *Server {
   868  	return n.overlay.server
   869  }
   870  
   871  // TreeNodeInstance returns itself (XXX quick hack for this services2 branch
   872  // version for the tests)
   873  func (n *TreeNodeInstance) TreeNodeInstance() *TreeNodeInstance {
   874  	return n
   875  }
   876  
   877  // SetConfig sets the GenericConfig c to be passed down in the first message
   878  // alongside with the protocol if it is non nil. This config can later be read
   879  // by Services in the NewProtocol method.
   880  func (n *TreeNodeInstance) SetConfig(c *GenericConfig) error {
   881  	n.configMut.Lock()
   882  	defer n.configMut.Unlock()
   883  	if n.config != nil {
   884  		return xerrors.New("Can't set config twice")
   885  	}
   886  	n.config = c
   887  	return nil
   888  }
   889  
   890  // Rx implements the CounterIO interface
   891  func (n *TreeNodeInstance) Rx() uint64 {
   892  	return n.rx.get()
   893  }
   894  
   895  // Tx implements the CounterIO interface
   896  func (n *TreeNodeInstance) Tx() uint64 {
   897  	return n.tx.get()
   898  }
   899  
   900  func (n *TreeNodeInstance) isBound() bool {
   901  	return n.instance != nil
   902  }
   903  
   904  func (n *TreeNodeInstance) bind(pi ProtocolInstance) {
   905  	n.instance = pi
   906  }