github.com/kaisenlinux/docker.io@v0.0.0-20230510090727-ea55db55fac7/swarmkit/agent/session.go (about)

     1  package agent
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"math"
     7  	"sync"
     8  	"time"
     9  
    10  	"github.com/docker/swarmkit/api"
    11  	"github.com/docker/swarmkit/connectionbroker"
    12  	"github.com/docker/swarmkit/log"
    13  	"github.com/sirupsen/logrus"
    14  	"google.golang.org/grpc"
    15  	"google.golang.org/grpc/codes"
    16  	"google.golang.org/grpc/status"
    17  )
    18  
    19  var (
    20  	dispatcherRPCTimeout = 5 * time.Second
    21  	errSessionClosed     = errors.New("agent: session closed")
    22  )
    23  
    24  // session encapsulates one round of registration with the manager. session
    25  // starts the registration and heartbeat control cycle. Any failure will result
    26  // in a complete shutdown of the session and it must be reestablished.
    27  //
    28  // All communication with the master is done through session.  Changes that
    29  // flow into the agent, such as task assignment, are called back into the
    30  // agent through errs, messages and tasks.
    31  type session struct {
    32  	conn *connectionbroker.Conn
    33  
    34  	agent         *Agent
    35  	sessionID     string
    36  	session       api.Dispatcher_SessionClient
    37  	errs          chan error
    38  	messages      chan *api.SessionMessage
    39  	assignments   chan *api.AssignmentsMessage
    40  	subscriptions chan *api.SubscriptionMessage
    41  
    42  	cancel     func()        // this is assumed to be never nil, and set whenever a session is created
    43  	registered chan struct{} // closed registration
    44  	closed     chan struct{}
    45  	closeOnce  sync.Once
    46  }
    47  
    48  func newSession(ctx context.Context, agent *Agent, delay time.Duration, sessionID string, description *api.NodeDescription) *session {
    49  	sessionCtx, sessionCancel := context.WithCancel(ctx)
    50  	s := &session{
    51  		agent:         agent,
    52  		sessionID:     sessionID,
    53  		errs:          make(chan error, 1),
    54  		messages:      make(chan *api.SessionMessage),
    55  		assignments:   make(chan *api.AssignmentsMessage),
    56  		subscriptions: make(chan *api.SubscriptionMessage),
    57  		registered:    make(chan struct{}),
    58  		closed:        make(chan struct{}),
    59  		cancel:        sessionCancel,
    60  	}
    61  
    62  	// TODO(stevvooe): Need to move connection management up a level or create
    63  	// independent connection for log broker client.
    64  
    65  	cc, err := agent.config.ConnBroker.Select(
    66  		grpc.WithTransportCredentials(agent.config.Credentials),
    67  		grpc.WithTimeout(dispatcherRPCTimeout),
    68  		grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(math.MaxInt32)),
    69  	)
    70  
    71  	if err != nil {
    72  		// since we are returning without launching the session goroutine, we
    73  		// need to provide the delay that is guaranteed by calling this
    74  		// function. We launch a goroutine so that we only delay the retry and
    75  		// avoid blocking the main loop.
    76  		go func() {
    77  			time.Sleep(delay)
    78  			s.errs <- err
    79  		}()
    80  		return s
    81  	}
    82  
    83  	log.G(ctx).Infof("manager selected by agent for new session: %v", cc.Peer())
    84  
    85  	s.conn = cc
    86  
    87  	go s.run(sessionCtx, delay, description)
    88  	return s
    89  }
    90  
    91  func (s *session) run(ctx context.Context, delay time.Duration, description *api.NodeDescription) {
    92  	timer := time.NewTimer(delay) // delay before registering.
    93  	log.G(ctx).Infof("waiting %v before registering session", delay)
    94  	defer timer.Stop()
    95  	select {
    96  	case <-timer.C:
    97  	case <-ctx.Done():
    98  		return
    99  	}
   100  
   101  	if err := s.start(ctx, description); err != nil {
   102  		select {
   103  		case s.errs <- err:
   104  		case <-s.closed:
   105  		case <-ctx.Done():
   106  		}
   107  		return
   108  	}
   109  
   110  	ctx = log.WithLogger(ctx, log.G(ctx).WithField("session.id", s.sessionID))
   111  
   112  	go runctx(ctx, s.closed, s.errs, s.heartbeat)
   113  	go runctx(ctx, s.closed, s.errs, s.watch)
   114  	go runctx(ctx, s.closed, s.errs, s.listen)
   115  	go runctx(ctx, s.closed, s.errs, s.logSubscriptions)
   116  
   117  	close(s.registered)
   118  }
   119  
   120  // start begins the session and returns the first SessionMessage.
   121  func (s *session) start(ctx context.Context, description *api.NodeDescription) error {
   122  	log.G(ctx).Debugf("(*session).start")
   123  
   124  	errChan := make(chan error, 1)
   125  	var (
   126  		msg    *api.SessionMessage
   127  		stream api.Dispatcher_SessionClient
   128  		err    error
   129  	)
   130  	// Note: we don't defer cancellation of this context, because the
   131  	// streaming RPC is used after this function returned. We only cancel
   132  	// it in the timeout case to make sure the goroutine completes.
   133  
   134  	// We also fork this context again from the `run` context, because on
   135  	// `dispatcherRPCTimeout`, we want to cancel establishing a session and
   136  	// return an error.  If we cancel the `run` context instead of forking,
   137  	// then in `run` it's possible that we just terminate the function because
   138  	// `ctx` is done and hence fail to propagate the timeout error to the agent.
   139  	// If the error is not propogated to the agent, the agent will not close
   140  	// the session or rebuild a new session.
   141  	sessionCtx, cancelSession := context.WithCancel(ctx) //nolint:govet
   142  
   143  	// Need to run Session in a goroutine since there's no way to set a
   144  	// timeout for an individual Recv call in a stream.
   145  	go func() {
   146  		client := api.NewDispatcherClient(s.conn.ClientConn)
   147  
   148  		stream, err = client.Session(sessionCtx, &api.SessionRequest{
   149  			Description: description,
   150  			SessionID:   s.sessionID,
   151  		})
   152  		if err != nil {
   153  			errChan <- err
   154  			return
   155  		}
   156  
   157  		msg, err = stream.Recv()
   158  		errChan <- err
   159  	}()
   160  
   161  	select {
   162  	case err := <-errChan:
   163  		if err != nil {
   164  			return err //nolint:govet
   165  		}
   166  	case <-time.After(dispatcherRPCTimeout):
   167  		cancelSession()
   168  		return errors.New("session initiation timed out")
   169  	}
   170  
   171  	s.sessionID = msg.SessionID
   172  	s.session = stream
   173  
   174  	return s.handleSessionMessage(ctx, msg)
   175  }
   176  
   177  func (s *session) heartbeat(ctx context.Context) error {
   178  	log.G(ctx).Debugf("(*session).heartbeat")
   179  	client := api.NewDispatcherClient(s.conn.ClientConn)
   180  	heartbeat := time.NewTimer(1) // send out a heartbeat right away
   181  	defer heartbeat.Stop()
   182  
   183  	fields := logrus.Fields{
   184  		"sessionID": s.sessionID,
   185  		"method":    "(*session).heartbeat",
   186  	}
   187  
   188  	for {
   189  		select {
   190  		case <-heartbeat.C:
   191  			heartbeatCtx, cancel := context.WithTimeout(ctx, dispatcherRPCTimeout)
   192  			// TODO(anshul) log manager info in all logs in this function.
   193  			log.G(ctx).WithFields(fields).Debugf("sending heartbeat to manager %v with timeout %v", s.conn.Peer(), dispatcherRPCTimeout)
   194  			resp, err := client.Heartbeat(heartbeatCtx, &api.HeartbeatRequest{
   195  				SessionID: s.sessionID,
   196  			})
   197  			cancel()
   198  			if err != nil {
   199  				log.G(ctx).WithFields(fields).WithError(err).Errorf("heartbeat to manager %v failed", s.conn.Peer())
   200  				st, _ := status.FromError(err)
   201  				if st.Code() == codes.NotFound {
   202  					err = errNodeNotRegistered
   203  				}
   204  
   205  				return err
   206  			}
   207  
   208  			log.G(ctx).WithFields(fields).Debugf("heartbeat successful to manager %v, next heartbeat period: %v", s.conn.Peer(), resp.Period)
   209  
   210  			heartbeat.Reset(resp.Period)
   211  		case <-s.closed:
   212  			return errSessionClosed
   213  		case <-ctx.Done():
   214  			return ctx.Err()
   215  		}
   216  	}
   217  }
   218  
   219  func (s *session) listen(ctx context.Context) error {
   220  	defer s.session.CloseSend()
   221  	log.G(ctx).Debugf("(*session).listen")
   222  	for {
   223  		msg, err := s.session.Recv()
   224  		if err != nil {
   225  			return err
   226  		}
   227  
   228  		if err := s.handleSessionMessage(ctx, msg); err != nil {
   229  			return err
   230  		}
   231  	}
   232  }
   233  
   234  func (s *session) handleSessionMessage(ctx context.Context, msg *api.SessionMessage) error {
   235  	select {
   236  	case s.messages <- msg:
   237  		return nil
   238  	case <-s.closed:
   239  		return errSessionClosed
   240  	case <-ctx.Done():
   241  		return ctx.Err()
   242  	}
   243  }
   244  
   245  func (s *session) logSubscriptions(ctx context.Context) error {
   246  	log := log.G(ctx).WithFields(logrus.Fields{"method": "(*session).logSubscriptions"})
   247  	log.Debugf("")
   248  
   249  	client := api.NewLogBrokerClient(s.conn.ClientConn)
   250  	subscriptions, err := client.ListenSubscriptions(ctx, &api.ListenSubscriptionsRequest{})
   251  	if err != nil {
   252  		return err
   253  	}
   254  	defer subscriptions.CloseSend()
   255  
   256  	for {
   257  		resp, err := subscriptions.Recv()
   258  		st, _ := status.FromError(err)
   259  		if st.Code() == codes.Unimplemented {
   260  			log.Warning("manager does not support log subscriptions")
   261  			// Don't return, because returning would bounce the session
   262  			select {
   263  			case <-s.closed:
   264  				return errSessionClosed
   265  			case <-ctx.Done():
   266  				return ctx.Err()
   267  			}
   268  		}
   269  		if err != nil {
   270  			return err
   271  		}
   272  
   273  		select {
   274  		case s.subscriptions <- resp:
   275  		case <-s.closed:
   276  			return errSessionClosed
   277  		case <-ctx.Done():
   278  			return ctx.Err()
   279  		}
   280  	}
   281  }
   282  
   283  func (s *session) watch(ctx context.Context) error {
   284  	log := log.G(ctx).WithFields(logrus.Fields{"method": "(*session).watch"})
   285  	log.Debugf("")
   286  	var (
   287  		resp            *api.AssignmentsMessage
   288  		assignmentWatch api.Dispatcher_AssignmentsClient
   289  		tasksWatch      api.Dispatcher_TasksClient
   290  		streamReference string
   291  		tasksFallback   bool
   292  		err             error
   293  	)
   294  
   295  	client := api.NewDispatcherClient(s.conn.ClientConn)
   296  	for {
   297  		// If this is the first time we're running the loop, or there was a reference mismatch
   298  		// attempt to get the assignmentWatch
   299  		if assignmentWatch == nil && !tasksFallback {
   300  			assignmentWatch, err = client.Assignments(ctx, &api.AssignmentsRequest{SessionID: s.sessionID})
   301  			if err != nil {
   302  				return err
   303  			}
   304  		}
   305  		// We have an assignmentWatch, let's try to receive an AssignmentMessage
   306  		if assignmentWatch != nil {
   307  			// If we get a code = 12 desc = unknown method Assignments, try to use tasks
   308  			resp, err = assignmentWatch.Recv()
   309  			if err != nil {
   310  				st, _ := status.FromError(err)
   311  				if st.Code() != codes.Unimplemented {
   312  					return err
   313  				}
   314  				tasksFallback = true
   315  				assignmentWatch = nil
   316  				log.WithError(err).Infof("falling back to Tasks")
   317  			}
   318  		}
   319  
   320  		// This code is here for backwards compatibility (so that newer clients can use the
   321  		// older method Tasks)
   322  		if tasksWatch == nil && tasksFallback {
   323  			tasksWatch, err = client.Tasks(ctx, &api.TasksRequest{SessionID: s.sessionID})
   324  			if err != nil {
   325  				return err
   326  			}
   327  		}
   328  		if tasksWatch != nil {
   329  			// When falling back to Tasks because of an old managers, we wrap the tasks in assignments.
   330  			var taskResp *api.TasksMessage
   331  			var assignmentChanges []*api.AssignmentChange
   332  			taskResp, err = tasksWatch.Recv()
   333  			if err != nil {
   334  				return err
   335  			}
   336  			for _, t := range taskResp.Tasks {
   337  				taskChange := &api.AssignmentChange{
   338  					Assignment: &api.Assignment{
   339  						Item: &api.Assignment_Task{
   340  							Task: t,
   341  						},
   342  					},
   343  					Action: api.AssignmentChange_AssignmentActionUpdate,
   344  				}
   345  
   346  				assignmentChanges = append(assignmentChanges, taskChange)
   347  			}
   348  			resp = &api.AssignmentsMessage{Type: api.AssignmentsMessage_COMPLETE, Changes: assignmentChanges}
   349  		}
   350  
   351  		// If there seems to be a gap in the stream, let's break out of the inner for and
   352  		// re-sync (by calling Assignments again).
   353  		if streamReference != "" && streamReference != resp.AppliesTo {
   354  			assignmentWatch = nil
   355  		} else {
   356  			streamReference = resp.ResultsIn
   357  		}
   358  
   359  		select {
   360  		case s.assignments <- resp:
   361  		case <-s.closed:
   362  			return errSessionClosed
   363  		case <-ctx.Done():
   364  			return ctx.Err()
   365  		}
   366  	}
   367  }
   368  
   369  // sendTaskStatus uses the current session to send the status of a single task.
   370  func (s *session) sendTaskStatus(ctx context.Context, taskID string, taskStatus *api.TaskStatus) error {
   371  	client := api.NewDispatcherClient(s.conn.ClientConn)
   372  	if _, err := client.UpdateTaskStatus(ctx, &api.UpdateTaskStatusRequest{
   373  		SessionID: s.sessionID,
   374  		Updates: []*api.UpdateTaskStatusRequest_TaskStatusUpdate{
   375  			{
   376  				TaskID: taskID,
   377  				Status: taskStatus,
   378  			},
   379  		},
   380  	}); err != nil {
   381  		// TODO(stevvooe): Dispatcher should not return this error. Status
   382  		// reports for unknown tasks should be ignored.
   383  		st, _ := status.FromError(err)
   384  		if st.Code() == codes.NotFound {
   385  			return errTaskUnknown
   386  		}
   387  
   388  		return err
   389  	}
   390  
   391  	return nil
   392  }
   393  
   394  func (s *session) sendTaskStatuses(ctx context.Context, updates ...*api.UpdateTaskStatusRequest_TaskStatusUpdate) ([]*api.UpdateTaskStatusRequest_TaskStatusUpdate, error) {
   395  	if len(updates) < 1 {
   396  		return nil, nil
   397  	}
   398  
   399  	const batchSize = 1024
   400  	select {
   401  	case <-s.registered:
   402  		select {
   403  		case <-s.closed:
   404  			return updates, ErrClosed
   405  		default:
   406  		}
   407  	case <-s.closed:
   408  		return updates, ErrClosed
   409  	case <-ctx.Done():
   410  		return updates, ctx.Err()
   411  	}
   412  
   413  	client := api.NewDispatcherClient(s.conn.ClientConn)
   414  	n := batchSize
   415  
   416  	if len(updates) < n {
   417  		n = len(updates)
   418  	}
   419  
   420  	if _, err := client.UpdateTaskStatus(ctx, &api.UpdateTaskStatusRequest{
   421  		SessionID: s.sessionID,
   422  		Updates:   updates[:n],
   423  	}); err != nil {
   424  		log.G(ctx).WithError(err).Errorf("failed sending task status batch size of %d", len(updates[:n]))
   425  		return updates, err
   426  	}
   427  
   428  	return updates[n:], nil
   429  }
   430  
   431  // sendError is used to send errors to errs channel and trigger session recreation
   432  func (s *session) sendError(err error) {
   433  	select {
   434  	case s.errs <- err:
   435  	case <-s.closed:
   436  	}
   437  }
   438  
   439  // close the given session. It should be called only in <-session.errs branch
   440  // of event loop, or when cleaning up the agent.
   441  func (s *session) close() error {
   442  	s.closeOnce.Do(func() {
   443  		s.cancel()
   444  		if s.conn != nil {
   445  			s.conn.Close(false)
   446  		}
   447  		close(s.closed)
   448  	})
   449  
   450  	return nil
   451  }