go.uber.org/cadence@v1.2.9/internal/internal_task_pollers.go (about)

     1  // Copyright (c) 2017-2020 Uber Technologies Inc.
     2  // Portions of the Software are attributed to Copyright (c) 2020 Temporal Technologies Inc.
     3  //
     4  // Permission is hereby granted, free of charge, to any person obtaining a copy
     5  // of this software and associated documentation files (the "Software"), to deal
     6  // in the Software without restriction, including without limitation the rights
     7  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     8  // copies of the Software, and to permit persons to whom the Software is
     9  // furnished to do so, subject to the following conditions:
    10  //
    11  // The above copyright notice and this permission notice shall be included in
    12  // all copies or substantial portions of the Software.
    13  //
    14  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    15  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    16  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    17  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    18  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    19  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    20  // THE SOFTWARE.
    21  
    22  package internal
    23  
    24  // All code in this file is private to the package.
    25  
    26  import (
    27  	"context"
    28  	"errors"
    29  	"fmt"
    30  	"sync"
    31  	"time"
    32  
    33  	"github.com/opentracing/opentracing-go"
    34  	"github.com/pborman/uuid"
    35  	"github.com/uber-go/tally"
    36  	"go.uber.org/zap"
    37  
    38  	"go.uber.org/cadence/.gen/go/cadence/workflowserviceclient"
    39  	s "go.uber.org/cadence/.gen/go/shared"
    40  	"go.uber.org/cadence/internal/common"
    41  	"go.uber.org/cadence/internal/common/backoff"
    42  	"go.uber.org/cadence/internal/common/metrics"
    43  	"go.uber.org/cadence/internal/common/serializer"
    44  )
    45  
    46  const (
    47  	pollTaskServiceTimeOut = 150 * time.Second // Server long poll is 2 * Minutes + delta
    48  
    49  	stickyDecisionScheduleToStartTimeoutSeconds = 5
    50  
    51  	ratioToForceCompleteDecisionTaskComplete = 0.8
    52  	serviceBusy                              = "serviceBusy"
    53  )
    54  
    55  type (
    56  	// taskPoller interface to poll and process for task
    57  	taskPoller interface {
    58  		// PollTask polls for one new task
    59  		PollTask() (interface{}, error)
    60  		// ProcessTask processes a task
    61  		ProcessTask(interface{}) error
    62  	}
    63  
    64  	// basePoller is the base class for all poller implementations
    65  	basePoller struct {
    66  		shutdownC <-chan struct{}
    67  	}
    68  
    69  	// workflowTaskPoller implements polling/processing a workflow task
    70  	workflowTaskPoller struct {
    71  		basePoller
    72  		domain       string
    73  		taskListName string
    74  		identity     string
    75  		service      workflowserviceclient.Interface
    76  		taskHandler  WorkflowTaskHandler
    77  		ldaTunnel    *locallyDispatchedActivityTunnel
    78  		metricsScope *metrics.TaggedScope
    79  		logger       *zap.Logger
    80  
    81  		stickyUUID                   string
    82  		disableStickyExecution       bool
    83  		StickyScheduleToStartTimeout time.Duration
    84  
    85  		pendingRegularPollCount int
    86  		pendingStickyPollCount  int
    87  		stickyBacklog           int64
    88  		requestLock             sync.Mutex
    89  		featureFlags            FeatureFlags
    90  	}
    91  
    92  	// activityTaskPoller implements polling/processing a workflow task
    93  	activityTaskPoller struct {
    94  		basePoller
    95  		domain              string
    96  		taskListName        string
    97  		identity            string
    98  		service             workflowserviceclient.Interface
    99  		taskHandler         ActivityTaskHandler
   100  		metricsScope        *metrics.TaggedScope
   101  		logger              *zap.Logger
   102  		activitiesPerSecond float64
   103  		featureFlags        FeatureFlags
   104  	}
   105  
   106  	// locallyDispatchedActivityTaskPoller implements polling/processing a locally dispatched activity task
   107  	locallyDispatchedActivityTaskPoller struct {
   108  		activityTaskPoller
   109  		ldaTunnel *locallyDispatchedActivityTunnel
   110  	}
   111  
   112  	historyIteratorImpl struct {
   113  		iteratorFunc   func(nextPageToken []byte) (*s.History, []byte, error)
   114  		execution      *s.WorkflowExecution
   115  		nextPageToken  []byte
   116  		domain         string
   117  		service        workflowserviceclient.Interface
   118  		metricsScope   tally.Scope
   119  		startedEventID int64
   120  		maxEventID     int64 // Equivalent to History Count
   121  		featureFlags   FeatureFlags
   122  	}
   123  
   124  	localActivityTaskPoller struct {
   125  		basePoller
   126  		handler      *localActivityTaskHandler
   127  		metricsScope tally.Scope
   128  		logger       *zap.Logger
   129  		laTunnel     *localActivityTunnel
   130  	}
   131  
   132  	localActivityTaskHandler struct {
   133  		userContext        context.Context
   134  		metricsScope       *metrics.TaggedScope
   135  		logger             *zap.Logger
   136  		dataConverter      DataConverter
   137  		contextPropagators []ContextPropagator
   138  		tracer             opentracing.Tracer
   139  	}
   140  
   141  	localActivityResult struct {
   142  		result  []byte
   143  		err     error // original error type, possibly an un-encoded user error
   144  		task    *localActivityTask
   145  		backoff time.Duration
   146  	}
   147  
   148  	localActivityTunnel struct {
   149  		taskCh   chan *localActivityTask
   150  		resultCh chan interface{}
   151  		stopCh   <-chan struct{}
   152  	}
   153  
   154  	locallyDispatchedActivityTunnel struct {
   155  		taskCh       chan *locallyDispatchedActivityTask
   156  		stopCh       <-chan struct{}
   157  		metricsScope *metrics.TaggedScope
   158  	}
   159  )
   160  
   161  func newLocalActivityTunnel(stopCh <-chan struct{}) *localActivityTunnel {
   162  	return &localActivityTunnel{
   163  		taskCh:   make(chan *localActivityTask, 1000),
   164  		resultCh: make(chan interface{}),
   165  		stopCh:   stopCh,
   166  	}
   167  }
   168  
   169  func (lat *localActivityTunnel) getTask() *localActivityTask {
   170  	select {
   171  	case task := <-lat.taskCh:
   172  		return task
   173  	case <-lat.stopCh:
   174  		return nil
   175  	}
   176  }
   177  
   178  func (lat *localActivityTunnel) sendTask(task *localActivityTask) bool {
   179  	select {
   180  	case lat.taskCh <- task:
   181  		return true
   182  	case <-lat.stopCh:
   183  		return false
   184  	}
   185  }
   186  
   187  func newLocallyDispatchedActivityTunnel(stopCh <-chan struct{}) *locallyDispatchedActivityTunnel {
   188  	return &locallyDispatchedActivityTunnel{
   189  		taskCh: make(chan *locallyDispatchedActivityTask),
   190  		stopCh: stopCh,
   191  	}
   192  }
   193  
   194  func (ldat *locallyDispatchedActivityTunnel) getTask() *locallyDispatchedActivityTask {
   195  	var task *locallyDispatchedActivityTask
   196  	select {
   197  	case task = <-ldat.taskCh:
   198  	case <-ldat.stopCh:
   199  		return nil
   200  	}
   201  
   202  	select {
   203  	case ready := <-task.readyCh:
   204  		if ready {
   205  			return task
   206  		} else {
   207  			return nil
   208  		}
   209  	case <-ldat.stopCh:
   210  		return nil
   211  	}
   212  }
   213  
   214  func (ldat *locallyDispatchedActivityTunnel) sendTask(task *locallyDispatchedActivityTask) bool {
   215  	select {
   216  	case ldat.taskCh <- task:
   217  		return true
   218  	default:
   219  		return false
   220  	}
   221  }
   222  
   223  func isClientSideError(err error) bool {
   224  	// If an activity execution exceeds deadline.
   225  	if err == context.DeadlineExceeded {
   226  		return true
   227  	}
   228  
   229  	return false
   230  }
   231  
   232  // shuttingDown returns true if worker is shutting down right now
   233  func (bp *basePoller) shuttingDown() bool {
   234  	select {
   235  	case <-bp.shutdownC:
   236  		return true
   237  	default:
   238  		return false
   239  	}
   240  }
   241  
   242  // doPoll runs the given pollFunc in a separate go routine. Returns when either of the conditions are met:
   243  // - poll succeeds, poll fails or worker is shutting down
   244  func (bp *basePoller) doPoll(
   245  	featureFlags FeatureFlags,
   246  	pollFunc func(ctx context.Context) (interface{}, error),
   247  ) (interface{}, error) {
   248  	if bp.shuttingDown() {
   249  		return nil, errShutdown
   250  	}
   251  
   252  	var err error
   253  	var result interface{}
   254  
   255  	doneC := make(chan struct{})
   256  	ctx, cancel, _ := newChannelContext(context.Background(), featureFlags, chanTimeout(pollTaskServiceTimeOut))
   257  
   258  	go func() {
   259  		result, err = pollFunc(ctx)
   260  		cancel()
   261  		close(doneC)
   262  	}()
   263  
   264  	select {
   265  	case <-doneC:
   266  		return result, err
   267  	case <-bp.shutdownC:
   268  		cancel()
   269  		return nil, errShutdown
   270  	}
   271  }
   272  
   273  // newWorkflowTaskPoller creates a new workflow task poller which must have a one to one relationship to workflow worker
   274  func newWorkflowTaskPoller(
   275  	taskHandler WorkflowTaskHandler,
   276  	ldaTunnel *locallyDispatchedActivityTunnel,
   277  	service workflowserviceclient.Interface,
   278  	domain string,
   279  	params workerExecutionParameters,
   280  ) *workflowTaskPoller {
   281  	return &workflowTaskPoller{
   282  		basePoller:                   basePoller{shutdownC: params.WorkerStopChannel},
   283  		service:                      service,
   284  		domain:                       domain,
   285  		taskListName:                 params.TaskList,
   286  		identity:                     params.Identity,
   287  		taskHandler:                  taskHandler,
   288  		ldaTunnel:                    ldaTunnel,
   289  		metricsScope:                 metrics.NewTaggedScope(params.MetricsScope),
   290  		logger:                       params.Logger,
   291  		stickyUUID:                   uuid.New(),
   292  		disableStickyExecution:       params.DisableStickyExecution,
   293  		StickyScheduleToStartTimeout: params.StickyScheduleToStartTimeout,
   294  		featureFlags:                 params.FeatureFlags,
   295  	}
   296  }
   297  
   298  // PollTask polls a new task
   299  func (wtp *workflowTaskPoller) PollTask() (interface{}, error) {
   300  	// Get the task.
   301  	workflowTask, err := wtp.doPoll(wtp.featureFlags, wtp.poll)
   302  	if err != nil {
   303  		return nil, err
   304  	}
   305  
   306  	return workflowTask, nil
   307  }
   308  
   309  // ProcessTask processes a task which could be workflow task or local activity result
   310  func (wtp *workflowTaskPoller) ProcessTask(task interface{}) error {
   311  	if wtp.shuttingDown() {
   312  		return errShutdown
   313  	}
   314  
   315  	switch task.(type) {
   316  	case *workflowTask:
   317  		return wtp.processWorkflowTask(task.(*workflowTask))
   318  	case *resetStickinessTask:
   319  		return wtp.processResetStickinessTask(task.(*resetStickinessTask))
   320  	default:
   321  		panic("unknown task type.")
   322  	}
   323  }
   324  
   325  func (wtp *workflowTaskPoller) processWorkflowTask(task *workflowTask) error {
   326  	if task.task == nil {
   327  		// We didn't have task, poll might have timeout.
   328  		traceLog(func() {
   329  			wtp.logger.Debug("Workflow task unavailable")
   330  		})
   331  		return nil
   332  	}
   333  	doneCh := make(chan struct{})
   334  	laResultCh := make(chan *localActivityResult)
   335  	// close doneCh so local activity worker won't get blocked forever when trying to send back result to laResultCh.
   336  	defer close(doneCh)
   337  
   338  	for {
   339  		var response *s.RespondDecisionTaskCompletedResponse
   340  		startTime := time.Now()
   341  		task.doneCh = doneCh
   342  		task.laResultCh = laResultCh
   343  		// Process the task.
   344  		completedRequest, err := wtp.taskHandler.ProcessWorkflowTask(
   345  			task,
   346  			func(response interface{}, startTime time.Time) (*workflowTask, error) {
   347  				wtp.logger.Debug("Force RespondDecisionTaskCompleted.", zap.Int64("TaskStartedEventID", task.task.GetStartedEventId()))
   348  				wtp.metricsScope.Counter(metrics.DecisionTaskForceCompleted).Inc(1)
   349  				heartbeatResponse, err := wtp.RespondTaskCompletedWithMetrics(response, nil, task.task, startTime)
   350  				if err != nil {
   351  					return nil, err
   352  				}
   353  				if heartbeatResponse == nil || heartbeatResponse.DecisionTask == nil {
   354  					return nil, nil
   355  				}
   356  				task := wtp.toWorkflowTask(heartbeatResponse.DecisionTask)
   357  				task.doneCh = doneCh
   358  				task.laResultCh = laResultCh
   359  				return task, nil
   360  			},
   361  		)
   362  		if completedRequest == nil && err == nil {
   363  			return nil
   364  		}
   365  		if _, ok := err.(decisionHeartbeatError); ok {
   366  			return err
   367  		}
   368  		response, err = wtp.RespondTaskCompletedWithMetrics(completedRequest, err, task.task, startTime)
   369  		if err != nil {
   370  			return err
   371  		}
   372  
   373  		if response == nil || response.DecisionTask == nil {
   374  			return nil
   375  		}
   376  
   377  		// we are getting new decision task, so reset the workflowTask and continue process the new one
   378  		task = wtp.toWorkflowTask(response.DecisionTask)
   379  	}
   380  }
   381  
   382  func (wtp *workflowTaskPoller) processResetStickinessTask(rst *resetStickinessTask) error {
   383  	tchCtx, cancel, opt := newChannelContext(context.Background(), wtp.featureFlags)
   384  	defer cancel()
   385  	wtp.metricsScope.Counter(metrics.StickyCacheEvict).Inc(1)
   386  	if _, err := wtp.service.ResetStickyTaskList(tchCtx, rst.task, opt...); err != nil {
   387  		wtp.logger.Warn("ResetStickyTaskList failed",
   388  			zap.String(tagWorkflowID, rst.task.Execution.GetWorkflowId()),
   389  			zap.String(tagRunID, rst.task.Execution.GetRunId()),
   390  			zap.Error(err))
   391  		return err
   392  	}
   393  
   394  	return nil
   395  }
   396  
   397  func (wtp *workflowTaskPoller) RespondTaskCompletedWithMetrics(completedRequest interface{}, taskErr error, task *s.PollForDecisionTaskResponse, startTime time.Time) (response *s.RespondDecisionTaskCompletedResponse, err error) {
   398  
   399  	metricsScope := wtp.metricsScope.GetTaggedScope(tagWorkflowType, task.WorkflowType.GetName())
   400  	if taskErr != nil {
   401  		metricsScope.Counter(metrics.DecisionExecutionFailedCounter).Inc(1)
   402  		wtp.logger.Warn("Failed to process decision task.",
   403  			zap.String(tagWorkflowType, task.WorkflowType.GetName()),
   404  			zap.String(tagWorkflowID, task.WorkflowExecution.GetWorkflowId()),
   405  			zap.String(tagRunID, task.WorkflowExecution.GetRunId()),
   406  			zap.Error(taskErr))
   407  		// convert err to DecisionTaskFailed
   408  		completedRequest = errorToFailDecisionTask(task.TaskToken, taskErr, wtp.identity)
   409  	} else {
   410  		metricsScope.Counter(metrics.DecisionTaskCompletedCounter).Inc(1)
   411  	}
   412  
   413  	metricsScope.Timer(metrics.DecisionExecutionLatency).Record(time.Now().Sub(startTime))
   414  
   415  	responseStartTime := time.Now()
   416  	if response, err = wtp.RespondTaskCompleted(completedRequest, task); err != nil {
   417  		metricsScope.Counter(metrics.DecisionResponseFailedCounter).Inc(1)
   418  		return
   419  	}
   420  	metricsScope.Timer(metrics.DecisionResponseLatency).Record(time.Now().Sub(responseStartTime))
   421  
   422  	return
   423  }
   424  
   425  func (wtp *workflowTaskPoller) RespondTaskCompleted(completedRequest interface{}, task *s.PollForDecisionTaskResponse) (response *s.RespondDecisionTaskCompletedResponse, err error) {
   426  	ctx := context.Background()
   427  	// Respond task completion.
   428  	err = backoff.Retry(ctx,
   429  		func() error {
   430  			tchCtx, cancel, opt := newChannelContext(ctx, wtp.featureFlags)
   431  			defer cancel()
   432  			var err1 error
   433  			switch request := completedRequest.(type) {
   434  			case *s.RespondDecisionTaskFailedRequest:
   435  				// Only fail decision on first attempt, subsequent failure on the same decision task will timeout.
   436  				// This is to avoid spin on the failed decision task. Checking Attempt not nil for older server.
   437  				if task.Attempt != nil && task.GetAttempt() == 0 {
   438  					err1 = wtp.service.RespondDecisionTaskFailed(tchCtx, request, opt...)
   439  					if err1 != nil {
   440  						traceLog(func() {
   441  							wtp.logger.Debug("RespondDecisionTaskFailed failed.", zap.Error(err1))
   442  						})
   443  					}
   444  				}
   445  			case *s.RespondDecisionTaskCompletedRequest:
   446  				if request.StickyAttributes == nil && !wtp.disableStickyExecution {
   447  					request.StickyAttributes = &s.StickyExecutionAttributes{
   448  						WorkerTaskList:                &s.TaskList{Name: common.StringPtr(getWorkerTaskList(wtp.stickyUUID))},
   449  						ScheduleToStartTimeoutSeconds: common.Int32Ptr(common.Int32Ceil(wtp.StickyScheduleToStartTimeout.Seconds())),
   450  					}
   451  				} else {
   452  					request.ReturnNewDecisionTask = common.BoolPtr(false)
   453  				}
   454  				var activityTasks []*locallyDispatchedActivityTask
   455  				if wtp.ldaTunnel != nil {
   456  					for _, decision := range request.Decisions {
   457  						attr := decision.ScheduleActivityTaskDecisionAttributes
   458  						if attr != nil && wtp.taskListName == attr.TaskList.GetName() {
   459  							// assume the activity type is in registry otherwise the activity would be failed and retried from server
   460  							activityTask := &locallyDispatchedActivityTask{
   461  								readyCh:                       make(chan bool, 1),
   462  								ActivityId:                    attr.ActivityId,
   463  								ActivityType:                  attr.ActivityType,
   464  								Input:                         attr.Input,
   465  								Header:                        attr.Header,
   466  								WorkflowDomain:                common.StringPtr(wtp.domain),
   467  								ScheduleToCloseTimeoutSeconds: attr.ScheduleToCloseTimeoutSeconds,
   468  								StartToCloseTimeoutSeconds:    attr.StartToCloseTimeoutSeconds,
   469  								HeartbeatTimeoutSeconds:       attr.HeartbeatTimeoutSeconds,
   470  								WorkflowExecution:             task.WorkflowExecution,
   471  								WorkflowType:                  task.WorkflowType,
   472  							}
   473  							if wtp.ldaTunnel.sendTask(activityTask) {
   474  								wtp.metricsScope.Counter(metrics.ActivityLocalDispatchSucceedCounter).Inc(1)
   475  								decision.ScheduleActivityTaskDecisionAttributes.RequestLocalDispatch = common.BoolPtr(true)
   476  								activityTasks = append(activityTasks, activityTask)
   477  							} else {
   478  								// all pollers are busy - no room to optimize
   479  								wtp.metricsScope.Counter(metrics.ActivityLocalDispatchFailedCounter).Inc(1)
   480  							}
   481  						}
   482  					}
   483  				}
   484  				defer func() {
   485  					for _, at := range activityTasks {
   486  						started := false
   487  						if response != nil && err1 == nil {
   488  							if adl, ok := response.ActivitiesToDispatchLocally[*at.ActivityId]; ok {
   489  								at.ScheduledTimestamp = adl.ScheduledTimestamp
   490  								at.StartedTimestamp = adl.StartedTimestamp
   491  								at.ScheduledTimestampOfThisAttempt = adl.ScheduledTimestampOfThisAttempt
   492  								at.TaskToken = adl.TaskToken
   493  								started = true
   494  							}
   495  						}
   496  						at.readyCh <- started
   497  					}
   498  				}()
   499  				response, err1 = wtp.service.RespondDecisionTaskCompleted(tchCtx, request, opt...)
   500  				if err1 != nil {
   501  					traceLog(func() {
   502  						wtp.logger.Debug("RespondDecisionTaskCompleted failed.", zap.Error(err1))
   503  					})
   504  				}
   505  
   506  			case *s.RespondQueryTaskCompletedRequest:
   507  				err1 = wtp.service.RespondQueryTaskCompleted(tchCtx, request, opt...)
   508  				if err1 != nil {
   509  					traceLog(func() {
   510  						wtp.logger.Debug("RespondQueryTaskCompleted failed.", zap.Error(err1))
   511  					})
   512  				}
   513  			default:
   514  				// should not happen
   515  				panic("unknown request type from ProcessWorkflowTask()")
   516  			}
   517  
   518  			return err1
   519  		}, createDynamicServiceRetryPolicy(ctx), isServiceTransientError)
   520  
   521  	return
   522  }
   523  
   524  func newLocalActivityPoller(params workerExecutionParameters, laTunnel *localActivityTunnel) *localActivityTaskPoller {
   525  	handler := &localActivityTaskHandler{
   526  		userContext:        params.UserContext,
   527  		metricsScope:       metrics.NewTaggedScope(params.MetricsScope),
   528  		logger:             params.Logger,
   529  		dataConverter:      params.DataConverter,
   530  		contextPropagators: params.ContextPropagators,
   531  		tracer:             params.Tracer,
   532  	}
   533  	return &localActivityTaskPoller{
   534  		basePoller:   basePoller{shutdownC: params.WorkerStopChannel},
   535  		handler:      handler,
   536  		metricsScope: params.MetricsScope,
   537  		logger:       params.Logger,
   538  		laTunnel:     laTunnel,
   539  	}
   540  }
   541  
   542  func (latp *localActivityTaskPoller) PollTask() (interface{}, error) {
   543  	return latp.laTunnel.getTask(), nil
   544  }
   545  
   546  func (latp *localActivityTaskPoller) ProcessTask(task interface{}) error {
   547  	if latp.shuttingDown() {
   548  		return errShutdown
   549  	}
   550  
   551  	result := latp.handler.executeLocalActivityTask(task.(*localActivityTask))
   552  	// We need to send back the local activity result to unblock workflowTaskPoller.processWorkflowTask() which is
   553  	// synchronously listening on the laResultCh. We also want to make sure we don't block here forever in case
   554  	// processWorkflowTask() already returns and nobody is receiving from laResultCh. We guarantee that doneCh is closed
   555  	// before returning from workflowTaskPoller.processWorkflowTask().
   556  	select {
   557  	case result.task.workflowTask.laResultCh <- result:
   558  		return nil
   559  	case <-result.task.workflowTask.doneCh:
   560  		// processWorkflowTask() already returns, just drop this local activity result.
   561  		return nil
   562  	}
   563  }
   564  
   565  func (lath *localActivityTaskHandler) executeLocalActivityTask(task *localActivityTask) (result *localActivityResult) {
   566  	workflowType := task.params.WorkflowInfo.WorkflowType.Name
   567  	activityType := task.params.ActivityType
   568  	metricsScope := getMetricsScopeForLocalActivity(lath.metricsScope, workflowType, activityType)
   569  
   570  	// keep in sync with regular activity logger tags
   571  	logger := lath.logger.With(
   572  		zap.String(tagLocalActivityID, task.activityID),
   573  		zap.String(tagLocalActivityType, activityType),
   574  		zap.String(tagWorkflowType, workflowType),
   575  		zap.String(tagWorkflowID, task.params.WorkflowInfo.WorkflowExecution.ID),
   576  		zap.String(tagRunID, task.params.WorkflowInfo.WorkflowExecution.RunID))
   577  
   578  	metricsScope.Counter(metrics.LocalActivityTotalCounter).Inc(1)
   579  
   580  	ae := activityExecutor{name: activityType, fn: task.params.ActivityFn}
   581  
   582  	rootCtx := lath.userContext
   583  	if rootCtx == nil {
   584  		rootCtx = context.Background()
   585  	}
   586  
   587  	workflowTypeLocal := task.params.WorkflowInfo.WorkflowType
   588  
   589  	ctx := context.WithValue(rootCtx, activityEnvContextKey, &activityEnvironment{
   590  		workflowType:      &workflowTypeLocal,
   591  		workflowDomain:    task.params.WorkflowInfo.Domain,
   592  		taskList:          task.params.WorkflowInfo.TaskListName,
   593  		activityType:      ActivityType{Name: activityType},
   594  		activityID:        fmt.Sprintf("%v", task.activityID),
   595  		workflowExecution: task.params.WorkflowInfo.WorkflowExecution,
   596  		logger:            logger,
   597  		metricsScope:      metricsScope,
   598  		isLocalActivity:   true,
   599  		dataConverter:     lath.dataConverter,
   600  		attempt:           task.attempt,
   601  	})
   602  
   603  	// propagate context information into the local activity activity context from the headers
   604  	for _, ctxProp := range lath.contextPropagators {
   605  		var err error
   606  		if ctx, err = ctxProp.Extract(ctx, NewHeaderReader(task.header)); err != nil {
   607  			result = &localActivityResult{
   608  				task:   task,
   609  				result: nil,
   610  				err:    fmt.Errorf("unable to propagate context %v", err),
   611  			}
   612  			return result
   613  		}
   614  	}
   615  
   616  	// count all failures beyond this point, as they come from the activity itself
   617  	defer func() {
   618  		if result.err != nil {
   619  			metricsScope.Counter(metrics.LocalActivityFailedCounter).Inc(1)
   620  		}
   621  	}()
   622  
   623  	timeoutDuration := time.Duration(task.params.ScheduleToCloseTimeoutSeconds) * time.Second
   624  	deadline := time.Now().Add(timeoutDuration)
   625  	if task.attempt > 0 && !task.expireTime.IsZero() && task.expireTime.Before(deadline) {
   626  		// this is attempt and expire time is before SCHEDULE_TO_CLOSE timeout
   627  		deadline = task.expireTime
   628  	}
   629  
   630  	ctx, cancel := context.WithDeadline(ctx, deadline)
   631  	defer cancel()
   632  
   633  	task.Lock()
   634  	if task.canceled {
   635  		task.Unlock()
   636  		return &localActivityResult{err: ErrCanceled, task: task}
   637  	}
   638  	task.cancelFunc = cancel
   639  	task.Unlock()
   640  
   641  	var laResult []byte
   642  	var err error
   643  	doneCh := make(chan struct{})
   644  	go func(ch chan struct{}) {
   645  		defer close(ch)
   646  
   647  		defer func() {
   648  			if p := recover(); p != nil {
   649  				topLine := fmt.Sprintf("local activity for %s [panic]:", activityType)
   650  				st := getStackTraceRaw(topLine, 7, 0)
   651  				logger.Error("LocalActivity panic.",
   652  					zap.String(tagPanicError, fmt.Sprintf("%v", p)),
   653  					zap.String(tagPanicStack, st))
   654  				metricsScope.Counter(metrics.LocalActivityPanicCounter).Inc(1)
   655  				err = newPanicError(p, st)
   656  			}
   657  		}()
   658  
   659  		laStartTime := time.Now()
   660  		ctx, span := createOpenTracingActivitySpan(ctx, lath.tracer, time.Now(), task.params.ActivityType, task.params.WorkflowInfo.WorkflowExecution.ID, task.params.WorkflowInfo.WorkflowExecution.RunID)
   661  		defer span.Finish()
   662  		laResult, err = ae.ExecuteWithActualArgs(ctx, task.params.InputArgs)
   663  		executionLatency := time.Now().Sub(laStartTime)
   664  		metricsScope.Timer(metrics.LocalActivityExecutionLatency).Record(executionLatency)
   665  		if executionLatency > timeoutDuration {
   666  			// If local activity takes longer than expected timeout, the context would already be DeadlineExceeded and
   667  			// the result would be discarded. Print a warning in this case.
   668  			logger.Warn("LocalActivity takes too long to complete.",
   669  				zap.Int32("ScheduleToCloseTimeoutSeconds", task.params.ScheduleToCloseTimeoutSeconds),
   670  				zap.Duration("ActualExecutionDuration", executionLatency))
   671  		}
   672  	}(doneCh)
   673  
   674  Wait_Result:
   675  	select {
   676  	case <-ctx.Done():
   677  		select {
   678  		case <-doneCh:
   679  			// double check if result is ready.
   680  			break Wait_Result
   681  		default:
   682  		}
   683  
   684  		// context is done
   685  		if ctx.Err() == context.Canceled {
   686  			metricsScope.Counter(metrics.LocalActivityCanceledCounter).Inc(1)
   687  			return &localActivityResult{err: ErrCanceled, task: task}
   688  		} else if ctx.Err() == context.DeadlineExceeded {
   689  			metricsScope.Counter(metrics.LocalActivityTimeoutCounter).Inc(1)
   690  			return &localActivityResult{err: ErrDeadlineExceeded, task: task}
   691  		} else {
   692  			// should not happen
   693  			return &localActivityResult{err: NewCustomError("unexpected context done"), task: task}
   694  		}
   695  	case <-doneCh:
   696  		// local activity completed
   697  	}
   698  
   699  	return &localActivityResult{result: laResult, err: err, task: task}
   700  }
   701  
   702  func (wtp *workflowTaskPoller) release(kind s.TaskListKind) {
   703  	if wtp.disableStickyExecution {
   704  		return
   705  	}
   706  
   707  	wtp.requestLock.Lock()
   708  	if kind == s.TaskListKindSticky {
   709  		wtp.pendingStickyPollCount--
   710  	} else {
   711  		wtp.pendingRegularPollCount--
   712  	}
   713  	wtp.requestLock.Unlock()
   714  }
   715  
   716  func (wtp *workflowTaskPoller) updateBacklog(taskListKind s.TaskListKind, backlogCountHint int64) {
   717  	if taskListKind == s.TaskListKindNormal || wtp.disableStickyExecution {
   718  		// we only care about sticky backlog for now.
   719  		return
   720  	}
   721  	wtp.requestLock.Lock()
   722  	wtp.stickyBacklog = backlogCountHint
   723  	wtp.requestLock.Unlock()
   724  }
   725  
   726  // getNextPollRequest returns appropriate next poll request based on poller configuration.
   727  // Simple rules:
   728  //  1. if sticky execution is disabled, always poll for regular task list
   729  //  2. otherwise:
   730  //     2.1) if sticky task list has backlog, always prefer to process sticky task first
   731  //     2.2) poll from the task list that has less pending requests (prefer sticky when they are the same).
   732  //
   733  // TODO: make this more smart to auto adjust based on poll latency
   734  func (wtp *workflowTaskPoller) getNextPollRequest() (request *s.PollForDecisionTaskRequest) {
   735  	taskListName := wtp.taskListName
   736  	taskListKind := s.TaskListKindNormal
   737  	if !wtp.disableStickyExecution {
   738  		wtp.requestLock.Lock()
   739  		if wtp.stickyBacklog > 0 || wtp.pendingStickyPollCount <= wtp.pendingRegularPollCount {
   740  			wtp.pendingStickyPollCount++
   741  			taskListName = getWorkerTaskList(wtp.stickyUUID)
   742  			taskListKind = s.TaskListKindSticky
   743  		} else {
   744  			wtp.pendingRegularPollCount++
   745  		}
   746  		wtp.requestLock.Unlock()
   747  	}
   748  
   749  	taskList := s.TaskList{
   750  		Name: common.StringPtr(taskListName),
   751  		Kind: common.TaskListKindPtr(taskListKind),
   752  	}
   753  	return &s.PollForDecisionTaskRequest{
   754  		Domain:         common.StringPtr(wtp.domain),
   755  		TaskList:       common.TaskListPtr(taskList),
   756  		Identity:       common.StringPtr(wtp.identity),
   757  		BinaryChecksum: common.StringPtr(getBinaryChecksum()),
   758  	}
   759  }
   760  
   761  // Poll for a single workflow task from the service
   762  func (wtp *workflowTaskPoller) poll(ctx context.Context) (interface{}, error) {
   763  	startTime := time.Now()
   764  	wtp.metricsScope.Counter(metrics.DecisionPollCounter).Inc(1)
   765  
   766  	traceLog(func() {
   767  		wtp.logger.Debug("workflowTaskPoller::Poll")
   768  	})
   769  
   770  	request := wtp.getNextPollRequest()
   771  	defer wtp.release(request.TaskList.GetKind())
   772  
   773  	response, err := wtp.service.PollForDecisionTask(ctx, request, getYarpcCallOptions(wtp.featureFlags)...)
   774  	if err != nil {
   775  		retryable := isServiceTransientError(err)
   776  
   777  		if retryable {
   778  			if target := (*s.ServiceBusyError)(nil); errors.As(err, &target) {
   779  				wtp.metricsScope.Tagged(map[string]string{causeTag: serviceBusy}).Counter(metrics.DecisionPollTransientFailedCounter).Inc(1)
   780  			} else {
   781  				wtp.metricsScope.Counter(metrics.DecisionPollTransientFailedCounter).Inc(1)
   782  			}
   783  		} else {
   784  			wtp.metricsScope.Counter(metrics.DecisionPollFailedCounter).Inc(1)
   785  		}
   786  		wtp.updateBacklog(request.TaskList.GetKind(), 0)
   787  
   788  		// pause for the retry delay if present.
   789  		// failures also have an exponential backoff, implemented at a higher level,
   790  		// but this ensures a minimum is respected.
   791  		retryAfter := backoff.ErrRetryableAfter(err)
   792  		if retryAfter > 0 {
   793  			t := time.NewTimer(retryAfter)
   794  			select {
   795  			case <-ctx.Done():
   796  				t.Stop()
   797  			case <-t.C:
   798  			}
   799  		}
   800  
   801  		return nil, err
   802  	}
   803  
   804  	if response == nil || len(response.TaskToken) == 0 {
   805  		wtp.metricsScope.Counter(metrics.DecisionPollNoTaskCounter).Inc(1)
   806  		wtp.updateBacklog(request.TaskList.GetKind(), 0)
   807  		return &workflowTask{}, nil
   808  	}
   809  
   810  	wtp.updateBacklog(request.TaskList.GetKind(), response.GetBacklogCountHint())
   811  
   812  	task := wtp.toWorkflowTask(response)
   813  	traceLog(func() {
   814  		var firstEventID int64 = -1
   815  		if response.History != nil && len(response.History.Events) > 0 {
   816  			firstEventID = response.History.Events[0].GetEventId()
   817  		}
   818  		wtp.logger.Debug("workflowTaskPoller::Poll Succeed",
   819  			zap.Int64("StartedEventID", response.GetStartedEventId()),
   820  			zap.Int64("Attempt", response.GetAttempt()),
   821  			zap.Int64("FirstEventID", firstEventID),
   822  			zap.Bool("IsQueryTask", response.Query != nil))
   823  	})
   824  
   825  	metricsScope := wtp.metricsScope.GetTaggedScope(tagWorkflowType, response.WorkflowType.GetName())
   826  	metricsScope.Counter(metrics.DecisionPollSucceedCounter).Inc(1)
   827  	metricsScope.Timer(metrics.DecisionPollLatency).Record(time.Now().Sub(startTime))
   828  
   829  	scheduledToStartLatency := time.Duration(response.GetStartedTimestamp() - response.GetScheduledTimestamp())
   830  	metricsScope.Timer(metrics.DecisionScheduledToStartLatency).Record(scheduledToStartLatency)
   831  	return task, nil
   832  }
   833  
   834  func (wtp *workflowTaskPoller) toWorkflowTask(response *s.PollForDecisionTaskResponse) *workflowTask {
   835  	startEventID := response.GetStartedEventId()
   836  	nextEventID := response.GetNextEventId()
   837  	if nextEventID != 0 && startEventID != 0 {
   838  		// first case is for normal decision, the second is for transient decision
   839  		if nextEventID-1 != startEventID && nextEventID+1 != startEventID {
   840  			wtp.logger.Warn("Invalid PollForDecisionTaskResponse, nextEventID doesn't match startedEventID",
   841  				zap.Int64("StartedEventID", startEventID),
   842  				zap.Int64("NextEventID", nextEventID),
   843  			)
   844  			wtp.metricsScope.Counter(metrics.DecisionPollInvalidCounter).Inc(1)
   845  		} else {
   846  			// in transient decision case, set nextEventID to be one more than startEventID in case
   847  			// we can need to use the field to truncate history for decision task (check comments in newGetHistoryPageFunc)
   848  			// this is safe as
   849  			// - currently we are not using nextEventID for decision task
   850  			// - for query task, startEventID is not assigned, so we won't reach here.
   851  			nextEventID = startEventID + 1
   852  		}
   853  	}
   854  	historyIterator := &historyIteratorImpl{
   855  		nextPageToken:  response.NextPageToken,
   856  		execution:      response.WorkflowExecution,
   857  		domain:         wtp.domain,
   858  		service:        wtp.service,
   859  		metricsScope:   wtp.metricsScope,
   860  		startedEventID: startEventID,
   861  		maxEventID:     nextEventID - 1,
   862  		featureFlags:   wtp.featureFlags,
   863  	}
   864  	task := &workflowTask{
   865  		task:            response,
   866  		historyIterator: historyIterator,
   867  	}
   868  	return task
   869  }
   870  
   871  func (h *historyIteratorImpl) GetNextPage() (*s.History, error) {
   872  	if h.iteratorFunc == nil {
   873  		h.iteratorFunc = newGetHistoryPageFunc(
   874  			context.Background(),
   875  			h.service,
   876  			h.domain,
   877  			h.execution,
   878  			h.startedEventID,
   879  			h.maxEventID,
   880  			h.metricsScope,
   881  			h.featureFlags)
   882  	}
   883  
   884  	history, token, err := h.iteratorFunc(h.nextPageToken)
   885  	if err != nil {
   886  		return nil, err
   887  	}
   888  	h.nextPageToken = token
   889  	return history, nil
   890  }
   891  
   892  func (h *historyIteratorImpl) Reset() {
   893  	h.nextPageToken = nil
   894  }
   895  
   896  func (h *historyIteratorImpl) HasNextPage() bool {
   897  	return h.nextPageToken != nil
   898  }
   899  
   900  // GetHistoryCount returns History Event Count of current history (aka maxEventID)
   901  func (h *historyIteratorImpl) GetHistoryCount() int64 {
   902  	return h.maxEventID
   903  }
   904  
   905  func newGetHistoryPageFunc(
   906  	ctx context.Context,
   907  	service workflowserviceclient.Interface,
   908  	domain string,
   909  	execution *s.WorkflowExecution,
   910  	atDecisionTaskCompletedEventID int64,
   911  	maxEventID int64,
   912  	metricsScope tally.Scope,
   913  	featureFlags FeatureFlags,
   914  ) func(nextPageToken []byte) (*s.History, []byte, error) {
   915  	return func(nextPageToken []byte) (*s.History, []byte, error) {
   916  		metricsScope.Counter(metrics.WorkflowGetHistoryCounter).Inc(1)
   917  		startTime := time.Now()
   918  		var resp *s.GetWorkflowExecutionHistoryResponse
   919  		err := backoff.Retry(ctx,
   920  			func() error {
   921  				tchCtx, cancel, opt := newChannelContext(ctx, featureFlags)
   922  				defer cancel()
   923  
   924  				var err1 error
   925  				resp, err1 = service.GetWorkflowExecutionHistory(tchCtx, &s.GetWorkflowExecutionHistoryRequest{
   926  					Domain:        common.StringPtr(domain),
   927  					Execution:     execution,
   928  					NextPageToken: nextPageToken,
   929  				}, opt...)
   930  				return err1
   931  			}, createDynamicServiceRetryPolicy(ctx), isServiceTransientError)
   932  		if err != nil {
   933  			metricsScope.Counter(metrics.WorkflowGetHistoryFailedCounter).Inc(1)
   934  			return nil, nil, err
   935  		}
   936  
   937  		metricsScope.Counter(metrics.WorkflowGetHistorySucceedCounter).Inc(1)
   938  		metricsScope.Timer(metrics.WorkflowGetHistoryLatency).Record(time.Now().Sub(startTime))
   939  
   940  		var h *s.History
   941  
   942  		if resp.RawHistory != nil {
   943  			var err1 error
   944  			h, err1 = serializer.DeserializeBlobDataToHistoryEvents(resp.RawHistory, s.HistoryEventFilterTypeAllEvent)
   945  			if err1 != nil {
   946  				return nil, nil, err1
   947  			}
   948  		} else {
   949  			h = resp.History
   950  		}
   951  
   952  		// TODO: is this check valid/useful? atDecisionTaskCompletedEventID is startedEventID in pollForDecisionTaskResponse and
   953  		// - For decision tasks, since there's only one inflight decision task, there won't be any event after startEventID.
   954  		//   Those events will be buffered. If there're too many buffer events, the current decision will be failed and events passed
   955  		//   startEventID may be returned. In that case, the last event after truncation is still decision task started event not completed.
   956  		// - For query tasks startEventID is not assigned so this check is never executed.
   957  		if shouldTruncateHistory(h, atDecisionTaskCompletedEventID) {
   958  			first := h.Events[0].GetEventId() // eventIds start from 1
   959  			h.Events = h.Events[:atDecisionTaskCompletedEventID-first+1]
   960  			if h.Events[len(h.Events)-1].GetEventType() != s.EventTypeDecisionTaskCompleted {
   961  				return nil, nil, fmt.Errorf("newGetHistoryPageFunc: atDecisionTaskCompletedEventID(%v) "+
   962  					"points to event that is not DecisionTaskCompleted", atDecisionTaskCompletedEventID)
   963  			}
   964  			return h, nil, nil
   965  		}
   966  
   967  		// TODO: Apply the check to decision tasks (remove the last condition)
   968  		// after validating maxEventID always equal to atDecisionTaskCompletedEventID (startedEventID).
   969  		// For now only apply to query task to be safe.
   970  		if shouldTruncateHistory(h, maxEventID) && isQueryTask(atDecisionTaskCompletedEventID) {
   971  			first := h.Events[0].GetEventId()
   972  			h.Events = h.Events[:maxEventID-first+1]
   973  			return h, nil, nil
   974  		}
   975  
   976  		return h, resp.NextPageToken, nil
   977  	}
   978  }
   979  
   980  func shouldTruncateHistory(h *s.History, maxEventID int64) bool {
   981  	size := len(h.Events)
   982  	return size > 0 && maxEventID > 0 && h.Events[size-1].GetEventId() > maxEventID
   983  }
   984  
   985  func isQueryTask(atDecisionTaskCompletedEventID int64) bool {
   986  	return atDecisionTaskCompletedEventID == 0
   987  }
   988  
   989  func newActivityTaskPoller(taskHandler ActivityTaskHandler, service workflowserviceclient.Interface,
   990  	domain string, params workerExecutionParameters) *activityTaskPoller {
   991  
   992  	activityTaskPoller := &activityTaskPoller{
   993  		basePoller:          basePoller{shutdownC: params.WorkerStopChannel},
   994  		taskHandler:         taskHandler,
   995  		service:             service,
   996  		domain:              domain,
   997  		taskListName:        params.TaskList,
   998  		identity:            params.Identity,
   999  		logger:              params.Logger,
  1000  		metricsScope:        metrics.NewTaggedScope(params.MetricsScope),
  1001  		activitiesPerSecond: params.TaskListActivitiesPerSecond,
  1002  		featureFlags:        params.FeatureFlags,
  1003  	}
  1004  	return activityTaskPoller
  1005  }
  1006  
  1007  // Poll for a single activity task from the service
  1008  func (atp *activityTaskPoller) poll(ctx context.Context) (*s.PollForActivityTaskResponse, time.Time, error) {
  1009  
  1010  	atp.metricsScope.Counter(metrics.ActivityPollCounter).Inc(1)
  1011  	startTime := time.Now()
  1012  
  1013  	traceLog(func() {
  1014  		atp.logger.Debug("activityTaskPoller::Poll")
  1015  	})
  1016  	request := &s.PollForActivityTaskRequest{
  1017  		Domain:           common.StringPtr(atp.domain),
  1018  		TaskList:         common.TaskListPtr(s.TaskList{Name: common.StringPtr(atp.taskListName)}),
  1019  		Identity:         common.StringPtr(atp.identity),
  1020  		TaskListMetadata: &s.TaskListMetadata{MaxTasksPerSecond: &atp.activitiesPerSecond},
  1021  	}
  1022  	response, err := atp.service.PollForActivityTask(ctx, request, getYarpcCallOptions(atp.featureFlags)...)
  1023  
  1024  	if err != nil {
  1025  		retryable := isServiceTransientError(err)
  1026  		if retryable {
  1027  
  1028  			if target := (*s.ServiceBusyError)(nil); errors.As(err, &target) {
  1029  				atp.metricsScope.Tagged(map[string]string{causeTag: serviceBusy}).Counter(metrics.ActivityPollTransientFailedCounter).Inc(1)
  1030  			} else {
  1031  				atp.metricsScope.Counter(metrics.ActivityPollTransientFailedCounter).Inc(1)
  1032  			}
  1033  		} else {
  1034  			atp.metricsScope.Counter(metrics.ActivityPollFailedCounter).Inc(1)
  1035  		}
  1036  
  1037  		// pause for the retry delay if present.
  1038  		// failures also have an exponential backoff, implemented at a higher level,
  1039  		// but this ensures a minimum is respected.
  1040  		retryAfter := backoff.ErrRetryableAfter(err)
  1041  		if retryAfter > 0 {
  1042  			t := time.NewTimer(retryAfter)
  1043  			select {
  1044  			case <-ctx.Done():
  1045  				t.Stop()
  1046  			case <-t.C:
  1047  			}
  1048  		}
  1049  
  1050  		return nil, startTime, err
  1051  	}
  1052  	if response == nil || len(response.TaskToken) == 0 {
  1053  		atp.metricsScope.Counter(metrics.ActivityPollNoTaskCounter).Inc(1)
  1054  		return nil, startTime, nil
  1055  	}
  1056  
  1057  	return response, startTime, err
  1058  }
  1059  
  1060  type pollFunc func(ctx context.Context) (*s.PollForActivityTaskResponse, time.Time, error)
  1061  
  1062  func (atp *activityTaskPoller) pollWithMetricsFunc(
  1063  	pollFunc pollFunc) func(ctx context.Context) (interface{}, error) {
  1064  	return func(ctx context.Context) (interface{}, error) { return atp.pollWithMetrics(ctx, pollFunc) }
  1065  }
  1066  
  1067  func (atp *activityTaskPoller) pollWithMetrics(ctx context.Context,
  1068  	pollFunc func(ctx context.Context) (*s.PollForActivityTaskResponse, time.Time, error)) (interface{}, error) {
  1069  
  1070  	response, startTime, err := pollFunc(ctx)
  1071  	if err != nil {
  1072  		return nil, err
  1073  	}
  1074  	if response == nil || len(response.TaskToken) == 0 {
  1075  		return &activityTask{}, nil
  1076  	}
  1077  
  1078  	workflowType := response.WorkflowType.GetName()
  1079  	activityType := response.ActivityType.GetName()
  1080  	metricsScope := getMetricsScopeForActivity(atp.metricsScope, workflowType, activityType)
  1081  	metricsScope.Counter(metrics.ActivityPollSucceedCounter).Inc(1)
  1082  	metricsScope.Timer(metrics.ActivityPollLatency).Record(time.Now().Sub(startTime))
  1083  
  1084  	scheduledToStartLatency := time.Duration(response.GetStartedTimestamp() - response.GetScheduledTimestampOfThisAttempt())
  1085  	metricsScope.Timer(metrics.ActivityScheduledToStartLatency).Record(scheduledToStartLatency)
  1086  
  1087  	return &activityTask{task: response, pollStartTime: startTime}, nil
  1088  }
  1089  
  1090  // PollTask polls a new task
  1091  func (atp *activityTaskPoller) PollTask() (interface{}, error) {
  1092  	// Get the task.
  1093  	activityTask, err := atp.doPoll(atp.featureFlags, atp.pollWithMetricsFunc(atp.poll))
  1094  	if err != nil {
  1095  		return nil, err
  1096  	}
  1097  	return activityTask, nil
  1098  }
  1099  
  1100  // ProcessTask processes a new task
  1101  func (atp *activityTaskPoller) ProcessTask(task interface{}) error {
  1102  	if atp.shuttingDown() {
  1103  		return errShutdown
  1104  	}
  1105  
  1106  	activityTask := task.(*activityTask)
  1107  	if activityTask.task == nil {
  1108  		// We didn't have task, poll might have timeout.
  1109  		traceLog(func() {
  1110  			atp.logger.Debug("Activity task unavailable")
  1111  		})
  1112  		return nil
  1113  	}
  1114  
  1115  	workflowType := activityTask.task.WorkflowType.GetName()
  1116  	activityType := activityTask.task.ActivityType.GetName()
  1117  	metricsScope := getMetricsScopeForActivity(atp.metricsScope, workflowType, activityType)
  1118  
  1119  	executionStartTime := time.Now()
  1120  	// Process the activity task.
  1121  	request, err := atp.taskHandler.Execute(atp.taskListName, activityTask.task)
  1122  	if err != nil {
  1123  		metricsScope.Counter(metrics.ActivityExecutionFailedCounter).Inc(1)
  1124  		return err
  1125  	}
  1126  	metricsScope.Timer(metrics.ActivityExecutionLatency).Record(time.Now().Sub(executionStartTime))
  1127  
  1128  	if request == ErrActivityResultPending {
  1129  		return nil
  1130  	}
  1131  
  1132  	// if worker is shutting down, don't bother reporting activity completion
  1133  	if atp.shuttingDown() {
  1134  		return errShutdown
  1135  	}
  1136  
  1137  	responseStartTime := time.Now()
  1138  	reportErr := reportActivityComplete(context.Background(), atp.service, request, metricsScope, atp.featureFlags)
  1139  	if reportErr != nil {
  1140  		metricsScope.Counter(metrics.ActivityResponseFailedCounter).Inc(1)
  1141  		traceLog(func() {
  1142  			atp.logger.Debug("reportActivityComplete failed", zap.Error(reportErr))
  1143  		})
  1144  		return reportErr
  1145  	}
  1146  
  1147  	metricsScope.Timer(metrics.ActivityResponseLatency).Record(time.Now().Sub(responseStartTime))
  1148  	metricsScope.Timer(metrics.ActivityEndToEndLatency).Record(time.Now().Sub(activityTask.pollStartTime))
  1149  	return nil
  1150  }
  1151  
  1152  func newLocallyDispatchedActivityTaskPoller(taskHandler ActivityTaskHandler, service workflowserviceclient.Interface,
  1153  	domain string, params workerExecutionParameters) *locallyDispatchedActivityTaskPoller {
  1154  	locallyDispatchedActivityTaskPoller := &locallyDispatchedActivityTaskPoller{
  1155  		activityTaskPoller: *newActivityTaskPoller(taskHandler, service, domain, params),
  1156  		ldaTunnel:          newLocallyDispatchedActivityTunnel(params.WorkerStopChannel),
  1157  	}
  1158  	return locallyDispatchedActivityTaskPoller
  1159  }
  1160  
  1161  // PollTask polls a new task
  1162  func (atp *locallyDispatchedActivityTaskPoller) PollTask() (interface{}, error) {
  1163  	// Get the task.
  1164  	activityTask, err := atp.doPoll(atp.featureFlags, atp.pollWithMetricsFunc(atp.pollLocallyDispatchedActivity))
  1165  	if err != nil {
  1166  		return nil, err
  1167  	}
  1168  
  1169  	return activityTask, nil
  1170  }
  1171  
  1172  func (atp *locallyDispatchedActivityTaskPoller) pollLocallyDispatchedActivity(ctx context.Context) (*s.PollForActivityTaskResponse, time.Time, error) {
  1173  	task := atp.ldaTunnel.getTask()
  1174  	atp.metricsScope.Counter(metrics.LocallyDispatchedActivityPollCounter).Inc(1)
  1175  	// consider to remove the poll latency metric for local dispatch as unnecessary
  1176  	startTime := time.Now()
  1177  	if task == nil {
  1178  		atp.metricsScope.Counter(metrics.LocallyDispatchedActivityPollNoTaskCounter).Inc(1)
  1179  		return nil, startTime, nil
  1180  	}
  1181  	// to be backwards compatible, update total poll counter if optimization succeeded only
  1182  	atp.metricsScope.Counter(metrics.ActivityPollCounter).Inc(1)
  1183  	atp.metricsScope.Counter(metrics.LocallyDispatchedActivityPollSucceedCounter).Inc(1)
  1184  	response := &s.PollForActivityTaskResponse{}
  1185  	response.ActivityId = task.ActivityId
  1186  	response.ActivityType = task.ActivityType
  1187  	response.Header = task.Header
  1188  	response.Input = task.Input
  1189  	response.WorkflowExecution = task.WorkflowExecution
  1190  	response.ScheduledTimestampOfThisAttempt = task.ScheduledTimestampOfThisAttempt
  1191  	response.ScheduledTimestamp = task.ScheduledTimestamp
  1192  	response.ScheduleToCloseTimeoutSeconds = task.ScheduleToCloseTimeoutSeconds
  1193  	response.StartedTimestamp = task.StartedTimestamp
  1194  	response.StartToCloseTimeoutSeconds = task.StartToCloseTimeoutSeconds
  1195  	response.HeartbeatTimeoutSeconds = task.HeartbeatTimeoutSeconds
  1196  	response.TaskToken = task.TaskToken
  1197  	response.WorkflowType = task.WorkflowType
  1198  	response.WorkflowDomain = task.WorkflowDomain
  1199  	response.Attempt = common.Int32Ptr(0)
  1200  	return response, startTime, nil
  1201  }
  1202  
  1203  func reportActivityComplete(
  1204  	ctx context.Context,
  1205  	service workflowserviceclient.Interface,
  1206  	request interface{},
  1207  	metricsScope tally.Scope,
  1208  	featureFlags FeatureFlags,
  1209  ) error {
  1210  	if request == nil {
  1211  		// nothing to report
  1212  		return nil
  1213  	}
  1214  
  1215  	var reportErr error
  1216  	switch request := request.(type) {
  1217  	case *s.RespondActivityTaskCanceledRequest:
  1218  		reportErr = backoff.Retry(ctx,
  1219  			func() error {
  1220  				tchCtx, cancel, opt := newChannelContext(ctx, featureFlags)
  1221  				defer cancel()
  1222  
  1223  				return service.RespondActivityTaskCanceled(tchCtx, request, opt...)
  1224  			}, createDynamicServiceRetryPolicy(ctx), isServiceTransientError)
  1225  	case *s.RespondActivityTaskFailedRequest:
  1226  		reportErr = backoff.Retry(ctx,
  1227  			func() error {
  1228  				tchCtx, cancel, opt := newChannelContext(ctx, featureFlags)
  1229  				defer cancel()
  1230  
  1231  				return service.RespondActivityTaskFailed(tchCtx, request, opt...)
  1232  			}, createDynamicServiceRetryPolicy(ctx), isServiceTransientError)
  1233  	case *s.RespondActivityTaskCompletedRequest:
  1234  		reportErr = backoff.Retry(ctx,
  1235  			func() error {
  1236  				tchCtx, cancel, opt := newChannelContext(ctx, featureFlags)
  1237  				defer cancel()
  1238  
  1239  				return service.RespondActivityTaskCompleted(tchCtx, request, opt...)
  1240  			}, createDynamicServiceRetryPolicy(ctx), isServiceTransientError)
  1241  	}
  1242  	if reportErr == nil {
  1243  		switch request.(type) {
  1244  		case *s.RespondActivityTaskCanceledRequest:
  1245  			metricsScope.Counter(metrics.ActivityTaskCanceledCounter).Inc(1)
  1246  		case *s.RespondActivityTaskFailedRequest:
  1247  			metricsScope.Counter(metrics.ActivityTaskFailedCounter).Inc(1)
  1248  		case *s.RespondActivityTaskCompletedRequest:
  1249  			metricsScope.Counter(metrics.ActivityTaskCompletedCounter).Inc(1)
  1250  		}
  1251  	}
  1252  
  1253  	return reportErr
  1254  }
  1255  
  1256  func reportActivityCompleteByID(
  1257  	ctx context.Context,
  1258  	service workflowserviceclient.Interface,
  1259  	request interface{},
  1260  	metricsScope tally.Scope,
  1261  	featureFlags FeatureFlags,
  1262  ) error {
  1263  	if request == nil {
  1264  		// nothing to report
  1265  		return nil
  1266  	}
  1267  
  1268  	var reportErr error
  1269  	switch request := request.(type) {
  1270  	case *s.RespondActivityTaskCanceledByIDRequest:
  1271  		reportErr = backoff.Retry(ctx,
  1272  			func() error {
  1273  				tchCtx, cancel, opt := newChannelContext(ctx, featureFlags)
  1274  				defer cancel()
  1275  
  1276  				return service.RespondActivityTaskCanceledByID(tchCtx, request, opt...)
  1277  			}, createDynamicServiceRetryPolicy(ctx), isServiceTransientError)
  1278  	case *s.RespondActivityTaskFailedByIDRequest:
  1279  		reportErr = backoff.Retry(ctx,
  1280  			func() error {
  1281  				tchCtx, cancel, opt := newChannelContext(ctx, featureFlags)
  1282  				defer cancel()
  1283  
  1284  				return service.RespondActivityTaskFailedByID(tchCtx, request, opt...)
  1285  			}, createDynamicServiceRetryPolicy(ctx), isServiceTransientError)
  1286  	case *s.RespondActivityTaskCompletedByIDRequest:
  1287  		reportErr = backoff.Retry(ctx,
  1288  			func() error {
  1289  				tchCtx, cancel, opt := newChannelContext(ctx, featureFlags)
  1290  				defer cancel()
  1291  
  1292  				return service.RespondActivityTaskCompletedByID(tchCtx, request, opt...)
  1293  			}, createDynamicServiceRetryPolicy(ctx), isServiceTransientError)
  1294  	}
  1295  	if reportErr == nil {
  1296  		switch request.(type) {
  1297  		case *s.RespondActivityTaskCanceledByIDRequest:
  1298  			metricsScope.Counter(metrics.ActivityTaskCanceledByIDCounter).Inc(1)
  1299  		case *s.RespondActivityTaskFailedByIDRequest:
  1300  			metricsScope.Counter(metrics.ActivityTaskFailedByIDCounter).Inc(1)
  1301  		case *s.RespondActivityTaskCompletedByIDRequest:
  1302  			metricsScope.Counter(metrics.ActivityTaskCompletedByIDCounter).Inc(1)
  1303  		}
  1304  	}
  1305  
  1306  	return reportErr
  1307  }
  1308  
  1309  func convertActivityResultToRespondRequest(identity string, taskToken, result []byte, err error,
  1310  	dataConverter DataConverter) interface{} {
  1311  	if err == ErrActivityResultPending {
  1312  		// activity result is pending and will be completed asynchronously.
  1313  		// nothing to report at this point
  1314  		return ErrActivityResultPending
  1315  	}
  1316  
  1317  	if err == nil {
  1318  		return &s.RespondActivityTaskCompletedRequest{
  1319  			TaskToken: taskToken,
  1320  			Result:    result,
  1321  			Identity:  common.StringPtr(identity)}
  1322  	}
  1323  
  1324  	reason, details := getErrorDetails(err, dataConverter)
  1325  	if _, ok := err.(*CanceledError); ok || err == context.Canceled {
  1326  		return &s.RespondActivityTaskCanceledRequest{
  1327  			TaskToken: taskToken,
  1328  			Details:   details,
  1329  			Identity:  common.StringPtr(identity)}
  1330  	}
  1331  
  1332  	return &s.RespondActivityTaskFailedRequest{
  1333  		TaskToken: taskToken,
  1334  		Reason:    common.StringPtr(reason),
  1335  		Details:   details,
  1336  		Identity:  common.StringPtr(identity)}
  1337  }
  1338  
  1339  func convertActivityResultToRespondRequestByID(identity, domain, workflowID, runID, activityID string,
  1340  	result []byte, err error, dataConverter DataConverter) interface{} {
  1341  	if err == ErrActivityResultPending {
  1342  		// activity result is pending and will be completed asynchronously.
  1343  		// nothing to report at this point
  1344  		return nil
  1345  	}
  1346  
  1347  	if err == nil {
  1348  		return &s.RespondActivityTaskCompletedByIDRequest{
  1349  			Domain:     common.StringPtr(domain),
  1350  			WorkflowID: common.StringPtr(workflowID),
  1351  			RunID:      common.StringPtr(runID),
  1352  			ActivityID: common.StringPtr(activityID),
  1353  			Result:     result,
  1354  			Identity:   common.StringPtr(identity)}
  1355  	}
  1356  
  1357  	reason, details := getErrorDetails(err, dataConverter)
  1358  	if _, ok := err.(*CanceledError); ok || err == context.Canceled {
  1359  		return &s.RespondActivityTaskCanceledByIDRequest{
  1360  			Domain:     common.StringPtr(domain),
  1361  			WorkflowID: common.StringPtr(workflowID),
  1362  			RunID:      common.StringPtr(runID),
  1363  			ActivityID: common.StringPtr(activityID),
  1364  			Details:    details,
  1365  			Identity:   common.StringPtr(identity)}
  1366  	}
  1367  
  1368  	return &s.RespondActivityTaskFailedByIDRequest{
  1369  		Domain:     common.StringPtr(domain),
  1370  		WorkflowID: common.StringPtr(workflowID),
  1371  		RunID:      common.StringPtr(runID),
  1372  		ActivityID: common.StringPtr(activityID),
  1373  		Reason:     common.StringPtr(reason),
  1374  		Details:    details,
  1375  		Identity:   common.StringPtr(identity)}
  1376  }