github.com/songzhibin97/gkit@v1.2.13/distributed/controller/controller_redis/redis.go (about) 1 package controller_redis 2 3 import ( 4 "bytes" 5 "context" 6 "errors" 7 "fmt" 8 "runtime" 9 "strconv" 10 "sync" 11 "time" 12 13 json "github.com/json-iterator/go" 14 15 "github.com/songzhibin97/gkit/distributed/broker" 16 17 "github.com/songzhibin97/gkit/distributed/task" 18 19 "github.com/songzhibin97/gkit/distributed/controller" 20 21 "github.com/go-redis/redis/v8" 22 "github.com/go-redsync/redsync/v4" 23 "github.com/go-redsync/redsync/v4/redis/goredis/v8" 24 ) 25 26 type ControllerRedis struct { 27 *broker.Broker 28 // client redis客户端 29 client redis.UniversalClient 30 // lock 分布式锁 31 lock *redsync.Redsync 32 33 // wg 34 35 // consumingWg 确保消费组并发完成 36 consumingWg sync.WaitGroup 37 // processingWg 确保正在处理的任务并发完成 38 processingWg sync.WaitGroup 39 // delayedWg 确保延时任务并发完成 40 delayedWg sync.WaitGroup 41 // consumingQueue 消费队列名称 42 consumingQueue string 43 // delayedQueue 延迟队列名称 44 delayedQueue string 45 } 46 47 // SetConsumingQueue 设置消费队列名称 48 func (c *ControllerRedis) SetConsumingQueue(consumingQueue string) { 49 c.consumingQueue = consumingQueue 50 } 51 52 // SetDelayedQueue 设置延迟队列名称 53 func (c *ControllerRedis) SetDelayedQueue(delayedQueue string) { 54 c.delayedQueue = delayedQueue 55 } 56 57 func (c *ControllerRedis) RegisterTask(name ...string) { 58 c.RegisterList(name...) 59 } 60 61 func (c *ControllerRedis) IsRegisterTask(name string) bool { 62 return c.IsRegister(name) 63 } 64 65 func (c *ControllerRedis) StartConsuming(concurrency int, handler task.Processor) (bool, error) { 66 c.consumingWg.Add(1) 67 defer c.consumingWg.Done() 68 69 // 设置阈值,如果并发数 < 1, 默认设置成 2*cpu 70 if concurrency < 1 { 71 concurrency = runtime.NumCPU() * 2 72 } 73 _, err := c.client.Ping(context.Background()).Result() 74 if err != nil { 75 // 重试 76 c.Broker.GetRetryFn()(c.Broker.GetRetryCtx()) 77 if c.Broker.GetRetry() { 78 return true, err 79 } 80 return false, controller.ErrorConnectClose 81 } 82 83 // 初始化工作区 84 pool := make(chan struct{}, concurrency) 85 worker := make(chan []byte, concurrency) 86 87 // 填满并发池 88 for i := 0; i < concurrency; i++ { 89 pool <- struct{}{} 90 } 91 go func() { 92 fmt.Println("[*] Waiting for messages. To exit press CTRL+C") 93 for { 94 select { 95 case <-c.GetStopCtx().Done(): 96 close(worker) 97 return 98 case _, ok := <-pool: 99 if !ok { 100 return 101 } 102 tByte, err := c.popTask(c.consumingQueue, 0) 103 if err != nil && !errors.Is(err, redis.Nil) { 104 fmt.Println("popTask err:", err) 105 } 106 // 如果是有效数据,发送给worker 107 if len(tByte) > 0 { 108 worker <- tByte 109 } 110 pool <- struct{}{} 111 } 112 } 113 }() 114 c.delayedWg.Add(1) 115 go func() { 116 defer c.delayedWg.Done() 117 for { 118 select { 119 case <-c.GetStopCtx().Done(): 120 return 121 default: 122 tBody, err := c.popDelayedTask(c.delayedQueue, 0) 123 if err != nil { 124 fmt.Println("popDelayedTask err:", err) 125 continue 126 } 127 if tBody == nil { 128 continue 129 } 130 t := task.Signature{} 131 if err = json.Unmarshal(tBody, &t); err != nil { 132 fmt.Println("unmarshal err:", err) 133 continue 134 } 135 if err = c.Publish(c.GetStopCtx(), &t); err != nil { 136 fmt.Println("publish err:", err) 137 continue 138 } 139 } 140 } 141 }() 142 143 if err = c.consume(worker, concurrency, handler); err != nil { 144 return c.GetRetry(), err 145 } 146 c.processingWg.Wait() 147 return c.GetRetry(), err 148 } 149 150 // popTask 弹出任务 151 func (c *ControllerRedis) popTask(queue string, blockTime int64) ([]byte, error) { 152 if blockTime <= 0 { 153 blockTime = int64(1000 * time.Millisecond) 154 } 155 items, err := c.client.BLPop(context.Background(), time.Duration(blockTime), queue).Result() 156 if err != nil { 157 return nil, err 158 } 159 // items[0] - the name of the key where an element was popped 160 // items[1] - the value of the popped element 161 if len(items) != 2 { 162 return nil, redis.Nil 163 } 164 result := []byte(items[1]) 165 return result, nil 166 } 167 168 // popDelayedTask 弹出延时任务,因为延时任务是使用Redis ZSet 169 func (c *ControllerRedis) popDelayedTask(queue string, blockTime int64) ([]byte, error) { 170 if blockTime <= 0 { 171 blockTime = int64(1000 * time.Millisecond) 172 } 173 var result []byte 174 for { 175 time.Sleep(time.Duration(blockTime)) 176 watchFn := func(tx *redis.Tx) error { 177 now := time.Now().Local().UnixNano() 178 max := strconv.FormatInt(now, 10) 179 items, err := tx.ZRevRangeByScore(c.GetStopCtx(), queue, &redis.ZRangeBy{Min: "0", Max: max, Offset: 0, Count: 1}).Result() 180 if err != nil { 181 return err 182 } 183 if len(items) != 1 { 184 return redis.Nil 185 } 186 _, err = tx.TxPipelined(c.GetStopCtx(), func(pipe redis.Pipeliner) error { 187 pipe.ZRem(c.GetStopCtx(), queue, items[0]) 188 result = []byte(items[0]) 189 return nil 190 }) 191 return err 192 } 193 if err := c.client.Watch(c.GetStopCtx(), watchFn, queue); err != nil { 194 break 195 } 196 } 197 return result, nil 198 } 199 200 // consume 消费 201 func (c *ControllerRedis) consume(worker <-chan []byte, concurrency int, handler task.Processor) error { 202 // 初始化工作区 203 pool := make(chan struct{}, concurrency) 204 errorsChan := make(chan error, concurrency*2) 205 206 // 填满并发池 207 for i := 0; i < concurrency; i++ { 208 pool <- struct{}{} 209 } 210 for { 211 select { 212 case <-c.GetStopCtx().Done(): 213 return c.GetStopCtx().Err() 214 case err := <-errorsChan: 215 return err 216 case v, ok := <-worker: 217 if !ok { 218 return nil 219 } 220 // 阻塞等待 221 select { 222 case <-pool: 223 case <-c.GetStopCtx().Done(): 224 return c.GetStopCtx().Err() 225 } 226 c.processingWg.Add(1) 227 go func() { 228 if err := c.consumeOne(v, c.consumingQueue, handler); err != nil { 229 errorsChan <- err 230 } 231 c.processingWg.Done() 232 233 pool <- struct{}{} 234 }() 235 } 236 } 237 } 238 239 func (c *ControllerRedis) consumeOne(taskBody []byte, queue string, handler task.Processor) error { 240 t := task.Signature{} 241 decoder := json.NewDecoder(bytes.NewReader(taskBody)) 242 decoder.UseNumber() 243 if err := decoder.Decode(&t); err != nil { 244 return err 245 } 246 247 if !c.IsRegisterTask(t.Name) { 248 if t.IgnoreNotRegisteredTask { 249 return nil 250 } 251 c.client.RPush(c.GetStopCtx(), queue, handler) 252 return nil 253 } 254 return handler.Process(&t) 255 } 256 257 func (c *ControllerRedis) StopConsuming() { 258 c.Broker.StopConsuming() 259 c.consumingWg.Wait() 260 c.delayedWg.Wait() 261 c.client.Close() 262 } 263 264 func (c *ControllerRedis) Publish(ctx context.Context, t *task.Signature) error { 265 tBody, err := json.Marshal(t) 266 if err != nil { 267 return err 268 } 269 if t.ETA != nil { 270 now := time.Now().Local() 271 if t.ETA.After(now) { 272 score := t.ETA.UnixNano() 273 return c.client.ZAdd(c.GetStopCtx(), c.delayedQueue, &redis.Z{Score: float64(score), Member: tBody}).Err() 274 } 275 } 276 return c.client.RPush(c.GetStopCtx(), t.Router, tBody).Err() 277 } 278 279 func (c *ControllerRedis) GetPendingTasks(queue string) ([]*task.Signature, error) { 280 results, err := c.client.LRange(c.GetStopCtx(), queue, 0, -1).Result() 281 if err != nil { 282 return nil, err 283 } 284 taskSlice := make([]*task.Signature, 0, len(results)) 285 for _, result := range results { 286 var t task.Signature 287 err = json.Unmarshal([]byte(result), &t) 288 if err != nil { 289 return nil, err 290 } 291 taskSlice = append(taskSlice, &t) 292 } 293 return taskSlice, nil 294 } 295 296 func (c *ControllerRedis) GetDelayedTasks() ([]*task.Signature, error) { 297 results, err := c.client.LRange(c.GetStopCtx(), c.delayedQueue, 0, -1).Result() 298 if err != nil { 299 return nil, err 300 } 301 taskSlice := make([]*task.Signature, 0, len(results)) 302 for _, result := range results { 303 var t task.Signature 304 err = json.Unmarshal([]byte(result), &t) 305 if err != nil { 306 return nil, err 307 } 308 taskSlice = append(taskSlice, &t) 309 } 310 return taskSlice, nil 311 } 312 313 func NewControllerRedis(broker *broker.Broker, client redis.UniversalClient, consumingQueue, delayedQueue string) controller.Controller { 314 return &ControllerRedis{ 315 Broker: broker, 316 client: client, 317 lock: redsync.New(goredis.NewPool(client)), 318 consumingQueue: consumingQueue, 319 delayedQueue: delayedQueue, 320 } 321 }