github.com/wfusion/gofusion@v1.1.14/cron/asynq.go (about) 1 package cron 2 3 import ( 4 "context" 5 "fmt" 6 "math/rand" 7 "reflect" 8 "strings" 9 "sync" 10 "time" 11 12 "github.com/pkg/errors" 13 "github.com/robfig/cron/v3" 14 "go.uber.org/multierr" 15 16 "github.com/wfusion/gofusion/common/constant" 17 "github.com/wfusion/gofusion/common/infra/asynq" 18 "github.com/wfusion/gofusion/common/utils" 19 "github.com/wfusion/gofusion/common/utils/inspect" 20 "github.com/wfusion/gofusion/common/utils/serialize/json" 21 "github.com/wfusion/gofusion/config" 22 "github.com/wfusion/gofusion/lock" 23 "github.com/wfusion/gofusion/log" 24 "github.com/wfusion/gofusion/redis" 25 "github.com/wfusion/gofusion/routine" 26 27 rdsDrv "github.com/redis/go-redis/v9" 28 29 fusCtx "github.com/wfusion/gofusion/context" 30 ) 31 32 const ( 33 asyncqTaskPayloadField = "payload" 34 asyncqTaskTypenameField = "typename" 35 ) 36 37 var ( 38 asynqLoggerType = reflect.TypeOf((*asynq.Logger)(nil)).Elem() 39 asynqPeriodicTaskConfigProviderType = reflect.TypeOf((*asynq.PeriodicTaskConfigProvider)(nil)).Elem() 40 ) 41 42 type asynqRouter struct { 43 *asynq.ServeMux 44 45 appName string 46 47 l sync.RWMutex 48 n string 49 c *Conf 50 51 mws []asynq.MiddlewareFunc 52 logger asynq.Logger 53 locker lock.Lockable 54 server *asynq.Server 55 trigger *asynq.PeriodicTaskManager 56 57 id string 58 lockDurations map[string]time.Duration 59 shouldShutdownServer bool 60 shouldShutdownTrigger bool 61 } 62 63 func newAsynq(ctx context.Context, appName, name string, conf *Conf) IRouter { 64 r := &asynqRouter{ 65 appName: appName, 66 n: name, 67 c: conf, 68 lockDurations: make(map[string]time.Duration, len(conf.Tasks)), 69 shouldShutdownServer: true, 70 shouldShutdownTrigger: true, 71 } 72 if utils.IsStrBlank(r.c.Queue) { 73 r.c.Queue = r.defaultQueue() 74 } 75 76 var rdsCli rdsDrv.UniversalClient 77 switch conf.InstanceType { 78 case instanceTypeRedis: 79 rdsCli = redis.Use(ctx, conf.Instance, redis.AppName(appName)) 80 case instanceTypeMysql: 81 fallthrough 82 default: 83 panic(errors.Errorf("unknown instance type: %s", conf.InstanceType)) 84 } 85 86 if r.logger == nil && utils.IsStrNotBlank(conf.Logger) { 87 loggerType := inspect.TypeOf(conf.Logger) 88 loggerValue := reflect.New(loggerType) 89 if loggerValue.Type().Implements(customLoggerType) { 90 logger := log.Use(conf.LogInstance, log.AppName(appName)) 91 loggerValue.Interface().(customLogger).Init(logger, appName, name) 92 } 93 r.logger = loggerValue.Convert(asynqLoggerType).Interface().(asynq.Logger) 94 } 95 if r.locker == nil && utils.IsStrNotBlank(conf.LockInstance) { 96 r.locker = lock.Use(conf.LockInstance, lock.AppName(appName)) 97 if r.locker == nil { 98 panic(errors.Errorf("locker instance not found: %s", conf.LockInstance)) 99 } 100 } 101 102 var provider asynq.PeriodicTaskConfigProvider 103 if utils.IsStrNotBlank(conf.TaskLoader) { 104 loaderType := inspect.TypeOf(conf.TaskLoader) 105 if loaderType == nil { 106 panic(errors.Errorf("%s not found", conf.TaskLoader)) 107 } 108 provider = reflect.New(loaderType). 109 Convert(asynqPeriodicTaskConfigProviderType).Interface().(asynq.PeriodicTaskConfigProvider) 110 } 111 112 logLevel := asynq.LogLevel(0) 113 utils.MustSuccess(logLevel.Set(conf.LogLevel)) 114 115 wrapper := &asynqWrapper{r: r, n: r.n, appName: appName, cli: rdsCli, provider: provider} 116 if conf.Trigger { 117 r.initTrigger(ctx, wrapper, logLevel) 118 } 119 if conf.Server { 120 r.initServer(ctx, wrapper, logLevel) 121 } 122 123 return r 124 } 125 126 func (a *asynqRouter) Use(mws ...routerMiddleware) { 127 for _, mw := range mws { 128 a.mws = append(a.mws, a.adaptMiddleware(mw)) 129 } 130 } 131 132 func (a *asynqRouter) Handle(pattern string, fn any, _ ...utils.OptionExtender) { 133 if !a.c.Server { 134 a.debug(context.Background(), "cannot handle task %s: client is not enabled", a.n) 135 return 136 } 137 138 a.ServeMux.Handle(a.formatTaskName(pattern), a.adaptAsynqHandlerFunc(fn)) 139 } 140 141 func (a *asynqRouter) Serve() (err error) { 142 defer a.info(context.Background(), "scheduler is running") 143 144 if a.c.Server { 145 a.ServeMux.Use(a.gatewayMiddleware) 146 a.ServeMux.Use(a.mws...) 147 } 148 149 if a.c.Trigger && !a.c.Server { 150 return a.trigger.Run() 151 } 152 if !a.c.Trigger && a.c.Server { 153 return a.server.Run(a.ServeMux) 154 } 155 156 a.shouldShutdownServer = false 157 if err = a.trigger.Start(); err != nil { 158 return 159 } 160 161 return a.server.Run(a.ServeMux) 162 } 163 164 func (a *asynqRouter) Start() (err error) { 165 defer a.info(context.Background(), "scheduler started") 166 167 if a.c.Trigger { 168 if err = a.trigger.Start(); err != nil { 169 return 170 } 171 } 172 173 if a.c.Server { 174 a.ServeMux.Use(a.gatewayMiddleware) 175 a.ServeMux.Use(a.mws...) 176 if err = a.server.Start(a.ServeMux); err != nil { 177 return 178 } 179 } 180 181 return 182 } 183 184 func (a *asynqRouter) shutdown() (err error) { 185 if a.c.Trigger { 186 _, catchErr := utils.Catch(a.trigger.Shutdown) 187 err = multierr.Append(err, errors.Cause(catchErr)) 188 } 189 if a.c.Server { 190 _, catchErr := utils.Catch(a.server.Shutdown) 191 err = multierr.Append(err, errors.Cause(catchErr)) 192 } 193 return 194 } 195 196 func (a *asynqRouter) initTrigger(ctx context.Context, wrapper *asynqWrapper, logLevel asynq.LogLevel) { 197 a.trigger = utils.Must( 198 asynq.NewPeriodicTaskManager(asynq.PeriodicTaskManagerOpts{ 199 PeriodicTaskConfigProvider: wrapper, 200 RedisConnOpt: wrapper, 201 SchedulerOpts: &asynq.SchedulerOpts{ 202 Logger: a.logger, 203 LogLevel: logLevel, 204 Location: utils.Must(time.LoadLocation(a.c.Timezone)), 205 DisableRedisConnClose: true, 206 PreEnqueueFunc: a.preEnqueueFunc(ctx), 207 PostEnqueueFunc: a.postEnqueueFunc(ctx), 208 EnqueueErrorHandler: func(task *asynq.Task, opts []asynq.Option, err error) { 209 ignored := []error{errDiscardMessage} 210 if a.locker == nil { 211 ignored = append(ignored, asynq.ErrDuplicateTask, asynq.ErrTaskIDConflict) 212 } 213 if err = utils.ErrIgnore(err, ignored...); err == nil { 214 return 215 } 216 taskName := "unknown" 217 if task != nil { 218 taskName = a.unformatTaskName(task.Type()) 219 } 220 a.warn(ctx, "enqueue task %s failed: %s", taskName, err) 221 }, 222 }, 223 SyncInterval: utils.Must(time.ParseDuration(a.c.RefreshTasksInterval)), 224 }), 225 ) 226 a.id = a.trigger.ID() 227 } 228 229 func (a *asynqRouter) initServer(ctx context.Context, wrapper *asynqWrapper, logLevel asynq.LogLevel) { 230 a.ServeMux = asynq.NewServeMux() 231 for pattern, taskCfg := range a.c.Tasks { 232 if utils.IsStrBlank(taskCfg.Callback) { 233 continue 234 } 235 handler := *(*routerHandleFunc)(inspect.FuncOf(taskCfg.Callback)) 236 a.ServeMux.Handle(a.formatTaskName(pattern), a.adaptAsynqHandlerFunc(handler)) 237 } 238 239 asynqCfg := asynq.Config{ 240 Concurrency: a.c.ServerConcurrency, 241 BaseContext: context.Background, 242 RetryDelayFunc: asynq.DefaultRetryDelayFunc, 243 IsFailure: func(err error) bool { return !errors.Is(err, errDiscardMessage) }, 244 Queues: nil, 245 StrictPriority: false, 246 ErrorHandler: asynq.ErrorHandlerFunc(func(ctx context.Context, task *asynq.Task, err error) { 247 taskName := "unknown" 248 if task != nil { 249 taskName = a.unformatTaskName(task.Type()) 250 } 251 a.info(ctx, "handle task %s message error %s", taskName, err) 252 }), 253 Logger: a.logger, 254 LogLevel: logLevel, 255 ShutdownTimeout: 8 * time.Second, 256 HealthCheckFunc: func(err error) { 257 if err != nil { 258 a.warn(ctx, "health check check failed: %s", err) 259 } 260 }, 261 HealthCheckInterval: 15 * time.Second, 262 DelayedTaskCheckInterval: 5 * time.Second, 263 GroupGracePeriod: 1 * time.Minute, 264 GroupMaxDelay: 0, 265 GroupMaxSize: 0, 266 GroupAggregator: nil, 267 DisableRedisConnClose: true, 268 } 269 if utils.IsStrNotBlank(a.c.Queue) { 270 asynqCfg.Queues = map[string]int{a.c.Queue: 3} 271 } 272 273 a.server = asynq.NewServer(wrapper, asynqCfg) 274 } 275 276 func (a *asynqRouter) preEnqueueFunc(ctx context.Context) func(*asynq.Task, []asynq.Option) error { 277 return func(task *asynq.Task, opts []asynq.Option) (err error) { 278 // when locker is disabled, we cannot determine which message should be discarded 279 if a.locker == nil { 280 return 281 } 282 283 taskName := a.unformatTaskName(task.Type()) 284 lockKey := a.formatLockKey(taskName) 285 if err = a.locker.Lock(ctx, lockKey, lock.Expire(tolerantOfTimeNotSync), lock.ReentrantKey(a.id)); err == nil { 286 a.info(ctx, "pre enqueue task %s success", taskName) 287 return 288 } 289 290 err = utils.ErrIgnore(err, lock.ErrTimeout, lock.ErrContextDone) 291 if err == nil { 292 a.debug(ctx, "pre enqueue discard task %s", taskName) 293 return errDiscardMessage 294 } 295 296 a.warn(ctx, "pre enqueue task %s failed: %s", taskName, err) 297 return 298 } 299 } 300 301 func (a *asynqRouter) postEnqueueFunc(ctx context.Context) func(info *asynq.TaskInfo, err error) { 302 return func(info *asynq.TaskInfo, err error) { 303 // release lock 304 if a.locker != nil { 305 defer routine.Go(a.releaseCronTaskLock, routine.Args(ctx, info), routine.AppName(a.appName)) 306 } 307 308 ignored := []error{errDiscardMessage} 309 if a.locker == nil { 310 ignored = append(ignored, asynq.ErrDuplicateTask, asynq.ErrTaskIDConflict) 311 } 312 313 if err = utils.ErrIgnore(err, ignored...); err == nil { 314 return 315 } 316 taskName := "unknown" 317 if info != nil { 318 taskName = a.unformatTaskName(info.Type) 319 } 320 a.debug(ctx, "post enqueue task %s failed: %s", taskName, err) 321 } 322 } 323 324 func (a *asynqRouter) releaseCronTaskLock(ctx context.Context, info *asynq.TaskInfo) { 325 if info == nil { 326 return 327 } 328 taskName := a.unformatTaskName(info.Type) 329 330 // 90 ~ 100ms jitter 331 jitter := 90*time.Millisecond + time.Duration(float64(10*time.Millisecond)*rand.Float64()) 332 333 a.l.RLock() 334 lockTime := a.lockDurations[info.Type] 335 a.l.RUnlock() 336 337 // prevent a negative tolerant 338 tolerant := utils.Min(tolerantOfTimeNotSync, lockTime) - jitter 339 tolerant = utils.Max(tolerant, 500*time.Millisecond) 340 timer := time.NewTimer(tolerant) 341 defer timer.Stop() 342 343 var e error 344 defer func() { 345 if e != nil { 346 a.warn(ctx, "post enqueue task %s release lock failed: %s", taskName, e) 347 } 348 }() 349 350 now := time.Now() 351 unlockKey := a.formatLockKey(taskName) 352 for { 353 select { 354 case <-ctx.Done(): 355 a.debug(ctx, "post enqueue task %s release lock: context done", taskName) 356 e = a.locker.Unlock(ctx, unlockKey, lock.ReentrantKey(a.id)) 357 return 358 case <-timer.C: 359 e = a.locker.Unlock(ctx, unlockKey, lock.ReentrantKey(a.id)) 360 return 361 default: 362 a.l.RLock() 363 newLockTime := a.lockDurations[info.Type] 364 a.l.RUnlock() 365 if newLockTime != lockTime { 366 lockTime = newLockTime 367 tolerant = utils.Min(tolerantOfTimeNotSync, lockTime) - jitter 368 tolerant = utils.Max(tolerant, 500*time.Millisecond) 369 tolerant = utils.Max(0, tolerant-time.Since(now)) 370 timer.Reset(tolerant) 371 } 372 } 373 } 374 } 375 376 func (a *asynqRouter) gatewayMiddleware(next asynq.Handler) asynq.Handler { 377 return asynq.HandlerFunc(func(ctx context.Context, raw *asynq.Task) (err error) { 378 taskName := a.unformatTaskName(raw.Type()) 379 inspect.SetField(raw, asyncqTaskTypenameField, taskName) 380 if utils.IsStrBlank(fusCtx.GetTraceID(ctx)) { 381 ctx = fusCtx.SetTraceID(ctx, utils.NginxID()) 382 } 383 if utils.IsStrBlank(fusCtx.GetCronTaskName(ctx)) { 384 ctx = fusCtx.SetCronTaskName(ctx, taskName) 385 } 386 return next.ProcessTask(ctx, raw) 387 }) 388 } 389 390 func (a *asynqRouter) adaptMiddleware(mw routerMiddleware) asynq.MiddlewareFunc { 391 return func(asynqNext asynq.Handler) asynq.Handler { 392 next := mw(a.adaptRouterHandlerFunc(asynqNext)) 393 return a.adaptAsynqHandlerFunc(next) 394 } 395 } 396 397 // adaptAsynqHandlerFunc support function signature 398 // - func(ctx context.Context) 399 // - func(ctx context.Context) error 400 // - func(ctx context.Context, args json.Serializable) 401 // - func(ctx context.Context, args *json.Serializable) error 402 func (a *asynqRouter) adaptAsynqHandlerFunc(h any) asynq.HandlerFunc { 403 if fn, ok := h.(routerHandleFunc); ok { 404 return func(ctx context.Context, raw *asynq.Task) (err error) { 405 return fn(ctx, a.newTask(raw)) 406 } 407 } 408 if fn, ok := h.(func(ctx context.Context, task Task) (err error)); ok { 409 return func(ctx context.Context, raw *asynq.Task) (err error) { 410 return fn(ctx, a.newTask(raw)) 411 } 412 } 413 414 var ( 415 hasArg bool 416 argType reflect.Type 417 argTypePtrDepth int 418 ) 419 if reflect.TypeOf(h).NumIn() > 1 { 420 argType = reflect.TypeOf(h).In(1) 421 for argType.Kind() == reflect.Ptr { 422 argType = argType.Elem() 423 argTypePtrDepth++ 424 } 425 hasArg = true 426 } 427 428 fn := utils.WrapFunc1[error](h) 429 return func(ctx context.Context, raw *asynq.Task) (err error) { 430 if !hasArg { 431 return fn(ctx) 432 } 433 arg := reflect.New(argType) 434 payload := raw.Payload() 435 if len(payload) == 0 { 436 payload = []byte("null") 437 } 438 if err = json.Unmarshal(payload, arg.Interface()); err != nil { 439 return 440 } 441 arg = arg.Elem() 442 for i := 0; i < argTypePtrDepth; i++ { 443 arg = arg.Addr() 444 } 445 446 return fn(ctx, arg.Interface()) 447 } 448 } 449 450 func (a *asynqRouter) adaptRouterHandlerFunc(h asynq.Handler) routerHandleFunc { 451 return func(ctx context.Context, raw Task) (err error) { 452 return h.ProcessTask(ctx, a.newAsynqTask(raw)) 453 } 454 } 455 456 func (a *asynqRouter) defaultQueue() (result string) { 457 return fmt.Sprintf("%s:cron", config.Use(a.appName).AppName()) 458 } 459 func (a *asynqRouter) formatLockKey(taskName string) string { 460 return fmt.Sprintf("cron_%s", taskName) 461 } 462 func (a *asynqRouter) formatTaskName(taskName string) (result string) { 463 return fmt.Sprintf("%s:cron:%s", config.Use(a.appName).AppName(), taskName) 464 } 465 func (a *asynqRouter) unformatTaskName(taskName string) (result string) { 466 return strings.TrimPrefix(taskName, fmt.Sprintf("%s:cron:", config.Use(a.appName).AppName())) 467 } 468 469 func (a *asynqRouter) newTask(raw *asynq.Task) (t Task) { 470 return &task{ 471 id: raw.Type(), 472 name: raw.Type(), 473 payload: raw.Payload(), 474 rawMessage: raw, 475 } 476 } 477 478 func (a *asynqRouter) newAsynqTask(raw Task) (t *asynq.Task) { 479 return raw.RawMessage().(*asynq.Task) 480 } 481 482 type asynqWrapper struct { 483 appName string 484 485 r *asynqRouter 486 n string 487 cli rdsDrv.UniversalClient 488 provider asynq.PeriodicTaskConfigProvider 489 } 490 491 func (a *asynqWrapper) MakeRedisClient() any { 492 return a.cli 493 } 494 495 func (a *asynqWrapper) GetConfigs() (result []*asynq.PeriodicTaskConfig, err error) { 496 result, err = a.getConfigs() 497 if err != nil { 498 return 499 } 500 501 a.r.l.Lock() 502 defer a.r.l.Unlock() 503 for _, cfg := range result { 504 // renaming 505 taskName := inspect.GetField[string](cfg.Task, asyncqTaskTypenameField) 506 inspect.SetField(cfg.Task, asyncqTaskTypenameField, a.r.formatTaskName(taskName)) 507 508 name := cfg.Task.Type() 509 a.r.lockDurations[name], err = a.getTaskExecuteInterval(cfg.Cronspec) 510 if err != nil { 511 return 512 } 513 } 514 515 return 516 } 517 518 func (a *asynqWrapper) getConfigs() (result []*asynq.PeriodicTaskConfig, err error) { 519 if a.provider != nil { 520 result, err = a.provider.GetConfigs() 521 if err != nil { 522 return 523 } 524 } 525 526 var confs map[string]*Conf 527 if err = config.Use(a.appName).LoadComponentConfig(config.ComponentCron, &confs); err != nil { 528 return 529 } 530 conf, ok := confs[a.n] 531 if !ok { 532 return nil, errors.Errorf("%s cron config not found", a.n) 533 } 534 535 loc, _ := time.LoadLocation(a.r.c.Timezone) 536 if loc == nil { 537 loc = constant.DefaultLocation() 538 } 539 540 queue := conf.Queue 541 if utils.IsStrBlank(queue) { 542 queue = a.r.c.Queue 543 } 544 for name, cfg := range conf.Tasks { 545 var ( 546 deadline time.Time 547 interval, timeout time.Duration 548 opts []asynq.Option 549 ) 550 if interval, err = a.getTaskExecuteInterval(cfg.Crontab); err != nil { 551 return 552 } 553 if utils.IsStrNotBlank(cfg.Timeout) { 554 if timeout, err = time.ParseDuration(cfg.Timeout); err != nil { 555 return 556 } 557 opts = append(opts, asynq.Timeout(timeout)) 558 } else { 559 opts = append(opts, asynq.Timeout(interval)) 560 } 561 if utils.IsStrNotBlank(cfg.Deadline) { 562 if deadline, err = time.ParseInLocation(constant.StdTimeLayout, cfg.Deadline, loc); err != nil { 563 return 564 } 565 opts = append(opts, asynq.Deadline(deadline)) 566 } 567 568 result = append(result, &asynq.PeriodicTaskConfig{ 569 Cronspec: cfg.Crontab, 570 Task: asynq.NewTask(name, []byte(cfg.Payload)), 571 Opts: append(opts, []asynq.Option{ 572 asynq.TaskID(name), 573 asynq.Unique(utils.Min(interval, tolerantOfTimeNotSync)), 574 asynq.Queue(queue), 575 asynq.MaxRetry(utils.Max(0, cfg.Retry)), 576 }...), 577 }) 578 } 579 return 580 } 581 582 func (a *asynqWrapper) getTaskExecuteInterval(spec string) (interval time.Duration, err error) { 583 now := time.Now() 584 scheduler, err := cron.ParseStandard(spec) 585 if err != nil { 586 return 0, err 587 } 588 next := scheduler.Next(now) 589 interval = scheduler.Next(next).Sub(next) 590 return 591 } 592 593 func init() { 594 rand.Seed(time.Now().UnixMicro()) 595 }