github.com/cozy/cozy-stack@v0.0.0-20240603063001-31110fa4cae1/model/job/worker.go (about)

     1  package job
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	"math/rand"
     9  	"runtime"
    10  	"runtime/debug"
    11  	"sync/atomic"
    12  	"time"
    13  
    14  	"github.com/cozy/cozy-stack/model/instance"
    15  	"github.com/cozy/cozy-stack/pkg/consts"
    16  	"github.com/cozy/cozy-stack/pkg/logger"
    17  	"github.com/cozy/cozy-stack/pkg/metrics"
    18  	"github.com/cozy/cozy-stack/pkg/prefixer"
    19  	"github.com/cozy/cozy-stack/pkg/realtime"
    20  	"github.com/prometheus/client_golang/prometheus"
    21  )
    22  
    23  var (
    24  	defaultConcurrency  = runtime.NumCPU()
    25  	defaultMaxExecCount = 1
    26  	defaultRetryDelay   = 60 * time.Millisecond
    27  	defaultTimeout      = 10 * time.Second
    28  )
    29  
    30  type (
    31  	// WorkerInitFunc is called at the start of the worker system, only once. It
    32  	// is not called before every job process. It can be useful to initialize a
    33  	// global variable used by the worker.
    34  	WorkerInitFunc func() error
    35  
    36  	// WorkerStartFunc is optionally called at the beginning of the each job
    37  	// process and can produce a context value.
    38  	WorkerStartFunc func(ctx *TaskContext) (*TaskContext, error)
    39  
    40  	// WorkerFunc represent the work function that a worker should implement.
    41  	WorkerFunc func(ctx *TaskContext) error
    42  
    43  	// WorkerCommit is an optional method that is always called once after the
    44  	// execution of the WorkerFunc.
    45  	WorkerCommit func(ctx *TaskContext, errjob error) error
    46  
    47  	// WorkerBeforeHook is an optional method that is always called before the
    48  	// job is being pushed into the queue. It can be useful to skip the job
    49  	// beforehand.
    50  	WorkerBeforeHook func(job *Job) (bool, error)
    51  
    52  	// JobErrorCheckerHook is an optional method called at the beginning of the
    53  	// job execution to prevent a retry according to the previous error
    54  	// (specifically useful in the retries loop)
    55  	JobErrorCheckerHook func(err error) bool
    56  
    57  	// WorkerConfig is the configuration parameter of a worker defined by the job
    58  	// system. It contains parameters of the worker along with the worker main
    59  	// function that perform the work against a job's message.
    60  	WorkerConfig struct {
    61  		WorkerInit   WorkerInitFunc
    62  		WorkerStart  WorkerStartFunc
    63  		WorkerFunc   WorkerFunc
    64  		WorkerCommit WorkerCommit
    65  		WorkerType   string
    66  		BeforeHook   WorkerBeforeHook
    67  		ErrorHook    JobErrorCheckerHook
    68  		Concurrency  int
    69  		MaxExecCount int
    70  		Reserved     bool // true when the clients must not push jobs for this worker
    71  		Timeout      time.Duration
    72  		RetryDelay   time.Duration
    73  	}
    74  
    75  	// Worker is a unit of work that will consume from a queue and execute the do
    76  	// method for each jobs it pulls.
    77  	Worker struct {
    78  		Type    string
    79  		Conf    *WorkerConfig
    80  		jobs    chan *Job
    81  		running uint32
    82  		closing chan struct{}
    83  		closed  chan struct{}
    84  	}
    85  
    86  	// TaskContext is a context.Context passed to the worker for each task
    87  	// execution and contains specific values from the job.
    88  	TaskContext struct {
    89  		context.Context
    90  		Instance *instance.Instance
    91  		job      *Job
    92  		log      logger.Logger
    93  		id       string
    94  		cookie   interface{}
    95  		noRetry  bool
    96  	}
    97  )
    98  
    99  var slots chan struct{}
   100  
   101  func setNbSlots(nb int) {
   102  	slots = make(chan struct{}, nb)
   103  	for i := 0; i < nb; i++ {
   104  		slots <- struct{}{}
   105  	}
   106  }
   107  
   108  // Clone clones the worker config
   109  func (w *WorkerConfig) Clone() *WorkerConfig {
   110  	cloned := *w
   111  	return &cloned
   112  }
   113  
   114  // NewTaskContext returns a context.Context usable by a worker.
   115  func NewTaskContext(workerID string, job *Job, inst *instance.Instance) (*TaskContext, context.CancelFunc) {
   116  	ctx, cancel := context.WithCancel(context.Background())
   117  	id := fmt.Sprintf("%s/%s", workerID, job.ID())
   118  	entry := logger.WithDomain(job.Domain).WithNamespace("jobs")
   119  
   120  	if job.ForwardLogs {
   121  		hook := realtime.LogHook(job, realtime.GetHub(), consts.Jobs, job.ID())
   122  		entry.AddHook(hook)
   123  	}
   124  
   125  	log := entry.
   126  		WithField("job_id", job.ID()).
   127  		WithField("worker_id", workerID)
   128  
   129  	return &TaskContext{
   130  		Context:  ctx,
   131  		Instance: inst,
   132  		job:      job,
   133  		log:      log,
   134  		id:       id,
   135  	}, cancel
   136  }
   137  
   138  // WithTimeout returns a clone of the context with a different deadline.
   139  func (c *TaskContext) WithTimeout(timeout time.Duration) (*TaskContext, context.CancelFunc) {
   140  	ctx, cancel := context.WithTimeout(c.Context, timeout)
   141  	newCtx := c.clone()
   142  	newCtx.Context = ctx
   143  	return newCtx, cancel
   144  }
   145  
   146  // WithCookie returns a clone of the context with a new cookie value.
   147  func (c *TaskContext) WithCookie(cookie interface{}) *TaskContext {
   148  	newCtx := c.clone()
   149  	newCtx.cookie = cookie
   150  	return newCtx
   151  }
   152  
   153  // SetNoRetry set the no-retry flag to prevent a retry on the next execution.
   154  func (c *TaskContext) SetNoRetry() {
   155  	c.noRetry = true
   156  }
   157  
   158  // NoRetry returns the no-retry flag.
   159  func (c *TaskContext) NoRetry() bool {
   160  	return c.noRetry
   161  }
   162  
   163  func (c *TaskContext) clone() *TaskContext {
   164  	return &TaskContext{
   165  		Context:  c.Context,
   166  		Instance: c.Instance,
   167  		job:      c.job,
   168  		log:      c.log,
   169  		id:       c.id,
   170  		cookie:   c.cookie,
   171  	}
   172  }
   173  
   174  // ID returns a unique identifier for the worker context.
   175  func (c *TaskContext) ID() string {
   176  	return c.id
   177  }
   178  
   179  // Logger return the logger associated with the worker context.
   180  func (c *TaskContext) Logger() logger.Logger {
   181  	return c.log
   182  }
   183  
   184  // UnmarshalMessage unmarshals the message contained in the worker context.
   185  func (c *TaskContext) UnmarshalMessage(v interface{}) error {
   186  	return c.job.Message.Unmarshal(v)
   187  }
   188  
   189  // UnmarshalEvent unmarshals the event contained in the worker context.
   190  func (c *TaskContext) UnmarshalEvent(v interface{}) error {
   191  	if c.job == nil || c.job.Event == nil {
   192  		return errors.New("jobs: does not have an event associated")
   193  	}
   194  	return c.job.Event.Unmarshal(v)
   195  }
   196  
   197  // UnmarshalPayload unmarshals the payload contained in the worker context.
   198  func (c *TaskContext) UnmarshalPayload() (map[string]interface{}, error) {
   199  	var payload map[string]interface{}
   200  	if err := c.job.Payload.Unmarshal(&payload); err != nil {
   201  		return nil, err
   202  	}
   203  	return payload, nil
   204  }
   205  
   206  // TriggerID returns the possible trigger identifier responsible for launching
   207  // the job.
   208  func (c *TaskContext) TriggerID() (string, bool) {
   209  	triggerID := c.job.TriggerID
   210  	return triggerID, triggerID != ""
   211  }
   212  
   213  // Cookie returns the cookie associated with the worker context.
   214  func (c *TaskContext) Cookie() interface{} {
   215  	return c.cookie
   216  }
   217  
   218  // Manual returns if the job was started manually
   219  func (c *TaskContext) Manual() bool {
   220  	return c.job.Manual
   221  }
   222  
   223  // NewWorker creates a new instance of Worker with the given configuration.
   224  func NewWorker(conf *WorkerConfig) *Worker {
   225  	return &Worker{
   226  		Type: conf.WorkerType,
   227  		Conf: conf,
   228  	}
   229  }
   230  
   231  // Start is used to start the worker consumption of messages from its queue.
   232  func (w *Worker) Start(jobs chan *Job) error {
   233  	if !atomic.CompareAndSwapUint32(&w.running, 0, 1) {
   234  		return ErrClosed
   235  	}
   236  	w.jobs = jobs
   237  	w.closing = make(chan struct{}, w.Conf.Concurrency)
   238  	w.closed = make(chan struct{})
   239  	if w.Conf.WorkerInit != nil {
   240  		if err := w.Conf.WorkerInit(); err != nil {
   241  			return fmt.Errorf("Could not start worker %s: %s", w.Type, err)
   242  		}
   243  	}
   244  	for i := 0; i < w.Conf.Concurrency; i++ {
   245  		name := fmt.Sprintf("%s/%d", w.Type, i)
   246  		joblog.Debugf("Start worker %s", name)
   247  		go w.work(name)
   248  	}
   249  	return nil
   250  }
   251  
   252  // Shutdown is used to close the worker, waiting for all tasks to end
   253  func (w *Worker) Shutdown(ctx context.Context) error {
   254  	if !atomic.CompareAndSwapUint32(&w.running, 1, 0) {
   255  		return ErrClosed
   256  	}
   257  	close(w.jobs)
   258  	for i := 0; i < w.Conf.Concurrency; i++ {
   259  		w.closing <- struct{}{}
   260  	}
   261  	for i := 0; i < w.Conf.Concurrency; i++ {
   262  		select {
   263  		case <-ctx.Done():
   264  			return ctx.Err()
   265  		case <-w.closed:
   266  		}
   267  	}
   268  	return nil
   269  }
   270  
   271  func (w *Worker) work(workerID string) {
   272  	for job := range w.jobs {
   273  		domain := job.Domain
   274  		if domain == "" {
   275  			joblog.Errorf("%s: missing domain from job request", workerID)
   276  			continue
   277  		}
   278  		var inst *instance.Instance
   279  		if domain != prefixer.GlobalPrefixer.DomainName() {
   280  			var err error
   281  			inst, err = instance.Get(job.Domain)
   282  			if err != nil {
   283  				joblog.Errorf("Instance not found for %s: %s", job.Domain, err)
   284  				continue
   285  			}
   286  			// Do not execute jobs for instances with blocking not signed TOS,
   287  			// except for:
   288  			// - mails because the user may needs a mail to login and accept
   289  			//   the new TOS (2FA, password reset, etc.)
   290  			// - migrations because the old version may be no longer supported
   291  			//   when the user will sign the TOS
   292  			if w.Type != "sendmail" && w.Type != "migrations" {
   293  				notSigned, deadline := inst.CheckTOSNotSignedAndDeadline()
   294  				if notSigned && deadline == instance.TOSBlocked {
   295  					continue
   296  				}
   297  			}
   298  		}
   299  		w.runTask(inst, workerID, job)
   300  	}
   301  	joblog.Debugf("%s: worker shut down", workerID)
   302  	w.closed <- struct{}{}
   303  }
   304  
   305  func (w *Worker) runTask(inst *instance.Instance, workerID string, job *Job) {
   306  	taskCtx, cancel := NewTaskContext(workerID, job, inst)
   307  	defer cancel()
   308  	if err := job.AckConsumed(); err != nil {
   309  		taskCtx.Logger().Errorf("error acking consume job: %s",
   310  			err.Error())
   311  		return
   312  	}
   313  	t := &task{
   314  		w:    w,
   315  		ctx:  taskCtx,
   316  		job:  job,
   317  		conf: w.defaultedConf(job.Options),
   318  	}
   319  
   320  	ch := make(chan error)
   321  	go func() {
   322  		errRun := t.run()
   323  		if errRun == ErrAbort {
   324  			errRun = nil
   325  		}
   326  		ch <- errRun
   327  	}()
   328  
   329  	var errRun error
   330  	select {
   331  	case <-w.closing:
   332  		cancel()
   333  		errRun = <-ch
   334  	case errRun = <-ch:
   335  	}
   336  
   337  	var runResultLabel string
   338  	var errAck error
   339  	if errRun != nil {
   340  		taskCtx.Logger().Errorf("error while performing job: %s",
   341  			errRun.Error())
   342  		runResultLabel = metrics.WorkerExecResultErrored
   343  		errAck = job.Nack(errRun.Error())
   344  	} else {
   345  		runResultLabel = metrics.WorkerExecResultSuccess
   346  		errAck = job.Ack()
   347  	}
   348  
   349  	// Distinguish classic job execution and konnector/account deletion
   350  	msg := struct {
   351  		Account        string `json:"account"`
   352  		AccountRev     string `json:"account_rev"`
   353  		Konnector      string `json:"konnector"`
   354  		AccountDeleted bool   `json:"account_deleted"`
   355  	}{}
   356  	err := json.Unmarshal(job.Message, &msg)
   357  
   358  	if err == nil && w.Type == "konnector" && msg.AccountDeleted {
   359  		metrics.WorkerKonnectorExecDeleteCounter.WithLabelValues(w.Type, runResultLabel).Inc()
   360  	} else {
   361  		metrics.WorkerExecCounter.WithLabelValues(w.Type, runResultLabel).Inc()
   362  	}
   363  
   364  	if errAck != nil {
   365  		taskCtx.Logger().Errorf("error while acking job done: %s",
   366  			errAck.Error())
   367  	}
   368  
   369  	// Delete the trigger associated with the job (if any) when we receive a
   370  	// BadTriggerError.
   371  	if job.TriggerID != "" && globalJobSystem != nil {
   372  		if _, ok := errRun.(BadTriggerError); ok {
   373  			_ = globalJobSystem.DeleteTrigger(job, job.TriggerID)
   374  		}
   375  	}
   376  }
   377  
   378  func (w *Worker) defaultedConf(opts *JobOptions) *WorkerConfig {
   379  	c := w.Conf.Clone()
   380  	if c.Concurrency == 0 {
   381  		c.Concurrency = defaultConcurrency
   382  	}
   383  	if c.MaxExecCount == 0 {
   384  		c.MaxExecCount = defaultMaxExecCount
   385  	}
   386  	if c.RetryDelay == 0 {
   387  		c.RetryDelay = defaultRetryDelay
   388  	}
   389  	if c.Timeout == 0 {
   390  		c.Timeout = defaultTimeout
   391  	}
   392  	if opts == nil {
   393  		return c
   394  	}
   395  	if opts.MaxExecCount != 0 && opts.MaxExecCount < c.MaxExecCount {
   396  		c.MaxExecCount = opts.MaxExecCount
   397  	}
   398  	if opts.Timeout > 0 && opts.Timeout < c.Timeout {
   399  		c.Timeout = opts.Timeout
   400  	}
   401  	return c
   402  }
   403  
   404  type task struct {
   405  	w    *Worker
   406  	ctx  *TaskContext
   407  	conf *WorkerConfig
   408  	job  *Job
   409  
   410  	startTime time.Time
   411  	endTime   time.Time
   412  	execCount int
   413  }
   414  
   415  func (t *task) run() (err error) {
   416  	t.startTime = time.Now()
   417  	t.execCount = 0
   418  
   419  	if t.conf.WorkerStart != nil {
   420  		t.ctx, err = t.conf.WorkerStart(t.ctx)
   421  		if err != nil {
   422  			return err
   423  		}
   424  	}
   425  	defer func() {
   426  		if t.conf.WorkerCommit != nil {
   427  			t.ctx.log = t.ctx.Logger().WithField("exec_time", t.endTime.Sub(t.startTime))
   428  			if errc := t.conf.WorkerCommit(t.ctx, err); errc != nil {
   429  				t.ctx.Logger().Warnf("Error while committing job: %s",
   430  					errc.Error())
   431  			}
   432  		}
   433  	}()
   434  	for {
   435  		retry, delay, timeout := t.nextDelay(err)
   436  
   437  		// The optional ErrorHook function allows to prevent retries depending
   438  		// on the previous error
   439  		if retry && t.conf.ErrorHook != nil {
   440  			retry = t.conf.ErrorHook(err)
   441  		}
   442  		if !retry {
   443  			break
   444  		}
   445  		if err != nil {
   446  			t.ctx.Logger().Warnf("Error while performing job: %s (retry in %s)",
   447  				err.Error(), delay)
   448  		}
   449  
   450  		if delay > 0 {
   451  			time.Sleep(delay)
   452  		}
   453  
   454  		t.ctx.Logger().Debugf("Executing job (%d) (timeout set to %s)",
   455  			t.execCount, timeout)
   456  
   457  		var execResultLabel string
   458  		timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) {
   459  			metrics.WorkerExecDurations.WithLabelValues(t.w.Type, execResultLabel).Observe(v)
   460  		}))
   461  
   462  		ctx, cancel := t.ctx.WithTimeout(timeout)
   463  		err = t.exec(ctx)
   464  		if err == nil {
   465  			execResultLabel = metrics.WorkerExecResultSuccess
   466  			timer.ObserveDuration()
   467  			t.endTime = time.Now()
   468  			cancel()
   469  			break
   470  		}
   471  		execResultLabel = metrics.WorkerExecResultErrored
   472  		timer.ObserveDuration()
   473  		t.endTime = time.Now()
   474  
   475  		// Incrementing timeouts counter
   476  		if t.job.Message != nil {
   477  			var slug string
   478  			var msg map[string]interface{}
   479  
   480  			if errd := json.Unmarshal(t.job.Message, &msg); errd != nil {
   481  				ctx.Logger().Errorf("Cannot unmarshal job message %s", t.job.Message)
   482  			} else {
   483  				switch t.w.Type {
   484  				case "konnector":
   485  					slug, _ = msg["konnector"].(string)
   486  				case "service":
   487  					slug, _ = msg["slug"].(string)
   488  				default:
   489  					slug = ""
   490  				}
   491  
   492  				// Forcing the timeout counter to 0 if it has not been initialized
   493  				metrics.WorkerExecTimeoutsCounter.WithLabelValues(t.w.Type, slug)
   494  
   495  				if errors.Is(err, context.DeadlineExceeded) { // This is a timeout
   496  					metrics.WorkerExecTimeoutsCounter.WithLabelValues(t.w.Type, slug).Inc()
   497  				}
   498  			}
   499  		}
   500  
   501  		// Even though ctx should have expired already, it is good practice to call
   502  		// its cancelation function in any case. Failure to do so may keep the
   503  		// context and its parent alive longer than necessary.
   504  		cancel()
   505  		t.execCount++
   506  
   507  		if ctx.NoRetry() {
   508  			break
   509  		}
   510  	}
   511  
   512  	metrics.WorkerExecRetries.WithLabelValues(t.w.Type).Observe(float64(t.execCount))
   513  	return
   514  }
   515  
   516  func (t *task) exec(ctx *TaskContext) (err error) {
   517  	var slot struct{}
   518  	if slots != nil {
   519  		slot = <-slots
   520  	}
   521  	defer func() {
   522  		if slots != nil {
   523  			slots <- slot
   524  		}
   525  		if r := recover(); r != nil {
   526  			var ok bool
   527  			err, ok = r.(error)
   528  			if !ok {
   529  				err = fmt.Errorf("%v", r)
   530  			}
   531  			ctx.Logger().Errorf("[panic] %s: %s", r, debug.Stack())
   532  		}
   533  	}()
   534  	return t.conf.WorkerFunc(ctx)
   535  }
   536  
   537  func (t *task) nextDelay(prevError error) (bool, time.Duration, time.Duration) {
   538  	// for certain kinds of errors, we do not have a retry since these error
   539  	// cannot be recovered from
   540  	{
   541  		if _, ok := prevError.(BadTriggerError); ok {
   542  			return false, 0, 0
   543  		}
   544  		switch prevError {
   545  		case ErrAbort, ErrMessageUnmarshal, ErrMessageNil:
   546  			return false, 0, 0
   547  		}
   548  	}
   549  
   550  	c := t.conf
   551  
   552  	if t.execCount >= c.MaxExecCount {
   553  		return false, 0, 0
   554  	}
   555  
   556  	// the worker timeout should take into account the maximum execution time
   557  	// allowed to the task
   558  	timeout := c.Timeout
   559  
   560  	var nextDelay time.Duration
   561  	if t.execCount == 0 {
   562  		// on first execution, execute immediately
   563  		nextDelay = 0
   564  	} else {
   565  		nextDelay = c.RetryDelay << uint(t.execCount-1)
   566  
   567  		// fuzzDelay number between delay * (1 +/- 0.1)
   568  		fuzzDelay := int(0.1 * float64(nextDelay))
   569  		nextDelay += time.Duration((rand.Intn(2*fuzzDelay) - fuzzDelay))
   570  	}
   571  
   572  	return true, nextDelay, timeout
   573  }