github.com/wfusion/gofusion@v1.1.14/cron/asynq.go (about)

     1  package cron
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"math/rand"
     7  	"reflect"
     8  	"strings"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/pkg/errors"
    13  	"github.com/robfig/cron/v3"
    14  	"go.uber.org/multierr"
    15  
    16  	"github.com/wfusion/gofusion/common/constant"
    17  	"github.com/wfusion/gofusion/common/infra/asynq"
    18  	"github.com/wfusion/gofusion/common/utils"
    19  	"github.com/wfusion/gofusion/common/utils/inspect"
    20  	"github.com/wfusion/gofusion/common/utils/serialize/json"
    21  	"github.com/wfusion/gofusion/config"
    22  	"github.com/wfusion/gofusion/lock"
    23  	"github.com/wfusion/gofusion/log"
    24  	"github.com/wfusion/gofusion/redis"
    25  	"github.com/wfusion/gofusion/routine"
    26  
    27  	rdsDrv "github.com/redis/go-redis/v9"
    28  
    29  	fusCtx "github.com/wfusion/gofusion/context"
    30  )
    31  
    32  const (
    33  	asyncqTaskPayloadField  = "payload"
    34  	asyncqTaskTypenameField = "typename"
    35  )
    36  
    37  var (
    38  	asynqLoggerType                     = reflect.TypeOf((*asynq.Logger)(nil)).Elem()
    39  	asynqPeriodicTaskConfigProviderType = reflect.TypeOf((*asynq.PeriodicTaskConfigProvider)(nil)).Elem()
    40  )
    41  
    42  type asynqRouter struct {
    43  	*asynq.ServeMux
    44  
    45  	appName string
    46  
    47  	l sync.RWMutex
    48  	n string
    49  	c *Conf
    50  
    51  	mws     []asynq.MiddlewareFunc
    52  	logger  asynq.Logger
    53  	locker  lock.Lockable
    54  	server  *asynq.Server
    55  	trigger *asynq.PeriodicTaskManager
    56  
    57  	id                    string
    58  	lockDurations         map[string]time.Duration
    59  	shouldShutdownServer  bool
    60  	shouldShutdownTrigger bool
    61  }
    62  
    63  func newAsynq(ctx context.Context, appName, name string, conf *Conf) IRouter {
    64  	r := &asynqRouter{
    65  		appName:               appName,
    66  		n:                     name,
    67  		c:                     conf,
    68  		lockDurations:         make(map[string]time.Duration, len(conf.Tasks)),
    69  		shouldShutdownServer:  true,
    70  		shouldShutdownTrigger: true,
    71  	}
    72  	if utils.IsStrBlank(r.c.Queue) {
    73  		r.c.Queue = r.defaultQueue()
    74  	}
    75  
    76  	var rdsCli rdsDrv.UniversalClient
    77  	switch conf.InstanceType {
    78  	case instanceTypeRedis:
    79  		rdsCli = redis.Use(ctx, conf.Instance, redis.AppName(appName))
    80  	case instanceTypeMysql:
    81  		fallthrough
    82  	default:
    83  		panic(errors.Errorf("unknown instance type: %s", conf.InstanceType))
    84  	}
    85  
    86  	if r.logger == nil && utils.IsStrNotBlank(conf.Logger) {
    87  		loggerType := inspect.TypeOf(conf.Logger)
    88  		loggerValue := reflect.New(loggerType)
    89  		if loggerValue.Type().Implements(customLoggerType) {
    90  			logger := log.Use(conf.LogInstance, log.AppName(appName))
    91  			loggerValue.Interface().(customLogger).Init(logger, appName, name)
    92  		}
    93  		r.logger = loggerValue.Convert(asynqLoggerType).Interface().(asynq.Logger)
    94  	}
    95  	if r.locker == nil && utils.IsStrNotBlank(conf.LockInstance) {
    96  		r.locker = lock.Use(conf.LockInstance, lock.AppName(appName))
    97  		if r.locker == nil {
    98  			panic(errors.Errorf("locker instance not found: %s", conf.LockInstance))
    99  		}
   100  	}
   101  
   102  	var provider asynq.PeriodicTaskConfigProvider
   103  	if utils.IsStrNotBlank(conf.TaskLoader) {
   104  		loaderType := inspect.TypeOf(conf.TaskLoader)
   105  		if loaderType == nil {
   106  			panic(errors.Errorf("%s not found", conf.TaskLoader))
   107  		}
   108  		provider = reflect.New(loaderType).
   109  			Convert(asynqPeriodicTaskConfigProviderType).Interface().(asynq.PeriodicTaskConfigProvider)
   110  	}
   111  
   112  	logLevel := asynq.LogLevel(0)
   113  	utils.MustSuccess(logLevel.Set(conf.LogLevel))
   114  
   115  	wrapper := &asynqWrapper{r: r, n: r.n, appName: appName, cli: rdsCli, provider: provider}
   116  	if conf.Trigger {
   117  		r.initTrigger(ctx, wrapper, logLevel)
   118  	}
   119  	if conf.Server {
   120  		r.initServer(ctx, wrapper, logLevel)
   121  	}
   122  
   123  	return r
   124  }
   125  
   126  func (a *asynqRouter) Use(mws ...routerMiddleware) {
   127  	for _, mw := range mws {
   128  		a.mws = append(a.mws, a.adaptMiddleware(mw))
   129  	}
   130  }
   131  
   132  func (a *asynqRouter) Handle(pattern string, fn any, _ ...utils.OptionExtender) {
   133  	if !a.c.Server {
   134  		a.debug(context.Background(), "cannot handle task %s: client is not enabled", a.n)
   135  		return
   136  	}
   137  
   138  	a.ServeMux.Handle(a.formatTaskName(pattern), a.adaptAsynqHandlerFunc(fn))
   139  }
   140  
   141  func (a *asynqRouter) Serve() (err error) {
   142  	defer a.info(context.Background(), "scheduler is running")
   143  
   144  	if a.c.Server {
   145  		a.ServeMux.Use(a.gatewayMiddleware)
   146  		a.ServeMux.Use(a.mws...)
   147  	}
   148  
   149  	if a.c.Trigger && !a.c.Server {
   150  		return a.trigger.Run()
   151  	}
   152  	if !a.c.Trigger && a.c.Server {
   153  		return a.server.Run(a.ServeMux)
   154  	}
   155  
   156  	a.shouldShutdownServer = false
   157  	if err = a.trigger.Start(); err != nil {
   158  		return
   159  	}
   160  
   161  	return a.server.Run(a.ServeMux)
   162  }
   163  
   164  func (a *asynqRouter) Start() (err error) {
   165  	defer a.info(context.Background(), "scheduler started")
   166  
   167  	if a.c.Trigger {
   168  		if err = a.trigger.Start(); err != nil {
   169  			return
   170  		}
   171  	}
   172  
   173  	if a.c.Server {
   174  		a.ServeMux.Use(a.gatewayMiddleware)
   175  		a.ServeMux.Use(a.mws...)
   176  		if err = a.server.Start(a.ServeMux); err != nil {
   177  			return
   178  		}
   179  	}
   180  
   181  	return
   182  }
   183  
   184  func (a *asynqRouter) shutdown() (err error) {
   185  	if a.c.Trigger {
   186  		_, catchErr := utils.Catch(a.trigger.Shutdown)
   187  		err = multierr.Append(err, errors.Cause(catchErr))
   188  	}
   189  	if a.c.Server {
   190  		_, catchErr := utils.Catch(a.server.Shutdown)
   191  		err = multierr.Append(err, errors.Cause(catchErr))
   192  	}
   193  	return
   194  }
   195  
   196  func (a *asynqRouter) initTrigger(ctx context.Context, wrapper *asynqWrapper, logLevel asynq.LogLevel) {
   197  	a.trigger = utils.Must(
   198  		asynq.NewPeriodicTaskManager(asynq.PeriodicTaskManagerOpts{
   199  			PeriodicTaskConfigProvider: wrapper,
   200  			RedisConnOpt:               wrapper,
   201  			SchedulerOpts: &asynq.SchedulerOpts{
   202  				Logger:                a.logger,
   203  				LogLevel:              logLevel,
   204  				Location:              utils.Must(time.LoadLocation(a.c.Timezone)),
   205  				DisableRedisConnClose: true,
   206  				PreEnqueueFunc:        a.preEnqueueFunc(ctx),
   207  				PostEnqueueFunc:       a.postEnqueueFunc(ctx),
   208  				EnqueueErrorHandler: func(task *asynq.Task, opts []asynq.Option, err error) {
   209  					ignored := []error{errDiscardMessage}
   210  					if a.locker == nil {
   211  						ignored = append(ignored, asynq.ErrDuplicateTask, asynq.ErrTaskIDConflict)
   212  					}
   213  					if err = utils.ErrIgnore(err, ignored...); err == nil {
   214  						return
   215  					}
   216  					taskName := "unknown"
   217  					if task != nil {
   218  						taskName = a.unformatTaskName(task.Type())
   219  					}
   220  					a.warn(ctx, "enqueue task %s failed: %s", taskName, err)
   221  				},
   222  			},
   223  			SyncInterval: utils.Must(time.ParseDuration(a.c.RefreshTasksInterval)),
   224  		}),
   225  	)
   226  	a.id = a.trigger.ID()
   227  }
   228  
   229  func (a *asynqRouter) initServer(ctx context.Context, wrapper *asynqWrapper, logLevel asynq.LogLevel) {
   230  	a.ServeMux = asynq.NewServeMux()
   231  	for pattern, taskCfg := range a.c.Tasks {
   232  		if utils.IsStrBlank(taskCfg.Callback) {
   233  			continue
   234  		}
   235  		handler := *(*routerHandleFunc)(inspect.FuncOf(taskCfg.Callback))
   236  		a.ServeMux.Handle(a.formatTaskName(pattern), a.adaptAsynqHandlerFunc(handler))
   237  	}
   238  
   239  	asynqCfg := asynq.Config{
   240  		Concurrency:    a.c.ServerConcurrency,
   241  		BaseContext:    context.Background,
   242  		RetryDelayFunc: asynq.DefaultRetryDelayFunc,
   243  		IsFailure:      func(err error) bool { return !errors.Is(err, errDiscardMessage) },
   244  		Queues:         nil,
   245  		StrictPriority: false,
   246  		ErrorHandler: asynq.ErrorHandlerFunc(func(ctx context.Context, task *asynq.Task, err error) {
   247  			taskName := "unknown"
   248  			if task != nil {
   249  				taskName = a.unformatTaskName(task.Type())
   250  			}
   251  			a.info(ctx, "handle task %s message error %s", taskName, err)
   252  		}),
   253  		Logger:          a.logger,
   254  		LogLevel:        logLevel,
   255  		ShutdownTimeout: 8 * time.Second,
   256  		HealthCheckFunc: func(err error) {
   257  			if err != nil {
   258  				a.warn(ctx, "health check check failed: %s", err)
   259  			}
   260  		},
   261  		HealthCheckInterval:      15 * time.Second,
   262  		DelayedTaskCheckInterval: 5 * time.Second,
   263  		GroupGracePeriod:         1 * time.Minute,
   264  		GroupMaxDelay:            0,
   265  		GroupMaxSize:             0,
   266  		GroupAggregator:          nil,
   267  		DisableRedisConnClose:    true,
   268  	}
   269  	if utils.IsStrNotBlank(a.c.Queue) {
   270  		asynqCfg.Queues = map[string]int{a.c.Queue: 3}
   271  	}
   272  
   273  	a.server = asynq.NewServer(wrapper, asynqCfg)
   274  }
   275  
   276  func (a *asynqRouter) preEnqueueFunc(ctx context.Context) func(*asynq.Task, []asynq.Option) error {
   277  	return func(task *asynq.Task, opts []asynq.Option) (err error) {
   278  		// when locker is disabled, we cannot determine which message should be discarded
   279  		if a.locker == nil {
   280  			return
   281  		}
   282  
   283  		taskName := a.unformatTaskName(task.Type())
   284  		lockKey := a.formatLockKey(taskName)
   285  		if err = a.locker.Lock(ctx, lockKey, lock.Expire(tolerantOfTimeNotSync), lock.ReentrantKey(a.id)); err == nil {
   286  			a.info(ctx, "pre enqueue task %s success", taskName)
   287  			return
   288  		}
   289  
   290  		err = utils.ErrIgnore(err, lock.ErrTimeout, lock.ErrContextDone)
   291  		if err == nil {
   292  			a.debug(ctx, "pre enqueue discard task %s", taskName)
   293  			return errDiscardMessage
   294  		}
   295  
   296  		a.warn(ctx, "pre enqueue task %s failed: %s", taskName, err)
   297  		return
   298  	}
   299  }
   300  
   301  func (a *asynqRouter) postEnqueueFunc(ctx context.Context) func(info *asynq.TaskInfo, err error) {
   302  	return func(info *asynq.TaskInfo, err error) {
   303  		// release lock
   304  		if a.locker != nil {
   305  			defer routine.Go(a.releaseCronTaskLock, routine.Args(ctx, info), routine.AppName(a.appName))
   306  		}
   307  
   308  		ignored := []error{errDiscardMessage}
   309  		if a.locker == nil {
   310  			ignored = append(ignored, asynq.ErrDuplicateTask, asynq.ErrTaskIDConflict)
   311  		}
   312  
   313  		if err = utils.ErrIgnore(err, ignored...); err == nil {
   314  			return
   315  		}
   316  		taskName := "unknown"
   317  		if info != nil {
   318  			taskName = a.unformatTaskName(info.Type)
   319  		}
   320  		a.debug(ctx, "post enqueue task %s failed: %s", taskName, err)
   321  	}
   322  }
   323  
   324  func (a *asynqRouter) releaseCronTaskLock(ctx context.Context, info *asynq.TaskInfo) {
   325  	if info == nil {
   326  		return
   327  	}
   328  	taskName := a.unformatTaskName(info.Type)
   329  
   330  	// 90 ~ 100ms jitter
   331  	jitter := 90*time.Millisecond + time.Duration(float64(10*time.Millisecond)*rand.Float64())
   332  
   333  	a.l.RLock()
   334  	lockTime := a.lockDurations[info.Type]
   335  	a.l.RUnlock()
   336  
   337  	// prevent a negative tolerant
   338  	tolerant := utils.Min(tolerantOfTimeNotSync, lockTime) - jitter
   339  	tolerant = utils.Max(tolerant, 500*time.Millisecond)
   340  	timer := time.NewTimer(tolerant)
   341  	defer timer.Stop()
   342  
   343  	var e error
   344  	defer func() {
   345  		if e != nil {
   346  			a.warn(ctx, "post enqueue task %s release lock failed: %s", taskName, e)
   347  		}
   348  	}()
   349  
   350  	now := time.Now()
   351  	unlockKey := a.formatLockKey(taskName)
   352  	for {
   353  		select {
   354  		case <-ctx.Done():
   355  			a.debug(ctx, "post enqueue task %s release lock: context done", taskName)
   356  			e = a.locker.Unlock(ctx, unlockKey, lock.ReentrantKey(a.id))
   357  			return
   358  		case <-timer.C:
   359  			e = a.locker.Unlock(ctx, unlockKey, lock.ReentrantKey(a.id))
   360  			return
   361  		default:
   362  			a.l.RLock()
   363  			newLockTime := a.lockDurations[info.Type]
   364  			a.l.RUnlock()
   365  			if newLockTime != lockTime {
   366  				lockTime = newLockTime
   367  				tolerant = utils.Min(tolerantOfTimeNotSync, lockTime) - jitter
   368  				tolerant = utils.Max(tolerant, 500*time.Millisecond)
   369  				tolerant = utils.Max(0, tolerant-time.Since(now))
   370  				timer.Reset(tolerant)
   371  			}
   372  		}
   373  	}
   374  }
   375  
   376  func (a *asynqRouter) gatewayMiddleware(next asynq.Handler) asynq.Handler {
   377  	return asynq.HandlerFunc(func(ctx context.Context, raw *asynq.Task) (err error) {
   378  		taskName := a.unformatTaskName(raw.Type())
   379  		inspect.SetField(raw, asyncqTaskTypenameField, taskName)
   380  		if utils.IsStrBlank(fusCtx.GetTraceID(ctx)) {
   381  			ctx = fusCtx.SetTraceID(ctx, utils.NginxID())
   382  		}
   383  		if utils.IsStrBlank(fusCtx.GetCronTaskName(ctx)) {
   384  			ctx = fusCtx.SetCronTaskName(ctx, taskName)
   385  		}
   386  		return next.ProcessTask(ctx, raw)
   387  	})
   388  }
   389  
   390  func (a *asynqRouter) adaptMiddleware(mw routerMiddleware) asynq.MiddlewareFunc {
   391  	return func(asynqNext asynq.Handler) asynq.Handler {
   392  		next := mw(a.adaptRouterHandlerFunc(asynqNext))
   393  		return a.adaptAsynqHandlerFunc(next)
   394  	}
   395  }
   396  
   397  // adaptAsynqHandlerFunc support function signature
   398  // - func(ctx context.Context)
   399  // - func(ctx context.Context) error
   400  // - func(ctx context.Context, args json.Serializable)
   401  // - func(ctx context.Context, args *json.Serializable) error
   402  func (a *asynqRouter) adaptAsynqHandlerFunc(h any) asynq.HandlerFunc {
   403  	if fn, ok := h.(routerHandleFunc); ok {
   404  		return func(ctx context.Context, raw *asynq.Task) (err error) {
   405  			return fn(ctx, a.newTask(raw))
   406  		}
   407  	}
   408  	if fn, ok := h.(func(ctx context.Context, task Task) (err error)); ok {
   409  		return func(ctx context.Context, raw *asynq.Task) (err error) {
   410  			return fn(ctx, a.newTask(raw))
   411  		}
   412  	}
   413  
   414  	var (
   415  		hasArg          bool
   416  		argType         reflect.Type
   417  		argTypePtrDepth int
   418  	)
   419  	if reflect.TypeOf(h).NumIn() > 1 {
   420  		argType = reflect.TypeOf(h).In(1)
   421  		for argType.Kind() == reflect.Ptr {
   422  			argType = argType.Elem()
   423  			argTypePtrDepth++
   424  		}
   425  		hasArg = true
   426  	}
   427  
   428  	fn := utils.WrapFunc1[error](h)
   429  	return func(ctx context.Context, raw *asynq.Task) (err error) {
   430  		if !hasArg {
   431  			return fn(ctx)
   432  		}
   433  		arg := reflect.New(argType)
   434  		payload := raw.Payload()
   435  		if len(payload) == 0 {
   436  			payload = []byte("null")
   437  		}
   438  		if err = json.Unmarshal(payload, arg.Interface()); err != nil {
   439  			return
   440  		}
   441  		arg = arg.Elem()
   442  		for i := 0; i < argTypePtrDepth; i++ {
   443  			arg = arg.Addr()
   444  		}
   445  
   446  		return fn(ctx, arg.Interface())
   447  	}
   448  }
   449  
   450  func (a *asynqRouter) adaptRouterHandlerFunc(h asynq.Handler) routerHandleFunc {
   451  	return func(ctx context.Context, raw Task) (err error) {
   452  		return h.ProcessTask(ctx, a.newAsynqTask(raw))
   453  	}
   454  }
   455  
   456  func (a *asynqRouter) defaultQueue() (result string) {
   457  	return fmt.Sprintf("%s:cron", config.Use(a.appName).AppName())
   458  }
   459  func (a *asynqRouter) formatLockKey(taskName string) string {
   460  	return fmt.Sprintf("cron_%s", taskName)
   461  }
   462  func (a *asynqRouter) formatTaskName(taskName string) (result string) {
   463  	return fmt.Sprintf("%s:cron:%s", config.Use(a.appName).AppName(), taskName)
   464  }
   465  func (a *asynqRouter) unformatTaskName(taskName string) (result string) {
   466  	return strings.TrimPrefix(taskName, fmt.Sprintf("%s:cron:", config.Use(a.appName).AppName()))
   467  }
   468  
   469  func (a *asynqRouter) newTask(raw *asynq.Task) (t Task) {
   470  	return &task{
   471  		id:         raw.Type(),
   472  		name:       raw.Type(),
   473  		payload:    raw.Payload(),
   474  		rawMessage: raw,
   475  	}
   476  }
   477  
   478  func (a *asynqRouter) newAsynqTask(raw Task) (t *asynq.Task) {
   479  	return raw.RawMessage().(*asynq.Task)
   480  }
   481  
   482  type asynqWrapper struct {
   483  	appName string
   484  
   485  	r        *asynqRouter
   486  	n        string
   487  	cli      rdsDrv.UniversalClient
   488  	provider asynq.PeriodicTaskConfigProvider
   489  }
   490  
   491  func (a *asynqWrapper) MakeRedisClient() any {
   492  	return a.cli
   493  }
   494  
   495  func (a *asynqWrapper) GetConfigs() (result []*asynq.PeriodicTaskConfig, err error) {
   496  	result, err = a.getConfigs()
   497  	if err != nil {
   498  		return
   499  	}
   500  
   501  	a.r.l.Lock()
   502  	defer a.r.l.Unlock()
   503  	for _, cfg := range result {
   504  		// renaming
   505  		taskName := inspect.GetField[string](cfg.Task, asyncqTaskTypenameField)
   506  		inspect.SetField(cfg.Task, asyncqTaskTypenameField, a.r.formatTaskName(taskName))
   507  
   508  		name := cfg.Task.Type()
   509  		a.r.lockDurations[name], err = a.getTaskExecuteInterval(cfg.Cronspec)
   510  		if err != nil {
   511  			return
   512  		}
   513  	}
   514  
   515  	return
   516  }
   517  
   518  func (a *asynqWrapper) getConfigs() (result []*asynq.PeriodicTaskConfig, err error) {
   519  	if a.provider != nil {
   520  		result, err = a.provider.GetConfigs()
   521  		if err != nil {
   522  			return
   523  		}
   524  	}
   525  
   526  	var confs map[string]*Conf
   527  	if err = config.Use(a.appName).LoadComponentConfig(config.ComponentCron, &confs); err != nil {
   528  		return
   529  	}
   530  	conf, ok := confs[a.n]
   531  	if !ok {
   532  		return nil, errors.Errorf("%s cron config not found", a.n)
   533  	}
   534  
   535  	loc, _ := time.LoadLocation(a.r.c.Timezone)
   536  	if loc == nil {
   537  		loc = constant.DefaultLocation()
   538  	}
   539  
   540  	queue := conf.Queue
   541  	if utils.IsStrBlank(queue) {
   542  		queue = a.r.c.Queue
   543  	}
   544  	for name, cfg := range conf.Tasks {
   545  		var (
   546  			deadline          time.Time
   547  			interval, timeout time.Duration
   548  			opts              []asynq.Option
   549  		)
   550  		if interval, err = a.getTaskExecuteInterval(cfg.Crontab); err != nil {
   551  			return
   552  		}
   553  		if utils.IsStrNotBlank(cfg.Timeout) {
   554  			if timeout, err = time.ParseDuration(cfg.Timeout); err != nil {
   555  				return
   556  			}
   557  			opts = append(opts, asynq.Timeout(timeout))
   558  		} else {
   559  			opts = append(opts, asynq.Timeout(interval))
   560  		}
   561  		if utils.IsStrNotBlank(cfg.Deadline) {
   562  			if deadline, err = time.ParseInLocation(constant.StdTimeLayout, cfg.Deadline, loc); err != nil {
   563  				return
   564  			}
   565  			opts = append(opts, asynq.Deadline(deadline))
   566  		}
   567  
   568  		result = append(result, &asynq.PeriodicTaskConfig{
   569  			Cronspec: cfg.Crontab,
   570  			Task:     asynq.NewTask(name, []byte(cfg.Payload)),
   571  			Opts: append(opts, []asynq.Option{
   572  				asynq.TaskID(name),
   573  				asynq.Unique(utils.Min(interval, tolerantOfTimeNotSync)),
   574  				asynq.Queue(queue),
   575  				asynq.MaxRetry(utils.Max(0, cfg.Retry)),
   576  			}...),
   577  		})
   578  	}
   579  	return
   580  }
   581  
   582  func (a *asynqWrapper) getTaskExecuteInterval(spec string) (interval time.Duration, err error) {
   583  	now := time.Now()
   584  	scheduler, err := cron.ParseStandard(spec)
   585  	if err != nil {
   586  		return 0, err
   587  	}
   588  	next := scheduler.Next(now)
   589  	interval = scheduler.Next(next).Sub(next)
   590  	return
   591  }
   592  
   593  func init() {
   594  	rand.Seed(time.Now().UnixMicro())
   595  }