
     1  // Copyright 2022 Kentaro Hibino. All rights reserved.
     2  // Use of this source code is governed by a MIT license
     3  // that can be found in the LICENSE file.
     5  package asynq
     7  import (
     8  	"context"
     9  	"sync"
    10  	"time"
    12  	""
    13  	""
    14  )
    16  // An aggregator is responsible for checking groups and aggregate into one task
    17  // if any of the grouping condition is met.
    18  type aggregator struct {
    19  	logger *log.Logger
    20  	broker base.Broker
    21  	client *Client
    23  	// channel to communicate back to the long running "aggregator" goroutine.
    24  	done chan struct{}
    26  	// list of queue names to check and aggregate.
    27  	queues []string
    29  	// Group configurations
    30  	gracePeriod time.Duration
    31  	maxDelay    time.Duration
    32  	maxSize     int
    34  	// User provided group aggregator.
    35  	ga GroupAggregator
    37  	// interval used to check for aggregation
    38  	interval time.Duration
    40  	// sema is a counting semaphore to ensure the number of active aggregating function
    41  	// does not exceed the limit.
    42  	sema chan struct{}
    43  }
    45  type aggregatorParams struct {
    46  	logger          *log.Logger
    47  	broker          base.Broker
    48  	queues          []string
    49  	gracePeriod     time.Duration
    50  	maxDelay        time.Duration
    51  	maxSize         int
    52  	groupAggregator GroupAggregator
    53  }
    55  const (
    56  	// Maximum number of aggregation checks in flight concurrently.
    57  	maxConcurrentAggregationChecks = 3
    59  	// Default interval used for aggregation checks. If the provided gracePeriod is less than
    60  	// the default, use the gracePeriod.
    61  	defaultAggregationCheckInterval = 7 * time.Second
    62  )
    64  func newAggregator(params aggregatorParams) *aggregator {
    65  	interval := defaultAggregationCheckInterval
    66  	if params.gracePeriod < interval {
    67  		interval = params.gracePeriod
    68  	}
    69  	return &aggregator{
    70  		logger:      params.logger,
    71  		broker:,
    72  		client:      &Client{broker:},
    73  		done:        make(chan struct{}),
    74  		queues:      params.queues,
    75  		gracePeriod: params.gracePeriod,
    76  		maxDelay:    params.maxDelay,
    77  		maxSize:     params.maxSize,
    78  		ga:          params.groupAggregator,
    79  		sema:        make(chan struct{}, maxConcurrentAggregationChecks),
    80  		interval:    interval,
    81  	}
    82  }
    84  func (a *aggregator) shutdown() {
    85  	if == nil {
    86  		return
    87  	}
    88  	a.logger.Debug("[Common] asynq aggregator shutting down...")
    89  	// Signal the aggregator goroutine to stop.
    90  	a.done <- struct{}{}
    91  }
    93  func (a *aggregator) start(wg *sync.WaitGroup) {
    94  	if == nil {
    95  		return
    96  	}
    97  	wg.Add(1)
    98  	go func() {
    99  		defer wg.Done()
   100  		ticker := time.NewTicker(a.interval)
   101  		for {
   102  			select {
   103  			case <-a.done:
   104  				a.logger.Debug("[Common] asynq waiting for all aggregation checks to finish...")
   105  				// block until all aggregation checks released the token
   106  				for i := 0; i < cap(a.sema); i++ {
   107  					a.sema <- struct{}{}
   108  				}
   109  				a.logger.Debug("[Common] asynq aggregator done")
   110  				ticker.Stop()
   111  				return
   112  			case t := <-ticker.C:
   113  				a.exec(t)
   114  			}
   115  		}
   116  	}()
   117  }
   119  func (a *aggregator) exec(t time.Time) {
   120  	select {
   121  	case a.sema <- struct{}{}: // acquire token
   122  		go a.aggregate(t)
   123  	default:
   124  		// If the semaphore blocks, then we are currently running max number of
   125  		// aggregation checks. Skip this round and log warning.
   126  		a.logger.Warnf("[Common] asynq max number of aggregation checks in flight. Skipping")
   127  	}
   128  }
   130  func (a *aggregator) aggregate(t time.Time) {
   131  	defer func() { <-a.sema /* release token */ }()
   132  	for _, qname := range a.queues {
   133  		groups, err :=
   134  		if err != nil {
   135  			a.logger.Errorf("[Common] asynq failed to list groups in queue: %q", qname)
   136  			continue
   137  		}
   138  		for _, gname := range groups {
   139  			aggregationSetID, err :=
   140  				qname, gname, t, a.gracePeriod, a.maxDelay, a.maxSize)
   141  			if err != nil {
   142  				a.logger.Errorf("[Common] asynq failed to run aggregation check: queue=%q group=%q", qname, gname)
   143  				continue
   144  			}
   145  			if aggregationSetID == "" {
   146  				a.logger.Debugf("[Common] asynq no aggregation needed at this time: queue=%q group=%q", qname, gname)
   147  				continue
   148  			}
   150  			// Aggregate and enqueue.
   151  			msgs, deadline, err :=, gname, aggregationSetID)
   152  			if err != nil {
   153  				a.logger.Errorf("[Common] asynq failed to read aggregation set: queue=%q, group=%q, setID=%q",
   154  					qname, gname, aggregationSetID)
   155  				continue
   156  			}
   157  			tasks := make([]*Task, len(msgs))
   158  			for i, m := range msgs {
   159  				tasks[i] = NewTask(m.Type, m.Payload)
   160  			}
   161  			aggregatedTask :=, tasks)
   162  			ctx, cancel := context.WithDeadline(context.Background(), deadline)
   163  			if _, err := a.client.EnqueueContext(ctx, aggregatedTask, Queue(qname)); err != nil {
   164  				a.logger.Errorf("[Common] asynq failed to enqueue aggregated task (queue=%q, group=%q, setID=%q): %v",
   165  					qname, gname, aggregationSetID, err)
   166  				cancel()
   167  				continue
   168  			}
   169  			if err :=, qname, gname, aggregationSetID); err != nil {
   170  				a.logger.Warnf("[Common] asynq failed to delete aggregation set: queue=%q, group=%q, setID=%q",
   171  					qname, gname, aggregationSetID)
   172  			}
   173  			cancel()
   174  		}
   175  	}
   176  }