github.com/machinefi/w3bstream@v1.6.5-rc9.0.20240426031326-b8c7c4876e72/pkg/depends/kit/mq/task_worker.go (about)

     1  package mq
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"os"
     7  	"os/signal"
     8  	"syscall"
     9  
    10  	"github.com/fatih/color"
    11  	"github.com/pkg/errors"
    12  
    13  	"github.com/machinefi/w3bstream/pkg/depends/base/consts"
    14  	"github.com/machinefi/w3bstream/pkg/depends/conf/logger"
    15  	"github.com/machinefi/w3bstream/pkg/depends/kit/kit"
    16  	"github.com/machinefi/w3bstream/pkg/depends/kit/metax"
    17  	"github.com/machinefi/w3bstream/pkg/depends/kit/mq/worker"
    18  	"github.com/machinefi/w3bstream/pkg/depends/x/contextx"
    19  	"github.com/machinefi/w3bstream/pkg/depends/x/mapx"
    20  )
    21  
    22  type TaskWorkerOption func(*taskWorkerOption)
    23  
    24  type taskWorkerOption struct {
    25  	Channel     string
    26  	WorkerCount int
    27  	OnFinished  func(ctx context.Context, t Task)
    28  }
    29  
    30  func WithChannel(ch string) TaskWorkerOption {
    31  	return func(o *taskWorkerOption) { o.Channel = ch }
    32  }
    33  
    34  func WithWorkerCount(cnt int) TaskWorkerOption {
    35  	return func(o *taskWorkerOption) { o.WorkerCount = cnt }
    36  }
    37  
    38  func WithFinishFunc(fn func(ctx context.Context, t Task)) TaskWorkerOption {
    39  	return func(o *taskWorkerOption) { o.OnFinished = fn }
    40  }
    41  
    42  func NewTaskWorker(tm TaskManager, options ...TaskWorkerOption) *TaskWorker {
    43  	tw := &TaskWorker{mgr: tm, ops: mapx.New[string, any]()}
    44  	for _, opt := range options {
    45  		opt(&tw.taskWorkerOption)
    46  	}
    47  	return tw
    48  }
    49  
    50  type TaskWorker struct {
    51  	taskWorkerOption
    52  	mgr    TaskManager
    53  	ops    *mapx.Map[string, any]
    54  	worker *worker.Worker
    55  	with   contextx.WithContext
    56  }
    57  
    58  func (w *TaskWorker) SetDefault() {
    59  	if w.Channel == "" {
    60  		w.Channel = "unknown"
    61  		if name := os.Getenv(consts.EnvProjectName); name != "" {
    62  			w.Channel = name
    63  		}
    64  	}
    65  	if w.WorkerCount == 0 {
    66  		w.WorkerCount = 5
    67  	}
    68  	if w.ops == nil {
    69  		w.ops = mapx.New[string, any]()
    70  	}
    71  }
    72  
    73  func (w *TaskWorker) Context() context.Context {
    74  	if w.with != nil {
    75  		return w.with(context.Background())
    76  	}
    77  	return context.Background()
    78  }
    79  
    80  func (w TaskWorker) WithContextInjector(with contextx.WithContext) *TaskWorker {
    81  	w.with = with
    82  	return &w
    83  }
    84  
    85  func (w *TaskWorker) Register(router *kit.Router) {
    86  	fmt.Printf("[Kit] TASK\n")
    87  	routes := router.Routes()
    88  	for i := range routes {
    89  		factories := routes[i].OperatorFactories()
    90  		if len(factories) != 1 {
    91  			continue
    92  		}
    93  		f := factories[0]
    94  		w.ops.Store(f.Type.Name(), f)
    95  		fmt.Println("[Kit]\t" + color.GreenString(f.String()))
    96  	}
    97  }
    98  
    99  func (w *TaskWorker) Serve(router *kit.Router) error {
   100  	w.Register(router)
   101  
   102  	stopCh := make(chan os.Signal, 1)
   103  	signal.Notify(stopCh, os.Interrupt, syscall.SIGTERM)
   104  
   105  	w.worker = worker.New(w.proc, w.WorkerCount)
   106  	go func() {
   107  		w.worker.Start(w.Context())
   108  	}()
   109  
   110  	<-stopCh
   111  	return errors.New("TaskWorker server closed")
   112  }
   113  
   114  func (w *TaskWorker) LivenessCheck() map[string]string {
   115  	m := map[string]string{}
   116  	w.ops.Range(func(k string, _ any) bool {
   117  		m[k] = "ok"
   118  		return true
   119  	})
   120  	return m
   121  }
   122  
   123  func (w *TaskWorker) operatorFactory(ch string) (*kit.OperatorFactory, error) {
   124  	op, ok := w.ops.Load(ch)
   125  	if !ok || op == nil {
   126  		return nil, errors.Errorf("missed operator %s", ch)
   127  	}
   128  	return op.(*kit.OperatorFactory), nil
   129  }
   130  
   131  func (w *TaskWorker) proc(ctx context.Context) (err error) {
   132  	var (
   133  		t  Task
   134  		se error // shadowed
   135  	)
   136  	t, err = w.mgr.Pop(w.Channel)
   137  	if err != nil {
   138  		return err
   139  	}
   140  	if t == nil {
   141  		return nil
   142  	}
   143  	ctx, l := logger.NewSpanContext(ctx, "TaskWorker.proc")
   144  	defer l.End()
   145  
   146  	l = l.WithValues("task_subject", t.Subject(), "task_id", t.ID())
   147  	defer func() {
   148  		if e := recover(); e != nil {
   149  			err = errors.Errorf("panic: %v", e)
   150  		}
   151  
   152  		state := TASK_STATE__SUCCEEDED
   153  		if err != nil {
   154  			state = TASK_STATE__FAILED
   155  		}
   156  		t.SetState(state)
   157  		l = l.WithValues("task_state", state)
   158  
   159  		if w.OnFinished != nil {
   160  			w.OnFinished(ctx, t)
   161  		}
   162  		l.Debug("task processed")
   163  	}()
   164  
   165  	opf, se := w.operatorFactory(t.Subject())
   166  	if se != nil {
   167  		err = se
   168  		return
   169  	}
   170  
   171  	op := opf.New()
   172  	if with, ok := t.(WithArg); ok {
   173  		if setter, ok := op.(SetArg); ok {
   174  			if se = setter.SetArg(with.Arg()); se != nil {
   175  				err = se
   176  				return
   177  			}
   178  		}
   179  	}
   180  
   181  	meta := metax.ParseMeta(t.ID())
   182  	meta.Add("task", w.Channel+"#"+t.Subject())
   183  
   184  	if _, se = op.Output(metax.ContextWithMeta(ctx, meta)); se != nil {
   185  		err = se
   186  	}
   187  	return
   188  }