github.com/songzhibin97/gkit@v1.2.13/distributed/controller/controller_redis/redis.go (about)

     1  package controller_redis
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"fmt"
     8  	"runtime"
     9  	"strconv"
    10  	"sync"
    11  	"time"
    12  
    13  	json "github.com/json-iterator/go"
    14  
    15  	"github.com/songzhibin97/gkit/distributed/broker"
    16  
    17  	"github.com/songzhibin97/gkit/distributed/task"
    18  
    19  	"github.com/songzhibin97/gkit/distributed/controller"
    20  
    21  	"github.com/go-redis/redis/v8"
    22  	"github.com/go-redsync/redsync/v4"
    23  	"github.com/go-redsync/redsync/v4/redis/goredis/v8"
    24  )
    25  
    26  type ControllerRedis struct {
    27  	*broker.Broker
    28  	// client redis客户端
    29  	client redis.UniversalClient
    30  	// lock 分布式锁
    31  	lock *redsync.Redsync
    32  
    33  	// wg
    34  
    35  	// consumingWg 确保消费组并发完成
    36  	consumingWg sync.WaitGroup
    37  	// processingWg 确保正在处理的任务并发完成
    38  	processingWg sync.WaitGroup
    39  	// delayedWg 确保延时任务并发完成
    40  	delayedWg sync.WaitGroup
    41  	// consumingQueue 消费队列名称
    42  	consumingQueue string
    43  	// delayedQueue  延迟队列名称
    44  	delayedQueue string
    45  }
    46  
    47  // SetConsumingQueue 设置消费队列名称
    48  func (c *ControllerRedis) SetConsumingQueue(consumingQueue string) {
    49  	c.consumingQueue = consumingQueue
    50  }
    51  
    52  // SetDelayedQueue 设置延迟队列名称
    53  func (c *ControllerRedis) SetDelayedQueue(delayedQueue string) {
    54  	c.delayedQueue = delayedQueue
    55  }
    56  
    57  func (c *ControllerRedis) RegisterTask(name ...string) {
    58  	c.RegisterList(name...)
    59  }
    60  
    61  func (c *ControllerRedis) IsRegisterTask(name string) bool {
    62  	return c.IsRegister(name)
    63  }
    64  
    65  func (c *ControllerRedis) StartConsuming(concurrency int, handler task.Processor) (bool, error) {
    66  	c.consumingWg.Add(1)
    67  	defer c.consumingWg.Done()
    68  
    69  	// 设置阈值,如果并发数 < 1, 默认设置成 2*cpu
    70  	if concurrency < 1 {
    71  		concurrency = runtime.NumCPU() * 2
    72  	}
    73  	_, err := c.client.Ping(context.Background()).Result()
    74  	if err != nil {
    75  		// 重试
    76  		c.Broker.GetRetryFn()(c.Broker.GetRetryCtx())
    77  		if c.Broker.GetRetry() {
    78  			return true, err
    79  		}
    80  		return false, controller.ErrorConnectClose
    81  	}
    82  
    83  	// 初始化工作区
    84  	pool := make(chan struct{}, concurrency)
    85  	worker := make(chan []byte, concurrency)
    86  
    87  	// 填满并发池
    88  	for i := 0; i < concurrency; i++ {
    89  		pool <- struct{}{}
    90  	}
    91  	go func() {
    92  		fmt.Println("[*] Waiting for messages. To exit press CTRL+C")
    93  		for {
    94  			select {
    95  			case <-c.GetStopCtx().Done():
    96  				close(worker)
    97  				return
    98  			case _, ok := <-pool:
    99  				if !ok {
   100  					return
   101  				}
   102  				tByte, err := c.popTask(c.consumingQueue, 0)
   103  				if err != nil && !errors.Is(err, redis.Nil) {
   104  					fmt.Println("popTask err:", err)
   105  				}
   106  				// 如果是有效数据,发送给worker
   107  				if len(tByte) > 0 {
   108  					worker <- tByte
   109  				}
   110  				pool <- struct{}{}
   111  			}
   112  		}
   113  	}()
   114  	c.delayedWg.Add(1)
   115  	go func() {
   116  		defer c.delayedWg.Done()
   117  		for {
   118  			select {
   119  			case <-c.GetStopCtx().Done():
   120  				return
   121  			default:
   122  				tBody, err := c.popDelayedTask(c.delayedQueue, 0)
   123  				if err != nil {
   124  					fmt.Println("popDelayedTask err:", err)
   125  					continue
   126  				}
   127  				if tBody == nil {
   128  					continue
   129  				}
   130  				t := task.Signature{}
   131  				if err = json.Unmarshal(tBody, &t); err != nil {
   132  					fmt.Println("unmarshal err:", err)
   133  					continue
   134  				}
   135  				if err = c.Publish(c.GetStopCtx(), &t); err != nil {
   136  					fmt.Println("publish err:", err)
   137  					continue
   138  				}
   139  			}
   140  		}
   141  	}()
   142  
   143  	if err = c.consume(worker, concurrency, handler); err != nil {
   144  		return c.GetRetry(), err
   145  	}
   146  	c.processingWg.Wait()
   147  	return c.GetRetry(), err
   148  }
   149  
   150  // popTask 弹出任务
   151  func (c *ControllerRedis) popTask(queue string, blockTime int64) ([]byte, error) {
   152  	if blockTime <= 0 {
   153  		blockTime = int64(1000 * time.Millisecond)
   154  	}
   155  	items, err := c.client.BLPop(context.Background(), time.Duration(blockTime), queue).Result()
   156  	if err != nil {
   157  		return nil, err
   158  	}
   159  	// items[0] - the name of the key where an element was popped
   160  	// items[1] - the value of the popped element
   161  	if len(items) != 2 {
   162  		return nil, redis.Nil
   163  	}
   164  	result := []byte(items[1])
   165  	return result, nil
   166  }
   167  
   168  // popDelayedTask 弹出延时任务,因为延时任务是使用Redis ZSet
   169  func (c *ControllerRedis) popDelayedTask(queue string, blockTime int64) ([]byte, error) {
   170  	if blockTime <= 0 {
   171  		blockTime = int64(1000 * time.Millisecond)
   172  	}
   173  	var result []byte
   174  	for {
   175  		time.Sleep(time.Duration(blockTime))
   176  		watchFn := func(tx *redis.Tx) error {
   177  			now := time.Now().Local().UnixNano()
   178  			max := strconv.FormatInt(now, 10)
   179  			items, err := tx.ZRevRangeByScore(c.GetStopCtx(), queue, &redis.ZRangeBy{Min: "0", Max: max, Offset: 0, Count: 1}).Result()
   180  			if err != nil {
   181  				return err
   182  			}
   183  			if len(items) != 1 {
   184  				return redis.Nil
   185  			}
   186  			_, err = tx.TxPipelined(c.GetStopCtx(), func(pipe redis.Pipeliner) error {
   187  				pipe.ZRem(c.GetStopCtx(), queue, items[0])
   188  				result = []byte(items[0])
   189  				return nil
   190  			})
   191  			return err
   192  		}
   193  		if err := c.client.Watch(c.GetStopCtx(), watchFn, queue); err != nil {
   194  			break
   195  		}
   196  	}
   197  	return result, nil
   198  }
   199  
   200  // consume 消费
   201  func (c *ControllerRedis) consume(worker <-chan []byte, concurrency int, handler task.Processor) error {
   202  	// 初始化工作区
   203  	pool := make(chan struct{}, concurrency)
   204  	errorsChan := make(chan error, concurrency*2)
   205  
   206  	// 填满并发池
   207  	for i := 0; i < concurrency; i++ {
   208  		pool <- struct{}{}
   209  	}
   210  	for {
   211  		select {
   212  		case <-c.GetStopCtx().Done():
   213  			return c.GetStopCtx().Err()
   214  		case err := <-errorsChan:
   215  			return err
   216  		case v, ok := <-worker:
   217  			if !ok {
   218  				return nil
   219  			}
   220  			// 阻塞等待
   221  			select {
   222  			case <-pool:
   223  			case <-c.GetStopCtx().Done():
   224  				return c.GetStopCtx().Err()
   225  			}
   226  			c.processingWg.Add(1)
   227  			go func() {
   228  				if err := c.consumeOne(v, c.consumingQueue, handler); err != nil {
   229  					errorsChan <- err
   230  				}
   231  				c.processingWg.Done()
   232  
   233  				pool <- struct{}{}
   234  			}()
   235  		}
   236  	}
   237  }
   238  
   239  func (c *ControllerRedis) consumeOne(taskBody []byte, queue string, handler task.Processor) error {
   240  	t := task.Signature{}
   241  	decoder := json.NewDecoder(bytes.NewReader(taskBody))
   242  	decoder.UseNumber()
   243  	if err := decoder.Decode(&t); err != nil {
   244  		return err
   245  	}
   246  
   247  	if !c.IsRegisterTask(t.Name) {
   248  		if t.IgnoreNotRegisteredTask {
   249  			return nil
   250  		}
   251  		c.client.RPush(c.GetStopCtx(), queue, handler)
   252  		return nil
   253  	}
   254  	return handler.Process(&t)
   255  }
   256  
   257  func (c *ControllerRedis) StopConsuming() {
   258  	c.Broker.StopConsuming()
   259  	c.consumingWg.Wait()
   260  	c.delayedWg.Wait()
   261  	c.client.Close()
   262  }
   263  
   264  func (c *ControllerRedis) Publish(ctx context.Context, t *task.Signature) error {
   265  	tBody, err := json.Marshal(t)
   266  	if err != nil {
   267  		return err
   268  	}
   269  	if t.ETA != nil {
   270  		now := time.Now().Local()
   271  		if t.ETA.After(now) {
   272  			score := t.ETA.UnixNano()
   273  			return c.client.ZAdd(c.GetStopCtx(), c.delayedQueue, &redis.Z{Score: float64(score), Member: tBody}).Err()
   274  		}
   275  	}
   276  	return c.client.RPush(c.GetStopCtx(), t.Router, tBody).Err()
   277  }
   278  
   279  func (c *ControllerRedis) GetPendingTasks(queue string) ([]*task.Signature, error) {
   280  	results, err := c.client.LRange(c.GetStopCtx(), queue, 0, -1).Result()
   281  	if err != nil {
   282  		return nil, err
   283  	}
   284  	taskSlice := make([]*task.Signature, 0, len(results))
   285  	for _, result := range results {
   286  		var t task.Signature
   287  		err = json.Unmarshal([]byte(result), &t)
   288  		if err != nil {
   289  			return nil, err
   290  		}
   291  		taskSlice = append(taskSlice, &t)
   292  	}
   293  	return taskSlice, nil
   294  }
   295  
   296  func (c *ControllerRedis) GetDelayedTasks() ([]*task.Signature, error) {
   297  	results, err := c.client.LRange(c.GetStopCtx(), c.delayedQueue, 0, -1).Result()
   298  	if err != nil {
   299  		return nil, err
   300  	}
   301  	taskSlice := make([]*task.Signature, 0, len(results))
   302  	for _, result := range results {
   303  		var t task.Signature
   304  		err = json.Unmarshal([]byte(result), &t)
   305  		if err != nil {
   306  			return nil, err
   307  		}
   308  		taskSlice = append(taskSlice, &t)
   309  	}
   310  	return taskSlice, nil
   311  }
   312  
   313  func NewControllerRedis(broker *broker.Broker, client redis.UniversalClient, consumingQueue, delayedQueue string) controller.Controller {
   314  	return &ControllerRedis{
   315  		Broker:         broker,
   316  		client:         client,
   317  		lock:           redsync.New(goredis.NewPool(client)),
   318  		consumingQueue: consumingQueue,
   319  		delayedQueue:   delayedQueue,
   320  	}
   321  }