github.com/zhiqiangxu/go-ethereum@v1.9.16-0.20210824055606-be91cfdebc48/p2p/nodestate/nodestate.go (about)

     1  // Copyright 2020 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package nodestate
    18  
    19  import (
    20  	"errors"
    21  	"reflect"
    22  	"sync"
    23  	"time"
    24  	"unsafe"
    25  
    26  	"github.com/zhiqiangxu/go-ethereum/common/mclock"
    27  	"github.com/zhiqiangxu/go-ethereum/ethdb"
    28  	"github.com/zhiqiangxu/go-ethereum/log"
    29  	"github.com/zhiqiangxu/go-ethereum/metrics"
    30  	"github.com/zhiqiangxu/go-ethereum/p2p/enode"
    31  	"github.com/zhiqiangxu/go-ethereum/p2p/enr"
    32  	"github.com/zhiqiangxu/go-ethereum/rlp"
    33  )
    34  
    35  type (
    36  	// NodeStateMachine connects different system components operating on subsets of
    37  	// network nodes. Node states are represented by 64 bit vectors with each bit assigned
    38  	// to a state flag. Each state flag has a descriptor structure and the mapping is
    39  	// created automatically. It is possible to subscribe to subsets of state flags and
    40  	// receive a callback if one of the nodes has a relevant state flag changed.
    41  	// Callbacks can also modify further flags of the same node or other nodes. State
    42  	// updates only return after all immediate effects throughout the system have happened
    43  	// (deadlocks should be avoided by design of the implemented state logic). The caller
    44  	// can also add timeouts assigned to a certain node and a subset of state flags.
    45  	// If the timeout elapses, the flags are reset. If all relevant flags are reset then
    46  	// the timer is dropped. State flags with no timeout are persisted in the database
    47  	// if the flag descriptor enables saving. If a node has no state flags set at any
    48  	// moment then it is discarded.
    49  	//
    50  	// Extra node fields can also be registered so system components can also store more
    51  	// complex state for each node that is relevant to them, without creating a custom
    52  	// peer set. Fields can be shared across multiple components if they all know the
    53  	// field ID. Subscription to fields is also possible. Persistent fields should have
    54  	// an encoder and a decoder function.
    55  	NodeStateMachine struct {
    56  		started, stopped    bool
    57  		lock                sync.Mutex
    58  		clock               mclock.Clock
    59  		db                  ethdb.KeyValueStore
    60  		dbNodeKey           []byte
    61  		nodes               map[enode.ID]*nodeInfo
    62  		offlineCallbackList []offlineCallback
    63  
    64  		// Registered state flags or fields. Modifications are allowed
    65  		// only when the node state machine has not been started.
    66  		setup     *Setup
    67  		fields    []*fieldInfo
    68  		saveFlags bitMask
    69  
    70  		// Installed callbacks. Modifications are allowed only when the
    71  		// node state machine has not been started.
    72  		stateSubs []stateSub
    73  
    74  		// Testing hooks, only for testing purposes.
    75  		saveNodeHook func(*nodeInfo)
    76  	}
    77  
    78  	// Flags represents a set of flags from a certain setup
    79  	Flags struct {
    80  		mask  bitMask
    81  		setup *Setup
    82  	}
    83  
    84  	// Field represents a field from a certain setup
    85  	Field struct {
    86  		index int
    87  		setup *Setup
    88  	}
    89  
    90  	// flagDefinition describes a node state flag. Each registered instance is automatically
    91  	// mapped to a bit of the 64 bit node states.
    92  	// If persistent is true then the node is saved when state machine is shutdown.
    93  	flagDefinition struct {
    94  		name       string
    95  		persistent bool
    96  	}
    97  
    98  	// fieldDefinition describes an optional node field of the given type. The contents
    99  	// of the field are only retained for each node as long as at least one of the
   100  	// state flags is set.
   101  	fieldDefinition struct {
   102  		name   string
   103  		ftype  reflect.Type
   104  		encode func(interface{}) ([]byte, error)
   105  		decode func([]byte) (interface{}, error)
   106  	}
   107  
   108  	// stateSetup contains the list of flags and fields used by the application
   109  	Setup struct {
   110  		Version uint
   111  		flags   []flagDefinition
   112  		fields  []fieldDefinition
   113  	}
   114  
   115  	// bitMask describes a node state or state mask. It represents a subset
   116  	// of node flags with each bit assigned to a flag index (LSB represents flag 0).
   117  	bitMask uint64
   118  
   119  	// StateCallback is a subscription callback which is called when one of the
   120  	// state flags that is included in the subscription state mask is changed.
   121  	// Note: oldState and newState are also masked with the subscription mask so only
   122  	// the relevant bits are included.
   123  	StateCallback func(n *enode.Node, oldState, newState Flags)
   124  
   125  	// FieldCallback is a subscription callback which is called when the value of
   126  	// a specific field is changed.
   127  	FieldCallback func(n *enode.Node, state Flags, oldValue, newValue interface{})
   128  
   129  	// nodeInfo contains node state, fields and state timeouts
   130  	nodeInfo struct {
   131  		node      *enode.Node
   132  		state     bitMask
   133  		timeouts  []*nodeStateTimeout
   134  		fields    []interface{}
   135  		db, dirty bool
   136  	}
   137  
   138  	nodeInfoEnc struct {
   139  		Enr     enr.Record
   140  		Version uint
   141  		State   bitMask
   142  		Fields  [][]byte
   143  	}
   144  
   145  	stateSub struct {
   146  		mask     bitMask
   147  		callback StateCallback
   148  	}
   149  
   150  	nodeStateTimeout struct {
   151  		mask  bitMask
   152  		timer mclock.Timer
   153  	}
   154  
   155  	fieldInfo struct {
   156  		fieldDefinition
   157  		subs []FieldCallback
   158  	}
   159  
   160  	offlineCallback struct {
   161  		node   *enode.Node
   162  		state  bitMask
   163  		fields []interface{}
   164  	}
   165  )
   166  
   167  // offlineState is a special state that is assumed to be set before a node is loaded from
   168  // the database and after it is shut down.
   169  const offlineState = bitMask(1)
   170  
   171  // NewFlag creates a new node state flag
   172  func (s *Setup) NewFlag(name string) Flags {
   173  	if s.flags == nil {
   174  		s.flags = []flagDefinition{{name: "offline"}}
   175  	}
   176  	f := Flags{mask: bitMask(1) << uint(len(s.flags)), setup: s}
   177  	s.flags = append(s.flags, flagDefinition{name: name})
   178  	return f
   179  }
   180  
   181  // NewPersistentFlag creates a new persistent node state flag
   182  func (s *Setup) NewPersistentFlag(name string) Flags {
   183  	if s.flags == nil {
   184  		s.flags = []flagDefinition{{name: "offline"}}
   185  	}
   186  	f := Flags{mask: bitMask(1) << uint(len(s.flags)), setup: s}
   187  	s.flags = append(s.flags, flagDefinition{name: name, persistent: true})
   188  	return f
   189  }
   190  
   191  // OfflineFlag returns the system-defined offline flag belonging to the given setup
   192  func (s *Setup) OfflineFlag() Flags {
   193  	return Flags{mask: offlineState, setup: s}
   194  }
   195  
   196  // NewField creates a new node state field
   197  func (s *Setup) NewField(name string, ftype reflect.Type) Field {
   198  	f := Field{index: len(s.fields), setup: s}
   199  	s.fields = append(s.fields, fieldDefinition{
   200  		name:  name,
   201  		ftype: ftype,
   202  	})
   203  	return f
   204  }
   205  
   206  // NewPersistentField creates a new persistent node field
   207  func (s *Setup) NewPersistentField(name string, ftype reflect.Type, encode func(interface{}) ([]byte, error), decode func([]byte) (interface{}, error)) Field {
   208  	f := Field{index: len(s.fields), setup: s}
   209  	s.fields = append(s.fields, fieldDefinition{
   210  		name:   name,
   211  		ftype:  ftype,
   212  		encode: encode,
   213  		decode: decode,
   214  	})
   215  	return f
   216  }
   217  
   218  // flagOp implements binary flag operations and also checks whether the operands belong to the same setup
   219  func flagOp(a, b Flags, trueIfA, trueIfB, trueIfBoth bool) Flags {
   220  	if a.setup == nil {
   221  		if a.mask != 0 {
   222  			panic("Node state flags have no setup reference")
   223  		}
   224  		a.setup = b.setup
   225  	}
   226  	if b.setup == nil {
   227  		if b.mask != 0 {
   228  			panic("Node state flags have no setup reference")
   229  		}
   230  		b.setup = a.setup
   231  	}
   232  	if a.setup != b.setup {
   233  		panic("Node state flags belong to a different setup")
   234  	}
   235  	res := Flags{setup: a.setup}
   236  	if trueIfA {
   237  		res.mask |= a.mask & ^b.mask
   238  	}
   239  	if trueIfB {
   240  		res.mask |= b.mask & ^a.mask
   241  	}
   242  	if trueIfBoth {
   243  		res.mask |= a.mask & b.mask
   244  	}
   245  	return res
   246  }
   247  
   248  // And returns the set of flags present in both a and b
   249  func (a Flags) And(b Flags) Flags { return flagOp(a, b, false, false, true) }
   250  
   251  // AndNot returns the set of flags present in a but not in b
   252  func (a Flags) AndNot(b Flags) Flags { return flagOp(a, b, true, false, false) }
   253  
   254  // Or returns the set of flags present in either a or b
   255  func (a Flags) Or(b Flags) Flags { return flagOp(a, b, true, true, true) }
   256  
   257  // Xor returns the set of flags present in either a or b but not both
   258  func (a Flags) Xor(b Flags) Flags { return flagOp(a, b, true, true, false) }
   259  
   260  // HasAll returns true if b is a subset of a
   261  func (a Flags) HasAll(b Flags) bool { return flagOp(a, b, false, true, false).mask == 0 }
   262  
   263  // HasNone returns true if a and b have no shared flags
   264  func (a Flags) HasNone(b Flags) bool { return flagOp(a, b, false, false, true).mask == 0 }
   265  
   266  // Equals returns true if a and b have the same flags set
   267  func (a Flags) Equals(b Flags) bool { return flagOp(a, b, true, true, false).mask == 0 }
   268  
   269  // IsEmpty returns true if a has no flags set
   270  func (a Flags) IsEmpty() bool { return a.mask == 0 }
   271  
   272  // MergeFlags merges multiple sets of state flags
   273  func MergeFlags(list ...Flags) Flags {
   274  	if len(list) == 0 {
   275  		return Flags{}
   276  	}
   277  	res := list[0]
   278  	for i := 1; i < len(list); i++ {
   279  		res = res.Or(list[i])
   280  	}
   281  	return res
   282  }
   283  
   284  // String returns a list of the names of the flags specified in the bit mask
   285  func (f Flags) String() string {
   286  	if f.mask == 0 {
   287  		return "[]"
   288  	}
   289  	s := "["
   290  	comma := false
   291  	for index, flag := range f.setup.flags {
   292  		if f.mask&(bitMask(1)<<uint(index)) != 0 {
   293  			if comma {
   294  				s = s + ", "
   295  			}
   296  			s = s + flag.name
   297  			comma = true
   298  		}
   299  	}
   300  	s = s + "]"
   301  	return s
   302  }
   303  
   304  // NewNodeStateMachine creates a new node state machine.
   305  // If db is not nil then the node states, fields and active timeouts are persisted.
   306  // Persistence can be enabled or disabled for each state flag and field.
   307  func NewNodeStateMachine(db ethdb.KeyValueStore, dbKey []byte, clock mclock.Clock, setup *Setup) *NodeStateMachine {
   308  	if setup.flags == nil {
   309  		panic("No state flags defined")
   310  	}
   311  	if len(setup.flags) > 8*int(unsafe.Sizeof(bitMask(0))) {
   312  		panic("Too many node state flags")
   313  	}
   314  	ns := &NodeStateMachine{
   315  		db:        db,
   316  		dbNodeKey: dbKey,
   317  		clock:     clock,
   318  		setup:     setup,
   319  		nodes:     make(map[enode.ID]*nodeInfo),
   320  		fields:    make([]*fieldInfo, len(setup.fields)),
   321  	}
   322  	stateNameMap := make(map[string]int)
   323  	for index, flag := range setup.flags {
   324  		if _, ok := stateNameMap[flag.name]; ok {
   325  			panic("Node state flag name collision")
   326  		}
   327  		stateNameMap[flag.name] = index
   328  		if flag.persistent {
   329  			ns.saveFlags |= bitMask(1) << uint(index)
   330  		}
   331  	}
   332  	fieldNameMap := make(map[string]int)
   333  	for index, field := range setup.fields {
   334  		if _, ok := fieldNameMap[field.name]; ok {
   335  			panic("Node field name collision")
   336  		}
   337  		ns.fields[index] = &fieldInfo{fieldDefinition: field}
   338  		fieldNameMap[field.name] = index
   339  	}
   340  	return ns
   341  }
   342  
   343  // stateMask checks whether the set of flags belongs to the same setup and returns its internal bit mask
   344  func (ns *NodeStateMachine) stateMask(flags Flags) bitMask {
   345  	if flags.setup != ns.setup && flags.mask != 0 {
   346  		panic("Node state flags belong to a different setup")
   347  	}
   348  	return flags.mask
   349  }
   350  
   351  // fieldIndex checks whether the field belongs to the same setup and returns its internal index
   352  func (ns *NodeStateMachine) fieldIndex(field Field) int {
   353  	if field.setup != ns.setup {
   354  		panic("Node field belongs to a different setup")
   355  	}
   356  	return field.index
   357  }
   358  
   359  // SubscribeState adds a node state subscription. The callback is called while the state
   360  // machine mutex is not held and it is allowed to make further state updates. All immediate
   361  // changes throughout the system are processed in the same thread/goroutine. It is the
   362  // responsibility of the implemented state logic to avoid deadlocks caused by the callbacks,
   363  // infinite toggling of flags or hazardous/non-deterministic state changes.
   364  // State subscriptions should be installed before loading the node database or making the
   365  // first state update.
   366  func (ns *NodeStateMachine) SubscribeState(flags Flags, callback StateCallback) {
   367  	ns.lock.Lock()
   368  	defer ns.lock.Unlock()
   369  
   370  	if ns.started {
   371  		panic("state machine already started")
   372  	}
   373  	ns.stateSubs = append(ns.stateSubs, stateSub{ns.stateMask(flags), callback})
   374  }
   375  
   376  // SubscribeField adds a node field subscription. Same rules apply as for SubscribeState.
   377  func (ns *NodeStateMachine) SubscribeField(field Field, callback FieldCallback) {
   378  	ns.lock.Lock()
   379  	defer ns.lock.Unlock()
   380  
   381  	if ns.started {
   382  		panic("state machine already started")
   383  	}
   384  	f := ns.fields[ns.fieldIndex(field)]
   385  	f.subs = append(f.subs, callback)
   386  }
   387  
   388  // newNode creates a new nodeInfo
   389  func (ns *NodeStateMachine) newNode(n *enode.Node) *nodeInfo {
   390  	return &nodeInfo{node: n, fields: make([]interface{}, len(ns.fields))}
   391  }
   392  
   393  // checkStarted checks whether the state machine has already been started and panics otherwise.
   394  func (ns *NodeStateMachine) checkStarted() {
   395  	if !ns.started {
   396  		panic("state machine not started yet")
   397  	}
   398  }
   399  
   400  // Start starts the state machine, enabling state and field operations and disabling
   401  // further subscriptions.
   402  func (ns *NodeStateMachine) Start() {
   403  	ns.lock.Lock()
   404  	if ns.started {
   405  		panic("state machine already started")
   406  	}
   407  	ns.started = true
   408  	if ns.db != nil {
   409  		ns.loadFromDb()
   410  	}
   411  	ns.lock.Unlock()
   412  	ns.offlineCallbacks(true)
   413  }
   414  
   415  // Stop stops the state machine and saves its state if a database was supplied
   416  func (ns *NodeStateMachine) Stop() {
   417  	ns.lock.Lock()
   418  	for _, node := range ns.nodes {
   419  		fields := make([]interface{}, len(node.fields))
   420  		copy(fields, node.fields)
   421  		ns.offlineCallbackList = append(ns.offlineCallbackList, offlineCallback{node.node, node.state, fields})
   422  	}
   423  	ns.stopped = true
   424  	if ns.db != nil {
   425  		ns.saveToDb()
   426  		ns.lock.Unlock()
   427  	} else {
   428  		ns.lock.Unlock()
   429  	}
   430  	ns.offlineCallbacks(false)
   431  }
   432  
   433  // loadFromDb loads persisted node states from the database
   434  func (ns *NodeStateMachine) loadFromDb() {
   435  	it := ns.db.NewIterator(ns.dbNodeKey, nil)
   436  	for it.Next() {
   437  		var id enode.ID
   438  		if len(it.Key()) != len(ns.dbNodeKey)+len(id) {
   439  			log.Error("Node state db entry with invalid length", "found", len(it.Key()), "expected", len(ns.dbNodeKey)+len(id))
   440  			continue
   441  		}
   442  		copy(id[:], it.Key()[len(ns.dbNodeKey):])
   443  		ns.decodeNode(id, it.Value())
   444  	}
   445  }
   446  
   447  type dummyIdentity enode.ID
   448  
   449  func (id dummyIdentity) Verify(r *enr.Record, sig []byte) error { return nil }
   450  func (id dummyIdentity) NodeAddr(r *enr.Record) []byte          { return id[:] }
   451  
   452  // decodeNode decodes a node database entry and adds it to the node set if successful
   453  func (ns *NodeStateMachine) decodeNode(id enode.ID, data []byte) {
   454  	var enc nodeInfoEnc
   455  	if err := rlp.DecodeBytes(data, &enc); err != nil {
   456  		log.Error("Failed to decode node info", "id", id, "error", err)
   457  		return
   458  	}
   459  	n, _ := enode.New(dummyIdentity(id), &enc.Enr)
   460  	node := ns.newNode(n)
   461  	node.db = true
   462  
   463  	if enc.Version != ns.setup.Version {
   464  		log.Debug("Removing stored node with unknown version", "current", ns.setup.Version, "stored", enc.Version)
   465  		ns.deleteNode(id)
   466  		return
   467  	}
   468  	if len(enc.Fields) > len(ns.setup.fields) {
   469  		log.Error("Invalid node field count", "id", id, "stored", len(enc.Fields))
   470  		return
   471  	}
   472  	// Resolve persisted node fields
   473  	for i, encField := range enc.Fields {
   474  		if len(encField) == 0 {
   475  			continue
   476  		}
   477  		if decode := ns.fields[i].decode; decode != nil {
   478  			if field, err := decode(encField); err == nil {
   479  				node.fields[i] = field
   480  			} else {
   481  				log.Error("Failed to decode node field", "id", id, "field name", ns.fields[i].name, "error", err)
   482  				return
   483  			}
   484  		} else {
   485  			log.Error("Cannot decode node field", "id", id, "field name", ns.fields[i].name)
   486  			return
   487  		}
   488  	}
   489  	// It's a compatible node record, add it to set.
   490  	ns.nodes[id] = node
   491  	node.state = enc.State
   492  	fields := make([]interface{}, len(node.fields))
   493  	copy(fields, node.fields)
   494  	ns.offlineCallbackList = append(ns.offlineCallbackList, offlineCallback{node.node, node.state, fields})
   495  	log.Debug("Loaded node state", "id", id, "state", Flags{mask: enc.State, setup: ns.setup})
   496  }
   497  
   498  // saveNode saves the given node info to the database
   499  func (ns *NodeStateMachine) saveNode(id enode.ID, node *nodeInfo) error {
   500  	if ns.db == nil {
   501  		return nil
   502  	}
   503  
   504  	storedState := node.state & ns.saveFlags
   505  	for _, t := range node.timeouts {
   506  		storedState &= ^t.mask
   507  	}
   508  	if storedState == 0 {
   509  		if node.db {
   510  			node.db = false
   511  			ns.deleteNode(id)
   512  		}
   513  		node.dirty = false
   514  		return nil
   515  	}
   516  
   517  	enc := nodeInfoEnc{
   518  		Enr:     *node.node.Record(),
   519  		Version: ns.setup.Version,
   520  		State:   storedState,
   521  		Fields:  make([][]byte, len(ns.fields)),
   522  	}
   523  	log.Debug("Saved node state", "id", id, "state", Flags{mask: enc.State, setup: ns.setup})
   524  	lastIndex := -1
   525  	for i, f := range node.fields {
   526  		if f == nil {
   527  			continue
   528  		}
   529  		encode := ns.fields[i].encode
   530  		if encode == nil {
   531  			continue
   532  		}
   533  		blob, err := encode(f)
   534  		if err != nil {
   535  			return err
   536  		}
   537  		enc.Fields[i] = blob
   538  		lastIndex = i
   539  	}
   540  	enc.Fields = enc.Fields[:lastIndex+1]
   541  	data, err := rlp.EncodeToBytes(&enc)
   542  	if err != nil {
   543  		return err
   544  	}
   545  	if err := ns.db.Put(append(ns.dbNodeKey, id[:]...), data); err != nil {
   546  		return err
   547  	}
   548  	node.dirty, node.db = false, true
   549  
   550  	if ns.saveNodeHook != nil {
   551  		ns.saveNodeHook(node)
   552  	}
   553  	return nil
   554  }
   555  
   556  // deleteNode removes a node info from the database
   557  func (ns *NodeStateMachine) deleteNode(id enode.ID) {
   558  	ns.db.Delete(append(ns.dbNodeKey, id[:]...))
   559  }
   560  
   561  // saveToDb saves the persistent flags and fields of all nodes that have been changed
   562  func (ns *NodeStateMachine) saveToDb() {
   563  	for id, node := range ns.nodes {
   564  		if node.dirty {
   565  			err := ns.saveNode(id, node)
   566  			if err != nil {
   567  				log.Error("Failed to save node", "id", id, "error", err)
   568  			}
   569  		}
   570  	}
   571  }
   572  
   573  // updateEnode updates the enode entry belonging to the given node if it already exists
   574  func (ns *NodeStateMachine) updateEnode(n *enode.Node) (enode.ID, *nodeInfo) {
   575  	id := n.ID()
   576  	node := ns.nodes[id]
   577  	if node != nil && n.Seq() > node.node.Seq() {
   578  		node.node = n
   579  	}
   580  	return id, node
   581  }
   582  
   583  // Persist saves the persistent state and fields of the given node immediately
   584  func (ns *NodeStateMachine) Persist(n *enode.Node) error {
   585  	ns.lock.Lock()
   586  	defer ns.lock.Unlock()
   587  
   588  	ns.checkStarted()
   589  	if id, node := ns.updateEnode(n); node != nil && node.dirty {
   590  		err := ns.saveNode(id, node)
   591  		if err != nil {
   592  			log.Error("Failed to save node", "id", id, "error", err)
   593  		}
   594  		return err
   595  	}
   596  	return nil
   597  }
   598  
   599  // SetState updates the given node state flags and processes all resulting callbacks.
   600  // It only returns after all subsequent immediate changes (including those changed by the
   601  // callbacks) have been processed. If a flag with a timeout is set again, the operation
   602  // removes or replaces the existing timeout.
   603  func (ns *NodeStateMachine) SetState(n *enode.Node, setFlags, resetFlags Flags, timeout time.Duration) {
   604  	ns.lock.Lock()
   605  	ns.checkStarted()
   606  	if ns.stopped {
   607  		ns.lock.Unlock()
   608  		return
   609  	}
   610  
   611  	set, reset := ns.stateMask(setFlags), ns.stateMask(resetFlags)
   612  	id, node := ns.updateEnode(n)
   613  	if node == nil {
   614  		if set == 0 {
   615  			ns.lock.Unlock()
   616  			return
   617  		}
   618  		node = ns.newNode(n)
   619  		ns.nodes[id] = node
   620  	}
   621  	oldState := node.state
   622  	newState := (node.state & (^reset)) | set
   623  	changed := oldState ^ newState
   624  	node.state = newState
   625  
   626  	// Remove the timeout callbacks for all reset and set flags,
   627  	// even they are not existent(it's noop).
   628  	ns.removeTimeouts(node, set|reset)
   629  
   630  	// Register the timeout callback if the new state is not empty
   631  	// and timeout itself is required.
   632  	if timeout != 0 && newState != 0 {
   633  		ns.addTimeout(n, set, timeout)
   634  	}
   635  	if newState == oldState {
   636  		ns.lock.Unlock()
   637  		return
   638  	}
   639  	if newState == 0 {
   640  		delete(ns.nodes, id)
   641  		if node.db {
   642  			ns.deleteNode(id)
   643  		}
   644  	} else {
   645  		if changed&ns.saveFlags != 0 {
   646  			node.dirty = true
   647  		}
   648  	}
   649  	ns.lock.Unlock()
   650  	// call state update subscription callbacks without holding the mutex
   651  	for _, sub := range ns.stateSubs {
   652  		if changed&sub.mask != 0 {
   653  			sub.callback(n, Flags{mask: oldState & sub.mask, setup: ns.setup}, Flags{mask: newState & sub.mask, setup: ns.setup})
   654  		}
   655  	}
   656  	if newState == 0 {
   657  		// call field subscriptions for discarded fields
   658  		for i, v := range node.fields {
   659  			if v != nil {
   660  				f := ns.fields[i]
   661  				if len(f.subs) > 0 {
   662  					for _, cb := range f.subs {
   663  						cb(n, Flags{setup: ns.setup}, v, nil)
   664  					}
   665  				}
   666  			}
   667  		}
   668  	}
   669  }
   670  
   671  // offlineCallbacks calls state update callbacks at startup or shutdown
   672  func (ns *NodeStateMachine) offlineCallbacks(start bool) {
   673  	for _, cb := range ns.offlineCallbackList {
   674  		for _, sub := range ns.stateSubs {
   675  			offState := offlineState & sub.mask
   676  			onState := cb.state & sub.mask
   677  			if offState != onState {
   678  				if start {
   679  					sub.callback(cb.node, Flags{mask: offState, setup: ns.setup}, Flags{mask: onState, setup: ns.setup})
   680  				} else {
   681  					sub.callback(cb.node, Flags{mask: onState, setup: ns.setup}, Flags{mask: offState, setup: ns.setup})
   682  				}
   683  			}
   684  		}
   685  		for i, f := range cb.fields {
   686  			if f != nil && ns.fields[i].subs != nil {
   687  				for _, fsub := range ns.fields[i].subs {
   688  					if start {
   689  						fsub(cb.node, Flags{mask: offlineState, setup: ns.setup}, nil, f)
   690  					} else {
   691  						fsub(cb.node, Flags{mask: offlineState, setup: ns.setup}, f, nil)
   692  					}
   693  				}
   694  			}
   695  		}
   696  	}
   697  	ns.offlineCallbackList = nil
   698  }
   699  
   700  // AddTimeout adds a node state timeout associated to the given state flag(s).
   701  // After the specified time interval, the relevant states will be reset.
   702  func (ns *NodeStateMachine) AddTimeout(n *enode.Node, flags Flags, timeout time.Duration) {
   703  	ns.lock.Lock()
   704  	defer ns.lock.Unlock()
   705  
   706  	ns.checkStarted()
   707  	if ns.stopped {
   708  		return
   709  	}
   710  	ns.addTimeout(n, ns.stateMask(flags), timeout)
   711  }
   712  
   713  // addTimeout adds a node state timeout associated to the given state flag(s).
   714  func (ns *NodeStateMachine) addTimeout(n *enode.Node, mask bitMask, timeout time.Duration) {
   715  	_, node := ns.updateEnode(n)
   716  	if node == nil {
   717  		return
   718  	}
   719  	mask &= node.state
   720  	if mask == 0 {
   721  		return
   722  	}
   723  	ns.removeTimeouts(node, mask)
   724  	t := &nodeStateTimeout{mask: mask}
   725  	t.timer = ns.clock.AfterFunc(timeout, func() {
   726  		ns.SetState(n, Flags{}, Flags{mask: t.mask, setup: ns.setup}, 0)
   727  	})
   728  	node.timeouts = append(node.timeouts, t)
   729  	if mask&ns.saveFlags != 0 {
   730  		node.dirty = true
   731  	}
   732  }
   733  
   734  // removeTimeout removes node state timeouts associated to the given state flag(s).
   735  // If a timeout was associated to multiple flags which are not all included in the
   736  // specified remove mask then only the included flags are de-associated and the timer
   737  // stays active.
   738  func (ns *NodeStateMachine) removeTimeouts(node *nodeInfo, mask bitMask) {
   739  	for i := 0; i < len(node.timeouts); i++ {
   740  		t := node.timeouts[i]
   741  		match := t.mask & mask
   742  		if match == 0 {
   743  			continue
   744  		}
   745  		t.mask -= match
   746  		if t.mask != 0 {
   747  			continue
   748  		}
   749  		t.timer.Stop()
   750  		node.timeouts[i] = node.timeouts[len(node.timeouts)-1]
   751  		node.timeouts = node.timeouts[:len(node.timeouts)-1]
   752  		i--
   753  		if match&ns.saveFlags != 0 {
   754  			node.dirty = true
   755  		}
   756  	}
   757  }
   758  
   759  // GetField retrieves the given field of the given node
   760  func (ns *NodeStateMachine) GetField(n *enode.Node, field Field) interface{} {
   761  	ns.lock.Lock()
   762  	defer ns.lock.Unlock()
   763  
   764  	ns.checkStarted()
   765  	if ns.stopped {
   766  		return nil
   767  	}
   768  	if _, node := ns.updateEnode(n); node != nil {
   769  		return node.fields[ns.fieldIndex(field)]
   770  	}
   771  	return nil
   772  }
   773  
   774  // SetField sets the given field of the given node
   775  func (ns *NodeStateMachine) SetField(n *enode.Node, field Field, value interface{}) error {
   776  	ns.lock.Lock()
   777  	ns.checkStarted()
   778  	if ns.stopped {
   779  		ns.lock.Unlock()
   780  		return nil
   781  	}
   782  	_, node := ns.updateEnode(n)
   783  	if node == nil {
   784  		ns.lock.Unlock()
   785  		return nil
   786  	}
   787  	fieldIndex := ns.fieldIndex(field)
   788  	f := ns.fields[fieldIndex]
   789  	if value != nil && reflect.TypeOf(value) != f.ftype {
   790  		log.Error("Invalid field type", "type", reflect.TypeOf(value), "required", f.ftype)
   791  		ns.lock.Unlock()
   792  		return errors.New("invalid field type")
   793  	}
   794  	oldValue := node.fields[fieldIndex]
   795  	if value == oldValue {
   796  		ns.lock.Unlock()
   797  		return nil
   798  	}
   799  	node.fields[fieldIndex] = value
   800  	if f.encode != nil {
   801  		node.dirty = true
   802  	}
   803  
   804  	state := node.state
   805  	ns.lock.Unlock()
   806  	if len(f.subs) > 0 {
   807  		for _, cb := range f.subs {
   808  			cb(n, Flags{mask: state, setup: ns.setup}, oldValue, value)
   809  		}
   810  	}
   811  	return nil
   812  }
   813  
   814  // ForEach calls the callback for each node having all of the required and none of the
   815  // disabled flags set
   816  func (ns *NodeStateMachine) ForEach(requireFlags, disableFlags Flags, cb func(n *enode.Node, state Flags)) {
   817  	ns.lock.Lock()
   818  	ns.checkStarted()
   819  	type callback struct {
   820  		node  *enode.Node
   821  		state bitMask
   822  	}
   823  	require, disable := ns.stateMask(requireFlags), ns.stateMask(disableFlags)
   824  	var callbacks []callback
   825  	for _, node := range ns.nodes {
   826  		if node.state&require == require && node.state&disable == 0 {
   827  			callbacks = append(callbacks, callback{node.node, node.state & (require | disable)})
   828  		}
   829  	}
   830  	ns.lock.Unlock()
   831  	for _, c := range callbacks {
   832  		cb(c.node, Flags{mask: c.state, setup: ns.setup})
   833  	}
   834  }
   835  
   836  // GetNode returns the enode currently associated with the given ID
   837  func (ns *NodeStateMachine) GetNode(id enode.ID) *enode.Node {
   838  	ns.lock.Lock()
   839  	defer ns.lock.Unlock()
   840  
   841  	ns.checkStarted()
   842  	if node := ns.nodes[id]; node != nil {
   843  		return node.node
   844  	}
   845  	return nil
   846  }
   847  
   848  // AddLogMetrics adds logging and/or metrics for nodes entering, exiting and currently
   849  // being in a given set specified by required and disabled state flags
   850  func (ns *NodeStateMachine) AddLogMetrics(requireFlags, disableFlags Flags, name string, inMeter, outMeter metrics.Meter, gauge metrics.Gauge) {
   851  	var count int64
   852  	ns.SubscribeState(requireFlags.Or(disableFlags), func(n *enode.Node, oldState, newState Flags) {
   853  		oldMatch := oldState.HasAll(requireFlags) && oldState.HasNone(disableFlags)
   854  		newMatch := newState.HasAll(requireFlags) && newState.HasNone(disableFlags)
   855  		if newMatch == oldMatch {
   856  			return
   857  		}
   858  
   859  		if newMatch {
   860  			count++
   861  			if name != "" {
   862  				log.Debug("Node entered", "set", name, "id", n.ID(), "count", count)
   863  			}
   864  			if inMeter != nil {
   865  				inMeter.Mark(1)
   866  			}
   867  		} else {
   868  			count--
   869  			if name != "" {
   870  				log.Debug("Node left", "set", name, "id", n.ID(), "count", count)
   871  			}
   872  			if outMeter != nil {
   873  				outMeter.Mark(1)
   874  			}
   875  		}
   876  		if gauge != nil {
   877  			gauge.Update(count)
   878  		}
   879  	})
   880  }