github.com/wfusion/gofusion@v1.1.14/common/infra/asynq/server.go (about)

     1  // Copyright 2020 Kentaro Hibino. All rights reserved.
     2  // Use of this source code is governed by a MIT license
     3  // that can be found in the LICENSE file.
     4  
     5  package asynq
     6  
     7  import (
     8  	"context"
     9  	"errors"
    10  	"fmt"
    11  	"math"
    12  	"math/rand"
    13  	"runtime"
    14  	"strings"
    15  	"sync"
    16  	"time"
    17  
    18  	"github.com/redis/go-redis/v9"
    19  	"github.com/wfusion/gofusion/common/infra/asynq/pkg/base"
    20  	"github.com/wfusion/gofusion/common/infra/asynq/pkg/log"
    21  	"github.com/wfusion/gofusion/common/infra/asynq/pkg/rdb"
    22  )
    23  
    24  // Server is responsible for task processing and task lifecycle management.
    25  //
    26  // Server pulls tasks off queues and processes them.
    27  // If the processing of a task is unsuccessful, server will schedule it for a retry.
    28  //
    29  // A task will be retried until either the task gets processed successfully
    30  // or until it reaches its max retry count.
    31  //
    32  // If a task exhausts its retries, it will be moved to the archive and
    33  // will be kept in the archive set.
    34  // Note that the archive size is finite and once it reaches its max size,
    35  // oldest tasks in the archive will be deleted.
    36  type Server struct {
    37  	logger *log.Logger
    38  
    39  	broker base.Broker
    40  
    41  	state *serverState
    42  
    43  	disableRedisConnClose bool
    44  
    45  	// wait group to wait for all goroutines to finish.
    46  	wg            sync.WaitGroup
    47  	forwarder     *forwarder
    48  	processor     *processor
    49  	syncer        *syncer
    50  	heartbeater   *heartbeater
    51  	subscriber    *subscriber
    52  	recoverer     *recoverer
    53  	healthchecker *healthchecker
    54  	janitor       *janitor
    55  	aggregator    *aggregator
    56  }
    57  
    58  type serverState struct {
    59  	mu    sync.Mutex
    60  	value serverStateValue
    61  }
    62  
    63  type serverStateValue int
    64  
    65  const (
    66  	// StateNew represents a new server. Server begins in
    67  	// this state and then transition to StatusActive when
    68  	// Start or Run is callled.
    69  	srvStateNew serverStateValue = iota
    70  
    71  	// StateActive indicates the server is up and active.
    72  	srvStateActive
    73  
    74  	// StateStopped indicates the server is up but no longer processing new tasks.
    75  	srvStateStopped
    76  
    77  	// StateClosed indicates the server has been shutdown.
    78  	srvStateClosed
    79  )
    80  
    81  var serverStates = []string{
    82  	"new",
    83  	"active",
    84  	"stopped",
    85  	"closed",
    86  }
    87  
    88  func (s serverStateValue) String() string {
    89  	if srvStateNew <= s && s <= srvStateClosed {
    90  		return serverStates[s]
    91  	}
    92  	return "unknown status"
    93  }
    94  
    95  // Config specifies the server's background-task processing behavior.
    96  type Config struct {
    97  	// Maximum number of concurrent processing of tasks.
    98  	//
    99  	// If set to a zero or negative value, NewServer will overwrite the value
   100  	// to the number of CPUs usable by the current process.
   101  	Concurrency int
   102  
   103  	// BaseContext optionally specifies a function that returns the base context for Handler invocations on this server.
   104  	//
   105  	// If BaseContext is nil, the default is context.Background().
   106  	// If this is defined, then it MUST return a non-nil context
   107  	BaseContext func() context.Context
   108  
   109  	// Function to calculate retry delay for a failed task.
   110  	//
   111  	// By default, it uses exponential backoff algorithm to calculate the delay.
   112  	RetryDelayFunc RetryDelayFunc
   113  
   114  	// Predicate function to determine whether the error returned from Handler is a failure.
   115  	// If the function returns false, Server will not increment the retried counter for the task,
   116  	// and Server won't record the queue stats (processed and failed stats) to avoid skewing the error
   117  	// rate of the queue.
   118  	//
   119  	// By default, if the given error is non-nil the function returns true.
   120  	IsFailure func(error) bool
   121  
   122  	// List of queues to process with given priority value. Keys are the names of the
   123  	// queues and values are associated priority value.
   124  	//
   125  	// If set to nil or not specified, the server will process only the "default" queue.
   126  	//
   127  	// Priority is treated as follows to avoid starving low priority queues.
   128  	//
   129  	// Example:
   130  	//
   131  	//     Queues: map[string]int{
   132  	//         "critical": 6,
   133  	//         "default":  3,
   134  	//         "low":      1,
   135  	//     }
   136  	//
   137  	// With the above config and given that all queues are not empty, the tasks
   138  	// in "critical", "default", "low" should be processed 60%, 30%, 10% of
   139  	// the time respectively.
   140  	//
   141  	// If a queue has a zero or negative priority value, the queue will be ignored.
   142  	Queues map[string]int
   143  
   144  	// StrictPriority indicates whether the queue priority should be treated strictly.
   145  	//
   146  	// If set to true, tasks in the queue with the highest priority is processed first.
   147  	// The tasks in lower priority queues are processed only when those queues with
   148  	// higher priorities are empty.
   149  	StrictPriority bool
   150  
   151  	// ErrorHandler handles errors returned by the task handler.
   152  	//
   153  	// HandleError is invoked only if the task handler returns a non-nil error.
   154  	//
   155  	// Example:
   156  	//
   157  	//     func reportError(ctx context, task *asynq.Task, err error) {
   158  	//         retried, _ := asynq.GetRetryCount(ctx)
   159  	//         maxRetry, _ := asynq.GetMaxRetry(ctx)
   160  	//     	   if retried >= maxRetry {
   161  	//             err = fmt.Errorf("retry exhausted for task %s: %w", task.Type, err)
   162  	//     	   }
   163  	//         errorReportingService.Notify(err)
   164  	//     })
   165  	//
   166  	//     ErrorHandler: asynq.ErrorHandlerFunc(reportError)
   167  
   168  	//    we can also handle panic error like:
   169  	//     func reportError(ctx context, task *asynq.Task, err error) {
   170  	//         if asynq.IsPanic(err) {
   171  	//	          errorReportingService.Notify(err)
   172  	// 	       }
   173  	//     })
   174  	//
   175  	//     ErrorHandler: asynq.ErrorHandlerFunc(reportError)
   176  
   177  	ErrorHandler ErrorHandler
   178  
   179  	// Logger specifies the logger used by the server instance.
   180  	//
   181  	// If unset, default logger is used.
   182  	Logger Logger
   183  
   184  	// LogLevel specifies the minimum log level to enable.
   185  	//
   186  	// If unset, InfoLevel is used by default.
   187  	LogLevel LogLevel
   188  
   189  	// ShutdownTimeout specifies the duration to wait to let workers finish their tasks
   190  	// before forcing them to abort when stopping the server.
   191  	//
   192  	// If unset or zero, default timeout of 8 seconds is used.
   193  	ShutdownTimeout time.Duration
   194  
   195  	// HealthCheckFunc is called periodically with any errors encountered during ping to the
   196  	// connected redis server.
   197  	HealthCheckFunc func(error)
   198  
   199  	// HealthCheckInterval specifies the interval between healthchecks.
   200  	//
   201  	// If unset or zero, the interval is set to 15 seconds.
   202  	HealthCheckInterval time.Duration
   203  
   204  	// DelayedTaskCheckInterval specifies the interval between checks run on 'scheduled' and 'retry'
   205  	// tasks, and forwarding them to 'pending' state if they are ready to be processed.
   206  	//
   207  	// If unset or zero, the interval is set to 5 seconds.
   208  	DelayedTaskCheckInterval time.Duration
   209  
   210  	// GroupGracePeriod specifies the amount of time the server will wait for an incoming task before aggregating
   211  	// the tasks in a group. If an incoming task is received within this period, the server will wait for another
   212  	// period of the same length, up to GroupMaxDelay if specified.
   213  	//
   214  	// If unset or zero, the grace period is set to 1 minute.
   215  	// Minimum duration for GroupGracePeriod is 1 second. If value specified is less than a second, the call to
   216  	// NewServer will panic.
   217  	GroupGracePeriod time.Duration
   218  
   219  	// GroupMaxDelay specifies the maximum amount of time the server will wait for incoming tasks before aggregating
   220  	// the tasks in a group.
   221  	//
   222  	// If unset or zero, no delay limit is used.
   223  	GroupMaxDelay time.Duration
   224  
   225  	// GroupMaxSize specifies the maximum number of tasks that can be aggregated into a single task within a group.
   226  	// If GroupMaxSize is reached, the server will aggregate the tasks into one immediately.
   227  	//
   228  	// If unset or zero, no size limit is used.
   229  	GroupMaxSize int
   230  
   231  	// GroupAggregator specifies the aggregation function used to aggregate multiple tasks in a group into one task.
   232  	//
   233  	// If unset or nil, the group aggregation feature will be disabled on the server.
   234  	GroupAggregator GroupAggregator
   235  
   236  	DisableRedisConnClose bool
   237  }
   238  
   239  // GroupAggregator aggregates a group of tasks into one before the tasks are passed to the Handler.
   240  type GroupAggregator interface {
   241  	// Aggregate aggregates the given tasks in a group with the given group name,
   242  	// and returns a new task which is the aggregation of those tasks.
   243  	//
   244  	// Use NewTask(typename, payload, opts...) to set any options for the aggregated task.
   245  	// The Queue option, if provided, will be ignored and the aggregated task will always be enqueued
   246  	// to the same queue the group belonged.
   247  	Aggregate(group string, tasks []*Task) *Task
   248  }
   249  
   250  // The GroupAggregatorFunc type is an adapter to allow the use of  ordinary functions as a GroupAggregator.
   251  // If f is a function with the appropriate signature, GroupAggregatorFunc(f) is a GroupAggregator that calls f.
   252  type GroupAggregatorFunc func(group string, tasks []*Task) *Task
   253  
   254  // Aggregate calls fn(group, tasks)
   255  func (fn GroupAggregatorFunc) Aggregate(group string, tasks []*Task) *Task {
   256  	return fn(group, tasks)
   257  }
   258  
   259  // An ErrorHandler handles an error occurred during task processing.
   260  type ErrorHandler interface {
   261  	HandleError(ctx context.Context, task *Task, err error)
   262  }
   263  
   264  // The ErrorHandlerFunc type is an adapter to allow the use of  ordinary functions as a ErrorHandler.
   265  // If f is a function with the appropriate signature, ErrorHandlerFunc(f) is a ErrorHandler that calls f.
   266  type ErrorHandlerFunc func(ctx context.Context, task *Task, err error)
   267  
   268  // HandleError calls fn(ctx, task, err)
   269  func (fn ErrorHandlerFunc) HandleError(ctx context.Context, task *Task, err error) {
   270  	fn(ctx, task, err)
   271  }
   272  
   273  // RetryDelayFunc calculates the retry delay duration for a failed task given
   274  // the retry count, error, and the task.
   275  //
   276  // n is the number of times the task has been retried.
   277  // e is the error returned by the task handler.
   278  // t is the task in question.
   279  type RetryDelayFunc func(n int, e error, t *Task) time.Duration
   280  
   281  // Logger supports logging at various log levels.
   282  type Logger interface {
   283  	// Debug logs a message at Debug level.
   284  	Debug(args ...any)
   285  
   286  	// Info logs a message at Info level.
   287  	Info(args ...any)
   288  
   289  	// Warn logs a message at Warning level.
   290  	Warn(args ...any)
   291  
   292  	// Error logs a message at Error level.
   293  	Error(args ...any)
   294  
   295  	// Fatal logs a message at Fatal level
   296  	// and process will exit with status set to 1.
   297  	Fatal(args ...any)
   298  }
   299  
   300  // LogLevel represents logging level.
   301  //
   302  // It satisfies flag.Value interface.
   303  type LogLevel int32
   304  
   305  const (
   306  	// Note: reserving value zero to differentiate unspecified case.
   307  	level_unspecified LogLevel = iota
   308  
   309  	// DebugLevel is the lowest level of logging.
   310  	// Debug logs are intended for debugging and development purposes.
   311  	DebugLevel
   312  
   313  	// InfoLevel is used for general informational log messages.
   314  	InfoLevel
   315  
   316  	// WarnLevel is used for undesired but relatively expected events,
   317  	// which may indicate a problem.
   318  	WarnLevel
   319  
   320  	// ErrorLevel is used for undesired and unexpected events that
   321  	// the program can recover from.
   322  	ErrorLevel
   323  
   324  	// FatalLevel is used for undesired and unexpected events that
   325  	// the program cannot recover from.
   326  	FatalLevel
   327  )
   328  
   329  // String is part of the flag.Value interface.
   330  func (l *LogLevel) String() string {
   331  	switch *l {
   332  	case DebugLevel:
   333  		return "debug"
   334  	case InfoLevel:
   335  		return "info"
   336  	case WarnLevel:
   337  		return "warn"
   338  	case ErrorLevel:
   339  		return "error"
   340  	case FatalLevel:
   341  		return "fatal"
   342  	}
   343  	panic(fmt.Sprintf("asynq: unexpected log level: %v", *l))
   344  }
   345  
   346  // Set is part of the flag.Value interface.
   347  func (l *LogLevel) Set(val string) error {
   348  	switch strings.ToLower(val) {
   349  	case "debug":
   350  		*l = DebugLevel
   351  	case "info":
   352  		*l = InfoLevel
   353  	case "warn", "warning":
   354  		*l = WarnLevel
   355  	case "error":
   356  		*l = ErrorLevel
   357  	case "fatal":
   358  		*l = FatalLevel
   359  	default:
   360  		return fmt.Errorf("asynq: unsupported log level %q", val)
   361  	}
   362  	return nil
   363  }
   364  
   365  func toInternalLogLevel(l LogLevel) log.Level {
   366  	switch l {
   367  	case DebugLevel:
   368  		return log.DebugLevel
   369  	case InfoLevel:
   370  		return log.InfoLevel
   371  	case WarnLevel:
   372  		return log.WarnLevel
   373  	case ErrorLevel:
   374  		return log.ErrorLevel
   375  	case FatalLevel:
   376  		return log.FatalLevel
   377  	}
   378  	panic(fmt.Sprintf("asynq: unexpected log level: %v", l))
   379  }
   380  
   381  // DefaultRetryDelayFunc is the default RetryDelayFunc used if one is not specified in Config.
   382  // It uses exponential back-off strategy to calculate the retry delay.
   383  func DefaultRetryDelayFunc(n int, e error, t *Task) time.Duration {
   384  	r := rand.New(rand.NewSource(time.Now().UnixNano()))
   385  	// Formula taken from https://github.com/mperham/sidekiq.
   386  	s := int(math.Pow(float64(n), 4)) + 15 + (r.Intn(30) * (n + 1))
   387  	return time.Duration(s) * time.Second
   388  }
   389  
   390  func defaultIsFailureFunc(err error) bool { return err != nil }
   391  
   392  var defaultQueueConfig = map[string]int{
   393  	base.DefaultQueueName: 1,
   394  }
   395  
   396  const (
   397  	defaultShutdownTimeout = 8 * time.Second
   398  
   399  	defaultHealthCheckInterval = 15 * time.Second
   400  
   401  	defaultDelayedTaskCheckInterval = 5 * time.Second
   402  
   403  	defaultGroupGracePeriod = 1 * time.Minute
   404  )
   405  
   406  // NewServer returns a new Server given a redis connection option
   407  // and server configuration.
   408  func NewServer(r RedisConnOpt, cfg Config) *Server {
   409  	c, ok := r.MakeRedisClient().(redis.UniversalClient)
   410  	if !ok {
   411  		panic(fmt.Sprintf("asynq: unsupported RedisConnOpt type %T", r))
   412  	}
   413  	baseCtxFn := cfg.BaseContext
   414  	if baseCtxFn == nil {
   415  		baseCtxFn = context.Background
   416  	}
   417  	n := cfg.Concurrency
   418  	if n < 1 {
   419  		n = runtime.NumCPU()
   420  	}
   421  	delayFunc := cfg.RetryDelayFunc
   422  	if delayFunc == nil {
   423  		delayFunc = DefaultRetryDelayFunc
   424  	}
   425  	isFailureFunc := cfg.IsFailure
   426  	if isFailureFunc == nil {
   427  		isFailureFunc = defaultIsFailureFunc
   428  	}
   429  	queues := make(map[string]int)
   430  	for qname, p := range cfg.Queues {
   431  		if err := base.ValidateQueueName(qname); err != nil {
   432  			continue // ignore invalid queue names
   433  		}
   434  		if p > 0 {
   435  			queues[qname] = p
   436  		}
   437  	}
   438  	if len(queues) == 0 {
   439  		queues = defaultQueueConfig
   440  	}
   441  	var qnames []string
   442  	for q := range queues {
   443  		qnames = append(qnames, q)
   444  	}
   445  	shutdownTimeout := cfg.ShutdownTimeout
   446  	if shutdownTimeout == 0 {
   447  		shutdownTimeout = defaultShutdownTimeout
   448  	}
   449  	healthcheckInterval := cfg.HealthCheckInterval
   450  	if healthcheckInterval == 0 {
   451  		healthcheckInterval = defaultHealthCheckInterval
   452  	}
   453  	// TODO: Create a helper to check for zero value and fall back to default (e.g. getDurationOrDefault())
   454  	groupGracePeriod := cfg.GroupGracePeriod
   455  	if groupGracePeriod == 0 {
   456  		groupGracePeriod = defaultGroupGracePeriod
   457  	}
   458  	if groupGracePeriod < time.Second {
   459  		panic("GroupGracePeriod cannot be less than a second")
   460  	}
   461  	logger := log.NewLogger(cfg.Logger)
   462  	loglevel := cfg.LogLevel
   463  	if loglevel == level_unspecified {
   464  		loglevel = InfoLevel
   465  	}
   466  	logger.SetLevel(toInternalLogLevel(loglevel))
   467  
   468  	rdb := rdb.NewRDB(c)
   469  	starting := make(chan *workerInfo)
   470  	finished := make(chan *base.TaskMessage)
   471  	syncCh := make(chan *syncRequest)
   472  	srvState := &serverState{value: srvStateNew}
   473  	cancels := base.NewCancelations()
   474  
   475  	syncer := newSyncer(syncerParams{
   476  		logger:     logger,
   477  		requestsCh: syncCh,
   478  		interval:   5 * time.Second,
   479  	})
   480  	heartbeater := newHeartbeater(heartbeaterParams{
   481  		logger:         logger,
   482  		broker:         rdb,
   483  		interval:       5 * time.Second,
   484  		concurrency:    n,
   485  		queues:         queues,
   486  		strictPriority: cfg.StrictPriority,
   487  		state:          srvState,
   488  		starting:       starting,
   489  		finished:       finished,
   490  	})
   491  	delayedTaskCheckInterval := cfg.DelayedTaskCheckInterval
   492  	if delayedTaskCheckInterval == 0 {
   493  		delayedTaskCheckInterval = defaultDelayedTaskCheckInterval
   494  	}
   495  	forwarder := newForwarder(forwarderParams{
   496  		logger:   logger,
   497  		broker:   rdb,
   498  		queues:   qnames,
   499  		interval: delayedTaskCheckInterval,
   500  	})
   501  	subscriber := newSubscriber(subscriberParams{
   502  		logger:       logger,
   503  		broker:       rdb,
   504  		cancelations: cancels,
   505  	})
   506  	processor := newProcessor(processorParams{
   507  		logger:          logger,
   508  		broker:          rdb,
   509  		retryDelayFunc:  delayFunc,
   510  		baseCtxFn:       baseCtxFn,
   511  		isFailureFunc:   isFailureFunc,
   512  		syncCh:          syncCh,
   513  		cancelations:    cancels,
   514  		concurrency:     n,
   515  		queues:          queues,
   516  		strictPriority:  cfg.StrictPriority,
   517  		errHandler:      cfg.ErrorHandler,
   518  		shutdownTimeout: shutdownTimeout,
   519  		starting:        starting,
   520  		finished:        finished,
   521  	})
   522  	recoverer := newRecoverer(recovererParams{
   523  		logger:         logger,
   524  		broker:         rdb,
   525  		retryDelayFunc: delayFunc,
   526  		isFailureFunc:  isFailureFunc,
   527  		queues:         qnames,
   528  		interval:       1 * time.Minute,
   529  	})
   530  	healthchecker := newHealthChecker(healthcheckerParams{
   531  		logger:          logger,
   532  		broker:          rdb,
   533  		interval:        healthcheckInterval,
   534  		healthcheckFunc: cfg.HealthCheckFunc,
   535  	})
   536  	janitor := newJanitor(janitorParams{
   537  		logger:   logger,
   538  		broker:   rdb,
   539  		queues:   qnames,
   540  		interval: 8 * time.Second,
   541  	})
   542  	aggregator := newAggregator(aggregatorParams{
   543  		logger:          logger,
   544  		broker:          rdb,
   545  		queues:          qnames,
   546  		gracePeriod:     groupGracePeriod,
   547  		maxDelay:        cfg.GroupMaxDelay,
   548  		maxSize:         cfg.GroupMaxSize,
   549  		groupAggregator: cfg.GroupAggregator,
   550  	})
   551  	return &Server{
   552  		logger:                logger,
   553  		broker:                rdb,
   554  		state:                 srvState,
   555  		disableRedisConnClose: cfg.DisableRedisConnClose,
   556  		forwarder:             forwarder,
   557  		processor:             processor,
   558  		syncer:                syncer,
   559  		heartbeater:           heartbeater,
   560  		subscriber:            subscriber,
   561  		recoverer:             recoverer,
   562  		healthchecker:         healthchecker,
   563  		janitor:               janitor,
   564  		aggregator:            aggregator,
   565  	}
   566  }
   567  
   568  // A Handler processes tasks.
   569  //
   570  // ProcessTask should return nil if the processing of a task
   571  // is successful.
   572  //
   573  // If ProcessTask returns a non-nil error or panics, the task
   574  // will be retried after delay if retry-count is remaining,
   575  // otherwise the task will be archived.
   576  //
   577  // One exception to this rule is when ProcessTask returns a SkipRetry error.
   578  // If the returned error is SkipRetry or an error wraps SkipRetry, retry is
   579  // skipped and the task will be immediately archived instead.
   580  type Handler interface {
   581  	ProcessTask(context.Context, *Task) error
   582  }
   583  
   584  // The HandlerFunc type is an adapter to allow the use of
   585  // ordinary functions as a Handler. If f is a function
   586  // with the appropriate signature, HandlerFunc(f) is a
   587  // Handler that calls f.
   588  type HandlerFunc func(context.Context, *Task) error
   589  
   590  // ProcessTask calls fn(ctx, task)
   591  func (fn HandlerFunc) ProcessTask(ctx context.Context, task *Task) error {
   592  	return fn(ctx, task)
   593  }
   594  
   595  // ErrServerClosed indicates that the operation is now illegal because of the server has been shutdown.
   596  var ErrServerClosed = errors.New("asynq: Server closed")
   597  
   598  // Run starts the task processing and blocks until
   599  // an os signal to exit the program is received. Once it receives
   600  // a signal, it gracefully shuts down all active workers and other
   601  // goroutines to process the tasks.
   602  //
   603  // Run returns any error encountered at server startup time.
   604  // If the server has already been shutdown, ErrServerClosed is returned.
   605  func (srv *Server) Run(handler Handler) error {
   606  	if err := srv.Start(handler); err != nil {
   607  		return err
   608  	}
   609  	srv.waitForSignals()
   610  	srv.Shutdown()
   611  	return nil
   612  }
   613  
   614  // Start starts the worker server. Once the server has started,
   615  // it pulls tasks off queues and starts a worker goroutine for each task
   616  // and then call Handler to process it.
   617  // Tasks are processed concurrently by the workers up to the number of
   618  // concurrency specified in Config.Concurrency.
   619  //
   620  // Start returns any error encountered at server startup time.
   621  // If the server has already been shutdown, ErrServerClosed is returned.
   622  func (srv *Server) Start(handler Handler) error {
   623  	if handler == nil {
   624  		return fmt.Errorf("asynq: server cannot run with nil handler")
   625  	}
   626  	srv.processor.handler = handler
   627  
   628  	if err := srv.start(); err != nil {
   629  		return err
   630  	}
   631  	srv.logger.Info("[Common] asynq starting processing")
   632  
   633  	srv.heartbeater.start(&srv.wg)
   634  	srv.healthchecker.start(&srv.wg)
   635  	srv.subscriber.start(&srv.wg)
   636  	srv.syncer.start(&srv.wg)
   637  	srv.recoverer.start(&srv.wg)
   638  	srv.forwarder.start(&srv.wg)
   639  	srv.processor.start(&srv.wg)
   640  	srv.janitor.start(&srv.wg)
   641  	srv.aggregator.start(&srv.wg)
   642  	return nil
   643  }
   644  
   645  // Checks server state and returns an error if pre-condition is not met.
   646  // Otherwise it sets the server state to active.
   647  func (srv *Server) start() error {
   648  	srv.state.mu.Lock()
   649  	defer srv.state.mu.Unlock()
   650  	switch srv.state.value {
   651  	case srvStateActive:
   652  		return fmt.Errorf("asynq: the server is already running")
   653  	case srvStateStopped:
   654  		return fmt.Errorf("asynq: the server is in the stopped state. Waiting for shutdown.")
   655  	case srvStateClosed:
   656  		return ErrServerClosed
   657  	}
   658  	srv.state.value = srvStateActive
   659  	return nil
   660  }
   661  
   662  // Shutdown gracefully shuts down the server.
   663  // It gracefully closes all active workers. The server will wait for
   664  // active workers to finish processing tasks for duration specified in Config.ShutdownTimeout.
   665  // If worker didn't finish processing a task during the timeout, the task will be pushed back to Redis.
   666  func (srv *Server) Shutdown() {
   667  	srv.state.mu.Lock()
   668  	if srv.state.value == srvStateNew || srv.state.value == srvStateClosed {
   669  		srv.state.mu.Unlock()
   670  		// server is not running, do nothing and return.
   671  		return
   672  	}
   673  	srv.state.value = srvStateClosed
   674  	srv.state.mu.Unlock()
   675  
   676  	srv.logger.Info("[Common] asynq starting graceful shutdown")
   677  	// Note: The order of shutdown is important.
   678  	// Sender goroutines should be terminated before the receiver goroutines.
   679  	// processor -> syncer (via syncCh)
   680  	// processor -> heartbeater (via starting, finished channels)
   681  	srv.forwarder.shutdown()
   682  	srv.processor.shutdown()
   683  	srv.recoverer.shutdown()
   684  	srv.syncer.shutdown()
   685  	srv.subscriber.shutdown()
   686  	srv.janitor.shutdown()
   687  	srv.aggregator.shutdown()
   688  	srv.healthchecker.shutdown()
   689  	srv.heartbeater.shutdown()
   690  	srv.wg.Wait()
   691  
   692  	if !srv.disableRedisConnClose {
   693  		_ = srv.broker.Close()
   694  	}
   695  	srv.logger.Info("[Common] asynq exiting")
   696  }
   697  
   698  // Stop signals the server to stop pulling new tasks off queues.
   699  // Stop can be used before shutting down the server to ensure that all
   700  // currently active tasks are processed before server shutdown.
   701  //
   702  // Stop does not shutdown the server, make sure to call Shutdown before exit.
   703  func (srv *Server) Stop() {
   704  	srv.state.mu.Lock()
   705  	if srv.state.value != srvStateActive {
   706  		// Invalid calll to Stop, server can only go from Active state to Stopped state.
   707  		srv.state.mu.Unlock()
   708  		return
   709  	}
   710  	srv.state.value = srvStateStopped
   711  	srv.state.mu.Unlock()
   712  
   713  	srv.logger.Info("[Common] asynq stopping processor")
   714  	srv.processor.stop()
   715  	srv.logger.Info("[Common] asynq processor stopped")
   716  }