gopkg.in/dedis/onet.v2@v2.0.0-20181115163211-c8f3724038a7/treenode.go (about)

     1  package onet
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"reflect"
     7  	"sync"
     8  
     9  	"gopkg.in/dedis/kyber.v2"
    10  	"gopkg.in/dedis/onet.v2/log"
    11  	"gopkg.in/dedis/onet.v2/network"
    12  )
    13  
    14  // TreeNodeInstance represents a protocol-instance in a given TreeNode. It embeds an
    15  // Overlay where all the tree-structures are stored.
    16  type TreeNodeInstance struct {
    17  	overlay *Overlay
    18  	token   *Token
    19  	// cache for the TreeNode this Node is representing
    20  	treeNode *TreeNode
    21  	// cached list of all TreeNodes
    22  	treeNodeList []*TreeNode
    23  	// mutex to synchronise creation of treeNodeList
    24  	mtx sync.Mutex
    25  
    26  	// channels holds all channels available for the different message-types
    27  	channels map[network.MessageTypeID]interface{}
    28  	// registered handler-functions for that protocol
    29  	handlers map[network.MessageTypeID]interface{}
    30  	// flags for messages - only one channel/handler possible
    31  	messageTypeFlags map[network.MessageTypeID]uint32
    32  	// The protocolInstance belonging to that node
    33  	instance ProtocolInstance
    34  	// aggregate messages in order to dispatch them at once in the protocol
    35  	// instance
    36  	msgQueue map[network.MessageTypeID][]*ProtocolMsg
    37  	// done callback
    38  	onDoneCallback func() bool
    39  	// queue holding msgs
    40  	msgDispatchQueue []*ProtocolMsg
    41  	// locking for msgqueue
    42  	msgDispatchQueueMutex sync.Mutex
    43  	// kicking off new message
    44  	msgDispatchQueueWait chan bool
    45  	// whether this node is closing
    46  	closing bool
    47  
    48  	protoIO MessageProxy
    49  
    50  	// config is to be passed down in the first message of what the protocol is
    51  	// sending if it is non nil. Set with `tni.SetConfig()`.
    52  	config    *GenericConfig
    53  	sentTo    map[TreeNodeID]bool
    54  	configMut sync.Mutex
    55  
    56  	// used for the CounterIO interface
    57  	tx safeAdder
    58  	rx safeAdder
    59  }
    60  
    61  type safeAdder struct {
    62  	sync.RWMutex
    63  	x uint64
    64  }
    65  
    66  func (a *safeAdder) add(x uint64) {
    67  	a.Lock()
    68  	a.x += x
    69  	a.Unlock()
    70  }
    71  
    72  func (a *safeAdder) get() (x uint64) {
    73  	a.RLock()
    74  	x = a.x
    75  	a.RUnlock()
    76  	return
    77  }
    78  
    79  const (
    80  	// AggregateMessages (if set) tells to aggregate messages from all children
    81  	// before sending to the (parent) Node
    82  	AggregateMessages = 1
    83  
    84  	// DefaultChannelLength is the default number of messages that can wait
    85  	// in a channel.
    86  	DefaultChannelLength = 100
    87  )
    88  
    89  // MsgHandler is called upon reception of a certain message-type
    90  type MsgHandler func([]*interface{})
    91  
    92  // NewNode creates a new node
    93  func newTreeNodeInstance(o *Overlay, tok *Token, tn *TreeNode, io MessageProxy) *TreeNodeInstance {
    94  	n := &TreeNodeInstance{overlay: o,
    95  		token:                tok,
    96  		channels:             make(map[network.MessageTypeID]interface{}),
    97  		handlers:             make(map[network.MessageTypeID]interface{}),
    98  		messageTypeFlags:     make(map[network.MessageTypeID]uint32),
    99  		msgQueue:             make(map[network.MessageTypeID][]*ProtocolMsg),
   100  		treeNode:             tn,
   101  		msgDispatchQueue:     make([]*ProtocolMsg, 0, 1),
   102  		msgDispatchQueueWait: make(chan bool, 1),
   103  		protoIO:              io,
   104  		sentTo:               make(map[TreeNodeID]bool),
   105  	}
   106  	go n.dispatchMsgReader()
   107  	return n
   108  }
   109  
   110  // TreeNode gets the treeNode of this node. If there is no TreeNode for the
   111  // Token of this node, the function will return nil
   112  func (n *TreeNodeInstance) TreeNode() *TreeNode {
   113  	return n.treeNode
   114  }
   115  
   116  // ServerIdentity returns our entity
   117  func (n *TreeNodeInstance) ServerIdentity() *network.ServerIdentity {
   118  	return n.treeNode.ServerIdentity
   119  }
   120  
   121  // Parent returns the parent-TreeNode of ourselves
   122  func (n *TreeNodeInstance) Parent() *TreeNode {
   123  	return n.treeNode.Parent
   124  }
   125  
   126  // Children returns the children of ourselves
   127  func (n *TreeNodeInstance) Children() []*TreeNode {
   128  	return n.treeNode.Children
   129  }
   130  
   131  // Root returns the root-node of that tree
   132  func (n *TreeNodeInstance) Root() *TreeNode {
   133  	return n.Tree().Root
   134  }
   135  
   136  // IsRoot returns whether whether we are at the top of the tree
   137  func (n *TreeNodeInstance) IsRoot() bool {
   138  	return n.treeNode.Parent == nil
   139  }
   140  
   141  // IsLeaf returns whether whether we are at the bottom of the tree
   142  func (n *TreeNodeInstance) IsLeaf() bool {
   143  	return len(n.treeNode.Children) == 0
   144  }
   145  
   146  // SendTo sends to a given node
   147  func (n *TreeNodeInstance) SendTo(to *TreeNode, msg interface{}) error {
   148  	if to == nil {
   149  		return errors.New("Sent to a nil TreeNode")
   150  	}
   151  	n.msgDispatchQueueMutex.Lock()
   152  	if n.closing {
   153  		n.msgDispatchQueueMutex.Unlock()
   154  		return errors.New("is closing")
   155  	}
   156  	n.msgDispatchQueueMutex.Unlock()
   157  	var c *GenericConfig
   158  	// only sends the config once
   159  	n.configMut.Lock()
   160  	if !n.sentTo[to.ID] {
   161  		c = n.config
   162  		n.sentTo[to.ID] = true
   163  	}
   164  	n.configMut.Unlock()
   165  
   166  	sentLen, err := n.overlay.SendToTreeNode(n.token, to, msg, n.protoIO, c)
   167  	n.tx.add(sentLen)
   168  	return err
   169  }
   170  
   171  // Tree returns the tree of that node
   172  func (n *TreeNodeInstance) Tree() *Tree {
   173  	return n.overlay.treeCache.GetFromToken(n.token)
   174  }
   175  
   176  // Roster returns the entity-list
   177  func (n *TreeNodeInstance) Roster() *Roster {
   178  	return n.Tree().Roster
   179  }
   180  
   181  // Suite can be used to get the current kyber.Suite (currently hardcoded into
   182  // the network library).
   183  func (n *TreeNodeInstance) Suite() network.Suite {
   184  	return n.overlay.suite()
   185  }
   186  
   187  // RegisterChannel is a compatibility-method for RegisterChannelLength
   188  // and setting up a channel with length 100.
   189  func (n *TreeNodeInstance) RegisterChannel(c interface{}) error {
   190  	return n.RegisterChannelLength(c, DefaultChannelLength)
   191  }
   192  
   193  // RegisterChannelLength takes a channel with a struct that contains two
   194  // elements: a TreeNode and a message. The second argument is the length of
   195  // the channel. It will send every message that are the
   196  // same type to this channel.
   197  // This function handles also
   198  // - registration of the message-type
   199  // - aggregation or not of messages: if you give a channel of slices, the
   200  //   messages will be aggregated, else they will come one-by-one
   201  func (n *TreeNodeInstance) RegisterChannelLength(c interface{}, length int) error {
   202  	flags := uint32(0)
   203  	cr := reflect.TypeOf(c)
   204  	if cr.Kind() == reflect.Ptr {
   205  		val := reflect.ValueOf(c).Elem()
   206  		val.Set(reflect.MakeChan(val.Type(), length))
   207  		return n.RegisterChannel(reflect.Indirect(val).Interface())
   208  	} else if reflect.ValueOf(c).IsNil() {
   209  		return errors.New("Can not Register a (value) channel not initialized")
   210  	}
   211  	// Check we have the correct channel-type
   212  	if cr.Kind() != reflect.Chan {
   213  		return errors.New("Input is not channel")
   214  	}
   215  	if cr.Elem().Kind() == reflect.Slice {
   216  		flags += AggregateMessages
   217  		cr = cr.Elem()
   218  	}
   219  	if cr.Elem().Kind() != reflect.Struct {
   220  		return errors.New("Input is not channel of structure")
   221  	}
   222  	if cr.Elem().NumField() != 2 {
   223  		return errors.New("Input is not channel of structure with 2 elements")
   224  	}
   225  	if cr.Elem().Field(0).Type != reflect.TypeOf(&TreeNode{}) {
   226  		return errors.New("Input-channel doesn't have TreeNode as element")
   227  	}
   228  	// Automatic registration of the message to the network library.
   229  	m := reflect.New(cr.Elem().Field(1).Type)
   230  	typ := network.RegisterMessage(m.Interface())
   231  	n.channels[typ] = c
   232  	//typ := network.RTypeToUUID(cr.Elem().Field(1).Type) n.channels[typ] = c
   233  	n.messageTypeFlags[typ] = flags
   234  	log.Lvl4("Registered channel", typ, "with flags", flags)
   235  	return nil
   236  }
   237  
   238  // RegisterChannels registers a list of given channels by calling RegisterChannel above
   239  func (n *TreeNodeInstance) RegisterChannels(channels ...interface{}) error {
   240  	for _, ch := range channels {
   241  		if err := n.RegisterChannel(ch); err != nil {
   242  			return fmt.Errorf("Error, could not register channel %T: %s",
   243  				ch, err.Error())
   244  		}
   245  	}
   246  	return nil
   247  }
   248  
   249  // RegisterChannelsLength is a convenience function to register a vararg of
   250  // channels with a given length.
   251  func (n *TreeNodeInstance) RegisterChannelsLength(length int, channels ...interface{}) error {
   252  	for _, ch := range channels {
   253  		if err := n.RegisterChannelLength(ch, length); err != nil {
   254  			return fmt.Errorf("Error, could not register channel %T: %s",
   255  				ch, err.Error())
   256  		}
   257  	}
   258  	return nil
   259  }
   260  
   261  // RegisterHandler takes a function which takes a struct as argument that contains two
   262  // elements: a TreeNode and a message. It will send every message that are the
   263  // same type to this channel.
   264  //
   265  // This function also handles:
   266  //     - registration of the message-type
   267  //     - aggregation or not of messages: if you give a channel of slices, the
   268  //       messages will be aggregated, otherwise they will come one by one
   269  func (n *TreeNodeInstance) RegisterHandler(c interface{}) error {
   270  	flags := uint32(0)
   271  	cr := reflect.TypeOf(c)
   272  	// Check we have the correct channel-type
   273  	if cr.Kind() != reflect.Func {
   274  		return errors.New("Input is not function")
   275  	}
   276  	if cr.NumOut() != 1 {
   277  		return errors.New("Need exactly one return argument of type error")
   278  	}
   279  	if cr.Out(0) != reflect.TypeOf((*error)(nil)).Elem() {
   280  		return errors.New("return-type of message-handler needs to be error")
   281  	}
   282  	ci := cr.In(0)
   283  	if ci.Kind() == reflect.Slice {
   284  		flags += AggregateMessages
   285  		ci = ci.Elem()
   286  	}
   287  	if ci.Kind() != reflect.Struct {
   288  		return errors.New("Input is not a structure")
   289  	}
   290  	if ci.NumField() != 2 {
   291  		return errors.New("Input is not a structure with 2 elements")
   292  	}
   293  	if ci.Field(0).Type != reflect.TypeOf(&TreeNode{}) {
   294  		return errors.New("Input-handler doesn't have TreeNode as element")
   295  	}
   296  	// Automatic registration of the message to the network library.
   297  	ptr := reflect.New(ci.Field(1).Type)
   298  	typ := network.RegisterMessage(ptr.Interface())
   299  	n.handlers[typ] = c
   300  	n.messageTypeFlags[typ] = flags
   301  	log.Lvl3("Registered handler", typ, "with flags", flags)
   302  	return nil
   303  }
   304  
   305  // RegisterHandlers registers a list of given handlers by calling RegisterHandler above
   306  func (n *TreeNodeInstance) RegisterHandlers(handlers ...interface{}) error {
   307  	for _, h := range handlers {
   308  		if err := n.RegisterHandler(h); err != nil {
   309  			return fmt.Errorf("Error, could not register handler %T: %s",
   310  				h, err.Error())
   311  		}
   312  	}
   313  	return nil
   314  }
   315  
   316  // ProtocolInstance returns the instance of the running protocol
   317  func (n *TreeNodeInstance) ProtocolInstance() ProtocolInstance {
   318  	return n.instance
   319  }
   320  
   321  // Dispatch - the standard dispatching function is empty
   322  func (n *TreeNodeInstance) Dispatch() error {
   323  	return nil
   324  }
   325  
   326  // Shutdown - standard Shutdown implementation. Define your own
   327  // in your protocol (if necessary)
   328  func (n *TreeNodeInstance) Shutdown() error {
   329  	return nil
   330  }
   331  
   332  // closeDispatch shuts down the go-routine and calls the protocolInstance-shutdown
   333  func (n *TreeNodeInstance) closeDispatch() error {
   334  	log.Lvl3("Closing node", n.Info())
   335  	n.msgDispatchQueueMutex.Lock()
   336  	n.closing = true
   337  	close(n.msgDispatchQueueWait)
   338  	n.msgDispatchQueueMutex.Unlock()
   339  	log.Lvl3("Closed node", n.Info())
   340  	pni := n.ProtocolInstance()
   341  	if pni == nil {
   342  		return errors.New("Can't shutdown empty ProtocolInstance")
   343  	}
   344  	return pni.Shutdown()
   345  }
   346  
   347  // ProtocolName will return the string representing that protocol
   348  func (n *TreeNodeInstance) ProtocolName() string {
   349  	return n.overlay.server.protocols.ProtocolIDToName(n.token.ProtoID)
   350  }
   351  
   352  func (n *TreeNodeInstance) dispatchHandler(msgSlice []*ProtocolMsg) error {
   353  	mt := msgSlice[0].MsgType
   354  	to := reflect.TypeOf(n.handlers[mt]).In(0)
   355  	f := reflect.ValueOf(n.handlers[mt])
   356  	var errV reflect.Value
   357  	if n.hasFlag(mt, AggregateMessages) {
   358  		msgs := reflect.MakeSlice(to, len(msgSlice), len(msgSlice))
   359  		for i, msg := range msgSlice {
   360  			msgs.Index(i).Set(n.reflectCreate(to.Elem(), msg))
   361  		}
   362  		log.Lvl4("Dispatching aggregation to", n.ServerIdentity().Address)
   363  		errV = f.Call([]reflect.Value{msgs})[0]
   364  	} else {
   365  		for _, msg := range msgSlice {
   366  			if errV.IsValid() && !errV.IsNil() {
   367  				// Before overwriting an error, print it out
   368  				log.Errorf("%s: error while dispatching message %s: %s",
   369  					n.Name(), reflect.TypeOf(msg.Msg),
   370  					errV.Interface().(error))
   371  			}
   372  			log.Lvl4("Dispatching", msg, "to", n.ServerIdentity().Address)
   373  			m := n.reflectCreate(to, msg)
   374  			errV = f.Call([]reflect.Value{m})[0]
   375  		}
   376  	}
   377  	log.Lvlf4("%s Done with handler for %s", n.Name(), f.Type())
   378  	if !errV.IsNil() {
   379  		return errV.Interface().(error)
   380  	}
   381  	return nil
   382  }
   383  
   384  func (n *TreeNodeInstance) reflectCreate(t reflect.Type, msg *ProtocolMsg) reflect.Value {
   385  	m := reflect.Indirect(reflect.New(t))
   386  	tn := n.Tree().Search(msg.From.TreeNodeID)
   387  	if tn != nil {
   388  		m.Field(0).Set(reflect.ValueOf(tn))
   389  		m.Field(1).Set(reflect.Indirect(reflect.ValueOf(msg.Msg)))
   390  	}
   391  	return m
   392  }
   393  
   394  // dispatchChannel takes a message and sends it to a channel
   395  func (n *TreeNodeInstance) dispatchChannel(msgSlice []*ProtocolMsg) error {
   396  	mt := msgSlice[0].MsgType
   397  	defer func() {
   398  		// In rare occasions we write to a closed channel which throws a panic.
   399  		// Catch it here so we can find out better why this happens.
   400  		if r := recover(); r != nil {
   401  			log.Error("Shouldn't happen, please report an issue:", n.Info(), r, mt)
   402  		}
   403  	}()
   404  	to := reflect.TypeOf(n.channels[mt])
   405  	if n.hasFlag(mt, AggregateMessages) {
   406  		log.Lvl4("Received aggregated message of type:", mt)
   407  		to = to.Elem()
   408  		out := reflect.MakeSlice(to, len(msgSlice), len(msgSlice))
   409  		for i, msg := range msgSlice {
   410  			log.Lvl4("Dispatching aggregated to", to)
   411  			m := n.reflectCreate(to.Elem(), msg)
   412  			log.Lvl4("Adding msg", m, "to", n.ServerIdentity().Address)
   413  			out.Index(i).Set(m)
   414  		}
   415  		reflect.ValueOf(n.channels[mt]).Send(out)
   416  	} else {
   417  		for _, msg := range msgSlice {
   418  			out := reflect.ValueOf(n.channels[mt])
   419  			m := n.reflectCreate(to.Elem(), msg)
   420  			log.Lvl4(n.Name(), "Dispatching msg type", mt, " to", to, " :", m.Field(1).Interface())
   421  			if out.Len() < out.Cap() {
   422  				n.msgDispatchQueueMutex.Lock()
   423  				closing := n.closing
   424  				n.msgDispatchQueueMutex.Unlock()
   425  				if !closing {
   426  					out.Send(m)
   427  				}
   428  			} else {
   429  				return fmt.Errorf("channel too small for msg %s in %s: "+
   430  					"please use RegisterChannelLength()",
   431  					mt, n.ProtocolName())
   432  			}
   433  		}
   434  	}
   435  	return nil
   436  }
   437  
   438  // ProcessProtocolMsg takes a message and puts it into a queue for later processing.
   439  // This allows a protocol to have a backlog of messages.
   440  func (n *TreeNodeInstance) ProcessProtocolMsg(msg *ProtocolMsg) {
   441  	log.Lvl4(n.Info(), "Received message")
   442  	n.msgDispatchQueueMutex.Lock()
   443  	defer n.msgDispatchQueueMutex.Unlock()
   444  	if n.closing {
   445  		log.Lvl3("Received message for closed protocol")
   446  		return
   447  	}
   448  	n.msgDispatchQueue = append(n.msgDispatchQueue, msg)
   449  	n.notifyDispatch()
   450  }
   451  
   452  func (n *TreeNodeInstance) notifyDispatch() {
   453  	select {
   454  	case n.msgDispatchQueueWait <- true:
   455  		return
   456  	default:
   457  		// Channel write would block: already been notified.
   458  		// So, nothing to do here.
   459  	}
   460  }
   461  
   462  func (n *TreeNodeInstance) dispatchMsgReader() {
   463  	log.Lvl3("Starting node", n.Info())
   464  	for {
   465  		n.msgDispatchQueueMutex.Lock()
   466  		if n.closing {
   467  			log.Lvl3("Closing reader")
   468  			n.msgDispatchQueueMutex.Unlock()
   469  			return
   470  		}
   471  		if len(n.msgDispatchQueue) > 0 {
   472  			log.Lvl4(n.Info(), "Read message and dispatching it",
   473  				len(n.msgDispatchQueue))
   474  			msg := n.msgDispatchQueue[0]
   475  			n.msgDispatchQueue = n.msgDispatchQueue[1:]
   476  			n.msgDispatchQueueMutex.Unlock()
   477  			err := n.dispatchMsgToProtocol(msg)
   478  			if err != nil {
   479  				log.Errorf("%s: error while dispatching message %s: %s",
   480  					n.Name(), reflect.TypeOf(msg.Msg), err)
   481  			}
   482  		} else {
   483  			n.msgDispatchQueueMutex.Unlock()
   484  			log.Lvl4(n.Info(), "Waiting for message")
   485  			// Allow for closing of the channel
   486  			select {
   487  			case <-n.msgDispatchQueueWait:
   488  			}
   489  		}
   490  	}
   491  }
   492  
   493  // dispatchMsgToProtocol will dispatch this onet.Data to the right instance
   494  func (n *TreeNodeInstance) dispatchMsgToProtocol(onetMsg *ProtocolMsg) error {
   495  
   496  	n.rx.add(uint64(onetMsg.Size))
   497  
   498  	// if message comes from parent, dispatch directly
   499  	// if messages come from children we must aggregate them
   500  	// if we still need to wait for additional messages, we return
   501  	msgType, msgs, done := n.aggregate(onetMsg)
   502  	if !done {
   503  		log.Lvl3(n.Name(), "Not done aggregating children msgs")
   504  		return nil
   505  	}
   506  	log.Lvlf5("%s->%s: Message is: %+v", onetMsg.From, n.Name(), onetMsg.Msg)
   507  
   508  	var err error
   509  	switch {
   510  	case n.channels[msgType] != nil:
   511  		log.Lvl4(n.Name(), "Dispatching to channel")
   512  		err = n.dispatchChannel(msgs)
   513  	case n.handlers[msgType] != nil:
   514  		log.Lvl4("Dispatching to handler", n.ServerIdentity().Address)
   515  		err = n.dispatchHandler(msgs)
   516  	default:
   517  		return fmt.Errorf("message-type not handled by the protocol: %s", reflect.TypeOf(onetMsg.Msg))
   518  	}
   519  	return err
   520  }
   521  
   522  // setFlag makes sure a given flag is set
   523  func (n *TreeNodeInstance) setFlag(mt network.MessageTypeID, f uint32) {
   524  	n.messageTypeFlags[mt] |= f
   525  }
   526  
   527  // clearFlag makes sure a given flag is removed
   528  func (n *TreeNodeInstance) clearFlag(mt network.MessageTypeID, f uint32) {
   529  	n.messageTypeFlags[mt] &^= f
   530  }
   531  
   532  // hasFlag returns true if the given flag is set
   533  func (n *TreeNodeInstance) hasFlag(mt network.MessageTypeID, f uint32) bool {
   534  	return n.messageTypeFlags[mt]&f != 0
   535  }
   536  
   537  // aggregate store the message for a protocol instance such that a protocol
   538  // instances will get all its children messages at once.
   539  // node is the node the host is representing in this Tree, and onetMsg is the
   540  // message being analyzed.
   541  func (n *TreeNodeInstance) aggregate(onetMsg *ProtocolMsg) (network.MessageTypeID, []*ProtocolMsg, bool) {
   542  	mt := onetMsg.MsgType
   543  	fromParent := !n.IsRoot() && onetMsg.From.TreeNodeID.Equal(n.Parent().ID)
   544  	if fromParent || !n.hasFlag(mt, AggregateMessages) {
   545  		return mt, []*ProtocolMsg{onetMsg}, true
   546  	}
   547  	// store the msg according to its type
   548  	if _, ok := n.msgQueue[mt]; !ok {
   549  		n.msgQueue[mt] = make([]*ProtocolMsg, 0)
   550  	}
   551  	msgs := append(n.msgQueue[mt], onetMsg)
   552  	n.msgQueue[mt] = msgs
   553  	log.Lvl4(n.ServerIdentity().Address, "received", len(msgs), "of", len(n.Children()), "messages")
   554  
   555  	// do we have everything yet or no
   556  	// get the node this host is in this tree
   557  	// OK we have all the children messages
   558  	if len(msgs) == len(n.Children()) {
   559  		// erase
   560  		delete(n.msgQueue, mt)
   561  		return mt, msgs, true
   562  	}
   563  	// no we still have to wait!
   564  	return mt, nil, false
   565  }
   566  
   567  // startProtocol calls the Start() on the underlying protocol which in turn will
   568  // initiate the first message to its children
   569  func (n *TreeNodeInstance) startProtocol() error {
   570  	return n.instance.Start()
   571  }
   572  
   573  // Done calls onDoneCallback if available and only finishes when the return-
   574  // value is true.
   575  func (n *TreeNodeInstance) Done() {
   576  	if n.onDoneCallback != nil {
   577  		ok := n.onDoneCallback()
   578  		if !ok {
   579  			return
   580  		}
   581  	}
   582  	log.Lvl3(n.Info(), "has finished. Deleting its resources")
   583  	n.overlay.nodeDone(n.token)
   584  }
   585  
   586  // OnDoneCallback should be called if we want to control the Done() of the node.
   587  // It is used by protocols that uses others protocols inside and that want to
   588  // control when the final Done() should be called.
   589  // the function should return true if the real Done() has to be called otherwise
   590  // false.
   591  func (n *TreeNodeInstance) OnDoneCallback(fn func() bool) {
   592  	n.onDoneCallback = fn
   593  }
   594  
   595  // Private returns the private key of the entity
   596  func (n *TreeNodeInstance) Private() kyber.Scalar {
   597  	return n.Host().private
   598  }
   599  
   600  // Public returns the public key of the entity
   601  func (n *TreeNodeInstance) Public() kyber.Point {
   602  	return n.ServerIdentity().Public
   603  }
   604  
   605  // CloseHost closes the underlying onet.Host (which closes the overlay
   606  // and sends Shutdown to all protocol instances)
   607  // NOTE: It is to be used VERY carefully and is likely to disappear in the next
   608  // releases.
   609  func (n *TreeNodeInstance) CloseHost() error {
   610  	return n.Host().Close()
   611  }
   612  
   613  // Name returns a human readable name of this Node (IP address).
   614  func (n *TreeNodeInstance) Name() string {
   615  	return n.ServerIdentity().Address.String()
   616  }
   617  
   618  // Info returns a human readable representation name of this Node
   619  // (IP address and TokenID).
   620  func (n *TreeNodeInstance) Info() string {
   621  	tid := n.TokenID()
   622  	name := protocols.ProtocolIDToName(n.token.ProtoID)
   623  	if name == "" {
   624  		name = n.overlay.server.protocols.ProtocolIDToName(n.token.ProtoID)
   625  	}
   626  	return fmt.Sprintf("%s (%s): %s", n.ServerIdentity().Address, tid.String(), name)
   627  }
   628  
   629  // TokenID returns the TokenID of the given node (to uniquely identify it)
   630  func (n *TreeNodeInstance) TokenID() TokenID {
   631  	return n.token.ID()
   632  }
   633  
   634  // Token returns a CLONE of the underlying onet.Token struct.
   635  // Useful for unit testing.
   636  func (n *TreeNodeInstance) Token() *Token {
   637  	return n.token.Clone()
   638  }
   639  
   640  // List returns the list of TreeNodes cached in the node (creating it if necessary)
   641  func (n *TreeNodeInstance) List() []*TreeNode {
   642  	n.mtx.Lock()
   643  	if n.treeNodeList == nil {
   644  		n.treeNodeList = n.Tree().List()
   645  	}
   646  	n.mtx.Unlock()
   647  	return n.treeNodeList
   648  }
   649  
   650  // Index returns the index of the node in the Roster
   651  func (n *TreeNodeInstance) Index() int {
   652  	return n.TreeNode().RosterIndex
   653  }
   654  
   655  // Broadcast sends a given message from the calling node directly to all other TreeNodes
   656  func (n *TreeNodeInstance) Broadcast(msg interface{}) []error {
   657  	var errs []error
   658  	for _, node := range n.List() {
   659  		if !node.Equal(n.TreeNode()) {
   660  			if err := n.SendTo(node, msg); err != nil {
   661  				errs = append(errs, err)
   662  			}
   663  		}
   664  	}
   665  	return errs
   666  }
   667  
   668  // Multicast ... XXX: should probably have a parallel more robust version like "SendToChildrenInParallel"
   669  func (n *TreeNodeInstance) Multicast(msg interface{}, nodes ...*TreeNode) []error {
   670  	var errs []error
   671  	for _, node := range nodes {
   672  		if err := n.SendTo(node, msg); err != nil {
   673  			errs = append(errs, err)
   674  		}
   675  	}
   676  	return errs
   677  }
   678  
   679  // SendToParent sends a given message to the parent of the calling node (unless it is the root)
   680  func (n *TreeNodeInstance) SendToParent(msg interface{}) error {
   681  	if n.IsRoot() {
   682  		return nil
   683  	}
   684  	return n.SendTo(n.Parent(), msg)
   685  }
   686  
   687  // SendToChildren sends a given message to all children of the calling node.
   688  // It stops sending if sending to one of the children fails. In that case it
   689  // returns an error. If the underlying node is a leaf node this function does
   690  // nothing.
   691  func (n *TreeNodeInstance) SendToChildren(msg interface{}) error {
   692  	if n.IsLeaf() {
   693  		return nil
   694  	}
   695  	for _, node := range n.Children() {
   696  		if err := n.SendTo(node, msg); err != nil {
   697  			return err
   698  		}
   699  	}
   700  	return nil
   701  }
   702  
   703  // SendToChildrenInParallel sends a given message to all children of the calling
   704  // node. It has the following differences to node.SendToChildren:
   705  // The actual sending happens in a go routine (in parallel).
   706  // It continues sending to the other nodes if sending to one of the children
   707  // fails. In that case it will collect all errors in a slice.
   708  // If the underlying node is a leaf node this function does
   709  // nothing.
   710  func (n *TreeNodeInstance) SendToChildrenInParallel(msg interface{}) []error {
   711  	if n.IsLeaf() {
   712  		return nil
   713  	}
   714  	children := n.Children()
   715  	var errs []error
   716  	eMut := sync.Mutex{}
   717  	wg := sync.WaitGroup{}
   718  	for _, node := range children {
   719  		name := node.Name()
   720  		wg.Add(1)
   721  		go func(n2 *TreeNode) {
   722  			if err := n.SendTo(n2, msg); err != nil {
   723  				eMut.Lock()
   724  				errs = append(errs, errors.New(name+": "+err.Error()))
   725  				eMut.Unlock()
   726  			}
   727  			wg.Done()
   728  		}(node)
   729  	}
   730  	wg.Wait()
   731  	return errs
   732  }
   733  
   734  // CreateProtocol instantiates a new protocol of name "name" and
   735  // returns it with any error that might have happened during the creation. If
   736  // the TreeNodeInstance calling this is attached to a service, the new protocol
   737  // will also be attached to this same service. Else the new protocol will only
   738  // be handled by onet.
   739  func (n *TreeNodeInstance) CreateProtocol(name string, t *Tree) (ProtocolInstance, error) {
   740  	pi, err := n.overlay.CreateProtocol(name, t, n.Token().ServiceID)
   741  	return pi, err
   742  }
   743  
   744  // Host returns the underlying Host of this node.
   745  //
   746  // WARNING: you should not play with that feature unless you know what you are
   747  // doing. This feature is meant to access the low level parts of the API. For
   748  // example it is used to add a new tree config / new entity list to the Server.
   749  func (n *TreeNodeInstance) Host() *Server {
   750  	return n.overlay.server
   751  }
   752  
   753  // TreeNodeInstance returns itself (XXX quick hack for this services2 branch
   754  // version for the tests)
   755  func (n *TreeNodeInstance) TreeNodeInstance() *TreeNodeInstance {
   756  	return n
   757  }
   758  
   759  // SetConfig sets the GenericConfig c to be passed down in the first message
   760  // alongside with the protocol if it is non nil. This config can later be read
   761  // by Services in the NewProtocol method.
   762  func (n *TreeNodeInstance) SetConfig(c *GenericConfig) error {
   763  	n.configMut.Lock()
   764  	defer n.configMut.Unlock()
   765  	if n.config != nil {
   766  		return errors.New("Can't set config twice")
   767  	}
   768  	n.config = c
   769  	return nil
   770  }
   771  
   772  // Rx implements the CounterIO interface
   773  func (n *TreeNodeInstance) Rx() uint64 {
   774  	return n.rx.get()
   775  }
   776  
   777  // Tx implements the CounterIO interface
   778  func (n *TreeNodeInstance) Tx() uint64 {
   779  	return n.tx.get()
   780  }
   781  
   782  func (n *TreeNodeInstance) isBound() bool {
   783  	return n.instance != nil
   784  }
   785  
   786  func (n *TreeNodeInstance) bind(pi ProtocolInstance) {
   787  	n.instance = pi
   788  }