github.com/wfusion/gofusion@v1.1.14/common/infra/asynq/aggregator.go (about) 1 // Copyright 2022 Kentaro Hibino. All rights reserved. 2 // Use of this source code is governed by a MIT license 3 // that can be found in the LICENSE file. 4 5 package asynq 6 7 import ( 8 "context" 9 "sync" 10 "time" 11 12 "github.com/wfusion/gofusion/common/infra/asynq/pkg/base" 13 "github.com/wfusion/gofusion/common/infra/asynq/pkg/log" 14 ) 15 16 // An aggregator is responsible for checking groups and aggregate into one task 17 // if any of the grouping condition is met. 18 type aggregator struct { 19 logger *log.Logger 20 broker base.Broker 21 client *Client 22 23 // channel to communicate back to the long running "aggregator" goroutine. 24 done chan struct{} 25 26 // list of queue names to check and aggregate. 27 queues []string 28 29 // Group configurations 30 gracePeriod time.Duration 31 maxDelay time.Duration 32 maxSize int 33 34 // User provided group aggregator. 35 ga GroupAggregator 36 37 // interval used to check for aggregation 38 interval time.Duration 39 40 // sema is a counting semaphore to ensure the number of active aggregating function 41 // does not exceed the limit. 42 sema chan struct{} 43 } 44 45 type aggregatorParams struct { 46 logger *log.Logger 47 broker base.Broker 48 queues []string 49 gracePeriod time.Duration 50 maxDelay time.Duration 51 maxSize int 52 groupAggregator GroupAggregator 53 } 54 55 const ( 56 // Maximum number of aggregation checks in flight concurrently. 57 maxConcurrentAggregationChecks = 3 58 59 // Default interval used for aggregation checks. If the provided gracePeriod is less than 60 // the default, use the gracePeriod. 61 defaultAggregationCheckInterval = 7 * time.Second 62 ) 63 64 func newAggregator(params aggregatorParams) *aggregator { 65 interval := defaultAggregationCheckInterval 66 if params.gracePeriod < interval { 67 interval = params.gracePeriod 68 } 69 return &aggregator{ 70 logger: params.logger, 71 broker: params.broker, 72 client: &Client{broker: params.broker}, 73 done: make(chan struct{}), 74 queues: params.queues, 75 gracePeriod: params.gracePeriod, 76 maxDelay: params.maxDelay, 77 maxSize: params.maxSize, 78 ga: params.groupAggregator, 79 sema: make(chan struct{}, maxConcurrentAggregationChecks), 80 interval: interval, 81 } 82 } 83 84 func (a *aggregator) shutdown() { 85 if a.ga == nil { 86 return 87 } 88 a.logger.Debug("[Common] asynq aggregator shutting down...") 89 // Signal the aggregator goroutine to stop. 90 a.done <- struct{}{} 91 } 92 93 func (a *aggregator) start(wg *sync.WaitGroup) { 94 if a.ga == nil { 95 return 96 } 97 wg.Add(1) 98 go func() { 99 defer wg.Done() 100 ticker := time.NewTicker(a.interval) 101 for { 102 select { 103 case <-a.done: 104 a.logger.Debug("[Common] asynq waiting for all aggregation checks to finish...") 105 // block until all aggregation checks released the token 106 for i := 0; i < cap(a.sema); i++ { 107 a.sema <- struct{}{} 108 } 109 a.logger.Debug("[Common] asynq aggregator done") 110 ticker.Stop() 111 return 112 case t := <-ticker.C: 113 a.exec(t) 114 } 115 } 116 }() 117 } 118 119 func (a *aggregator) exec(t time.Time) { 120 select { 121 case a.sema <- struct{}{}: // acquire token 122 go a.aggregate(t) 123 default: 124 // If the semaphore blocks, then we are currently running max number of 125 // aggregation checks. Skip this round and log warning. 126 a.logger.Warnf("[Common] asynq max number of aggregation checks in flight. Skipping") 127 } 128 } 129 130 func (a *aggregator) aggregate(t time.Time) { 131 defer func() { <-a.sema /* release token */ }() 132 for _, qname := range a.queues { 133 groups, err := a.broker.ListGroups(qname) 134 if err != nil { 135 a.logger.Errorf("[Common] asynq failed to list groups in queue: %q", qname) 136 continue 137 } 138 for _, gname := range groups { 139 aggregationSetID, err := a.broker.AggregationCheck( 140 qname, gname, t, a.gracePeriod, a.maxDelay, a.maxSize) 141 if err != nil { 142 a.logger.Errorf("[Common] asynq failed to run aggregation check: queue=%q group=%q", qname, gname) 143 continue 144 } 145 if aggregationSetID == "" { 146 a.logger.Debugf("[Common] asynq no aggregation needed at this time: queue=%q group=%q", qname, gname) 147 continue 148 } 149 150 // Aggregate and enqueue. 151 msgs, deadline, err := a.broker.ReadAggregationSet(qname, gname, aggregationSetID) 152 if err != nil { 153 a.logger.Errorf("[Common] asynq failed to read aggregation set: queue=%q, group=%q, setID=%q", 154 qname, gname, aggregationSetID) 155 continue 156 } 157 tasks := make([]*Task, len(msgs)) 158 for i, m := range msgs { 159 tasks[i] = NewTask(m.Type, m.Payload) 160 } 161 aggregatedTask := a.ga.Aggregate(gname, tasks) 162 ctx, cancel := context.WithDeadline(context.Background(), deadline) 163 if _, err := a.client.EnqueueContext(ctx, aggregatedTask, Queue(qname)); err != nil { 164 a.logger.Errorf("[Common] asynq failed to enqueue aggregated task (queue=%q, group=%q, setID=%q): %v", 165 qname, gname, aggregationSetID, err) 166 cancel() 167 continue 168 } 169 if err := a.broker.DeleteAggregationSet(ctx, qname, gname, aggregationSetID); err != nil { 170 a.logger.Warnf("[Common] asynq failed to delete aggregation set: queue=%q, group=%q, setID=%q", 171 qname, gname, aggregationSetID) 172 } 173 cancel() 174 } 175 } 176 }