github.com/Jeffail/benthos/v3@v3.65.0/lib/input/batcher.go (about)

     1  package input
     2  
     3  import (
     4  	"context"
     5  	"sync"
     6  	"time"
     7  
     8  	"github.com/Jeffail/benthos/v3/internal/shutdown"
     9  	"github.com/Jeffail/benthos/v3/internal/transaction"
    10  	"github.com/Jeffail/benthos/v3/lib/log"
    11  	"github.com/Jeffail/benthos/v3/lib/message/batch"
    12  	"github.com/Jeffail/benthos/v3/lib/metrics"
    13  	"github.com/Jeffail/benthos/v3/lib/types"
    14  )
    15  
    16  //------------------------------------------------------------------------------
    17  
    18  // Batcher wraps an input with a batch policy.
    19  type Batcher struct {
    20  	stats metrics.Type
    21  	log   log.Modular
    22  
    23  	child   Type
    24  	batcher *batch.Policy
    25  
    26  	messagesOut chan types.Transaction
    27  
    28  	shutSig *shutdown.Signaller
    29  }
    30  
    31  // NewBatcher creates a new Batcher around an input.
    32  func NewBatcher(
    33  	batcher *batch.Policy,
    34  	child Type,
    35  	log log.Modular,
    36  	stats metrics.Type,
    37  ) Type {
    38  	b := Batcher{
    39  		stats:       stats,
    40  		log:         log,
    41  		child:       child,
    42  		batcher:     batcher,
    43  		messagesOut: make(chan types.Transaction),
    44  		shutSig:     shutdown.NewSignaller(),
    45  	}
    46  	go b.loop()
    47  	return &b
    48  }
    49  
    50  //------------------------------------------------------------------------------
    51  
    52  func (m *Batcher) loop() {
    53  	defer func() {
    54  		go func() {
    55  			select {
    56  			case <-m.shutSig.CloseNowChan():
    57  				_ = m.child.WaitForClose(0)
    58  				_ = m.batcher.WaitForClose(0)
    59  			case <-m.shutSig.HasClosedChan():
    60  			}
    61  		}()
    62  
    63  		m.child.CloseAsync()
    64  		_ = m.child.WaitForClose(shutdown.MaximumShutdownWait())
    65  
    66  		m.batcher.CloseAsync()
    67  		_ = m.batcher.WaitForClose(shutdown.MaximumShutdownWait())
    68  
    69  		close(m.messagesOut)
    70  		m.shutSig.ShutdownComplete()
    71  	}()
    72  
    73  	var nextTimedBatchChan <-chan time.Time
    74  	if tNext := m.batcher.UntilNext(); tNext >= 0 {
    75  		nextTimedBatchChan = time.After(tNext)
    76  	}
    77  
    78  	pendingTrans := []*transaction.Tracked{}
    79  	pendingAcks := sync.WaitGroup{}
    80  
    81  	flushBatchFn := func() {
    82  		sendMsg := m.batcher.Flush()
    83  		if sendMsg == nil {
    84  			return
    85  		}
    86  
    87  		resChan := make(chan types.Response)
    88  		select {
    89  		case m.messagesOut <- types.NewTransaction(sendMsg, resChan):
    90  		case <-m.shutSig.CloseNowChan():
    91  			return
    92  		}
    93  
    94  		pendingAcks.Add(1)
    95  		go func(rChan <-chan types.Response, aggregatedTransactions []*transaction.Tracked) {
    96  			defer pendingAcks.Done()
    97  
    98  			select {
    99  			case <-m.shutSig.CloseNowChan():
   100  				return
   101  			case res, open := <-rChan:
   102  				if !open {
   103  					return
   104  				}
   105  				closeNowCtx, done := m.shutSig.CloseNowCtx(context.Background())
   106  				for _, c := range aggregatedTransactions {
   107  					if err := c.Ack(closeNowCtx, res.Error()); err != nil {
   108  						done()
   109  						return
   110  					}
   111  				}
   112  				done()
   113  			}
   114  		}(resChan, pendingTrans)
   115  		pendingTrans = nil
   116  	}
   117  
   118  	defer func() {
   119  		// Final flush of remaining documents.
   120  		m.log.Debugln("Flushing remaining messages of batch.")
   121  		flushBatchFn()
   122  
   123  		// Wait for all pending acks to resolve.
   124  		m.log.Debugln("Waiting for pending acks to resolve before shutting down.")
   125  		pendingAcks.Wait()
   126  		m.log.Debugln("Pending acks resolved.")
   127  	}()
   128  
   129  	for {
   130  		if nextTimedBatchChan == nil {
   131  			if tNext := m.batcher.UntilNext(); tNext >= 0 {
   132  				nextTimedBatchChan = time.After(tNext)
   133  			}
   134  		}
   135  
   136  		var flushBatch bool
   137  		select {
   138  		case tran, open := <-m.child.TransactionChan():
   139  			if !open {
   140  				// If we're waiting for a timed batch then we will respect it.
   141  				if nextTimedBatchChan != nil {
   142  					select {
   143  					case <-nextTimedBatchChan:
   144  					case <-m.shutSig.CloseAtLeisureChan():
   145  						return
   146  					}
   147  				}
   148  				flushBatchFn()
   149  				return
   150  			}
   151  
   152  			trackedTran := transaction.NewTracked(tran.Payload, tran.ResponseChan)
   153  			trackedTran.Message().Iter(func(i int, p types.Part) error {
   154  				if m.batcher.Add(p) {
   155  					flushBatch = true
   156  				}
   157  				return nil
   158  			})
   159  			pendingTrans = append(pendingTrans, trackedTran)
   160  		case <-nextTimedBatchChan:
   161  			flushBatch = true
   162  			nextTimedBatchChan = nil
   163  		case <-m.shutSig.CloseAtLeisureChan():
   164  			return
   165  		}
   166  
   167  		if flushBatch {
   168  			flushBatchFn()
   169  		}
   170  	}
   171  }
   172  
   173  // Connected returns true if the underlying input is connected.
   174  func (m *Batcher) Connected() bool {
   175  	return m.child.Connected()
   176  }
   177  
   178  // TransactionChan returns the channel used for consuming messages from this
   179  // buffer.
   180  func (m *Batcher) TransactionChan() <-chan types.Transaction {
   181  	return m.messagesOut
   182  }
   183  
   184  // CloseAsync shuts down the Batcher and stops processing messages.
   185  func (m *Batcher) CloseAsync() {
   186  	m.shutSig.CloseAtLeisure()
   187  }
   188  
   189  // WaitForClose blocks until the Batcher output has closed down.
   190  func (m *Batcher) WaitForClose(timeout time.Duration) error {
   191  	go func() {
   192  		<-time.After(timeout - time.Second)
   193  		m.shutSig.CloseNow()
   194  	}()
   195  	select {
   196  	case <-m.shutSig.HasClosedChan():
   197  	case <-time.After(timeout):
   198  		return types.ErrTimeout
   199  	}
   200  	return nil
   201  }
   202  
   203  //------------------------------------------------------------------------------