
     1  package redis
     3  import (
     4  	"context"
     5  	"sync"
     6  	"time"
     8  	""
     9  	""
    10  	""
    12  	""
    13  	""
    14  	""
    15  )
    17  const (
    18  	groupStartid   = ">"
    19  	redisBusyGroup = "BUSYGROUP Consumer Group name already exists"
    20  )
    22  const (
    23  	// NoSleep can be set to SubscriberConfig.NackResendSleep
    24  	NoSleep time.Duration = -1
    26  	DefaultBlockTime = time.Millisecond * 100
    28  	DefaultClaimInterval = time.Second * 5
    30  	DefaultClaimBatchSize = int64(100)
    32  	DefaultMaxIdleTime = time.Second * 60
    34  	DefaultCheckConsumersInterval = time.Second * 300
    35  	DefaultConsumerTimeout        = time.Second * 600
    36  )
    38  type Subscriber struct {
    39  	config        SubscriberConfig
    40  	client        redis.UniversalClient
    41  	logger        watermill.LoggerAdapter
    42  	closing       chan struct{}
    43  	subscribersWg sync.WaitGroup
    45  	closed     bool
    46  	closeMutex sync.Mutex
    47  }
    49  // NewSubscriber creates a new redis stream Subscriber.
    50  func NewSubscriber(config SubscriberConfig, logger watermill.LoggerAdapter) (*Subscriber, error) {
    51  	config.setDefaults()
    53  	if err := config.Validate(); err != nil {
    54  		return nil, err
    55  	}
    57  	if logger == nil {
    58  		logger = &watermill.NopLogger{}
    59  	}
    61  	return &Subscriber{
    62  		config:  config,
    63  		client:  config.Client,
    64  		logger:  logger,
    65  		closing: make(chan struct{}),
    66  	}, nil
    67  }
    69  type SubscriberConfig struct {
    70  	Client redis.UniversalClient
    72  	Unmarshaller Unmarshaller
    74  	// Redis stream consumer id, paired with ConsumerGroup.
    75  	Consumer string
    76  	// When empty, fan-out mode will be used.
    77  	ConsumerGroup string
    79  	// How long after Nack message should be redelivered.
    80  	NackResendSleep time.Duration
    82  	// Block to wait next redis stream message.
    83  	BlockTime time.Duration
    85  	// Claim idle pending message interval.
    86  	ClaimInterval time.Duration
    88  	// How many pending messages are claimed at most each claim interval.
    89  	ClaimBatchSize int64
    91  	// How long should we treat a pending message as claimable.
    92  	MaxIdleTime time.Duration
    94  	// Check consumer status interval.
    95  	CheckConsumersInterval time.Duration
    97  	// After this timeout an idle consumer with no pending messages will be removed from the consumer group.
    98  	ConsumerTimeout time.Duration
   100  	// Start consumption from the specified message ID.
   101  	// When using "0", the consumer group will consume from the very first message.
   102  	// When using "$", the consumer group will consume from the latest message.
   103  	OldestId string
   105  	// If this is set, it will be called to decide whether a pending message that
   106  	// has been idle for more than MaxIdleTime should actually be claimed.
   107  	// If this is not set, then all pending messages that have been idle for more than MaxIdleTime will be claimed.
   108  	// This can be useful e.g. for tasks where the processing time can be very variable -
   109  	// so we can't just use a short MaxIdleTime; but at the same time dead
   110  	// consumers should be spotted quickly - so we can't just use a long MaxIdleTime either.
   111  	// In such cases, if we have another way for checking consumers' health, then we can
   112  	// leverage that in this callback.
   113  	ShouldClaimPendingMessage func(redis.XPendingExt) bool
   115  	DisableRedisConnClose bool
   116  }
   118  func (sc *SubscriberConfig) setDefaults() {
   119  	if sc.Unmarshaller == nil {
   120  		sc.Unmarshaller = DefaultMarshallerUnmarshaller{}
   121  	}
   122  	if sc.Consumer == "" {
   123  		sc.Consumer = utils.ShortUUID()
   124  	}
   125  	if sc.NackResendSleep == 0 {
   126  		sc.NackResendSleep = NoSleep
   127  	}
   128  	if sc.BlockTime == 0 {
   129  		sc.BlockTime = DefaultBlockTime
   130  	}
   131  	if sc.ClaimInterval == 0 {
   132  		sc.ClaimInterval = DefaultClaimInterval
   133  	}
   134  	if sc.ClaimBatchSize == 0 {
   135  		sc.ClaimBatchSize = DefaultClaimBatchSize
   136  	}
   137  	if sc.MaxIdleTime == 0 {
   138  		sc.MaxIdleTime = DefaultMaxIdleTime
   139  	}
   140  	if sc.CheckConsumersInterval == 0 {
   141  		sc.CheckConsumersInterval = DefaultCheckConsumersInterval
   142  	}
   143  	if sc.ConsumerTimeout == 0 {
   144  		sc.ConsumerTimeout = DefaultConsumerTimeout
   145  	}
   146  	// Consume from scratch by default
   147  	if sc.OldestId == "" {
   148  		sc.OldestId = "0"
   149  	}
   150  }
   152  func (sc *SubscriberConfig) Validate() error {
   153  	if sc.Client == nil {
   154  		return errors.New("redis client is empty")
   155  	}
   156  	return nil
   157  }
   159  func (s *Subscriber) Subscribe(ctx context.Context, topic string) (<-chan *message.Message, error) {
   160  	if s.closed {
   161  		return nil, errors.New("subscriber closed")
   162  	}
   164  	s.subscribersWg.Add(1)
   166  	logFields := watermill.LogFields{
   167  		"provider":       "redis",
   168  		"topic":          topic,
   169  		"consumer_group": s.config.ConsumerGroup,
   170  		"consumer_uuid":  s.config.Consumer,
   171  	}
   172  	s.logger.Info("[Common] watermill redis subscribing to redis stream topic", logFields)
   174  	// we don't want to have buffered channel to not consume messsage from redis stream when consumer is not consuming
   175  	output := make(chan *message.Message)
   177  	consumeClosed, err := s.consumeMessages(ctx, topic, output, logFields)
   178  	if err != nil {
   179  		s.subscribersWg.Done()
   180  		return nil, err
   181  	}
   183  	go func() {
   184  		<-consumeClosed
   185  		close(output)
   186  		s.subscribersWg.Done()
   187  	}()
   189  	return output, nil
   190  }
   192  func (s *Subscriber) consumeMessages(ctx context.Context, topic string,
   193  	output chan *message.Message, logFields watermill.LogFields) (consumeMessageClosed chan struct{}, err error) {
   194  	s.logger.Info("Starting consuming", logFields)
   196  	ctx, cancel := context.WithCancel(ctx)
   197  	go func() {
   198  		select {
   199  		case <-s.closing:
   200  			s.logger.Debug("[Common] watermill redis closing subscriber, cancelling consumeMessages", logFields)
   201  			cancel()
   202  		case <-ctx.Done():
   203  			// avoid goroutine leak
   204  		}
   205  	}()
   206  	if s.config.ConsumerGroup != "" {
   207  		// create consumer group
   208  		if _, err := s.client.XGroupCreateMkStream(ctx, topic,
   209  			s.config.ConsumerGroup, s.config.OldestId).Result(); err != nil && err.Error() != redisBusyGroup {
   210  			return nil, err
   211  		}
   212  	}
   214  	consumeMessageClosed, err = s.consumeStreams(ctx, topic, output, logFields)
   215  	if err != nil {
   216  		s.logger.Debug(
   217  			"[Common] watermill redis starting consume failed, cancelling context",
   218  			logFields.Add(watermill.LogFields{"err": err}),
   219  		)
   220  		cancel()
   221  		return nil, err
   222  	}
   224  	return consumeMessageClosed, nil
   225  }
   227  func (s *Subscriber) consumeStreams(ctx context.Context, stream string,
   228  	output chan *message.Message, logFields watermill.LogFields) (chan struct{}, error) {
   229  	messageHandler := s.createMessageHandler(output)
   230  	consumeMessageClosed := make(chan struct{})
   232  	go func() {
   233  		defer close(consumeMessageClosed)
   235  		readChannel := make(chan *redis.XStream, 1)
   236  		go, stream, readChannel, logFields)
   238  		for {
   239  			select {
   240  			case xs := <-readChannel:
   241  				if xs == nil {
   242  					s.logger.Debug(
   243  						"[Common] watermill redis readStreamChannel is closed, stopping readStream", logFields)
   244  					return
   245  				}
   246  				if err := messageHandler.processMessage(ctx, xs.Stream, &xs.Messages[0], logFields); err != nil {
   247  					s.logger.Error("[Common] watermill redis processMessage fail", err, logFields)
   248  					return
   249  				}
   250  			case <-s.closing:
   251  				s.logger.Debug("[Common] watermill redis subscriber is closing, stopping readStream", logFields)
   252  				return
   253  			case <-ctx.Done():
   254  				s.logger.Debug("[Common] watermill redis ctx was cancelled, stopping readStream", logFields)
   255  				return
   256  			}
   257  		}
   258  	}()
   260  	return consumeMessageClosed, nil
   261  }
   263  func (s *Subscriber) read(ctx context.Context, stream string,
   264  	readChannel chan<- *redis.XStream, logFields watermill.LogFields) {
   265  	wg := &sync.WaitGroup{}
   266  	subCtx, subCancel := context.WithCancel(ctx)
   267  	defer func() {
   268  		subCancel()
   269  		wg.Wait()
   270  		close(readChannel)
   271  	}()
   272  	var (
   273  		streamsGroup = []string{stream, groupStartid}
   275  		fanOutStartid               = "$"
   276  		countFanOut   int64         = 0
   277  		blockTime     time.Duration = 0
   279  		xss []redis.XStream
   280  		xs  *redis.XStream
   281  		err error
   282  	)
   284  	if s.config.ConsumerGroup != "" {
   285  		// 1. get pending message from idle consumer
   286  		wg.Add(1)
   287  		s.claim(subCtx, stream, readChannel, false, wg, logFields)
   289  		// 2. background
   290  		wg.Add(1)
   291  		go s.claim(subCtx, stream, readChannel, true, wg, logFields)
   293  		// check consumer status and remove idling consumers if possible
   294  		wg.Add(1)
   295  		go s.checkConsumers(subCtx, stream, wg, logFields)
   296  	}
   298  	for {
   299  		select {
   300  		case <-s.closing:
   301  			return
   302  		case <-ctx.Done():
   303  			return
   304  		default:
   305  			if s.config.ConsumerGroup != "" {
   306  				xss, err = s.client.XReadGroup(
   307  					ctx,
   308  					&redis.XReadGroupArgs{
   309  						Group:    s.config.ConsumerGroup,
   310  						Consumer: s.config.Consumer,
   311  						Streams:  streamsGroup,
   312  						Count:    1,
   313  						Block:    blockTime,
   314  					}).Result()
   315  			} else {
   316  				xss, err = s.client.XRead(
   317  					ctx,
   318  					&redis.XReadArgs{
   319  						Streams: []string{stream, fanOutStartid},
   320  						Count:   countFanOut,
   321  						Block:   blockTime,
   322  					}).Result()
   323  			}
   324  			if errors.Is(err, redis.Nil) {
   325  				continue
   326  			} else if err != nil {
   327  				if _, ok := utils.IsChannelClosed(s.closing); !ok &&
   328  					!errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
   329  					s.logger.Error("[Common] watermill redis read fail", err, logFields)
   330  				}
   331  			}
   332  			if len(xss) < 1 || len(xss[0].Messages) < 1 {
   333  				continue
   334  			}
   335  			// update last delivered message
   336  			xs = &xss[0]
   337  			if s.config.ConsumerGroup == "" {
   338  				fanOutStartid = xs.Messages[0].ID
   339  				countFanOut = 1
   340  			}
   342  			blockTime = s.config.BlockTime
   344  			select {
   345  			case <-s.closing:
   346  				return
   347  			case <-ctx.Done():
   348  				return
   349  			case readChannel <- xs:
   350  			}
   351  		}
   352  	}
   353  }
   355  func (s *Subscriber) claim(ctx context.Context, stream string,
   356  	readChannel chan<- *redis.XStream, keep bool, wg *sync.WaitGroup, logFields watermill.LogFields) {
   357  	var (
   358  		xps    []redis.XPendingExt
   359  		err    error
   360  		xp     redis.XPendingExt
   361  		xm     []redis.XMessage
   362  		tick   = time.NewTicker(s.config.ClaimInterval)
   363  		initCh = make(chan byte, 1)
   364  	)
   365  	defer func() {
   366  		tick.Stop()
   367  		close(initCh)
   368  		wg.Done()
   369  	}()
   370  	if !keep { // if not keep, run immediately
   371  		initCh <- 1
   372  	}
   374  OUTER_LOOP:
   375  	for {
   376  		select {
   377  		case <-s.closing:
   378  			return
   379  		case <-ctx.Done():
   380  			return
   381  		case <-tick.C:
   382  		case <-initCh:
   383  		}
   385  		xps, err = s.client.XPendingExt(ctx, &redis.XPendingExtArgs{
   386  			Stream:   stream,
   387  			Group:    s.config.ConsumerGroup,
   388  			Idle:     s.config.MaxIdleTime,
   389  			Start:    "0",
   390  			End:      "+",
   391  			Count:    s.config.ClaimBatchSize,
   392  			Consumer: "",
   393  		}).Result()
   394  		if err != nil {
   395  			s.logger.Error(
   396  				"[Common] watermill redis xpendingext fail",
   397  				err,
   398  				logFields,
   399  			)
   400  			continue
   401  		}
   402  		for _, xp = range xps {
   403  			shouldClaim := xp.Idle >= s.config.MaxIdleTime
   404  			if shouldClaim && s.config.ShouldClaimPendingMessage != nil {
   405  				shouldClaim = s.config.ShouldClaimPendingMessage(xp)
   406  			}
   408  			if shouldClaim {
   409  				// assign the ownership of a pending message to the current consumer
   410  				xm, err = s.client.XClaim(ctx, &redis.XClaimArgs{
   411  					Stream:   stream,
   412  					Group:    s.config.ConsumerGroup,
   413  					Consumer: s.config.Consumer,
   414  					// this is important: it ensures that 2 concurrent subscribers
   415  					// won't claim the same pending message at the same time
   416  					MinIdle:  s.config.MaxIdleTime,
   417  					Messages: []string{xp.ID},
   418  				}).Result()
   419  				if err != nil {
   420  					s.logger.Error(
   421  						"[Common] watermill redis xclaim fail",
   422  						err,
   423  						logFields.Add(watermill.LogFields{"xp": xp}),
   424  					)
   425  					continue OUTER_LOOP
   426  				}
   427  				if len(xm) > 0 {
   428  					select {
   429  					case <-s.closing:
   430  						return
   431  					case <-ctx.Done():
   432  						return
   433  					case readChannel <- &redis.XStream{Stream: stream, Messages: xm}:
   434  					}
   435  				}
   436  			}
   437  		}
   438  		if len(xps) == 0 || int64(len(xps)) < s.config.ClaimBatchSize { // done
   439  			if !keep {
   440  				return
   441  			}
   442  			continue
   443  		}
   444  	}
   445  }
   447  func (s *Subscriber) checkConsumers(ctx context.Context, stream string,
   448  	wg *sync.WaitGroup, logFields watermill.LogFields) {
   449  	tick := time.NewTicker(s.config.CheckConsumersInterval)
   450  	defer func() {
   451  		tick.Stop()
   452  		wg.Done()
   453  	}()
   455  	for {
   456  		select {
   457  		case <-s.closing:
   458  			return
   459  		case <-ctx.Done():
   460  			return
   461  		case <-tick.C:
   462  		}
   463  		xics, err := s.client.XInfoConsumers(ctx, stream, s.config.ConsumerGroup).Result()
   464  		if err != nil {
   465  			s.logger.Error(
   466  				"[Common] watermill redis xinfoconsumers failed",
   467  				err,
   468  				logFields,
   469  			)
   470  		}
   471  		for _, xic := range xics {
   472  			if xic.Idle < s.config.ConsumerTimeout {
   473  				continue
   474  			}
   475  			if xic.Pending == 0 {
   476  				if err = s.client.XGroupDelConsumer(ctx, stream, s.config.ConsumerGroup, xic.Name).Err(); err != nil {
   477  					s.logger.Error(
   478  						"[Common] watermill redis xgroupdelconsumer failed",
   479  						err,
   480  						logFields,
   481  					)
   482  				}
   483  			}
   484  		}
   485  	}
   486  }
   488  func (s *Subscriber) createMessageHandler(output chan *message.Message) messageHandler {
   489  	return messageHandler{
   490  		outputChannel:   output,
   491  		rc:              s.client,
   492  		consumerGroup:   s.config.ConsumerGroup,
   493  		unmarshaller:    s.config.Unmarshaller,
   494  		nackResendSleep: s.config.NackResendSleep,
   495  		logger:          s.logger,
   496  		closing:         s.closing,
   497  	}
   498  }
   500  func (s *Subscriber) Close() error {
   501  	s.closeMutex.Lock()
   502  	defer s.closeMutex.Unlock()
   504  	if s.closed {
   505  		return nil
   506  	}
   508  	s.closed = true
   509  	close(s.closing)
   510  	s.subscribersWg.Wait()
   512  	if !s.config.DisableRedisConnClose {
   513  		if err := s.client.Close(); err != nil {
   514  			return err
   515  		}
   516  	}
   518  	s.logger.Debug("[Common] watermill redis stream subscriber closed", nil)
   520  	return nil
   521  }
   523  type messageHandler struct {
   524  	outputChannel chan<- *message.Message
   525  	rc            redis.UniversalClient
   526  	consumerGroup string
   527  	unmarshaller  Unmarshaller
   529  	nackResendSleep time.Duration
   531  	logger  watermill.LoggerAdapter
   532  	closing chan struct{}
   533  }
   535  func (h *messageHandler) processMessage(ctx context.Context, stream string,
   536  	xm *redis.XMessage, messageLogFields watermill.LogFields) error {
   537  	receivedMsgLogFields := messageLogFields.Add(watermill.LogFields{
   538  		"xadd_id":        xm.ID,
   539  		"stream":         stream,
   540  		"message_raw_id": xm.ID,
   541  	})
   543  	h.logger.Trace("[Common] watermill received message from redis stream", receivedMsgLogFields)
   545  	msg, err := h.unmarshaller.Unmarshal(xm.Values)
   546  	if err != nil {
   547  		return errors.Wrapf(err, "message unmarshal failed")
   548  	}
   550  	ctx = context.WithValue(ctx, watermill.ContextKeyMessageUUID, msg.UUID)
   551  	ctx = context.WithValue(ctx, watermill.ContextKeyRawMessageID, xm.ID)
   552  	ctx, cancelCtx := context.WithCancel(ctx)
   553  	msg.SetContext(ctx)
   554  	defer cancelCtx()
   556  	receivedMsgLogFields = receivedMsgLogFields.Add(watermill.LogFields{
   557  		"message_uuid": msg.UUID,
   558  	})
   560  ResendLoop:
   561  	for {
   562  		select {
   563  		case h.outputChannel <- msg:
   564  			h.logger.Trace("[Common] watermill redis message sent to consumer", receivedMsgLogFields)
   565  		case <-h.closing:
   566  			h.logger.Trace("[Common] watermill redis closing, message discarded", receivedMsgLogFields)
   567  			return nil
   568  		case <-ctx.Done():
   569  			h.logger.Trace("[Common] watermill redis closing, ctx cancelled before sent to consumer",
   570  				receivedMsgLogFields)
   571  			return nil
   572  		}
   574  		select {
   575  		case <-msg.Acked():
   576  			if h.consumerGroup != "" {
   577  				// deadly retry ack
   578  				err := retry.Retry(func(attempt uint) error {
   579  					err := h.rc.XAck(ctx, stream, h.consumerGroup, xm.ID).Err()
   580  					return err
   581  				}, func(attempt uint) bool {
   582  					if attempt != 0 {
   583  						time.Sleep(time.Millisecond * 100)
   584  					}
   585  					return true
   586  				}, func(attempt uint) bool {
   587  					select {
   588  					case <-h.closing:
   589  					case <-ctx.Done():
   590  					default:
   591  						return true
   592  					}
   593  					return false
   594  				})
   595  				if err != nil {
   596  					h.logger.Error("[Common] watermill redis message acked fail", err, receivedMsgLogFields)
   597  				}
   598  			}
   599  			h.logger.Trace("[Common] watermill redis message acked", receivedMsgLogFields)
   600  			break ResendLoop
   601  		case <-msg.Nacked():
   602  			h.logger.Trace("[Common] watermill redis message nacked", receivedMsgLogFields)
   604  			// reset acks, etc.
   605  			msg = msg.Copy()
   606  			msg.SetContext(ctx)
   607  			if h.nackResendSleep != NoSleep {
   608  				time.Sleep(h.nackResendSleep)
   609  			}
   611  			continue ResendLoop
   612  		case <-h.closing:
   613  			h.logger.Trace("[Common] watermill redis closing, message discarded before ack", receivedMsgLogFields)
   614  			return nil
   615  		case <-ctx.Done():
   616  			h.logger.Trace("[Common] watermill redis closing, ctx cancelled before ack", receivedMsgLogFields)
   617  			return nil
   618  		}
   619  	}
   621  	return nil
   622  }