github.com/songzhibin97/gkit@v1.2.13/distributed/service.go (about) 1 package distributed 2 3 import ( 4 "context" 5 "sync" 6 "time" 7 8 "github.com/songzhibin97/gkit/options" 9 10 "github.com/songzhibin97/gkit/log" 11 12 "github.com/pkg/errors" 13 14 "github.com/songzhibin97/gkit/distributed/backend/result" 15 16 "github.com/songzhibin97/gkit/distributed/task" 17 18 "github.com/robfig/cron/v3" 19 20 "github.com/songzhibin97/gkit/distributed/backend" 21 "github.com/songzhibin97/gkit/distributed/controller" 22 "github.com/songzhibin97/gkit/distributed/locker" 23 ) 24 25 // Config distributed service配置文件 26 type Config struct { 27 NoUnixSignals bool `json:"no_unix_signals"` 28 ResultExpire int64 `json:"result_expire"` 29 Concurrency int64 `json:"concurrency"` 30 ConsumeQueue string `json:"consume_queue"` 31 DelayedQueue string `json:"delayed_queue"` 32 } 33 34 type Server struct { 35 config *Config 36 registeredTasks *sync.Map // registeredTasks 注册任务处理函数 37 controller controller.Controller // controller 控制器 38 backend backend.Backend // backend 后端引擎 39 lock locker.Locker // lock 锁 40 scheduler *cron.Cron // scheduler 调度器 41 prePublishHandler func(task *task.Signature) // prePublishHandler 预处理器 42 helper *log.Helper 43 } 44 45 // GetConfig 获取配置文件 46 func (s *Server) GetConfig() *Config { 47 return s.config 48 } 49 50 // GetController 获取 Controller 51 func (s *Server) GetController() controller.Controller { 52 return s.controller 53 } 54 55 // GetBackend 获取 Backend 56 func (s *Server) GetBackend() backend.Backend { 57 return s.backend 58 } 59 60 // GetLocker 获取 Locker 61 func (s *Server) GetLocker() locker.Locker { 62 return s.lock 63 } 64 65 // RegisteredTasks 注册多个任务 66 // handelTaskMap map[string]interface{} 67 // interface 规则: 必须是func且必须有返回参数,最后一个出参是error 68 func (s *Server) RegisteredTasks(handelTaskMap map[string]interface{}) error { 69 for name, fn := range handelTaskMap { 70 if err := task.ValidateTask(fn); err != nil { 71 return err 72 } 73 s.registeredTasks.Store(name, fn) 74 s.controller.RegisterTask(name) 75 } 76 return nil 77 } 78 79 // RegisteredTask 注册多个任务 80 // interface 规则: 必须是func且必须有返回参数,最后一个出参是error 81 func (s *Server) RegisteredTask(name string, fn interface{}) error { 82 if err := task.ValidateTask(fn); err != nil { 83 return err 84 } 85 s.registeredTasks.Store(name, fn) 86 s.controller.RegisterTask(name) 87 return nil 88 } 89 90 // IsRegisteredTask 判断任务是否注册 91 func (s *Server) IsRegisteredTask(name string) bool { 92 _, ok := s.registeredTasks.Load(name) 93 return ok 94 } 95 96 // GetRegisteredTask 获取注册的任务 97 func (s *Server) GetRegisteredTask(name string) (interface{}, bool) { 98 return s.registeredTasks.Load(name) 99 } 100 101 // SendTaskWithContext 发送任务,可以传入ctx 102 func (s *Server) SendTaskWithContext(ctx context.Context, signature *task.Signature) (*result.AsyncResult, error) { 103 // 设置任务状态为pending 104 if err := s.backend.SetStatePending(signature); err != nil { 105 return nil, errors.Wrap(err, "set state pending") 106 } 107 // 是否预处理 108 if s.prePublishHandler != nil { 109 s.prePublishHandler(signature) 110 } 111 // 任务发布 112 if err := s.controller.Publish(ctx, signature); err != nil { 113 return nil, errors.Wrap(err, "publish err") 114 } 115 return result.NewAsyncResult(signature, s.backend), nil 116 } 117 118 // SendTask 发送任务 119 func (s *Server) SendTask(signature *task.Signature) (*result.AsyncResult, error) { 120 return s.SendTaskWithContext(context.Background(), signature) 121 } 122 123 // SendChain 发送链式调用任务 124 func (s *Server) SendChain(chain *task.Chain) (*result.ChainAsyncResult, error) { 125 _, err := s.SendTask(chain.Tasks[0]) 126 if err != nil { 127 return nil, err 128 } 129 return result.NewChainAsyncResult(chain.Tasks, s.backend), nil 130 } 131 132 // SendGroupWithContext 发送并行执行的任务组 133 func (s *Server) SendGroupWithContext(ctx context.Context, group *task.Group, concurrency int) ([]*result.AsyncResult, error) { 134 if concurrency < 0 { 135 concurrency = 1 136 } 137 var ( 138 asyncResults = make([]*result.AsyncResult, len(group.Tasks)) 139 wg sync.WaitGroup 140 ln = len(group.Tasks) 141 errChan = make(chan error, ln*2) 142 pool = make(chan struct{}, concurrency) 143 done = make(chan struct{}) 144 ) 145 146 // 接管任务 147 err := s.backend.GroupTakeOver(group.GroupID, group.Name, group.GetTaskIDs()...) 148 if err != nil { 149 return nil, err 150 } 151 152 // 初始化任务 153 for _, signature := range group.Tasks { 154 if err = s.backend.SetStatePending(signature); err != nil { 155 errChan <- err 156 continue 157 } 158 } 159 160 // 初始化并发池 161 go func() { 162 for i := 0; i < concurrency; i++ { 163 pool <- struct{}{} 164 } 165 }() 166 167 wg.Add(ln) 168 // 执行任务 169 for i, signature := range group.Tasks { 170 <-pool 171 go func(t *task.Signature, index int) { 172 defer wg.Done() 173 // 发布任务 174 err := s.controller.Publish(ctx, t) 175 pool <- struct{}{} 176 if err != nil { 177 errChan <- errors.Wrap(err, "set state pending") 178 return 179 } 180 asyncResults[index] = result.NewAsyncResult(t, s.backend) 181 }(signature, i) 182 } 183 go func() { 184 wg.Wait() 185 close(done) 186 }() 187 188 select { 189 case <-ctx.Done(): 190 return asyncResults, ctx.Err() 191 case err = <-errChan: 192 return asyncResults, err 193 case <-done: 194 return asyncResults, nil 195 } 196 } 197 198 // SendGroup 发送并行任务组 199 func (s *Server) SendGroup(group *task.Group, concurrency int) ([]*result.AsyncResult, error) { 200 return s.SendGroupWithContext(context.Background(), group, concurrency) 201 } 202 203 // SendGroupCallbackWithContext 发送具有回调任务的任务组 204 func (s *Server) SendGroupCallbackWithContext(ctx context.Context, groupCallback *task.GroupCallback, concurrency int) (*result.GroupCallbackAsyncResult, error) { 205 _, err := s.SendGroupWithContext(ctx, groupCallback.Group, concurrency) 206 if err != nil { 207 return nil, err 208 } 209 return result.NewGroupCallbackAsyncResult(groupCallback.Group.Tasks, groupCallback.Callback, s.backend), nil 210 } 211 212 // SendGroupCallback 发送具有回调任务的任务组 213 func (s *Server) SendGroupCallback(groupCallback *task.GroupCallback, concurrency int) (*result.GroupCallbackAsyncResult, error) { 214 return s.SendGroupCallbackWithContext(context.Background(), groupCallback, concurrency) 215 } 216 217 // RegisteredTimedTask 注册定时任务 218 func (s *Server) RegisteredTimedTask(spec, name string, signature *task.Signature) error { 219 // 检查spec是否合法 220 schedule, err := cron.ParseStandard(spec) 221 if err != nil { 222 return err 223 } 224 f := func() { 225 key := getLockName(name, spec) 226 err := s.lock.Lock(key, int(schedule.Next(time.Now()).UnixNano()-1), key) 227 if err != nil { 228 return 229 } 230 defer s.lock.UnLock(key, key) 231 232 // send task 233 _, err = s.SendTask(task.CopySignature(signature)) 234 if err != nil { 235 s.helper.Errorf("timed task failed. task name is: %s. error is %s\", name, err.Error()\n", name, err.Error()) 236 } 237 } 238 _, err = s.scheduler.AddFunc(spec, f) 239 return err 240 } 241 242 // RegisteredTimedChain 注册定时链式任务 243 func (s *Server) RegisteredTimedChain(spec, name string, signatures ...*task.Signature) error { 244 // 检查spec是否合法 245 schedule, err := cron.ParseStandard(spec) 246 if err != nil { 247 return err 248 } 249 f := func() { 250 chain, _ := task.NewChain(name, task.CopySignatures(signatures...)...) 251 252 // get lock 253 key := getLockName(name, spec) 254 err := s.lock.Lock(key, int(schedule.Next(time.Now()).UnixNano()-1), key) 255 if err != nil { 256 return 257 } 258 defer s.lock.UnLock(key, key) 259 260 // send task 261 _, err = s.SendChain(chain) 262 if err != nil { 263 s.helper.Errorf("timed task failed. task name is: %s. error is %s\", name, err.Error()\n", name, err.Error()) 264 } 265 } 266 _, err = s.scheduler.AddFunc(spec, f) 267 return err 268 } 269 270 // RegisteredTimedGroup 注册定时任务组 271 func (s *Server) RegisteredTimedGroup(spec, name string, groupID string, concurrency int, signatures ...*task.Signature) error { 272 // 检查spec是否合法 273 schedule, err := cron.ParseStandard(spec) 274 if err != nil { 275 return err 276 } 277 f := func() { 278 group, _ := task.NewGroup(groupID, name, task.CopySignatures(signatures...)...) 279 // get lock 280 key := getLockName(name, spec) 281 err := s.lock.Lock(key, int(schedule.Next(time.Now()).UnixNano()-1), key) 282 if err != nil { 283 return 284 } 285 defer s.lock.UnLock(key, key) 286 287 _, err = s.SendGroup(group, concurrency) 288 if err != nil { 289 s.helper.Errorf("timed task failed. task name is: %s. error is %s\", name, err.Error()\n", name, err.Error()) 290 } 291 } 292 _, err = s.scheduler.AddFunc(spec, f) 293 return err 294 } 295 296 // RegisteredTimedGroupCallback 注册具有回调的组任务 297 func (s *Server) RegisteredTimedGroupCallback(spec, name string, groupID string, concurrency int, callback *task.Signature, signatures ...*task.Signature) error { 298 // 检查spec是否合法 299 schedule, err := cron.ParseStandard(spec) 300 if err != nil { 301 return err 302 } 303 f := func() { 304 group, _ := task.NewGroup(groupID, name, task.CopySignatures(signatures...)...) 305 c, _ := task.NewGroupCallback(group, name, callback) 306 // get lock 307 key := getLockName(name, spec) 308 err := s.lock.Lock(key, int(schedule.Next(time.Now()).UnixNano()-1), key) 309 if err != nil { 310 return 311 } 312 defer s.lock.UnLock(key, key) 313 314 _, err = s.SendGroupCallback(c, concurrency) 315 if err != nil { 316 s.helper.Errorf("timed task failed. task name is: %s. error is %s\", name, err.Error()\n", name, err.Error()) 317 } 318 } 319 _, err = s.scheduler.AddFunc(spec, f) 320 return err 321 } 322 323 func getLockName(name, spec string) string { 324 return name + spec 325 } 326 327 // NewServer 创建服务 328 func NewServer(controller controller.Controller, backend backend.Backend, lock locker.Locker, helper *log.Helper, prePublishHandler func(task *task.Signature), options ...options.Option) *Server { 329 server := &Server{ 330 config: &Config{ 331 NoUnixSignals: false, 332 ResultExpire: 0, 333 Concurrency: 1, 334 ConsumeQueue: "consume_queue", 335 DelayedQueue: "delayed_queue", 336 }, 337 registeredTasks: &sync.Map{}, 338 controller: controller, 339 backend: backend, 340 lock: lock, 341 scheduler: cron.New(), 342 prePublishHandler: prePublishHandler, 343 helper: helper, 344 } 345 for _, option := range options { 346 option(server.config) 347 } 348 server.EnforcementConf() 349 go server.scheduler.Run() 350 return server 351 } 352 353 func (s *Server) EnforcementConf() { 354 s.backend.SetResultExpire(s.config.ResultExpire) 355 s.controller.SetConsumingQueue(s.config.ConsumeQueue) 356 s.controller.SetDelayedQueue(s.config.DelayedQueue) 357 } 358 359 func (s *Server) NewWorker(consumerTag string, concurrency int, queue string) *Worker { 360 return &Worker{ 361 bindService: s, 362 Concurrency: concurrency, 363 ConsumerTag: consumerTag, 364 Queue: queue, 365 } 366 }