github.com/wfusion/gofusion@v1.1.14/async/asynq.go (about) 1 package async 2 3 import ( 4 "context" 5 "fmt" 6 "reflect" 7 "strings" 8 "time" 9 10 "github.com/pkg/errors" 11 "github.com/wfusion/gofusion/log" 12 "go.uber.org/multierr" 13 14 "github.com/wfusion/gofusion/common/infra/asynq" 15 "github.com/wfusion/gofusion/common/utils" 16 "github.com/wfusion/gofusion/common/utils/compress" 17 "github.com/wfusion/gofusion/common/utils/inspect" 18 "github.com/wfusion/gofusion/common/utils/serialize" 19 "github.com/wfusion/gofusion/config" 20 "github.com/wfusion/gofusion/redis" 21 22 rdsDrv "github.com/redis/go-redis/v9" 23 24 pd "github.com/wfusion/gofusion/internal/util/payload" 25 ) 26 27 const ( 28 asyncqTaskTypenameField = "typename" 29 ) 30 31 var ( 32 asynqLoggerType = reflect.TypeOf((*asynq.Logger)(nil)).Elem() 33 ) 34 35 type asynqConsumer struct { 36 *asynq.ServeMux 37 38 appName string 39 n string 40 c *Conf 41 42 mws []asynq.MiddlewareFunc 43 logger asynq.Logger 44 consumer *asynq.Server 45 } 46 47 func newAsynqConsumer(ctx context.Context, appName, name string, conf *Conf) Consumable { 48 consumer := &asynqConsumer{appName: appName, n: name, c: conf} 49 50 var rdsCli rdsDrv.UniversalClient 51 switch conf.InstanceType { 52 case instanceTypeRedis: 53 rdsCli = redis.Use(ctx, conf.Instance, redis.AppName(appName)) 54 case instanceTypeDB: 55 fallthrough 56 default: 57 panic(errors.Errorf("unknown instance type: %s", conf.InstanceType)) 58 } 59 60 if consumer.logger == nil && utils.IsStrNotBlank(conf.Logger) { 61 loggerType := inspect.TypeOf(conf.Logger) 62 loggerValue := reflect.New(loggerType) 63 if loggerValue.Type().Implements(customLoggerType) { 64 logger := log.Use(conf.LogInstance, log.AppName(appName)) 65 loggerValue.Interface().(customLogger).Init(logger, appName, name) 66 } 67 consumer.logger = loggerValue.Convert(asynqLoggerType).Interface().(asynq.Logger) 68 } 69 70 logLevel := asynq.LogLevel(0) 71 utils.MustSuccess(logLevel.Set(conf.LogLevel)) 72 73 consumer.ServeMux = asynq.NewServeMux() 74 asynqCfg := asynq.Config{ 75 Concurrency: conf.ConsumerConcurrency, 76 BaseContext: context.Background, 77 RetryDelayFunc: asynq.DefaultRetryDelayFunc, 78 IsFailure: nil, 79 Queues: nil, 80 StrictPriority: conf.StrictPriority, 81 ErrorHandler: asynq.ErrorHandlerFunc(func(ctx context.Context, task *asynq.Task, err error) { 82 taskName := "unknown" 83 if task != nil { 84 taskName = consumer.unformatTaskName(task.Type()) 85 } 86 consumer.info(ctx, "handle task %s message error %s", taskName, err) 87 }), 88 Logger: consumer.logger, 89 LogLevel: logLevel, 90 ShutdownTimeout: 8 * time.Second, 91 HealthCheckFunc: func(err error) { consumer.warn(ctx, "health check check failed: %s", err) }, 92 HealthCheckInterval: 15 * time.Second, 93 DelayedTaskCheckInterval: 5 * time.Second, 94 GroupGracePeriod: 1 * time.Minute, 95 GroupMaxDelay: 0, 96 GroupMaxSize: 0, 97 GroupAggregator: nil, 98 DisableRedisConnClose: true, 99 } 100 if len(conf.Queues) > 0 { 101 asynqCfg.Queues = make(map[string]int, len(conf.Queues)) 102 for _, queue := range conf.Queues { 103 if _, ok := asynqCfg.Queues[queue.Name]; ok { 104 panic(ErrDuplicatedQueueName) 105 } 106 if utils.IsStrBlank(queue.Name) { 107 queue.Name = defaultQueue(appName) 108 } 109 asynqCfg.Queues[queue.Name] = queue.Level 110 } 111 } else { 112 asynqCfg.Queues = map[string]int{defaultQueue(appName): 3} 113 } 114 115 consumer.consumer = asynq.NewServer(&asynqRedisConnOpt{UniversalClient: rdsCli}, asynqCfg) 116 return consumer 117 } 118 119 func (a *asynqConsumer) Use(mws ...routerMiddleware) { 120 for _, mw := range mws { 121 a.mws = append(a.mws, a.adaptMiddleware(mw)) 122 } 123 } 124 125 func (a *asynqConsumer) Handle(pattern string, fn any, _ ...utils.OptionExtender) { 126 if !a.c.Consumer { 127 a.debug(context.Background(), "cannot handle task: consumer is not enabled") 128 return 129 } 130 name := formatTaskName(a.appName, pattern) 131 funcName := formatTaskName(a.appName, utils.GetFuncName(fn)) 132 133 callbackMapLock.Lock() 134 defer callbackMapLock.Unlock() 135 if callbackMap[a.appName] == nil { 136 callbackMap[a.appName] = make(map[string]any) 137 } 138 if funcNameToTaskName[a.appName] == nil { 139 funcNameToTaskName[a.appName] = make(map[string]string) 140 } 141 if _, ok := callbackMap[a.appName][name]; ok { 142 panic(ErrDuplicatedHandlerName) 143 } 144 callbackMap[a.appName][name] = fn 145 callbackMap[a.appName][funcName] = fn 146 funcNameToTaskName[a.appName][funcName] = name 147 148 typ, embed := wrapParams(fn) 149 a.ServeMux.Handle(name, a.adaptAsynqHandlerFunc(fn, typ, embed)) 150 if name != funcName { 151 a.ServeMux.Handle(funcName, a.adaptAsynqHandlerFunc(fn, typ, embed)) 152 } 153 } 154 155 func (a *asynqConsumer) HandleFunc(fn any, _ ...utils.OptionExtender) { 156 if !a.c.Consumer { 157 a.debug(context.Background(), "cannot handle task: consumer is not enabled") 158 return 159 } 160 funcName := formatTaskName(a.appName, utils.GetFuncName(fn)) 161 162 callbackMapLock.Lock() 163 defer callbackMapLock.Unlock() 164 if callbackMap[a.appName] == nil { 165 callbackMap[a.appName] = make(map[string]any) 166 } 167 if funcNameToTaskName[a.appName] == nil { 168 funcNameToTaskName[a.appName] = make(map[string]string) 169 } 170 if _, ok := callbackMap[funcName]; ok { 171 panic(ErrDuplicatedHandlerName) 172 } 173 callbackMap[a.appName][funcName] = fn 174 175 typ, embed := wrapParams(fn) 176 a.ServeMux.Handle(funcName, a.adaptAsynqHandlerFunc(fn, typ, embed)) 177 } 178 179 func (a *asynqConsumer) Serve() (err error) { 180 if !a.c.Consumer { 181 return ErrConsumerDisabled 182 } 183 defer a.info(context.Background(), "consumer started") 184 185 a.ServeMux.Use(a.gatewayMiddleware) 186 a.ServeMux.Use(a.mws...) 187 return a.consumer.Run(a.ServeMux) 188 } 189 190 func (a *asynqConsumer) Start() (err error) { 191 if !a.c.Consumer { 192 return ErrConsumerDisabled 193 } 194 195 defer a.info(context.Background(), "consumer started") 196 197 a.ServeMux.Use(a.gatewayMiddleware) 198 a.ServeMux.Use(a.mws...) 199 200 return a.consumer.Start(a.ServeMux) 201 } 202 203 func (a *asynqConsumer) shutdown() (err error) { 204 if a.consumer != nil { 205 _, catchErr := utils.Catch(a.consumer.Shutdown) 206 err = multierr.Append(err, errors.Cause(catchErr)) 207 } 208 return 209 } 210 211 func (a *asynqConsumer) gatewayMiddleware(next asynq.Handler) asynq.Handler { 212 return asynq.HandlerFunc(func(ctx context.Context, raw *asynq.Task) (err error) { 213 taskName := a.unformatTaskName(raw.Type()) 214 inspect.SetField(raw, asyncqTaskTypenameField, taskName) 215 return next.ProcessTask(ctx, raw) 216 }) 217 } 218 219 func (a *asynqConsumer) adaptMiddleware(mw routerMiddleware) asynq.MiddlewareFunc { 220 return func(asynqNext asynq.Handler) asynq.Handler { 221 next := mw(a.adaptRouterHandlerFunc(asynqNext)) 222 return asynq.HandlerFunc(func(ctx context.Context, t *asynq.Task) error { 223 return next(ctx, a.newTask(t)) 224 }) 225 } 226 } 227 228 func (a *asynqConsumer) adaptAsynqHandlerFunc(h any, typ reflect.Type, embed bool) asynq.HandlerFunc { 229 fn := utils.WrapFunc1[error](h) 230 return func(ctx context.Context, task *asynq.Task) (err error) { 231 ctx, data, _, err := pd.Unseal(task.Payload(), pd.Type(typ)) 232 if err != nil { 233 return 234 } 235 params := unwrapParams(typ, embed, data) 236 return fn(append([]any{ctx}, params...)...) 237 } 238 } 239 240 func (a *asynqConsumer) adaptRouterHandlerFunc(h asynq.Handler) routerMiddlewareFunc { 241 return func(ctx context.Context, raw Task) (err error) { 242 return h.ProcessTask(ctx, a.newAsynqTask(raw)) 243 } 244 } 245 246 func (a *asynqConsumer) unformatTaskName(taskName string) (result string) { 247 return strings.TrimPrefix(taskName, fmt.Sprintf("%s:async:", config.Use(a.appName).AppName())) 248 } 249 250 func (a *asynqConsumer) newTask(raw *asynq.Task) (t Task) { 251 return &task{ 252 id: raw.Type(), 253 name: raw.Type(), 254 payload: raw.Payload(), 255 rawMessage: raw, 256 } 257 } 258 259 func (a *asynqConsumer) newAsynqTask(raw Task) (t *asynq.Task) { 260 return raw.RawMessage().(*asynq.Task) 261 } 262 263 type asynqProducer struct { 264 *asynq.Client 265 266 appName string 267 n string 268 c *Conf 269 270 compressType compress.Algorithm 271 serializeType serialize.Algorithm 272 } 273 274 func newAsynqProducer(ctx context.Context, appName, name string, conf *Conf) Producable { 275 var rdsCli rdsDrv.UniversalClient 276 switch conf.InstanceType { 277 case instanceTypeRedis: 278 rdsCli = redis.Use(ctx, conf.Instance, redis.AppName(appName)) 279 case instanceTypeDB: 280 fallthrough 281 default: 282 panic(errors.Errorf("unknown instance type: %s", conf.InstanceType)) 283 } 284 285 producer := &asynqProducer{ 286 appName: appName, 287 n: name, 288 c: conf, 289 Client: asynq.NewClient(&asynqRedisConnOpt{UniversalClient: rdsCli}), 290 compressType: compress.ParseAlgorithm(conf.MessageCompressType), 291 serializeType: serialize.ParseAlgorithm(conf.MessageSerializeType), 292 } 293 // default serialize type 294 if !producer.serializeType.IsValid() { 295 producer.serializeType = serialize.AlgorithmGob 296 } 297 return producer 298 } 299 300 func (a *asynqProducer) Go(fn any, opts ...utils.OptionExtender) (err error) { 301 var data any 302 opt := utils.ApplyOptions[produceOption](opts...) 303 if len(opt.args) > 0 { 304 argType, embed := wrapParams(fn) 305 data = setParams(argType, embed, opt.args...) 306 } 307 308 // get task name by func name 309 funcName := formatTaskName(a.appName, utils.GetFuncName(fn)) 310 callbackMapLock.RLock() 311 if mappingName, ok := funcNameToTaskName[a.appName][funcName]; ok { 312 funcName = mappingName 313 } 314 callbackMapLock.RUnlock() 315 316 ctx := context.Background() 317 task, err := a.newTask(ctx, funcName, data) 318 if err != nil { 319 return 320 } 321 322 _, err = a.Client.EnqueueContext(ctx, task, a.parseOption(opt)...) 323 if err != nil { 324 return 325 } 326 return 327 } 328 329 func (a *asynqProducer) Goc(ctx context.Context, fn any, opts ...utils.OptionExtender) (err error) { 330 var data any 331 opt := utils.ApplyOptions[produceOption](opts...) 332 if len(opt.args) > 0 { 333 argType, embed := wrapParams(fn) 334 data = setParams(argType, embed, opt.args...) 335 } 336 337 // get task name by func name 338 funcName := formatTaskName(a.appName, utils.GetFuncName(fn)) 339 callbackMapLock.RLock() 340 if mappingName, ok := funcNameToTaskName[a.appName][funcName]; ok { 341 funcName = mappingName 342 } 343 callbackMapLock.RUnlock() 344 345 task, err := a.newTask(ctx, funcName, data) 346 if err != nil { 347 return 348 } 349 350 _, err = a.Client.EnqueueContext(ctx, task, a.parseOption(opt)...) 351 if err != nil { 352 return 353 } 354 return 355 } 356 357 func (a *asynqProducer) Send(ctx context.Context, taskName string, data any, opts ...utils.OptionExtender) (err error) { 358 opt := utils.ApplyOptions[produceOption](opts...) 359 task, err := a.newTask(ctx, formatTaskName(a.appName, taskName), data) 360 if err != nil { 361 return 362 } 363 364 _, err = a.Client.EnqueueContext(ctx, task, a.parseOption(opt)...) 365 if err != nil { 366 return 367 } 368 return 369 } 370 371 func (a *asynqProducer) parseOption(src *produceOption) (dst []asynq.Option) { 372 if utils.IsStrNotBlank(src.id) { 373 dst = append(dst, asynq.TaskID(src.id)) 374 } 375 if utils.IsStrNotBlank(src.queue) { 376 dst = append(dst, asynq.Queue(src.queue)) 377 } else if len(a.c.Queues) == 1 { 378 dst = append(dst, asynq.Queue(a.c.Queues[0].Name)) 379 } else { 380 dst = append(dst, asynq.Queue(defaultQueue(a.appName))) 381 } 382 if src.maxRetry > 0 { 383 dst = append(dst, asynq.MaxRetry(src.maxRetry)) 384 } 385 if !src.deadline.IsZero() { 386 dst = append(dst, asynq.Deadline(src.deadline)) 387 } 388 if src.timeout > 0 { 389 dst = append(dst, asynq.Timeout(src.timeout)) 390 } 391 if src.delayDuration > 0 { 392 dst = append(dst, asynq.ProcessIn(src.timeout)) 393 } 394 if !src.delayTime.IsZero() { 395 dst = append(dst, asynq.ProcessAt(src.delayTime)) 396 } 397 if src.retentionDuration > 0 { 398 dst = append(dst, asynq.Retention(src.retentionDuration)) 399 } 400 401 return 402 } 403 404 func (a *asynqProducer) newTask(ctx context.Context, taskName string, data any) (task *asynq.Task, err error) { 405 payload, err := pd.Seal(data, pd.Context(ctx), pd.Serialize(a.serializeType), pd.Compress(a.compressType)) 406 if err != nil { 407 return 408 } 409 410 task = asynq.NewTask(taskName, payload) 411 return 412 } 413 414 type asynqRedisConnOpt struct{ rdsDrv.UniversalClient } 415 416 func (a *asynqRedisConnOpt) MakeRedisClient() any { return a.UniversalClient } 417 418 func formatTaskName(appName, taskName string) (result string) { 419 return fmt.Sprintf("%s:async:%s", config.Use(appName).AppName(), taskName) 420 } 421 422 func defaultQueue(appName string) (result string) { 423 return fmt.Sprintf("%s:async", config.Use(appName).AppName()) 424 }