github.com/benz9527/toy-box/algo@v0.0.0-20240221120937-66c0c6bd5abd/queue/delay_queue.go (about)

     1  package queue
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"log/slog"
     7  	"sync"
     8  	"sync/atomic"
     9  	"time"
    10  )
    11  
    12  var (
    13  	errEmptyContext = errors.New("empty context")
    14  )
    15  
    16  type dqItem[E comparable] struct {
    17  	PQItem[E]
    18  }
    19  
    20  func NewDQItem[E comparable](value E, expiration int64) DQItem[E] {
    21  	return &dqItem[E]{
    22  		PQItem: NewPQItem[E](value, expiration),
    23  	}
    24  }
    25  
    26  func (d *dqItem[E]) GetPQItem() PQItem[E] {
    27  	return d
    28  }
    29  
    30  func (d *dqItem[E]) GetExpiration() int64 {
    31  	return d.GetPriority()
    32  }
    33  
    34  type sleepEnum = int32
    35  
    36  const (
    37  	wakeUp sleepEnum = iota
    38  	fallAsleep
    39  )
    40  
    41  // size: 56
    42  type arrayDQ[E comparable] struct {
    43  	pq               PriorityQueue[E] // alignment size: 8; size: 16
    44  	wakeUpC          chan struct{}    // alignment size: 8; size: 8
    45  	waitNextExpiredC chan struct{}    // alignment size: 8; size: 8
    46  	lock             *sync.RWMutex    // alignment size: 8; size: 8
    47  	mu               *sync.Mutex      // alignment size: 8; size: 8
    48  	sleeping         int32            // alignment size: 4; size: 4
    49  }
    50  
    51  func NewArrayDelayQueue[E comparable](
    52  	capacity int,
    53  	comparator ...LessThan[E],
    54  ) DelayQueue[E] {
    55  	if capacity <= 0 {
    56  		capacity = 32
    57  	}
    58  	if len(comparator) <= 0 {
    59  		comparator = []LessThan[E]{
    60  			func(i, j PQItem[E]) bool {
    61  				return i.GetPriority() < j.GetPriority()
    62  			},
    63  		}
    64  	}
    65  	return &arrayDQ[E]{
    66  		wakeUpC:          make(chan struct{}, 1),
    67  		waitNextExpiredC: make(chan struct{}, 1),
    68  		pq:               NewArrayPriorityQueue[E](capacity, comparator[0]),
    69  		mu:               &sync.Mutex{},
    70  		lock:             &sync.RWMutex{},
    71  	}
    72  }
    73  
    74  func (dq *arrayDQ[E]) popIfExpired(expiredBoundary int64) (item PQItem[E], deltaMs int64) {
    75  	if (*dq).pq.Len() == 0 {
    76  		return nil, 0
    77  	}
    78  
    79  	item = (*dq).pq.Peek()
    80  	expiration := item.(DQItem[E]).GetExpiration()
    81  	if expiration > expiredBoundary {
    82  		// not matched
    83  		return nil, expiration - expiredBoundary
    84  	}
    85  	item = (*dq).pq.Pop()
    86  	return item, 0
    87  }
    88  
    89  func (dq *arrayDQ[E]) Offer(item E, expiration int64) error {
    90  	e := NewDQItem[E](item, expiration)
    91  	dq.lock.Lock()
    92  	dq.pq.Push(e.GetPQItem())
    93  	dq.lock.Unlock()
    94  
    95  	if e.GetPQItem().GetIndex() == 0 {
    96  		// Highest priority item, wake up the consumer
    97  		if atomic.CompareAndSwapInt32(&dq.sleeping, fallAsleep, wakeUp) {
    98  			dq.wakeUpC <- struct{}{}
    99  		}
   100  	}
   101  	return nil
   102  }
   103  
   104  func (dq *arrayDQ[E]) poll(ctx context.Context, nowFn func() int64, sender chan<- E, closeChAfterFinish bool) {
   105  	dq.mu.Lock() // Avoid concurrent execution of Poll()
   106  	var timer *time.Timer
   107  	defer func() {
   108  		// FIXME recover defer execution order
   109  		if err := recover(); err != nil {
   110  			slog.Error("delay queue panic recover", "error", err)
   111  		}
   112  		// before exit
   113  		atomic.StoreInt32(&dq.sleeping, wakeUp)
   114  		if closeChAfterFinish && sender != nil {
   115  			close(sender)
   116  		}
   117  		dq.mu.Unlock()
   118  		if timer != nil {
   119  			timer.Stop()
   120  			timer = nil
   121  		}
   122  	}()
   123  	for {
   124  		now := nowFn()
   125  		dq.lock.RLock() // Concurrency control, avoid long time lock, block Offer()
   126  		item, deltaMs := dq.popIfExpired(now)
   127  		if item == nil {
   128  			// No expired item in the queue
   129  			// 1. without any item in the queue
   130  			// 2. all items in the queue are not expired
   131  			atomic.StoreInt32(&dq.sleeping, fallAsleep)
   132  		}
   133  		dq.lock.RUnlock()
   134  		if item == nil && deltaMs > 0 {
   135  			if timer != nil {
   136  				timer.Stop()
   137  			}
   138  			// Avoid to use time.After(), it will create a new timer every time
   139  			// what's worse the underlay timer will not be GC.
   140  			timer = time.AfterFunc(time.Duration(deltaMs)*time.Millisecond, func() {
   141  				if atomic.SwapInt32(&dq.sleeping, wakeUp) == fallAsleep {
   142  					dq.waitNextExpiredC <- struct{}{}
   143  				}
   144  			})
   145  		}
   146  
   147  		if item == nil {
   148  			if deltaMs == 0 {
   149  				// Queue is empty, waiting for new item
   150  				select {
   151  				case <-ctx.Done():
   152  					return
   153  				case <-dq.wakeUpC:
   154  					// Waiting for an immediately executed item
   155  					continue
   156  				}
   157  			} else if deltaMs > 0 {
   158  				select {
   159  				case <-ctx.Done():
   160  					return
   161  				case <-dq.wakeUpC:
   162  					continue
   163  				case <-dq.waitNextExpiredC:
   164  					// Waiting for this item to be expired
   165  					if timer != nil {
   166  						timer.Stop()
   167  						timer = nil
   168  					}
   169  					continue
   170  				}
   171  			}
   172  		}
   173  
   174  		// Wakeup, stop wait next expired timer
   175  		if timer != nil {
   176  			timer.Stop()
   177  			timer = nil
   178  		}
   179  
   180  		select {
   181  		case <-ctx.Done():
   182  			return
   183  		case sender <- item.GetValue():
   184  			// Waiting for the consumer to consume this item
   185  			// If external channel is closed, here will be panic
   186  		}
   187  	}
   188  }
   189  
   190  func (dq *arrayDQ[E]) Poll(ctx context.Context, nowFn func() int64) (<-chan E, error) {
   191  	if ctx == nil {
   192  		return nil, errEmptyContext
   193  	}
   194  	resultC := make(chan E)
   195  	// FIXME using goroutine pool
   196  	go dq.poll(ctx, nowFn, resultC, true)
   197  	return resultC, nil
   198  }
   199  
   200  func (dq *arrayDQ[E]) PollToChannel(ctx context.Context, nowFn func() int64, C chan<- E) error {
   201  	if ctx == nil {
   202  		return errEmptyContext
   203  	}
   204  	dq.poll(ctx, nowFn, C, false)
   205  	return nil
   206  }