github.com/wfusion/gofusion@v1.1.14/async/asynq.go (about)

     1  package async
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"reflect"
     7  	"strings"
     8  	"time"
     9  
    10  	"github.com/pkg/errors"
    11  	"github.com/wfusion/gofusion/log"
    12  	"go.uber.org/multierr"
    13  
    14  	"github.com/wfusion/gofusion/common/infra/asynq"
    15  	"github.com/wfusion/gofusion/common/utils"
    16  	"github.com/wfusion/gofusion/common/utils/compress"
    17  	"github.com/wfusion/gofusion/common/utils/inspect"
    18  	"github.com/wfusion/gofusion/common/utils/serialize"
    19  	"github.com/wfusion/gofusion/config"
    20  	"github.com/wfusion/gofusion/redis"
    21  
    22  	rdsDrv "github.com/redis/go-redis/v9"
    23  
    24  	pd "github.com/wfusion/gofusion/internal/util/payload"
    25  )
    26  
    27  const (
    28  	asyncqTaskTypenameField = "typename"
    29  )
    30  
    31  var (
    32  	asynqLoggerType = reflect.TypeOf((*asynq.Logger)(nil)).Elem()
    33  )
    34  
    35  type asynqConsumer struct {
    36  	*asynq.ServeMux
    37  
    38  	appName string
    39  	n       string
    40  	c       *Conf
    41  
    42  	mws      []asynq.MiddlewareFunc
    43  	logger   asynq.Logger
    44  	consumer *asynq.Server
    45  }
    46  
    47  func newAsynqConsumer(ctx context.Context, appName, name string, conf *Conf) Consumable {
    48  	consumer := &asynqConsumer{appName: appName, n: name, c: conf}
    49  
    50  	var rdsCli rdsDrv.UniversalClient
    51  	switch conf.InstanceType {
    52  	case instanceTypeRedis:
    53  		rdsCli = redis.Use(ctx, conf.Instance, redis.AppName(appName))
    54  	case instanceTypeDB:
    55  		fallthrough
    56  	default:
    57  		panic(errors.Errorf("unknown instance type: %s", conf.InstanceType))
    58  	}
    59  
    60  	if consumer.logger == nil && utils.IsStrNotBlank(conf.Logger) {
    61  		loggerType := inspect.TypeOf(conf.Logger)
    62  		loggerValue := reflect.New(loggerType)
    63  		if loggerValue.Type().Implements(customLoggerType) {
    64  			logger := log.Use(conf.LogInstance, log.AppName(appName))
    65  			loggerValue.Interface().(customLogger).Init(logger, appName, name)
    66  		}
    67  		consumer.logger = loggerValue.Convert(asynqLoggerType).Interface().(asynq.Logger)
    68  	}
    69  
    70  	logLevel := asynq.LogLevel(0)
    71  	utils.MustSuccess(logLevel.Set(conf.LogLevel))
    72  
    73  	consumer.ServeMux = asynq.NewServeMux()
    74  	asynqCfg := asynq.Config{
    75  		Concurrency:    conf.ConsumerConcurrency,
    76  		BaseContext:    context.Background,
    77  		RetryDelayFunc: asynq.DefaultRetryDelayFunc,
    78  		IsFailure:      nil,
    79  		Queues:         nil,
    80  		StrictPriority: conf.StrictPriority,
    81  		ErrorHandler: asynq.ErrorHandlerFunc(func(ctx context.Context, task *asynq.Task, err error) {
    82  			taskName := "unknown"
    83  			if task != nil {
    84  				taskName = consumer.unformatTaskName(task.Type())
    85  			}
    86  			consumer.info(ctx, "handle task %s message error %s", taskName, err)
    87  		}),
    88  		Logger:                   consumer.logger,
    89  		LogLevel:                 logLevel,
    90  		ShutdownTimeout:          8 * time.Second,
    91  		HealthCheckFunc:          func(err error) { consumer.warn(ctx, "health check check failed: %s", err) },
    92  		HealthCheckInterval:      15 * time.Second,
    93  		DelayedTaskCheckInterval: 5 * time.Second,
    94  		GroupGracePeriod:         1 * time.Minute,
    95  		GroupMaxDelay:            0,
    96  		GroupMaxSize:             0,
    97  		GroupAggregator:          nil,
    98  		DisableRedisConnClose:    true,
    99  	}
   100  	if len(conf.Queues) > 0 {
   101  		asynqCfg.Queues = make(map[string]int, len(conf.Queues))
   102  		for _, queue := range conf.Queues {
   103  			if _, ok := asynqCfg.Queues[queue.Name]; ok {
   104  				panic(ErrDuplicatedQueueName)
   105  			}
   106  			if utils.IsStrBlank(queue.Name) {
   107  				queue.Name = defaultQueue(appName)
   108  			}
   109  			asynqCfg.Queues[queue.Name] = queue.Level
   110  		}
   111  	} else {
   112  		asynqCfg.Queues = map[string]int{defaultQueue(appName): 3}
   113  	}
   114  
   115  	consumer.consumer = asynq.NewServer(&asynqRedisConnOpt{UniversalClient: rdsCli}, asynqCfg)
   116  	return consumer
   117  }
   118  
   119  func (a *asynqConsumer) Use(mws ...routerMiddleware) {
   120  	for _, mw := range mws {
   121  		a.mws = append(a.mws, a.adaptMiddleware(mw))
   122  	}
   123  }
   124  
   125  func (a *asynqConsumer) Handle(pattern string, fn any, _ ...utils.OptionExtender) {
   126  	if !a.c.Consumer {
   127  		a.debug(context.Background(), "cannot handle task: consumer is not enabled")
   128  		return
   129  	}
   130  	name := formatTaskName(a.appName, pattern)
   131  	funcName := formatTaskName(a.appName, utils.GetFuncName(fn))
   132  
   133  	callbackMapLock.Lock()
   134  	defer callbackMapLock.Unlock()
   135  	if callbackMap[a.appName] == nil {
   136  		callbackMap[a.appName] = make(map[string]any)
   137  	}
   138  	if funcNameToTaskName[a.appName] == nil {
   139  		funcNameToTaskName[a.appName] = make(map[string]string)
   140  	}
   141  	if _, ok := callbackMap[a.appName][name]; ok {
   142  		panic(ErrDuplicatedHandlerName)
   143  	}
   144  	callbackMap[a.appName][name] = fn
   145  	callbackMap[a.appName][funcName] = fn
   146  	funcNameToTaskName[a.appName][funcName] = name
   147  
   148  	typ, embed := wrapParams(fn)
   149  	a.ServeMux.Handle(name, a.adaptAsynqHandlerFunc(fn, typ, embed))
   150  	if name != funcName {
   151  		a.ServeMux.Handle(funcName, a.adaptAsynqHandlerFunc(fn, typ, embed))
   152  	}
   153  }
   154  
   155  func (a *asynqConsumer) HandleFunc(fn any, _ ...utils.OptionExtender) {
   156  	if !a.c.Consumer {
   157  		a.debug(context.Background(), "cannot handle task: consumer is not enabled")
   158  		return
   159  	}
   160  	funcName := formatTaskName(a.appName, utils.GetFuncName(fn))
   161  
   162  	callbackMapLock.Lock()
   163  	defer callbackMapLock.Unlock()
   164  	if callbackMap[a.appName] == nil {
   165  		callbackMap[a.appName] = make(map[string]any)
   166  	}
   167  	if funcNameToTaskName[a.appName] == nil {
   168  		funcNameToTaskName[a.appName] = make(map[string]string)
   169  	}
   170  	if _, ok := callbackMap[funcName]; ok {
   171  		panic(ErrDuplicatedHandlerName)
   172  	}
   173  	callbackMap[a.appName][funcName] = fn
   174  
   175  	typ, embed := wrapParams(fn)
   176  	a.ServeMux.Handle(funcName, a.adaptAsynqHandlerFunc(fn, typ, embed))
   177  }
   178  
   179  func (a *asynqConsumer) Serve() (err error) {
   180  	if !a.c.Consumer {
   181  		return ErrConsumerDisabled
   182  	}
   183  	defer a.info(context.Background(), "consumer started")
   184  
   185  	a.ServeMux.Use(a.gatewayMiddleware)
   186  	a.ServeMux.Use(a.mws...)
   187  	return a.consumer.Run(a.ServeMux)
   188  }
   189  
   190  func (a *asynqConsumer) Start() (err error) {
   191  	if !a.c.Consumer {
   192  		return ErrConsumerDisabled
   193  	}
   194  
   195  	defer a.info(context.Background(), "consumer started")
   196  
   197  	a.ServeMux.Use(a.gatewayMiddleware)
   198  	a.ServeMux.Use(a.mws...)
   199  
   200  	return a.consumer.Start(a.ServeMux)
   201  }
   202  
   203  func (a *asynqConsumer) shutdown() (err error) {
   204  	if a.consumer != nil {
   205  		_, catchErr := utils.Catch(a.consumer.Shutdown)
   206  		err = multierr.Append(err, errors.Cause(catchErr))
   207  	}
   208  	return
   209  }
   210  
   211  func (a *asynqConsumer) gatewayMiddleware(next asynq.Handler) asynq.Handler {
   212  	return asynq.HandlerFunc(func(ctx context.Context, raw *asynq.Task) (err error) {
   213  		taskName := a.unformatTaskName(raw.Type())
   214  		inspect.SetField(raw, asyncqTaskTypenameField, taskName)
   215  		return next.ProcessTask(ctx, raw)
   216  	})
   217  }
   218  
   219  func (a *asynqConsumer) adaptMiddleware(mw routerMiddleware) asynq.MiddlewareFunc {
   220  	return func(asynqNext asynq.Handler) asynq.Handler {
   221  		next := mw(a.adaptRouterHandlerFunc(asynqNext))
   222  		return asynq.HandlerFunc(func(ctx context.Context, t *asynq.Task) error {
   223  			return next(ctx, a.newTask(t))
   224  		})
   225  	}
   226  }
   227  
   228  func (a *asynqConsumer) adaptAsynqHandlerFunc(h any, typ reflect.Type, embed bool) asynq.HandlerFunc {
   229  	fn := utils.WrapFunc1[error](h)
   230  	return func(ctx context.Context, task *asynq.Task) (err error) {
   231  		ctx, data, _, err := pd.Unseal(task.Payload(), pd.Type(typ))
   232  		if err != nil {
   233  			return
   234  		}
   235  		params := unwrapParams(typ, embed, data)
   236  		return fn(append([]any{ctx}, params...)...)
   237  	}
   238  }
   239  
   240  func (a *asynqConsumer) adaptRouterHandlerFunc(h asynq.Handler) routerMiddlewareFunc {
   241  	return func(ctx context.Context, raw Task) (err error) {
   242  		return h.ProcessTask(ctx, a.newAsynqTask(raw))
   243  	}
   244  }
   245  
   246  func (a *asynqConsumer) unformatTaskName(taskName string) (result string) {
   247  	return strings.TrimPrefix(taskName, fmt.Sprintf("%s:async:", config.Use(a.appName).AppName()))
   248  }
   249  
   250  func (a *asynqConsumer) newTask(raw *asynq.Task) (t Task) {
   251  	return &task{
   252  		id:         raw.Type(),
   253  		name:       raw.Type(),
   254  		payload:    raw.Payload(),
   255  		rawMessage: raw,
   256  	}
   257  }
   258  
   259  func (a *asynqConsumer) newAsynqTask(raw Task) (t *asynq.Task) {
   260  	return raw.RawMessage().(*asynq.Task)
   261  }
   262  
   263  type asynqProducer struct {
   264  	*asynq.Client
   265  
   266  	appName string
   267  	n       string
   268  	c       *Conf
   269  
   270  	compressType  compress.Algorithm
   271  	serializeType serialize.Algorithm
   272  }
   273  
   274  func newAsynqProducer(ctx context.Context, appName, name string, conf *Conf) Producable {
   275  	var rdsCli rdsDrv.UniversalClient
   276  	switch conf.InstanceType {
   277  	case instanceTypeRedis:
   278  		rdsCli = redis.Use(ctx, conf.Instance, redis.AppName(appName))
   279  	case instanceTypeDB:
   280  		fallthrough
   281  	default:
   282  		panic(errors.Errorf("unknown instance type: %s", conf.InstanceType))
   283  	}
   284  
   285  	producer := &asynqProducer{
   286  		appName:       appName,
   287  		n:             name,
   288  		c:             conf,
   289  		Client:        asynq.NewClient(&asynqRedisConnOpt{UniversalClient: rdsCli}),
   290  		compressType:  compress.ParseAlgorithm(conf.MessageCompressType),
   291  		serializeType: serialize.ParseAlgorithm(conf.MessageSerializeType),
   292  	}
   293  	// default serialize type
   294  	if !producer.serializeType.IsValid() {
   295  		producer.serializeType = serialize.AlgorithmGob
   296  	}
   297  	return producer
   298  }
   299  
   300  func (a *asynqProducer) Go(fn any, opts ...utils.OptionExtender) (err error) {
   301  	var data any
   302  	opt := utils.ApplyOptions[produceOption](opts...)
   303  	if len(opt.args) > 0 {
   304  		argType, embed := wrapParams(fn)
   305  		data = setParams(argType, embed, opt.args...)
   306  	}
   307  
   308  	// get task name by func name
   309  	funcName := formatTaskName(a.appName, utils.GetFuncName(fn))
   310  	callbackMapLock.RLock()
   311  	if mappingName, ok := funcNameToTaskName[a.appName][funcName]; ok {
   312  		funcName = mappingName
   313  	}
   314  	callbackMapLock.RUnlock()
   315  
   316  	ctx := context.Background()
   317  	task, err := a.newTask(ctx, funcName, data)
   318  	if err != nil {
   319  		return
   320  	}
   321  
   322  	_, err = a.Client.EnqueueContext(ctx, task, a.parseOption(opt)...)
   323  	if err != nil {
   324  		return
   325  	}
   326  	return
   327  }
   328  
   329  func (a *asynqProducer) Goc(ctx context.Context, fn any, opts ...utils.OptionExtender) (err error) {
   330  	var data any
   331  	opt := utils.ApplyOptions[produceOption](opts...)
   332  	if len(opt.args) > 0 {
   333  		argType, embed := wrapParams(fn)
   334  		data = setParams(argType, embed, opt.args...)
   335  	}
   336  
   337  	// get task name by func name
   338  	funcName := formatTaskName(a.appName, utils.GetFuncName(fn))
   339  	callbackMapLock.RLock()
   340  	if mappingName, ok := funcNameToTaskName[a.appName][funcName]; ok {
   341  		funcName = mappingName
   342  	}
   343  	callbackMapLock.RUnlock()
   344  
   345  	task, err := a.newTask(ctx, funcName, data)
   346  	if err != nil {
   347  		return
   348  	}
   349  
   350  	_, err = a.Client.EnqueueContext(ctx, task, a.parseOption(opt)...)
   351  	if err != nil {
   352  		return
   353  	}
   354  	return
   355  }
   356  
   357  func (a *asynqProducer) Send(ctx context.Context, taskName string, data any, opts ...utils.OptionExtender) (err error) {
   358  	opt := utils.ApplyOptions[produceOption](opts...)
   359  	task, err := a.newTask(ctx, formatTaskName(a.appName, taskName), data)
   360  	if err != nil {
   361  		return
   362  	}
   363  
   364  	_, err = a.Client.EnqueueContext(ctx, task, a.parseOption(opt)...)
   365  	if err != nil {
   366  		return
   367  	}
   368  	return
   369  }
   370  
   371  func (a *asynqProducer) parseOption(src *produceOption) (dst []asynq.Option) {
   372  	if utils.IsStrNotBlank(src.id) {
   373  		dst = append(dst, asynq.TaskID(src.id))
   374  	}
   375  	if utils.IsStrNotBlank(src.queue) {
   376  		dst = append(dst, asynq.Queue(src.queue))
   377  	} else if len(a.c.Queues) == 1 {
   378  		dst = append(dst, asynq.Queue(a.c.Queues[0].Name))
   379  	} else {
   380  		dst = append(dst, asynq.Queue(defaultQueue(a.appName)))
   381  	}
   382  	if src.maxRetry > 0 {
   383  		dst = append(dst, asynq.MaxRetry(src.maxRetry))
   384  	}
   385  	if !src.deadline.IsZero() {
   386  		dst = append(dst, asynq.Deadline(src.deadline))
   387  	}
   388  	if src.timeout > 0 {
   389  		dst = append(dst, asynq.Timeout(src.timeout))
   390  	}
   391  	if src.delayDuration > 0 {
   392  		dst = append(dst, asynq.ProcessIn(src.timeout))
   393  	}
   394  	if !src.delayTime.IsZero() {
   395  		dst = append(dst, asynq.ProcessAt(src.delayTime))
   396  	}
   397  	if src.retentionDuration > 0 {
   398  		dst = append(dst, asynq.Retention(src.retentionDuration))
   399  	}
   400  
   401  	return
   402  }
   403  
   404  func (a *asynqProducer) newTask(ctx context.Context, taskName string, data any) (task *asynq.Task, err error) {
   405  	payload, err := pd.Seal(data, pd.Context(ctx), pd.Serialize(a.serializeType), pd.Compress(a.compressType))
   406  	if err != nil {
   407  		return
   408  	}
   409  
   410  	task = asynq.NewTask(taskName, payload)
   411  	return
   412  }
   413  
   414  type asynqRedisConnOpt struct{ rdsDrv.UniversalClient }
   415  
   416  func (a *asynqRedisConnOpt) MakeRedisClient() any { return a.UniversalClient }
   417  
   418  func formatTaskName(appName, taskName string) (result string) {
   419  	return fmt.Sprintf("%s:async:%s", config.Use(appName).AppName(), taskName)
   420  }
   421  
   422  func defaultQueue(appName string) (result string) {
   423  	return fmt.Sprintf("%s:async", config.Use(appName).AppName())
   424  }