
     1  package cluster
     3  import (
     4  	"log/slog"
     5  	"sync"
     6  	"sync/atomic"
     7  	"time"
     9  	""
    10  	""
    11  	""
    12  	""
    13  )
    15  // PublishingErrorHandler decides what to do with a publishing error in BatchingProducer
    16  type PublishingErrorHandler func(retries int, e error, batch *PubSubBatch) *PublishingErrorDecision
    18  type BatchingProducerConfig struct {
    19  	// Maximum size of the published batch. Default: 2000.
    20  	BatchSize int
    21  	// Max size of the requests waiting in queue. If value is provided, the producer will throw
    22  	// ProducerQueueFullException when queue size is exceeded. If 0 or unset, the queue is unbounded
    23  	// Note that bounded queue has better performance than unbounded queue.
    24  	// Default: 0 (unbounded)
    25  	MaxQueueSize int
    27  	// How long to wait for the publishing to complete.
    28  	// Default: 5s
    29  	PublishTimeout time.Duration
    31  	// Error handler that can decide what to do with an error when publishing a batch.
    32  	// Default: Fail and stop the BatchingProducer
    33  	OnPublishingError PublishingErrorHandler
    35  	// A throttle for logging from this producer. By default, a throttle shared between all instances of
    36  	// BatchingProducer is used, that allows for 10 events in 1 second.
    37  	LogThrottle actor.ShouldThrottle
    39  	// Optional idle timeout which will specify to the `IPublisher` how long it should wait before invoking clean
    40  	// up code to recover resources.
    41  	PublisherIdleTimeout time.Duration
    42  }
    44  func newBatchingProducerConfig(logger *slog.Logger, opts ...BatchingProducerConfigOption) *BatchingProducerConfig {
    45  	config := &BatchingProducerConfig{
    46  		BatchSize:      2000,
    47  		PublishTimeout: 5 * time.Second,
    48  		OnPublishingError: func(retries int, e error, batch *PubSubBatch) *PublishingErrorDecision {
    49  			return FailBatchAndStop
    50  		},
    51  		LogThrottle: actor.NewThrottleWithLogger(logger, 10, time.Second, func(logger *slog.Logger, i int32) {
    52  			logger.Info("[BatchingProducer] Throttled logs", slog.Int("count", int(i)))
    53  		}),
    54  	}
    56  	for _, opt := range opts {
    57  		opt(config)
    58  	}
    60  	return config
    61  }
    63  type BatchingProducer struct {
    64  	config           *BatchingProducerConfig
    65  	topic            string
    66  	publisher        Publisher
    67  	publisherChannel channel[produceMessage]
    68  	loopCancel       context.CancelFunc
    69  	loopDone         chan struct{}
    70  	msgLeft          uint32
    71  }
    73  func NewBatchingProducer(publisher Publisher, topic string, opts ...BatchingProducerConfigOption) *BatchingProducer {
    74  	config := newBatchingProducerConfig(publisher.Logger(), opts...)
    75  	p := &BatchingProducer{
    76  		config:    config,
    77  		topic:     topic,
    78  		publisher: publisher,
    79  		msgLeft:   0,
    80  		loopDone:  make(chan struct{}),
    81  	}
    82  	if config.MaxQueueSize > 0 {
    83  		p.publisherChannel = newBoundedChannel[produceMessage](config.MaxQueueSize)
    84  	} else {
    85  		p.publisherChannel = newUnboundedChannel[produceMessage]()
    86  	}
    87  	ctx, cancelFunc := context.WithCancel(context.Background())
    88  	p.loopCancel = cancelFunc
    89  	go p.publishLoop(ctx)
    91  	return p
    92  }
    94  type pubsubBatchWithReceipts struct {
    95  	batch  *PubSubBatch
    96  	ctxArr []context.Context
    97  }
    99  // newPubSubBatchWithReceipts creates a new pubsubBatchWithReceipts
   100  func newPubSubBatchWithReceipts() *pubsubBatchWithReceipts {
   101  	return &pubsubBatchWithReceipts{
   102  		batch:  &PubSubBatch{Envelopes: make([]proto.Message, 0, 10)},
   103  		ctxArr: make([]context.Context, 0, 10),
   104  	}
   105  }
   107  type produceMessage struct {
   108  	message proto.Message
   109  	ctx     context.Context
   110  }
   112  // Dispose stops the producer and releases all resources.
   113  func (p *BatchingProducer) Dispose() {
   114  	p.loopCancel()
   115  	p.publisherChannel.broadcast()
   116  	<-p.loopDone
   117  }
   119  // ProduceProcessInfo is the context for a Produce call
   120  type ProduceProcessInfo struct {
   121  	Finished   chan struct{}
   122  	Err        error
   123  	cancelFunc context.CancelFunc
   124  	cancelled  chan struct{}
   125  }
   127  // IsCancelled returns true if the context has been cancelled
   128  func (p *ProduceProcessInfo) IsCancelled() bool {
   129  	select {
   130  	case <-p.cancelled:
   131  		return true
   132  	default:
   133  		return false
   134  	}
   135  }
   137  // IsFinished returns true if the context has been finished
   138  func (p *ProduceProcessInfo) IsFinished() bool {
   139  	select {
   140  	case <-p.Finished:
   141  		return true
   142  	default:
   143  		return false
   144  	}
   145  }
   147  // setErr sets the error for the ProduceProcessInfo
   148  func (p *ProduceProcessInfo) setErr(err error) {
   149  	p.Err = err
   150  	p.cancelFunc()
   151  	close(p.Finished)
   152  }
   154  // cancel the ProduceProcessInfo context
   155  func (p *ProduceProcessInfo) cancel() {
   156  	p.cancelFunc()
   157  	close(p.Finished)
   158  	close(p.cancelled)
   159  }
   161  // success closes the ProduceProcessInfo Finished channel
   162  func (p *ProduceProcessInfo) success() {
   163  	p.cancelFunc()
   164  	close(p.Finished)
   165  }
   167  type produceProcessInfoKey struct{}
   169  // GetProduceProcessInfo adds a new produce info to the BatchingProducer.Produce context
   170  func (p *BatchingProducer) getProduceProcessInfo(ctx context.Context) *ProduceProcessInfo {
   171  	return ctx.Value(produceProcessInfoKey{}).(*ProduceProcessInfo)
   172  }
   174  // Produce a message to producer queue. The return info can be used to wait for the message to be published.
   175  func (p *BatchingProducer) Produce(ctx context.Context, message proto.Message) (*ProduceProcessInfo, error) {
   176  	ctx, cancel := context.WithCancel(ctx)
   177  	info := &ProduceProcessInfo{
   178  		Finished:   make(chan struct{}),
   179  		cancelled:  make(chan struct{}),
   180  		cancelFunc: cancel,
   181  	}
   182  	ctx = context.WithValue(ctx, produceProcessInfoKey{}, info)
   183  	if !p.publisherChannel.tryWrite(produceMessage{
   184  		message: message,
   185  		ctx:     ctx,
   186  	}) {
   187  		if p.publisherChannel.isComplete() {
   188  			return info, &InvalidOperationException{Topic: p.topic}
   189  		}
   190  		return info, &ProducerQueueFullException{topic: p.topic}
   191  	}
   192  	return info, nil
   193  }
   195  // publishLoop is the main loop of the producer. It reads messages from the queue and publishes them in batches.
   196  func (p *BatchingProducer) publishLoop(ctx context.Context) {
   197  	defer close(p.loopDone)
   199  	p.publisher.Logger().Debug("Producer is starting the publisher loop for topic", slog.String("topic", p.topic))
   200  	batchWrapper := newPubSubBatchWithReceipts()
   202  	handleUnrecoverableError := func(err error) {
   203  		p.stopAcceptingNewMessages()
   204  		if p.config.LogThrottle() == actor.Open {
   205  			p.publisher.Logger().Error("Error in the publisher loop of Producer for topic", slog.String("topic", p.topic), slog.Any("error", err))
   206  		}
   207  		p.failBatch(batchWrapper, err)
   208  		p.failPendingMessages(err)
   209  	}
   211  	_, err := p.publisher.Initialize(ctx, p.topic, PublisherConfig{IdleTimeout: p.config.PublisherIdleTimeout})
   212  	if err != nil && err != context.Canceled {
   213  		handleUnrecoverableError(err)
   214  	}
   216  loop:
   217  	for {
   218  		select {
   219  		case <-ctx.Done():
   220  			p.stopAcceptingNewMessages()
   221  			break loop
   222  		default:
   223  			if msg, ok := p.publisherChannel.tryRead(); ok {
   225  				// if msg ctx not done
   226  				select {
   227  				case <-msg.ctx.Done():
   228  					p.getProduceProcessInfo(msg.ctx).cancel()
   229  				default:
   230  					batchWrapper.batch.Envelopes = append(batchWrapper.batch.Envelopes, msg.message)
   231  					batchWrapper.ctxArr = append(batchWrapper.ctxArr, msg.ctx)
   232  				}
   234  				if len(batchWrapper.batch.Envelopes) < p.config.BatchSize {
   235  					continue
   236  				}
   238  				err := p.publishBatch(ctx, batchWrapper)
   239  				if err != nil {
   240  					handleUnrecoverableError(err)
   241  					break loop
   242  				}
   243  				batchWrapper = newPubSubBatchWithReceipts()
   244  			} else {
   245  				if len(batchWrapper.batch.Envelopes) > 0 {
   246  					err := p.publishBatch(ctx, batchWrapper)
   247  					if err != nil {
   248  						handleUnrecoverableError(err)
   249  						break loop
   250  					}
   251  					batchWrapper = newPubSubBatchWithReceipts()
   252  				}
   253  				p.publisherChannel.waitToRead()
   254  			}
   255  		}
   256  	}
   257  	p.cancelBatch(batchWrapper)
   258  	p.cancelPendingMessages()
   259  }
   261  // cancelPendingMessages cancels all pending messages
   262  func (p *BatchingProducer) cancelPendingMessages() {
   263  	for {
   264  		if msg, ok := p.publisherChannel.tryRead(); ok {
   265  			p.getProduceProcessInfo(msg.ctx).cancel()
   266  		} else {
   267  			break
   268  		}
   269  	}
   270  }
   272  // cancelBatch cancels all contexts in the batch wrapper
   273  func (p *BatchingProducer) cancelBatch(batchWrapper *pubsubBatchWithReceipts) {
   274  	for _, ctx := range batchWrapper.ctxArr {
   275  		p.getProduceProcessInfo(ctx).cancel()
   276  	}
   278  	// ensure once cancelled, we won't touch the batch anymore
   279  	p.clearBatch(batchWrapper)
   280  }
   282  // failPendingMessages fails all pending messages
   283  func (p *BatchingProducer) failPendingMessages(err error) {
   284  	for {
   285  		if msg, ok := p.publisherChannel.tryRead(); ok {
   286  			p.getProduceProcessInfo(msg.ctx).setErr(err)
   287  		} else {
   288  			break
   289  		}
   290  	}
   291  }
   293  // failBatch marks all contexts in the batch wrapper as failed
   294  func (p *BatchingProducer) failBatch(batchWrapper *pubsubBatchWithReceipts, err error) {
   295  	for _, ctx := range batchWrapper.ctxArr {
   296  		p.getProduceProcessInfo(ctx).setErr(err)
   297  	}
   299  	// ensure once failed, we won't touch the batch anymore
   300  	p.clearBatch(batchWrapper)
   301  }
   303  // clearBatch clears the batch wrapper
   304  func (p *BatchingProducer) clearBatch(batchWrapper *pubsubBatchWithReceipts) {
   305  	batchWrapper.batch = &PubSubBatch{Envelopes: make([]proto.Message, 0, 10)}
   306  	batchWrapper.ctxArr = batchWrapper.ctxArr[:0]
   307  }
   309  // completeBatch marks all contexts in the batch wrapper as completed
   310  func (p *BatchingProducer) completeBatch(batchWrapper *pubsubBatchWithReceipts) {
   311  	for _, ctx := range batchWrapper.ctxArr {
   312  		p.getProduceProcessInfo(ctx).success()
   313  	}
   315  	// ensure once completed, we won't touch the batch anymore
   316  	p.clearBatch(batchWrapper)
   317  }
   319  // removeCancelledFromBatch removes all cancelled contexts from the batch wrapper
   320  func (p *BatchingProducer) removeCancelledFromBatch(batchWrapper *pubsubBatchWithReceipts) {
   321  	for i := len(batchWrapper.ctxArr) - 1; i >= 0; i-- {
   322  		select {
   323  		case <-batchWrapper.ctxArr[i].Done():
   324  			info := p.getProduceProcessInfo(batchWrapper.ctxArr[i])
   325  			select {
   326  			case <-info.Finished:
   327  				// if the message is already finished, we don't need to do anything
   328  			default:
   329  				info.cancel()
   330  			}
   332  			batchWrapper.batch.Envelopes = append(batchWrapper.batch.Envelopes[:i], batchWrapper.batch.Envelopes[i+1:]...)
   333  			batchWrapper.ctxArr = append(batchWrapper.ctxArr[:i], batchWrapper.ctxArr[i+1:]...)
   334  		default:
   335  			continue
   336  		}
   337  	}
   338  }
   340  // stopAcceptingNewMessages stops accepting new messages into the channel.
   341  func (p *BatchingProducer) stopAcceptingNewMessages() {
   342  	p.publisherChannel.complete()
   343  }
   345  // publishBatch publishes a batch of messages using Publisher.
   346  func (p *BatchingProducer) publishBatch(ctx context.Context, batchWrapper *pubsubBatchWithReceipts) error {
   347  	retries := 0
   348  	retry := true
   350  loop:
   351  	for retry {
   352  		select {
   353  		case <-ctx.Done():
   354  			p.cancelBatch(batchWrapper)
   355  			break loop
   356  		default:
   357  			retries++
   358  			_, err := p.publisher.PublishBatch(ctx, p.topic, batchWrapper.batch, WithTimeout(p.config.PublishTimeout))
   359  			if err != nil {
   360  				decision := p.config.OnPublishingError(retries, err, batchWrapper.batch)
   361  				if decision == FailBatchAndStop {
   362  					p.stopAcceptingNewMessages()
   363  					p.failBatch(batchWrapper, err)
   364  					return err // let the main producer loop exit
   365  				}
   367  				if p.config.LogThrottle() == actor.Open {
   368  					p.publisher.Logger().Warn("Error while publishing batch", slog.Any("error", err))
   369  				}
   371  				if decision == FailBatchAndContinue {
   372  					p.failBatch(batchWrapper, err)
   373  					return nil
   374  				}
   376  				// the decision is to retry
   377  				// if any of the messages have been canceled in the meantime, remove them and cancel the delivery report
   378  				p.removeCancelledFromBatch(batchWrapper)
   380  				if len(batchWrapper.batch.Envelopes) == 0 {
   381  					retry = false
   382  				} else if decision.Delay > 0 {
   383  					time.Sleep(decision.Delay)
   384  				}
   386  				continue
   387  			}
   389  			retry = false
   390  			p.completeBatch(batchWrapper)
   391  		}
   392  	}
   394  	return nil
   395  }
   397  type ProducerQueueFullException struct {
   398  	topic string
   399  }
   401  func (p *ProducerQueueFullException) Error() string {
   402  	return "Producer for topic " + p.topic + " has full queue"
   403  }
   405  func (p *ProducerQueueFullException) Is(target error) bool {
   406  	_, ok := target.(*ProducerQueueFullException)
   407  	return ok
   408  }
   410  type InvalidOperationException struct {
   411  	Topic string
   412  }
   414  func (i *InvalidOperationException) Is(err error) bool {
   415  	_, ok := err.(*InvalidOperationException)
   416  	return ok
   417  }
   419  func (i *InvalidOperationException) Error() string {
   420  	return "Producer for topic " + i.Topic + " is stopped, cannot produce more messages."
   421  }
   423  // channel is a wrapper around a channel that can be used to read and write messages.
   424  // messages must be pointers.
   425  type channel[T any] interface {
   426  	tryWrite(msg T) bool
   427  	tryRead() (T, bool)
   428  	isComplete() bool
   429  	complete()
   430  	empty() bool
   431  	waitToRead()
   432  	broadcast()
   433  }
   435  // BoundedChannel is a bounded channel with the given capacity.
   436  type boundedChannel[T any] struct {
   437  	capacity int
   438  	c        chan T
   439  	quit     chan struct{}
   440  	once     *sync.Once
   441  	cond     *sync.Cond
   442  	left     *atomic.Bool
   443  }
   445  func (b *boundedChannel[T]) tryWrite(msg T) bool {
   446  	select {
   447  	case b.c <- msg:
   448  		b.cond.Broadcast()
   449  		return true
   450  	case <-b.quit:
   451  		return false
   452  	default:
   453  		return false
   454  	}
   455  }
   457  func (b *boundedChannel[T]) tryRead() (msg T, ok bool) {
   458  	var msgDefault T
   459  	select {
   460  	case msg, ok = <-b.c:
   461  		return
   462  	default:
   463  		return msgDefault, false
   464  	}
   465  }
   467  func (b *boundedChannel[T]) isComplete() bool {
   468  	select {
   469  	case <-b.quit:
   470  		return true
   471  	default:
   472  		return false
   473  	}
   474  }
   476  func (b *boundedChannel[T]) complete() {
   477  	b.once.Do(func() {
   478  		close(b.quit)
   479  	})
   480  }
   482  func (b *boundedChannel[T]) empty() bool {
   483  	return len(b.c) == 0
   484  }
   486  func (b *boundedChannel[T]) waitToRead() {
   487  	b.cond.L.Lock()
   488  	defer b.cond.L.Unlock()
   489  	for b.empty() && !b.left.Load() {
   490  		b.cond.Wait()
   491  	}
   492  	b.left.Store(false)
   493  }
   495  func (b *boundedChannel[T]) broadcast() {
   496  	b.left.Store(true)
   497  	b.cond.Broadcast()
   498  }
   500  // newBoundedChannel creates a new bounded channel with the given capacity.
   501  func newBoundedChannel[T any](capacity int) channel[T] {
   502  	return &boundedChannel[T]{
   503  		capacity: capacity,
   504  		c:        make(chan T, capacity),
   505  		quit:     make(chan struct{}),
   506  		cond:     sync.NewCond(&sync.Mutex{}),
   507  		once:     &sync.Once{},
   508  		left:     &atomic.Bool{},
   509  	}
   510  }
   512  // UnboundedChannel is an unbounded channel.
   513  type unboundedChannel[T any] struct {
   514  	queue *mpsc.Queue
   515  	quit  chan struct{}
   516  	once  *sync.Once
   517  	cond  *sync.Cond
   518  	left  *atomic.Bool
   519  }
   521  func (u *unboundedChannel[T]) tryWrite(msg T) bool {
   522  	select {
   523  	case <-u.quit:
   524  		return false
   525  	default:
   526  		u.queue.Push(msg)
   527  		u.cond.Broadcast()
   528  		return true
   529  	}
   530  }
   532  func (u *unboundedChannel[T]) tryRead() (T, bool) {
   533  	var msg T
   534  	tmp := u.queue.Pop()
   535  	if tmp == nil {
   536  		return msg, false
   537  	} else {
   538  		u.cond.Broadcast()
   539  		return tmp.(T), true
   540  	}
   541  }
   543  func (u *unboundedChannel[T]) complete() {
   544  	u.once.Do(func() {
   545  		close(u.quit)
   546  	})
   547  }
   549  func (u *unboundedChannel[T]) isComplete() bool {
   550  	select {
   551  	case <-u.quit:
   552  		return true
   553  	default:
   554  		return false
   555  	}
   556  }
   558  func (u *unboundedChannel[T]) empty() bool {
   559  	return u.queue.Empty()
   560  }
   562  func (u *unboundedChannel[T]) waitToRead() {
   563  	u.cond.L.Lock()
   564  	defer u.cond.L.Unlock()
   565  	for u.empty() && !u.left.Load() {
   566  		u.cond.Wait()
   567  	}
   568  	u.left.Store(false)
   569  }
   571  func (u *unboundedChannel[T]) broadcast() {
   572  	u.left.Store(true)
   573  	u.cond.Broadcast()
   574  }
   576  // newUnboundedChannel creates a new unbounded channel.
   577  func newUnboundedChannel[T any]() channel[T] {
   578  	return &unboundedChannel[T]{
   579  		queue: mpsc.New(),
   580  		quit:  make(chan struct{}),
   581  		cond:  sync.NewCond(&sync.Mutex{}),
   582  		once:  &sync.Once{},
   583  		left:  &atomic.Bool{},
   584  	}
   585  }