github.com/kaisenlinux/docker.io@v0.0.0-20230510090727-ea55db55fac7/swarmkit/manager/dispatcher/dispatcher.go (about)

     1  package dispatcher
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net"
     7  	"strconv"
     8  	"sync"
     9  	"time"
    10  
    11  	"github.com/docker/go-events"
    12  	"github.com/docker/go-metrics"
    13  	"github.com/docker/swarmkit/api"
    14  	"github.com/docker/swarmkit/api/equality"
    15  	"github.com/docker/swarmkit/ca"
    16  	"github.com/docker/swarmkit/log"
    17  	"github.com/docker/swarmkit/manager/drivers"
    18  	"github.com/docker/swarmkit/manager/state/store"
    19  	"github.com/docker/swarmkit/protobuf/ptypes"
    20  	"github.com/docker/swarmkit/remotes"
    21  	"github.com/docker/swarmkit/watch"
    22  	gogotypes "github.com/gogo/protobuf/types"
    23  	"github.com/pkg/errors"
    24  	"github.com/sirupsen/logrus"
    25  	"google.golang.org/grpc/codes"
    26  	"google.golang.org/grpc/status"
    27  )
    28  
    29  const (
    30  	// DefaultHeartBeatPeriod is used for setting default value in cluster config
    31  	// and in case if cluster config is missing.
    32  	DefaultHeartBeatPeriod       = 5 * time.Second
    33  	defaultHeartBeatEpsilon      = 500 * time.Millisecond
    34  	defaultGracePeriodMultiplier = 3
    35  	defaultRateLimitPeriod       = 8 * time.Second
    36  
    37  	// maxBatchItems is the threshold of queued writes that should
    38  	// trigger an actual transaction to commit them to the shared store.
    39  	maxBatchItems = 10000
    40  
    41  	// maxBatchInterval needs to strike a balance between keeping
    42  	// latency low, and realizing opportunities to combine many writes
    43  	// into a single transaction. A fraction of a second feels about
    44  	// right.
    45  	maxBatchInterval = 100 * time.Millisecond
    46  
    47  	modificationBatchLimit = 100
    48  	batchingWaitTime       = 100 * time.Millisecond
    49  
    50  	// defaultNodeDownPeriod specifies the default time period we
    51  	// wait before moving tasks assigned to down nodes to ORPHANED
    52  	// state.
    53  	defaultNodeDownPeriod = 24 * time.Hour
    54  )
    55  
    56  var (
    57  	// ErrNodeAlreadyRegistered returned if node with same ID was already
    58  	// registered with this dispatcher.
    59  	ErrNodeAlreadyRegistered = errors.New("node already registered")
    60  	// ErrNodeNotRegistered returned if node with such ID wasn't registered
    61  	// with this dispatcher.
    62  	ErrNodeNotRegistered = errors.New("node not registered")
    63  	// ErrSessionInvalid returned when the session in use is no longer valid.
    64  	// The node should re-register and start a new session.
    65  	ErrSessionInvalid = errors.New("session invalid")
    66  	// ErrNodeNotFound returned when the Node doesn't exist in raft.
    67  	ErrNodeNotFound = errors.New("node not found")
    68  
    69  	// Scheduling delay timer.
    70  	schedulingDelayTimer metrics.Timer
    71  )
    72  
    73  func init() {
    74  	ns := metrics.NewNamespace("swarm", "dispatcher", nil)
    75  	schedulingDelayTimer = ns.NewTimer("scheduling_delay",
    76  		"Scheduling delay is the time a task takes to go from NEW to RUNNING state.")
    77  	metrics.Register(ns)
    78  }
    79  
    80  // Config is configuration for Dispatcher. For default you should use
    81  // DefaultConfig.
    82  type Config struct {
    83  	HeartbeatPeriod  time.Duration
    84  	HeartbeatEpsilon time.Duration
    85  	// RateLimitPeriod specifies how often node with same ID can try to register
    86  	// new session.
    87  	RateLimitPeriod       time.Duration
    88  	GracePeriodMultiplier int
    89  }
    90  
    91  // DefaultConfig returns default config for Dispatcher.
    92  func DefaultConfig() *Config {
    93  	return &Config{
    94  		HeartbeatPeriod:       DefaultHeartBeatPeriod,
    95  		HeartbeatEpsilon:      defaultHeartBeatEpsilon,
    96  		RateLimitPeriod:       defaultRateLimitPeriod,
    97  		GracePeriodMultiplier: defaultGracePeriodMultiplier,
    98  	}
    99  }
   100  
   101  // Cluster is interface which represent raft cluster. manager/state/raft.Node
   102  // is implements it. This interface needed only for easier unit-testing.
   103  type Cluster interface {
   104  	GetMemberlist() map[uint64]*api.RaftMember
   105  	SubscribePeers() (chan events.Event, func())
   106  	MemoryStore() *store.MemoryStore
   107  }
   108  
   109  // nodeUpdate provides a new status and/or description to apply to a node
   110  // object.
   111  type nodeUpdate struct {
   112  	status      *api.NodeStatus
   113  	description *api.NodeDescription
   114  }
   115  
   116  // clusterUpdate is an object that stores an update to the cluster that should trigger
   117  // a new session message.  These are pointers to indicate the difference between
   118  // "there is no update" and "update this to nil"
   119  type clusterUpdate struct {
   120  	managerUpdate      *[]*api.WeightedPeer
   121  	bootstrapKeyUpdate *[]*api.EncryptionKey
   122  	rootCAUpdate       *[]byte
   123  }
   124  
   125  // Dispatcher is responsible for dispatching tasks and tracking agent health.
   126  type Dispatcher struct {
   127  	// Mutex to synchronize access to dispatcher shared state e.g. nodes,
   128  	// lastSeenManagers, networkBootstrapKeys etc.
   129  	// TODO(anshul): This can potentially be removed and rpcRW used in its place.
   130  	mu sync.Mutex
   131  	// WaitGroup to handle the case when Stop() gets called before Run()
   132  	// has finished initializing the dispatcher.
   133  	wg sync.WaitGroup
   134  	// This RWMutex synchronizes RPC handlers and the dispatcher stop().
   135  	// The RPC handlers use the read lock while stop() uses the write lock
   136  	// and acts as a barrier to shutdown.
   137  	rpcRW                sync.RWMutex
   138  	nodes                *nodeStore
   139  	store                *store.MemoryStore
   140  	lastSeenManagers     []*api.WeightedPeer
   141  	networkBootstrapKeys []*api.EncryptionKey
   142  	lastSeenRootCert     []byte
   143  	config               *Config
   144  	cluster              Cluster
   145  	ctx                  context.Context
   146  	cancel               context.CancelFunc
   147  	clusterUpdateQueue   *watch.Queue
   148  	dp                   *drivers.DriverProvider
   149  	securityConfig       *ca.SecurityConfig
   150  
   151  	taskUpdates     map[string]*api.TaskStatus // indexed by task ID
   152  	taskUpdatesLock sync.Mutex
   153  
   154  	nodeUpdates     map[string]nodeUpdate // indexed by node ID
   155  	nodeUpdatesLock sync.Mutex
   156  
   157  	downNodes *nodeStore
   158  
   159  	processUpdatesTrigger chan struct{}
   160  
   161  	// for waiting for the next task/node batch update
   162  	processUpdatesLock sync.Mutex
   163  	processUpdatesCond *sync.Cond
   164  }
   165  
   166  // New returns Dispatcher with cluster interface(usually raft.Node).
   167  func New() *Dispatcher {
   168  	d := &Dispatcher{
   169  		downNodes:             newNodeStore(defaultNodeDownPeriod, 0, 1, 0),
   170  		processUpdatesTrigger: make(chan struct{}, 1),
   171  	}
   172  
   173  	d.processUpdatesCond = sync.NewCond(&d.processUpdatesLock)
   174  
   175  	return d
   176  }
   177  
   178  // Init is used to initialize the dispatcher and
   179  // is typically called before starting the dispatcher
   180  // when a manager becomes a leader.
   181  // The dispatcher is a grpc server, and unlike other components,
   182  // it can't simply be recreated on becoming a leader.
   183  // This function ensures the dispatcher restarts with a clean slate.
   184  func (d *Dispatcher) Init(cluster Cluster, c *Config, dp *drivers.DriverProvider, securityConfig *ca.SecurityConfig) {
   185  	d.cluster = cluster
   186  	d.config = c
   187  	d.securityConfig = securityConfig
   188  	d.dp = dp
   189  	d.store = cluster.MemoryStore()
   190  	d.nodes = newNodeStore(c.HeartbeatPeriod, c.HeartbeatEpsilon, c.GracePeriodMultiplier, c.RateLimitPeriod)
   191  }
   192  
   193  func getWeightedPeers(cluster Cluster) []*api.WeightedPeer {
   194  	members := cluster.GetMemberlist()
   195  	var mgrs []*api.WeightedPeer
   196  	for _, m := range members {
   197  		mgrs = append(mgrs, &api.WeightedPeer{
   198  			Peer: &api.Peer{
   199  				NodeID: m.NodeID,
   200  				Addr:   m.Addr,
   201  			},
   202  
   203  			// TODO(stevvooe): Calculate weight of manager selection based on
   204  			// cluster-level observations, such as number of connections and
   205  			// load.
   206  			Weight: remotes.DefaultObservationWeight,
   207  		})
   208  	}
   209  	return mgrs
   210  }
   211  
   212  // Run runs dispatcher tasks which should be run on leader dispatcher.
   213  // Dispatcher can be stopped with cancelling ctx or calling Stop().
   214  func (d *Dispatcher) Run(ctx context.Context) error {
   215  	ctx = log.WithModule(ctx, "dispatcher")
   216  	log.G(ctx).Info("dispatcher starting")
   217  
   218  	d.taskUpdatesLock.Lock()
   219  	d.taskUpdates = make(map[string]*api.TaskStatus)
   220  	d.taskUpdatesLock.Unlock()
   221  
   222  	d.nodeUpdatesLock.Lock()
   223  	d.nodeUpdates = make(map[string]nodeUpdate)
   224  	d.nodeUpdatesLock.Unlock()
   225  
   226  	d.mu.Lock()
   227  	if d.isRunning() {
   228  		d.mu.Unlock()
   229  		return errors.New("dispatcher is already running")
   230  	}
   231  	if err := d.markNodesUnknown(ctx); err != nil {
   232  		log.G(ctx).Errorf(`failed to move all nodes to "unknown" state: %v`, err)
   233  	}
   234  	configWatcher, cancel, err := store.ViewAndWatch(
   235  		d.store,
   236  		func(readTx store.ReadTx) error {
   237  			clusters, err := store.FindClusters(readTx, store.ByName(store.DefaultClusterName))
   238  			if err != nil {
   239  				return err
   240  			}
   241  			if len(clusters) == 1 {
   242  				heartbeatPeriod, err := gogotypes.DurationFromProto(clusters[0].Spec.Dispatcher.HeartbeatPeriod)
   243  				if err == nil && heartbeatPeriod > 0 {
   244  					d.config.HeartbeatPeriod = heartbeatPeriod
   245  				}
   246  				if clusters[0].NetworkBootstrapKeys != nil {
   247  					d.networkBootstrapKeys = clusters[0].NetworkBootstrapKeys
   248  				}
   249  				d.lastSeenRootCert = clusters[0].RootCA.CACert
   250  			}
   251  			return nil
   252  		},
   253  		api.EventUpdateCluster{},
   254  	)
   255  	if err != nil {
   256  		d.mu.Unlock()
   257  		return err
   258  	}
   259  	// set queue here to guarantee that Close will close it
   260  	d.clusterUpdateQueue = watch.NewQueue()
   261  
   262  	peerWatcher, peerCancel := d.cluster.SubscribePeers()
   263  	defer peerCancel()
   264  	d.lastSeenManagers = getWeightedPeers(d.cluster)
   265  
   266  	defer cancel()
   267  	d.ctx, d.cancel = context.WithCancel(ctx)
   268  	ctx = d.ctx
   269  	d.wg.Add(1)
   270  	defer d.wg.Done()
   271  	d.mu.Unlock()
   272  
   273  	publishManagers := func(peers []*api.Peer) {
   274  		var mgrs []*api.WeightedPeer
   275  		for _, p := range peers {
   276  			mgrs = append(mgrs, &api.WeightedPeer{
   277  				Peer:   p,
   278  				Weight: remotes.DefaultObservationWeight,
   279  			})
   280  		}
   281  		d.mu.Lock()
   282  		d.lastSeenManagers = mgrs
   283  		d.mu.Unlock()
   284  		d.clusterUpdateQueue.Publish(clusterUpdate{managerUpdate: &mgrs})
   285  	}
   286  
   287  	batchTimer := time.NewTimer(maxBatchInterval)
   288  	defer batchTimer.Stop()
   289  
   290  	for {
   291  		select {
   292  		case ev := <-peerWatcher:
   293  			publishManagers(ev.([]*api.Peer))
   294  		case <-d.processUpdatesTrigger:
   295  			d.processUpdates(ctx)
   296  			batchTimer.Stop()
   297  			// drain the timer, if it has already expired
   298  			select {
   299  			case <-batchTimer.C:
   300  			default:
   301  			}
   302  			batchTimer.Reset(maxBatchInterval)
   303  		case <-batchTimer.C:
   304  			d.processUpdates(ctx)
   305  			// batch timer has already expired, so no need to drain
   306  			batchTimer.Reset(maxBatchInterval)
   307  		case v := <-configWatcher:
   308  			cluster := v.(api.EventUpdateCluster)
   309  			d.mu.Lock()
   310  			if cluster.Cluster.Spec.Dispatcher.HeartbeatPeriod != nil {
   311  				// ignore error, since Spec has passed validation before
   312  				heartbeatPeriod, _ := gogotypes.DurationFromProto(cluster.Cluster.Spec.Dispatcher.HeartbeatPeriod)
   313  				if heartbeatPeriod != d.config.HeartbeatPeriod {
   314  					// only call d.nodes.updatePeriod when heartbeatPeriod changes
   315  					d.config.HeartbeatPeriod = heartbeatPeriod
   316  					d.nodes.updatePeriod(d.config.HeartbeatPeriod, d.config.HeartbeatEpsilon, d.config.GracePeriodMultiplier)
   317  				}
   318  			}
   319  			d.lastSeenRootCert = cluster.Cluster.RootCA.CACert
   320  			d.networkBootstrapKeys = cluster.Cluster.NetworkBootstrapKeys
   321  			d.mu.Unlock()
   322  			d.clusterUpdateQueue.Publish(clusterUpdate{
   323  				bootstrapKeyUpdate: &cluster.Cluster.NetworkBootstrapKeys,
   324  				rootCAUpdate:       &cluster.Cluster.RootCA.CACert,
   325  			})
   326  		case <-ctx.Done():
   327  			return nil
   328  		}
   329  	}
   330  }
   331  
   332  // Stop stops dispatcher and closes all grpc streams.
   333  func (d *Dispatcher) Stop() error {
   334  	d.mu.Lock()
   335  	if !d.isRunning() {
   336  		d.mu.Unlock()
   337  		return errors.New("dispatcher is already stopped")
   338  	}
   339  
   340  	log := log.G(d.ctx).WithField("method", "(*Dispatcher).Stop")
   341  	log.Info("dispatcher stopping")
   342  	d.cancel()
   343  	d.mu.Unlock()
   344  
   345  	d.processUpdatesLock.Lock()
   346  	// when we called d.cancel(), there may be routines, servicing RPC calls to
   347  	// the (*Dispatcher).Session endpoint, currently waiting at
   348  	// d.processUpdatesCond.Wait() inside of (*Dispatcher).markNodeReady().
   349  	//
   350  	// these routines are typically woken by a call to
   351  	// d.processUpdatesCond.Broadcast() at the end of
   352  	// (*Dispatcher).processUpdates() as part of the main Run loop. However,
   353  	// when d.cancel() is called, the main Run loop is stopped, and there are
   354  	// no more opportunties for processUpdates to be called. Any calls to
   355  	// Session would be stuck waiting on a call to Broadcast that will never
   356  	// come.
   357  	//
   358  	// Further, because the rpcRW write lock cannot be obtained until every RPC
   359  	// has exited and released its read lock, then Stop would be stuck forever.
   360  	//
   361  	// To avoid this case, we acquire the processUpdatesLock (so that no new
   362  	// waits can start) and then do a Broadcast to wake all of the waiting
   363  	// routines. Further, if any routines are waiting in markNodeReady to
   364  	// acquire this lock, but not yet waiting, those routines will check the
   365  	// context cancelation, see the context is canceled, and exit before doing
   366  	// the Wait.
   367  	//
   368  	// This call to Broadcast must occur here. If we called Broadcast before
   369  	// context cancelation, then some new routines could enter the wait. If we
   370  	// call Broadcast after attempting to acquire the rpcRW lock, we will be
   371  	// deadlocked. If we do this Broadcast without obtaining this lock (as is
   372  	// done in the processUpdates method), then it would be possible for that
   373  	// broadcast to come after the context cancelation check in markNodeReady,
   374  	// but before the call to Wait.
   375  	d.processUpdatesCond.Broadcast()
   376  	d.processUpdatesLock.Unlock()
   377  
   378  	// The active nodes list can be cleaned out only when all
   379  	// existing RPCs have finished.
   380  	// RPCs that start after rpcRW.Unlock() should find the context
   381  	// cancelled and should fail organically.
   382  	d.rpcRW.Lock()
   383  	d.nodes.Clean()
   384  	d.downNodes.Clean()
   385  	d.rpcRW.Unlock()
   386  
   387  	d.clusterUpdateQueue.Close()
   388  
   389  	// TODO(anshul): This use of Wait() could be unsafe.
   390  	// According to go's documentation on WaitGroup,
   391  	// Add() with a positive delta that occur when the counter is zero
   392  	// must happen before a Wait().
   393  	// As is, dispatcher Stop() can race with Run().
   394  	d.wg.Wait()
   395  
   396  	return nil
   397  }
   398  
   399  func (d *Dispatcher) isRunningLocked() (context.Context, error) {
   400  	d.mu.Lock()
   401  	if !d.isRunning() {
   402  		d.mu.Unlock()
   403  		return nil, status.Errorf(codes.Aborted, "dispatcher is stopped")
   404  	}
   405  	ctx := d.ctx
   406  	d.mu.Unlock()
   407  	return ctx, nil
   408  }
   409  
   410  func (d *Dispatcher) markNodesUnknown(ctx context.Context) error {
   411  	log := log.G(ctx).WithField("method", "(*Dispatcher).markNodesUnknown")
   412  	var nodes []*api.Node
   413  	var err error
   414  	d.store.View(func(tx store.ReadTx) {
   415  		nodes, err = store.FindNodes(tx, store.All)
   416  	})
   417  	if err != nil {
   418  		return errors.Wrap(err, "failed to get list of nodes")
   419  	}
   420  	err = d.store.Batch(func(batch *store.Batch) error {
   421  		for _, n := range nodes {
   422  			err := batch.Update(func(tx store.Tx) error {
   423  				// check if node is still here
   424  				node := store.GetNode(tx, n.ID)
   425  				if node == nil {
   426  					return nil
   427  				}
   428  				// do not try to resurrect down nodes
   429  				if node.Status.State == api.NodeStatus_DOWN {
   430  					nodeCopy := node
   431  					expireFunc := func() {
   432  						log.Infof("moving tasks to orphaned state for node: %s", nodeCopy.ID)
   433  						if err := d.moveTasksToOrphaned(nodeCopy.ID); err != nil {
   434  							log.WithError(err).Errorf(`failed to move all tasks for node %s to "ORPHANED" state`, node.ID)
   435  						}
   436  
   437  						d.downNodes.Delete(nodeCopy.ID)
   438  					}
   439  
   440  					log.Infof(`node %s was found to be down when marking unknown on dispatcher start`, node.ID)
   441  					d.downNodes.Add(nodeCopy, expireFunc)
   442  					return nil
   443  				}
   444  
   445  				node.Status.State = api.NodeStatus_UNKNOWN
   446  				node.Status.Message = `Node moved to "unknown" state due to leadership change in cluster`
   447  
   448  				nodeID := node.ID
   449  
   450  				expireFunc := func() {
   451  					log := log.WithField("node", nodeID)
   452  					log.Infof(`heartbeat expiration for node %s in state "unknown"`, nodeID)
   453  					if err := d.markNodeNotReady(nodeID, api.NodeStatus_DOWN, `heartbeat failure for node in "unknown" state`); err != nil {
   454  						log.WithError(err).Error(`failed deregistering node after heartbeat expiration for node in "unknown" state`)
   455  					}
   456  				}
   457  				if err := d.nodes.AddUnknown(node, expireFunc); err != nil {
   458  					return errors.Wrapf(err, `adding node %s in "unknown" state to node store failed`, nodeID)
   459  				}
   460  				if err := store.UpdateNode(tx, node); err != nil {
   461  					return errors.Wrapf(err, "update for node %s failed", nodeID)
   462  				}
   463  				return nil
   464  			})
   465  			if err != nil {
   466  				log.WithField("node", n.ID).WithError(err).Error(`failed to move node to "unknown" state`)
   467  			}
   468  		}
   469  		return nil
   470  	})
   471  	return err
   472  }
   473  
   474  func (d *Dispatcher) isRunning() bool {
   475  	if d.ctx == nil {
   476  		return false
   477  	}
   478  	select {
   479  	case <-d.ctx.Done():
   480  		return false
   481  	default:
   482  	}
   483  	return true
   484  }
   485  
   486  // markNodeReady updates the description of a node, updates its address, and sets status to READY
   487  // this is used during registration when a new node description is provided
   488  // and during node updates when the node description changes
   489  func (d *Dispatcher) markNodeReady(ctx context.Context, nodeID string, description *api.NodeDescription, addr string) error {
   490  	d.nodeUpdatesLock.Lock()
   491  	d.nodeUpdates[nodeID] = nodeUpdate{
   492  		status: &api.NodeStatus{
   493  			State: api.NodeStatus_READY,
   494  			Addr:  addr,
   495  		},
   496  		description: description,
   497  	}
   498  	numUpdates := len(d.nodeUpdates)
   499  	d.nodeUpdatesLock.Unlock()
   500  
   501  	// Node is marked ready. Remove the node from down nodes if it
   502  	// is there.
   503  	d.downNodes.Delete(nodeID)
   504  
   505  	if numUpdates >= maxBatchItems {
   506  		select {
   507  		case d.processUpdatesTrigger <- struct{}{}:
   508  		case <-ctx.Done():
   509  			return ctx.Err()
   510  		}
   511  
   512  	}
   513  
   514  	// Wait until the node update batch happens before unblocking register.
   515  	d.processUpdatesLock.Lock()
   516  	defer d.processUpdatesLock.Unlock()
   517  
   518  	select {
   519  	case <-ctx.Done():
   520  		return ctx.Err()
   521  	default:
   522  	}
   523  	d.processUpdatesCond.Wait()
   524  
   525  	return nil
   526  }
   527  
   528  // gets the node IP from the context of a grpc call
   529  func nodeIPFromContext(ctx context.Context) (string, error) {
   530  	nodeInfo, err := ca.RemoteNode(ctx)
   531  	if err != nil {
   532  		return "", err
   533  	}
   534  	addr, _, err := net.SplitHostPort(nodeInfo.RemoteAddr)
   535  	if err != nil {
   536  		return "", errors.Wrap(err, "unable to get ip from addr:port")
   537  	}
   538  	return addr, nil
   539  }
   540  
   541  // register is used for registration of node with particular dispatcher.
   542  func (d *Dispatcher) register(ctx context.Context, nodeID string, description *api.NodeDescription) (string, error) {
   543  	logLocal := log.G(ctx).WithField("method", "(*Dispatcher).register")
   544  	// prevent register until we're ready to accept it
   545  	dctx, err := d.isRunningLocked()
   546  	if err != nil {
   547  		return "", err
   548  	}
   549  
   550  	if err := d.nodes.CheckRateLimit(nodeID); err != nil {
   551  		return "", err
   552  	}
   553  
   554  	// TODO(stevvooe): Validate node specification.
   555  	var node *api.Node
   556  	d.store.View(func(tx store.ReadTx) {
   557  		node = store.GetNode(tx, nodeID)
   558  	})
   559  	if node == nil {
   560  		return "", ErrNodeNotFound
   561  	}
   562  
   563  	addr, err := nodeIPFromContext(ctx)
   564  	if err != nil {
   565  		logLocal.WithError(err).Debug("failed to get remote node IP")
   566  	}
   567  
   568  	if err := d.markNodeReady(dctx, nodeID, description, addr); err != nil {
   569  		return "", err
   570  	}
   571  
   572  	expireFunc := func() {
   573  		log.G(ctx).Debugf("heartbeat expiration for worker %s, setting worker status to NodeStatus_DOWN ", nodeID)
   574  		if err := d.markNodeNotReady(nodeID, api.NodeStatus_DOWN, "heartbeat failure"); err != nil {
   575  			log.G(ctx).WithError(err).Errorf("failed deregistering node after heartbeat expiration")
   576  		}
   577  	}
   578  
   579  	rn := d.nodes.Add(node, expireFunc)
   580  	logLocal.Infof("worker %s was successfully registered", nodeID)
   581  
   582  	// NOTE(stevvooe): We need be a little careful with re-registration. The
   583  	// current implementation just matches the node id and then gives away the
   584  	// sessionID. If we ever want to use sessionID as a secret, which we may
   585  	// want to, this is giving away the keys to the kitchen.
   586  	//
   587  	// The right behavior is going to be informed by identity. Basically, each
   588  	// time a node registers, we invalidate the session and issue a new
   589  	// session, once identity is proven. This will cause misbehaved agents to
   590  	// be kicked when multiple connections are made.
   591  	return rn.SessionID, nil
   592  }
   593  
   594  // UpdateTaskStatus updates status of task. Node should send such updates
   595  // on every status change of its tasks.
   596  func (d *Dispatcher) UpdateTaskStatus(ctx context.Context, r *api.UpdateTaskStatusRequest) (*api.UpdateTaskStatusResponse, error) {
   597  	d.rpcRW.RLock()
   598  	defer d.rpcRW.RUnlock()
   599  
   600  	dctx, err := d.isRunningLocked()
   601  	if err != nil {
   602  		return nil, err
   603  	}
   604  
   605  	nodeInfo, err := ca.RemoteNode(ctx)
   606  	if err != nil {
   607  		return nil, err
   608  	}
   609  	nodeID := nodeInfo.NodeID
   610  	fields := logrus.Fields{
   611  		"node.id":      nodeID,
   612  		"node.session": r.SessionID,
   613  		"method":       "(*Dispatcher).UpdateTaskStatus",
   614  	}
   615  	if nodeInfo.ForwardedBy != nil {
   616  		fields["forwarder.id"] = nodeInfo.ForwardedBy.NodeID
   617  	}
   618  	log := log.G(ctx).WithFields(fields)
   619  
   620  	if _, err := d.nodes.GetWithSession(nodeID, r.SessionID); err != nil {
   621  		return nil, err
   622  	}
   623  
   624  	validTaskUpdates := make([]*api.UpdateTaskStatusRequest_TaskStatusUpdate, 0, len(r.Updates))
   625  
   626  	// Validate task updates
   627  	for _, u := range r.Updates {
   628  		if u.Status == nil {
   629  			log.WithField("task.id", u.TaskID).Warn("task report has nil status")
   630  			continue
   631  		}
   632  
   633  		var t *api.Task
   634  		d.store.View(func(tx store.ReadTx) {
   635  			t = store.GetTask(tx, u.TaskID)
   636  		})
   637  		if t == nil {
   638  			// Task may have been deleted
   639  			log.WithField("task.id", u.TaskID).Debug("cannot find target task in store")
   640  			continue
   641  		}
   642  
   643  		if t.NodeID != nodeID {
   644  			err := status.Errorf(codes.PermissionDenied, "cannot update a task not assigned this node")
   645  			log.WithField("task.id", u.TaskID).Error(err)
   646  			return nil, err
   647  		}
   648  
   649  		validTaskUpdates = append(validTaskUpdates, u)
   650  	}
   651  
   652  	d.taskUpdatesLock.Lock()
   653  	// Enqueue task updates
   654  	for _, u := range validTaskUpdates {
   655  		d.taskUpdates[u.TaskID] = u.Status
   656  	}
   657  
   658  	numUpdates := len(d.taskUpdates)
   659  	d.taskUpdatesLock.Unlock()
   660  
   661  	if numUpdates >= maxBatchItems {
   662  		select {
   663  		case d.processUpdatesTrigger <- struct{}{}:
   664  		case <-dctx.Done():
   665  		}
   666  	}
   667  	return nil, nil
   668  }
   669  
   670  func (d *Dispatcher) processUpdates(ctx context.Context) {
   671  	var (
   672  		taskUpdates map[string]*api.TaskStatus
   673  		nodeUpdates map[string]nodeUpdate
   674  	)
   675  	d.taskUpdatesLock.Lock()
   676  	if len(d.taskUpdates) != 0 {
   677  		taskUpdates = d.taskUpdates
   678  		d.taskUpdates = make(map[string]*api.TaskStatus)
   679  	}
   680  	d.taskUpdatesLock.Unlock()
   681  
   682  	d.nodeUpdatesLock.Lock()
   683  	if len(d.nodeUpdates) != 0 {
   684  		nodeUpdates = d.nodeUpdates
   685  		d.nodeUpdates = make(map[string]nodeUpdate)
   686  	}
   687  	d.nodeUpdatesLock.Unlock()
   688  
   689  	if len(taskUpdates) == 0 && len(nodeUpdates) == 0 {
   690  		return
   691  	}
   692  
   693  	log := log.G(ctx).WithFields(logrus.Fields{
   694  		"method": "(*Dispatcher).processUpdates",
   695  	})
   696  
   697  	err := d.store.Batch(func(batch *store.Batch) error {
   698  		for taskID, status := range taskUpdates {
   699  			err := batch.Update(func(tx store.Tx) error {
   700  				logger := log.WithField("task.id", taskID)
   701  				task := store.GetTask(tx, taskID)
   702  				if task == nil {
   703  					// Task may have been deleted
   704  					logger.Debug("cannot find target task in store")
   705  					return nil
   706  				}
   707  
   708  				logger = logger.WithField("state.transition", fmt.Sprintf("%v->%v", task.Status.State, status.State))
   709  
   710  				if task.Status == *status {
   711  					logger.Debug("task status identical, ignoring")
   712  					return nil
   713  				}
   714  
   715  				if task.Status.State > status.State {
   716  					logger.Debug("task status invalid transition")
   717  					return nil
   718  				}
   719  
   720  				// Update scheduling delay metric for running tasks.
   721  				// We use the status update time on the leader to calculate the scheduling delay.
   722  				// Because of this, the recorded scheduling delay will be an overestimate and include
   723  				// the network delay between the worker and the leader.
   724  				// This is not ideal, but its a known overestimation, rather than using the status update time
   725  				// from the worker node, which may cause unknown incorrect results due to possible clock skew.
   726  				if status.State == api.TaskStateRunning {
   727  					start := time.Unix(status.AppliedAt.GetSeconds(), int64(status.AppliedAt.GetNanos()))
   728  					schedulingDelayTimer.UpdateSince(start)
   729  				}
   730  
   731  				task.Status = *status
   732  				task.Status.AppliedBy = d.securityConfig.ClientTLSCreds.NodeID()
   733  				task.Status.AppliedAt = ptypes.MustTimestampProto(time.Now())
   734  				logger.Debugf("state for task %v updated to %v", task.GetID(), task.Status.State)
   735  				if err := store.UpdateTask(tx, task); err != nil {
   736  					logger.WithError(err).Error("failed to update task status")
   737  					return nil
   738  				}
   739  				logger.Debug("dispatcher committed status update to store")
   740  				return nil
   741  			})
   742  			if err != nil {
   743  				log.WithError(err).Error("dispatcher task update transaction failed")
   744  			}
   745  		}
   746  
   747  		for nodeID, nodeUpdate := range nodeUpdates {
   748  			err := batch.Update(func(tx store.Tx) error {
   749  				logger := log.WithField("node.id", nodeID)
   750  				node := store.GetNode(tx, nodeID)
   751  				if node == nil {
   752  					logger.Errorf("node unavailable")
   753  					return nil
   754  				}
   755  
   756  				if nodeUpdate.status != nil {
   757  					node.Status.State = nodeUpdate.status.State
   758  					node.Status.Message = nodeUpdate.status.Message
   759  					if nodeUpdate.status.Addr != "" {
   760  						node.Status.Addr = nodeUpdate.status.Addr
   761  					}
   762  				}
   763  				if nodeUpdate.description != nil {
   764  					node.Description = nodeUpdate.description
   765  				}
   766  
   767  				if err := store.UpdateNode(tx, node); err != nil {
   768  					logger.WithError(err).Error("failed to update node status")
   769  					return nil
   770  				}
   771  				logger.Debug("node status updated")
   772  				return nil
   773  			})
   774  			if err != nil {
   775  				log.WithError(err).Error("dispatcher node update transaction failed")
   776  			}
   777  		}
   778  
   779  		return nil
   780  	})
   781  	if err != nil {
   782  		log.WithError(err).Error("dispatcher batch failed")
   783  	}
   784  
   785  	d.processUpdatesCond.Broadcast()
   786  }
   787  
   788  // Tasks is a stream of tasks state for node. Each message contains full list
   789  // of tasks which should be run on node, if task is not present in that list,
   790  // it should be terminated.
   791  func (d *Dispatcher) Tasks(r *api.TasksRequest, stream api.Dispatcher_TasksServer) error {
   792  	d.rpcRW.RLock()
   793  	defer d.rpcRW.RUnlock()
   794  
   795  	dctx, err := d.isRunningLocked()
   796  	if err != nil {
   797  		return err
   798  	}
   799  
   800  	nodeInfo, err := ca.RemoteNode(stream.Context())
   801  	if err != nil {
   802  		return err
   803  	}
   804  	nodeID := nodeInfo.NodeID
   805  
   806  	fields := logrus.Fields{
   807  		"node.id":      nodeID,
   808  		"node.session": r.SessionID,
   809  		"method":       "(*Dispatcher).Tasks",
   810  	}
   811  	if nodeInfo.ForwardedBy != nil {
   812  		fields["forwarder.id"] = nodeInfo.ForwardedBy.NodeID
   813  	}
   814  	log.G(stream.Context()).WithFields(fields).Debug("")
   815  
   816  	if _, err = d.nodes.GetWithSession(nodeID, r.SessionID); err != nil {
   817  		return err
   818  	}
   819  
   820  	tasksMap := make(map[string]*api.Task)
   821  	nodeTasks, cancel, err := store.ViewAndWatch(
   822  		d.store,
   823  		func(readTx store.ReadTx) error {
   824  			tasks, err := store.FindTasks(readTx, store.ByNodeID(nodeID))
   825  			if err != nil {
   826  				return err
   827  			}
   828  			for _, t := range tasks {
   829  				tasksMap[t.ID] = t
   830  			}
   831  			return nil
   832  		},
   833  		api.EventCreateTask{Task: &api.Task{NodeID: nodeID},
   834  			Checks: []api.TaskCheckFunc{api.TaskCheckNodeID}},
   835  		api.EventUpdateTask{Task: &api.Task{NodeID: nodeID},
   836  			Checks: []api.TaskCheckFunc{api.TaskCheckNodeID}},
   837  		api.EventDeleteTask{Task: &api.Task{NodeID: nodeID},
   838  			Checks: []api.TaskCheckFunc{api.TaskCheckNodeID}},
   839  	)
   840  	if err != nil {
   841  		return err
   842  	}
   843  	defer cancel()
   844  
   845  	for {
   846  		if _, err := d.nodes.GetWithSession(nodeID, r.SessionID); err != nil {
   847  			return err
   848  		}
   849  
   850  		var tasks []*api.Task
   851  		for _, t := range tasksMap {
   852  			// dispatcher only sends tasks that have been assigned to a node
   853  			if t != nil && t.Status.State >= api.TaskStateAssigned {
   854  				tasks = append(tasks, t)
   855  			}
   856  		}
   857  
   858  		if err := stream.Send(&api.TasksMessage{Tasks: tasks}); err != nil {
   859  			return err
   860  		}
   861  
   862  		// bursty events should be processed in batches and sent out snapshot
   863  		var (
   864  			modificationCnt int
   865  			batchingTimer   *time.Timer
   866  			batchingTimeout <-chan time.Time
   867  		)
   868  
   869  	batchingLoop:
   870  		for modificationCnt < modificationBatchLimit {
   871  			select {
   872  			case event := <-nodeTasks:
   873  				switch v := event.(type) {
   874  				case api.EventCreateTask:
   875  					tasksMap[v.Task.ID] = v.Task
   876  					modificationCnt++
   877  				case api.EventUpdateTask:
   878  					if oldTask, exists := tasksMap[v.Task.ID]; exists {
   879  						// States ASSIGNED and below are set by the orchestrator/scheduler,
   880  						// not the agent, so tasks in these states need to be sent to the
   881  						// agent even if nothing else has changed.
   882  						if equality.TasksEqualStable(oldTask, v.Task) && v.Task.Status.State > api.TaskStateAssigned {
   883  							// this update should not trigger action at agent
   884  							tasksMap[v.Task.ID] = v.Task
   885  							continue
   886  						}
   887  					}
   888  					tasksMap[v.Task.ID] = v.Task
   889  					modificationCnt++
   890  				case api.EventDeleteTask:
   891  					delete(tasksMap, v.Task.ID)
   892  					modificationCnt++
   893  				}
   894  				if batchingTimer != nil {
   895  					batchingTimer.Reset(batchingWaitTime)
   896  				} else {
   897  					batchingTimer = time.NewTimer(batchingWaitTime)
   898  					batchingTimeout = batchingTimer.C
   899  				}
   900  			case <-batchingTimeout:
   901  				break batchingLoop
   902  			case <-stream.Context().Done():
   903  				return stream.Context().Err()
   904  			case <-dctx.Done():
   905  				return dctx.Err()
   906  			}
   907  		}
   908  
   909  		if batchingTimer != nil {
   910  			batchingTimer.Stop()
   911  		}
   912  	}
   913  }
   914  
   915  // Assignments is a stream of assignments for a node. Each message contains
   916  // either full list of tasks and secrets for the node, or an incremental update.
   917  func (d *Dispatcher) Assignments(r *api.AssignmentsRequest, stream api.Dispatcher_AssignmentsServer) error {
   918  	d.rpcRW.RLock()
   919  	defer d.rpcRW.RUnlock()
   920  
   921  	dctx, err := d.isRunningLocked()
   922  	if err != nil {
   923  		return err
   924  	}
   925  
   926  	nodeInfo, err := ca.RemoteNode(stream.Context())
   927  	if err != nil {
   928  		return err
   929  	}
   930  	nodeID := nodeInfo.NodeID
   931  
   932  	fields := logrus.Fields{
   933  		"node.id":      nodeID,
   934  		"node.session": r.SessionID,
   935  		"method":       "(*Dispatcher).Assignments",
   936  	}
   937  	if nodeInfo.ForwardedBy != nil {
   938  		fields["forwarder.id"] = nodeInfo.ForwardedBy.NodeID
   939  	}
   940  	log := log.G(stream.Context()).WithFields(fields)
   941  	log.Debug("")
   942  
   943  	if _, err = d.nodes.GetWithSession(nodeID, r.SessionID); err != nil {
   944  		return err
   945  	}
   946  
   947  	var (
   948  		sequence    int64
   949  		appliesTo   string
   950  		assignments = newAssignmentSet(log, d.dp)
   951  	)
   952  
   953  	sendMessage := func(msg api.AssignmentsMessage, assignmentType api.AssignmentsMessage_Type) error {
   954  		sequence++
   955  		msg.AppliesTo = appliesTo
   956  		msg.ResultsIn = strconv.FormatInt(sequence, 10)
   957  		appliesTo = msg.ResultsIn
   958  		msg.Type = assignmentType
   959  
   960  		return stream.Send(&msg)
   961  	}
   962  
   963  	// TODO(aaronl): Also send node secrets that should be exposed to
   964  	// this node.
   965  	nodeTasks, cancel, err := store.ViewAndWatch(
   966  		d.store,
   967  		func(readTx store.ReadTx) error {
   968  			tasks, err := store.FindTasks(readTx, store.ByNodeID(nodeID))
   969  			if err != nil {
   970  				return err
   971  			}
   972  
   973  			for _, t := range tasks {
   974  				assignments.addOrUpdateTask(readTx, t)
   975  			}
   976  
   977  			return nil
   978  		},
   979  		api.EventUpdateTask{Task: &api.Task{NodeID: nodeID},
   980  			Checks: []api.TaskCheckFunc{api.TaskCheckNodeID}},
   981  		api.EventDeleteTask{Task: &api.Task{NodeID: nodeID},
   982  			Checks: []api.TaskCheckFunc{api.TaskCheckNodeID}},
   983  	)
   984  	if err != nil {
   985  		return err
   986  	}
   987  	defer cancel()
   988  
   989  	if err := sendMessage(assignments.message(), api.AssignmentsMessage_COMPLETE); err != nil {
   990  		return err
   991  	}
   992  
   993  	for {
   994  		// Check for session expiration
   995  		if _, err := d.nodes.GetWithSession(nodeID, r.SessionID); err != nil {
   996  			return err
   997  		}
   998  
   999  		// bursty events should be processed in batches and sent out together
  1000  		var (
  1001  			modificationCnt int
  1002  			batchingTimer   *time.Timer
  1003  			batchingTimeout <-chan time.Time
  1004  		)
  1005  
  1006  		oneModification := func() {
  1007  			modificationCnt++
  1008  
  1009  			if batchingTimer != nil {
  1010  				batchingTimer.Reset(batchingWaitTime)
  1011  			} else {
  1012  				batchingTimer = time.NewTimer(batchingWaitTime)
  1013  				batchingTimeout = batchingTimer.C
  1014  			}
  1015  		}
  1016  
  1017  		// The batching loop waits for 50 ms after the most recent
  1018  		// change, or until modificationBatchLimit is reached. The
  1019  		// worst case latency is modificationBatchLimit * batchingWaitTime,
  1020  		// which is 10 seconds.
  1021  	batchingLoop:
  1022  		for modificationCnt < modificationBatchLimit {
  1023  			select {
  1024  			case event := <-nodeTasks:
  1025  				switch v := event.(type) {
  1026  				// We don't monitor EventCreateTask because tasks are
  1027  				// never created in the ASSIGNED state. First tasks are
  1028  				// created by the orchestrator, then the scheduler moves
  1029  				// them to ASSIGNED. If this ever changes, we will need
  1030  				// to monitor task creations as well.
  1031  				case api.EventUpdateTask:
  1032  					d.store.View(func(readTx store.ReadTx) {
  1033  						if assignments.addOrUpdateTask(readTx, v.Task) {
  1034  							oneModification()
  1035  						}
  1036  					})
  1037  				case api.EventDeleteTask:
  1038  					if assignments.removeTask(v.Task) {
  1039  						oneModification()
  1040  					}
  1041  					// TODO(aaronl): For node secrets, we'll need to handle
  1042  					// EventCreateSecret.
  1043  				}
  1044  			case <-batchingTimeout:
  1045  				break batchingLoop
  1046  			case <-stream.Context().Done():
  1047  				return stream.Context().Err()
  1048  			case <-dctx.Done():
  1049  				return dctx.Err()
  1050  			}
  1051  		}
  1052  
  1053  		if batchingTimer != nil {
  1054  			batchingTimer.Stop()
  1055  		}
  1056  
  1057  		if modificationCnt > 0 {
  1058  			if err := sendMessage(assignments.message(), api.AssignmentsMessage_INCREMENTAL); err != nil {
  1059  				return err
  1060  			}
  1061  		}
  1062  	}
  1063  }
  1064  
  1065  func (d *Dispatcher) moveTasksToOrphaned(nodeID string) error {
  1066  	err := d.store.Batch(func(batch *store.Batch) error {
  1067  		var (
  1068  			tasks []*api.Task
  1069  			err   error
  1070  		)
  1071  
  1072  		d.store.View(func(tx store.ReadTx) {
  1073  			tasks, err = store.FindTasks(tx, store.ByNodeID(nodeID))
  1074  		})
  1075  		if err != nil {
  1076  			return err
  1077  		}
  1078  
  1079  		for _, task := range tasks {
  1080  			// Tasks running on an unreachable node need to be marked as
  1081  			// orphaned since we have no idea whether the task is still running
  1082  			// or not.
  1083  			//
  1084  			// This only applies for tasks that could have made progress since
  1085  			// the agent became unreachable (assigned<->running)
  1086  			//
  1087  			// Tasks in a final state (e.g. rejected) *cannot* have made
  1088  			// progress, therefore there's no point in marking them as orphaned
  1089  			if task.Status.State >= api.TaskStateAssigned && task.Status.State <= api.TaskStateRunning {
  1090  				task.Status.State = api.TaskStateOrphaned
  1091  			}
  1092  
  1093  			err := batch.Update(func(tx store.Tx) error {
  1094  				return store.UpdateTask(tx, task)
  1095  			})
  1096  			if err != nil {
  1097  				return err
  1098  			}
  1099  
  1100  		}
  1101  
  1102  		return nil
  1103  	})
  1104  
  1105  	return err
  1106  }
  1107  
  1108  // markNodeNotReady sets the node state to some state other than READY
  1109  func (d *Dispatcher) markNodeNotReady(id string, state api.NodeStatus_State, message string) error {
  1110  	logLocal := log.G(d.ctx).WithField("method", "(*Dispatcher).markNodeNotReady")
  1111  
  1112  	dctx, err := d.isRunningLocked()
  1113  	if err != nil {
  1114  		return err
  1115  	}
  1116  
  1117  	// Node is down. Add it to down nodes so that we can keep
  1118  	// track of tasks assigned to the node.
  1119  	var node *api.Node
  1120  	d.store.View(func(readTx store.ReadTx) {
  1121  		node = store.GetNode(readTx, id)
  1122  		if node == nil {
  1123  			err = fmt.Errorf("could not find node %s while trying to add to down nodes store", id)
  1124  		}
  1125  	})
  1126  	if err != nil {
  1127  		return err
  1128  	}
  1129  
  1130  	expireFunc := func() {
  1131  		log.G(dctx).Debugf(`worker timed-out %s in "down" state, moving all tasks to "ORPHANED" state`, id)
  1132  		if err := d.moveTasksToOrphaned(id); err != nil {
  1133  			log.G(dctx).WithError(err).Error(`failed to move all tasks to "ORPHANED" state`)
  1134  		}
  1135  
  1136  		d.downNodes.Delete(id)
  1137  	}
  1138  
  1139  	d.downNodes.Add(node, expireFunc)
  1140  	logLocal.Debugf("added node %s to down nodes list", node.ID)
  1141  
  1142  	status := &api.NodeStatus{
  1143  		State:   state,
  1144  		Message: message,
  1145  	}
  1146  
  1147  	d.nodeUpdatesLock.Lock()
  1148  	// pluck the description out of nodeUpdates. this protects against a case
  1149  	// where a node is marked ready and a description is added, but then the
  1150  	// node is immediately marked not ready. this preserves that description
  1151  	d.nodeUpdates[id] = nodeUpdate{status: status, description: d.nodeUpdates[id].description}
  1152  	numUpdates := len(d.nodeUpdates)
  1153  	d.nodeUpdatesLock.Unlock()
  1154  
  1155  	if numUpdates >= maxBatchItems {
  1156  		select {
  1157  		case d.processUpdatesTrigger <- struct{}{}:
  1158  		case <-dctx.Done():
  1159  		}
  1160  	}
  1161  
  1162  	if rn := d.nodes.Delete(id); rn == nil {
  1163  		return errors.Errorf("node %s is not found in local storage", id)
  1164  	}
  1165  	logLocal.Debugf("deleted node %s from node store", node.ID)
  1166  
  1167  	return nil
  1168  }
  1169  
  1170  // Heartbeat is heartbeat method for nodes. It returns new TTL in response.
  1171  // Node should send new heartbeat earlier than now + TTL, otherwise it will
  1172  // be deregistered from dispatcher and its status will be updated to NodeStatus_DOWN
  1173  func (d *Dispatcher) Heartbeat(ctx context.Context, r *api.HeartbeatRequest) (*api.HeartbeatResponse, error) {
  1174  	d.rpcRW.RLock()
  1175  	defer d.rpcRW.RUnlock()
  1176  
  1177  	// TODO(anshul) Explore if its possible to check context here without locking.
  1178  	if _, err := d.isRunningLocked(); err != nil {
  1179  		return nil, status.Errorf(codes.Aborted, "dispatcher is stopped")
  1180  	}
  1181  
  1182  	nodeInfo, err := ca.RemoteNode(ctx)
  1183  	if err != nil {
  1184  		return nil, err
  1185  	}
  1186  
  1187  	period, err := d.nodes.Heartbeat(nodeInfo.NodeID, r.SessionID)
  1188  
  1189  	log.G(ctx).WithField("method", "(*Dispatcher).Heartbeat").Debugf("received heartbeat from worker %v, expect next heartbeat in %v", nodeInfo, period)
  1190  	return &api.HeartbeatResponse{Period: period}, err
  1191  }
  1192  
  1193  func (d *Dispatcher) getManagers() []*api.WeightedPeer {
  1194  	d.mu.Lock()
  1195  	defer d.mu.Unlock()
  1196  	return d.lastSeenManagers
  1197  }
  1198  
  1199  func (d *Dispatcher) getNetworkBootstrapKeys() []*api.EncryptionKey {
  1200  	d.mu.Lock()
  1201  	defer d.mu.Unlock()
  1202  	return d.networkBootstrapKeys
  1203  }
  1204  
  1205  func (d *Dispatcher) getRootCACert() []byte {
  1206  	d.mu.Lock()
  1207  	defer d.mu.Unlock()
  1208  	return d.lastSeenRootCert
  1209  }
  1210  
  1211  // Session is a stream which controls agent connection.
  1212  // Each message contains list of backup Managers with weights. Also there is
  1213  // a special boolean field Disconnect which if true indicates that node should
  1214  // reconnect to another Manager immediately.
  1215  func (d *Dispatcher) Session(r *api.SessionRequest, stream api.Dispatcher_SessionServer) error {
  1216  	d.rpcRW.RLock()
  1217  	defer d.rpcRW.RUnlock()
  1218  
  1219  	dctx, err := d.isRunningLocked()
  1220  	if err != nil {
  1221  		return err
  1222  	}
  1223  
  1224  	ctx := stream.Context()
  1225  
  1226  	nodeInfo, err := ca.RemoteNode(ctx)
  1227  	if err != nil {
  1228  		return err
  1229  	}
  1230  	nodeID := nodeInfo.NodeID
  1231  
  1232  	var sessionID string
  1233  	if _, err := d.nodes.GetWithSession(nodeID, r.SessionID); err != nil {
  1234  		// register the node.
  1235  		sessionID, err = d.register(ctx, nodeID, r.Description)
  1236  		if err != nil {
  1237  			return err
  1238  		}
  1239  	} else {
  1240  		sessionID = r.SessionID
  1241  		// get the node IP addr
  1242  		addr, err := nodeIPFromContext(stream.Context())
  1243  		if err != nil {
  1244  			log.G(ctx).WithError(err).Debug("failed to get remote node IP")
  1245  		}
  1246  		// update the node description
  1247  		if err := d.markNodeReady(dctx, nodeID, r.Description, addr); err != nil {
  1248  			return err
  1249  		}
  1250  	}
  1251  
  1252  	fields := logrus.Fields{
  1253  		"node.id":      nodeID,
  1254  		"node.session": sessionID,
  1255  		"method":       "(*Dispatcher).Session",
  1256  	}
  1257  	if nodeInfo.ForwardedBy != nil {
  1258  		fields["forwarder.id"] = nodeInfo.ForwardedBy.NodeID
  1259  	}
  1260  	log := log.G(ctx).WithFields(fields)
  1261  
  1262  	var nodeObj *api.Node
  1263  	nodeUpdates, cancel, err := store.ViewAndWatch(d.store, func(readTx store.ReadTx) error {
  1264  		nodeObj = store.GetNode(readTx, nodeID)
  1265  		return nil
  1266  	}, api.EventUpdateNode{Node: &api.Node{ID: nodeID},
  1267  		Checks: []api.NodeCheckFunc{api.NodeCheckID}},
  1268  	)
  1269  	if cancel != nil {
  1270  		defer cancel()
  1271  	}
  1272  
  1273  	if err != nil {
  1274  		log.WithError(err).Error("ViewAndWatch Node failed")
  1275  	}
  1276  
  1277  	if _, err = d.nodes.GetWithSession(nodeID, sessionID); err != nil {
  1278  		return err
  1279  	}
  1280  
  1281  	clusterUpdatesCh, clusterCancel := d.clusterUpdateQueue.Watch()
  1282  	defer clusterCancel()
  1283  
  1284  	if err := stream.Send(&api.SessionMessage{
  1285  		SessionID:            sessionID,
  1286  		Node:                 nodeObj,
  1287  		Managers:             d.getManagers(),
  1288  		NetworkBootstrapKeys: d.getNetworkBootstrapKeys(),
  1289  		RootCA:               d.getRootCACert(),
  1290  	}); err != nil {
  1291  		return err
  1292  	}
  1293  
  1294  	// disconnectNode is a helper forcibly shutdown connection
  1295  	disconnectNode := func() error {
  1296  		log.Infof("dispatcher session dropped, marking node %s down", nodeID)
  1297  		if err := d.markNodeNotReady(nodeID, api.NodeStatus_DISCONNECTED, "node is currently trying to find new manager"); err != nil {
  1298  			log.WithError(err).Error("failed to remove node")
  1299  		}
  1300  		// still return an abort if the transport closure was ineffective.
  1301  		return status.Errorf(codes.Aborted, "node must disconnect")
  1302  	}
  1303  
  1304  	for {
  1305  		// After each message send, we need to check the nodes sessionID hasn't
  1306  		// changed. If it has, we will shut down the stream and make the node
  1307  		// re-register.
  1308  		node, err := d.nodes.GetWithSession(nodeID, sessionID)
  1309  		if err != nil {
  1310  			return err
  1311  		}
  1312  
  1313  		var (
  1314  			disconnect bool
  1315  			mgrs       []*api.WeightedPeer
  1316  			netKeys    []*api.EncryptionKey
  1317  			rootCert   []byte
  1318  		)
  1319  
  1320  		select {
  1321  		case ev := <-clusterUpdatesCh:
  1322  			update := ev.(clusterUpdate)
  1323  			if update.managerUpdate != nil {
  1324  				mgrs = *update.managerUpdate
  1325  			}
  1326  			if update.bootstrapKeyUpdate != nil {
  1327  				netKeys = *update.bootstrapKeyUpdate
  1328  			}
  1329  			if update.rootCAUpdate != nil {
  1330  				rootCert = *update.rootCAUpdate
  1331  			}
  1332  		case ev := <-nodeUpdates:
  1333  			nodeObj = ev.(api.EventUpdateNode).Node
  1334  		case <-stream.Context().Done():
  1335  			return stream.Context().Err()
  1336  		case <-node.Disconnect:
  1337  			disconnect = true
  1338  		case <-dctx.Done():
  1339  			disconnect = true
  1340  		}
  1341  		if mgrs == nil {
  1342  			mgrs = d.getManagers()
  1343  		}
  1344  		if netKeys == nil {
  1345  			netKeys = d.getNetworkBootstrapKeys()
  1346  		}
  1347  		if rootCert == nil {
  1348  			rootCert = d.getRootCACert()
  1349  		}
  1350  
  1351  		if err := stream.Send(&api.SessionMessage{
  1352  			SessionID:            sessionID,
  1353  			Node:                 nodeObj,
  1354  			Managers:             mgrs,
  1355  			NetworkBootstrapKeys: netKeys,
  1356  			RootCA:               rootCert,
  1357  		}); err != nil {
  1358  			return err
  1359  		}
  1360  		if disconnect {
  1361  			return disconnectNode()
  1362  		}
  1363  	}
  1364  }