github.com/songzhibin97/gkit@v1.2.13/distributed/service.go (about)

     1  package distributed
     2  
     3  import (
     4  	"context"
     5  	"sync"
     6  	"time"
     7  
     8  	"github.com/songzhibin97/gkit/options"
     9  
    10  	"github.com/songzhibin97/gkit/log"
    11  
    12  	"github.com/pkg/errors"
    13  
    14  	"github.com/songzhibin97/gkit/distributed/backend/result"
    15  
    16  	"github.com/songzhibin97/gkit/distributed/task"
    17  
    18  	"github.com/robfig/cron/v3"
    19  
    20  	"github.com/songzhibin97/gkit/distributed/backend"
    21  	"github.com/songzhibin97/gkit/distributed/controller"
    22  	"github.com/songzhibin97/gkit/distributed/locker"
    23  )
    24  
    25  // Config distributed service配置文件
    26  type Config struct {
    27  	NoUnixSignals bool   `json:"no_unix_signals"`
    28  	ResultExpire  int64  `json:"result_expire"`
    29  	Concurrency   int64  `json:"concurrency"`
    30  	ConsumeQueue  string `json:"consume_queue"`
    31  	DelayedQueue  string `json:"delayed_queue"`
    32  }
    33  
    34  type Server struct {
    35  	config            *Config
    36  	registeredTasks   *sync.Map                  // registeredTasks 注册任务处理函数
    37  	controller        controller.Controller      // controller 控制器
    38  	backend           backend.Backend            // backend 后端引擎
    39  	lock              locker.Locker              // lock 锁
    40  	scheduler         *cron.Cron                 // scheduler 调度器
    41  	prePublishHandler func(task *task.Signature) // prePublishHandler 预处理器
    42  	helper            *log.Helper
    43  }
    44  
    45  // GetConfig 获取配置文件
    46  func (s *Server) GetConfig() *Config {
    47  	return s.config
    48  }
    49  
    50  // GetController 获取 Controller
    51  func (s *Server) GetController() controller.Controller {
    52  	return s.controller
    53  }
    54  
    55  // GetBackend 获取 Backend
    56  func (s *Server) GetBackend() backend.Backend {
    57  	return s.backend
    58  }
    59  
    60  // GetLocker 获取 Locker
    61  func (s *Server) GetLocker() locker.Locker {
    62  	return s.lock
    63  }
    64  
    65  // RegisteredTasks 注册多个任务
    66  // handelTaskMap map[string]interface{}
    67  // interface 规则: 必须是func且必须有返回参数,最后一个出参是error
    68  func (s *Server) RegisteredTasks(handelTaskMap map[string]interface{}) error {
    69  	for name, fn := range handelTaskMap {
    70  		if err := task.ValidateTask(fn); err != nil {
    71  			return err
    72  		}
    73  		s.registeredTasks.Store(name, fn)
    74  		s.controller.RegisterTask(name)
    75  	}
    76  	return nil
    77  }
    78  
    79  // RegisteredTask 注册多个任务
    80  // interface 规则: 必须是func且必须有返回参数,最后一个出参是error
    81  func (s *Server) RegisteredTask(name string, fn interface{}) error {
    82  	if err := task.ValidateTask(fn); err != nil {
    83  		return err
    84  	}
    85  	s.registeredTasks.Store(name, fn)
    86  	s.controller.RegisterTask(name)
    87  	return nil
    88  }
    89  
    90  // IsRegisteredTask 判断任务是否注册
    91  func (s *Server) IsRegisteredTask(name string) bool {
    92  	_, ok := s.registeredTasks.Load(name)
    93  	return ok
    94  }
    95  
    96  // GetRegisteredTask 获取注册的任务
    97  func (s *Server) GetRegisteredTask(name string) (interface{}, bool) {
    98  	return s.registeredTasks.Load(name)
    99  }
   100  
   101  // SendTaskWithContext 发送任务,可以传入ctx
   102  func (s *Server) SendTaskWithContext(ctx context.Context, signature *task.Signature) (*result.AsyncResult, error) {
   103  	// 设置任务状态为pending
   104  	if err := s.backend.SetStatePending(signature); err != nil {
   105  		return nil, errors.Wrap(err, "set state pending")
   106  	}
   107  	// 是否预处理
   108  	if s.prePublishHandler != nil {
   109  		s.prePublishHandler(signature)
   110  	}
   111  	// 任务发布
   112  	if err := s.controller.Publish(ctx, signature); err != nil {
   113  		return nil, errors.Wrap(err, "publish err")
   114  	}
   115  	return result.NewAsyncResult(signature, s.backend), nil
   116  }
   117  
   118  // SendTask 发送任务
   119  func (s *Server) SendTask(signature *task.Signature) (*result.AsyncResult, error) {
   120  	return s.SendTaskWithContext(context.Background(), signature)
   121  }
   122  
   123  // SendChain 发送链式调用任务
   124  func (s *Server) SendChain(chain *task.Chain) (*result.ChainAsyncResult, error) {
   125  	_, err := s.SendTask(chain.Tasks[0])
   126  	if err != nil {
   127  		return nil, err
   128  	}
   129  	return result.NewChainAsyncResult(chain.Tasks, s.backend), nil
   130  }
   131  
   132  // SendGroupWithContext 发送并行执行的任务组
   133  func (s *Server) SendGroupWithContext(ctx context.Context, group *task.Group, concurrency int) ([]*result.AsyncResult, error) {
   134  	if concurrency < 0 {
   135  		concurrency = 1
   136  	}
   137  	var (
   138  		asyncResults = make([]*result.AsyncResult, len(group.Tasks))
   139  		wg           sync.WaitGroup
   140  		ln           = len(group.Tasks)
   141  		errChan      = make(chan error, ln*2)
   142  		pool         = make(chan struct{}, concurrency)
   143  		done         = make(chan struct{})
   144  	)
   145  
   146  	// 接管任务
   147  	err := s.backend.GroupTakeOver(group.GroupID, group.Name, group.GetTaskIDs()...)
   148  	if err != nil {
   149  		return nil, err
   150  	}
   151  
   152  	// 初始化任务
   153  	for _, signature := range group.Tasks {
   154  		if err = s.backend.SetStatePending(signature); err != nil {
   155  			errChan <- err
   156  			continue
   157  		}
   158  	}
   159  
   160  	// 初始化并发池
   161  	go func() {
   162  		for i := 0; i < concurrency; i++ {
   163  			pool <- struct{}{}
   164  		}
   165  	}()
   166  
   167  	wg.Add(ln)
   168  	// 执行任务
   169  	for i, signature := range group.Tasks {
   170  		<-pool
   171  		go func(t *task.Signature, index int) {
   172  			defer wg.Done()
   173  			// 发布任务
   174  			err := s.controller.Publish(ctx, t)
   175  			pool <- struct{}{}
   176  			if err != nil {
   177  				errChan <- errors.Wrap(err, "set state pending")
   178  				return
   179  			}
   180  			asyncResults[index] = result.NewAsyncResult(t, s.backend)
   181  		}(signature, i)
   182  	}
   183  	go func() {
   184  		wg.Wait()
   185  		close(done)
   186  	}()
   187  
   188  	select {
   189  	case <-ctx.Done():
   190  		return asyncResults, ctx.Err()
   191  	case err = <-errChan:
   192  		return asyncResults, err
   193  	case <-done:
   194  		return asyncResults, nil
   195  	}
   196  }
   197  
   198  // SendGroup 发送并行任务组
   199  func (s *Server) SendGroup(group *task.Group, concurrency int) ([]*result.AsyncResult, error) {
   200  	return s.SendGroupWithContext(context.Background(), group, concurrency)
   201  }
   202  
   203  // SendGroupCallbackWithContext 发送具有回调任务的任务组
   204  func (s *Server) SendGroupCallbackWithContext(ctx context.Context, groupCallback *task.GroupCallback, concurrency int) (*result.GroupCallbackAsyncResult, error) {
   205  	_, err := s.SendGroupWithContext(ctx, groupCallback.Group, concurrency)
   206  	if err != nil {
   207  		return nil, err
   208  	}
   209  	return result.NewGroupCallbackAsyncResult(groupCallback.Group.Tasks, groupCallback.Callback, s.backend), nil
   210  }
   211  
   212  // SendGroupCallback 发送具有回调任务的任务组
   213  func (s *Server) SendGroupCallback(groupCallback *task.GroupCallback, concurrency int) (*result.GroupCallbackAsyncResult, error) {
   214  	return s.SendGroupCallbackWithContext(context.Background(), groupCallback, concurrency)
   215  }
   216  
   217  // RegisteredTimedTask 注册定时任务
   218  func (s *Server) RegisteredTimedTask(spec, name string, signature *task.Signature) error {
   219  	// 检查spec是否合法
   220  	schedule, err := cron.ParseStandard(spec)
   221  	if err != nil {
   222  		return err
   223  	}
   224  	f := func() {
   225  		key := getLockName(name, spec)
   226  		err := s.lock.Lock(key, int(schedule.Next(time.Now()).UnixNano()-1), key)
   227  		if err != nil {
   228  			return
   229  		}
   230  		defer s.lock.UnLock(key, key)
   231  
   232  		// send task
   233  		_, err = s.SendTask(task.CopySignature(signature))
   234  		if err != nil {
   235  			s.helper.Errorf("timed task failed. task name is: %s. error is %s\", name, err.Error()\n", name, err.Error())
   236  		}
   237  	}
   238  	_, err = s.scheduler.AddFunc(spec, f)
   239  	return err
   240  }
   241  
   242  // RegisteredTimedChain 注册定时链式任务
   243  func (s *Server) RegisteredTimedChain(spec, name string, signatures ...*task.Signature) error {
   244  	// 检查spec是否合法
   245  	schedule, err := cron.ParseStandard(spec)
   246  	if err != nil {
   247  		return err
   248  	}
   249  	f := func() {
   250  		chain, _ := task.NewChain(name, task.CopySignatures(signatures...)...)
   251  
   252  		// get lock
   253  		key := getLockName(name, spec)
   254  		err := s.lock.Lock(key, int(schedule.Next(time.Now()).UnixNano()-1), key)
   255  		if err != nil {
   256  			return
   257  		}
   258  		defer s.lock.UnLock(key, key)
   259  
   260  		// send task
   261  		_, err = s.SendChain(chain)
   262  		if err != nil {
   263  			s.helper.Errorf("timed task failed. task name is: %s. error is %s\", name, err.Error()\n", name, err.Error())
   264  		}
   265  	}
   266  	_, err = s.scheduler.AddFunc(spec, f)
   267  	return err
   268  }
   269  
   270  // RegisteredTimedGroup 注册定时任务组
   271  func (s *Server) RegisteredTimedGroup(spec, name string, groupID string, concurrency int, signatures ...*task.Signature) error {
   272  	// 检查spec是否合法
   273  	schedule, err := cron.ParseStandard(spec)
   274  	if err != nil {
   275  		return err
   276  	}
   277  	f := func() {
   278  		group, _ := task.NewGroup(groupID, name, task.CopySignatures(signatures...)...)
   279  		// get lock
   280  		key := getLockName(name, spec)
   281  		err := s.lock.Lock(key, int(schedule.Next(time.Now()).UnixNano()-1), key)
   282  		if err != nil {
   283  			return
   284  		}
   285  		defer s.lock.UnLock(key, key)
   286  
   287  		_, err = s.SendGroup(group, concurrency)
   288  		if err != nil {
   289  			s.helper.Errorf("timed task failed. task name is: %s. error is %s\", name, err.Error()\n", name, err.Error())
   290  		}
   291  	}
   292  	_, err = s.scheduler.AddFunc(spec, f)
   293  	return err
   294  }
   295  
   296  // RegisteredTimedGroupCallback 注册具有回调的组任务
   297  func (s *Server) RegisteredTimedGroupCallback(spec, name string, groupID string, concurrency int, callback *task.Signature, signatures ...*task.Signature) error {
   298  	// 检查spec是否合法
   299  	schedule, err := cron.ParseStandard(spec)
   300  	if err != nil {
   301  		return err
   302  	}
   303  	f := func() {
   304  		group, _ := task.NewGroup(groupID, name, task.CopySignatures(signatures...)...)
   305  		c, _ := task.NewGroupCallback(group, name, callback)
   306  		// get lock
   307  		key := getLockName(name, spec)
   308  		err := s.lock.Lock(key, int(schedule.Next(time.Now()).UnixNano()-1), key)
   309  		if err != nil {
   310  			return
   311  		}
   312  		defer s.lock.UnLock(key, key)
   313  
   314  		_, err = s.SendGroupCallback(c, concurrency)
   315  		if err != nil {
   316  			s.helper.Errorf("timed task failed. task name is: %s. error is %s\", name, err.Error()\n", name, err.Error())
   317  		}
   318  	}
   319  	_, err = s.scheduler.AddFunc(spec, f)
   320  	return err
   321  }
   322  
   323  func getLockName(name, spec string) string {
   324  	return name + spec
   325  }
   326  
   327  // NewServer 创建服务
   328  func NewServer(controller controller.Controller, backend backend.Backend, lock locker.Locker, helper *log.Helper, prePublishHandler func(task *task.Signature), options ...options.Option) *Server {
   329  	server := &Server{
   330  		config: &Config{
   331  			NoUnixSignals: false,
   332  			ResultExpire:  0,
   333  			Concurrency:   1,
   334  			ConsumeQueue:  "consume_queue",
   335  			DelayedQueue:  "delayed_queue",
   336  		},
   337  		registeredTasks:   &sync.Map{},
   338  		controller:        controller,
   339  		backend:           backend,
   340  		lock:              lock,
   341  		scheduler:         cron.New(),
   342  		prePublishHandler: prePublishHandler,
   343  		helper:            helper,
   344  	}
   345  	for _, option := range options {
   346  		option(server.config)
   347  	}
   348  	server.EnforcementConf()
   349  	go server.scheduler.Run()
   350  	return server
   351  }
   352  
   353  func (s *Server) EnforcementConf() {
   354  	s.backend.SetResultExpire(s.config.ResultExpire)
   355  	s.controller.SetConsumingQueue(s.config.ConsumeQueue)
   356  	s.controller.SetDelayedQueue(s.config.DelayedQueue)
   357  }
   358  
   359  func (s *Server) NewWorker(consumerTag string, concurrency int, queue string) *Worker {
   360  	return &Worker{
   361  		bindService: s,
   362  		Concurrency: concurrency,
   363  		ConsumerTag: consumerTag,
   364  		Queue:       queue,
   365  	}
   366  }