github.com/bingoohuang/gg@v0.0.0-20240325092523-45da7dee9335/pkg/timex/timingwheel.go (about)

     1  package timex
     2  
     3  import (
     4  	"container/list"
     5  	"errors"
     6  	"fmt"
     7  	"log"
     8  	"runtime/debug"
     9  	"time"
    10  
    11  	"github.com/bingoohuang/gg/pkg/mapp"
    12  )
    13  
    14  const drainWorkers = 8
    15  
    16  var (
    17  	ErrClosed   = errors.New("TimingWheel is closed already")
    18  	ErrArgument = errors.New("incorrect task argument")
    19  )
    20  
    21  type (
    22  	// Execute defines the method to execute the task.
    23  	Execute func(key, value interface{})
    24  
    25  	// A TimingWheel is a timing wheel object to schedule tasks.
    26  	TimingWheel struct {
    27  		interval      time.Duration
    28  		ticker        *time.Ticker
    29  		slots         []*list.List
    30  		timers        *mapp.SafeMap
    31  		tickedPos     int
    32  		numSlots      int
    33  		execute       Execute
    34  		setChannel    chan timingEntry
    35  		moveChannel   chan baseEntry
    36  		removeChannel chan interface{}
    37  		drainChannel  chan func(key, value interface{})
    38  		stopChannel   chan struct{}
    39  	}
    40  
    41  	timingEntry struct {
    42  		baseEntry
    43  		value   interface{}
    44  		circle  int
    45  		diff    int
    46  		removed bool
    47  	}
    48  
    49  	baseEntry struct {
    50  		delay time.Duration
    51  		key   interface{}
    52  	}
    53  
    54  	positionEntry struct {
    55  		pos  int
    56  		item *timingEntry
    57  	}
    58  
    59  	timingTask struct {
    60  		key   interface{}
    61  		value interface{}
    62  	}
    63  )
    64  
    65  // NewTimingWheel returns a TimingWheel.
    66  func NewTimingWheel(interval time.Duration, numSlots int, execute Execute) (*TimingWheel, error) {
    67  	if interval <= 0 || numSlots <= 0 || execute == nil {
    68  		return nil, fmt.Errorf("interval: %v, slots: %d, execute: %p",
    69  			interval, numSlots, execute)
    70  	}
    71  
    72  	return newTimingWheelWithClock(interval, numSlots, execute, time.NewTicker(interval))
    73  }
    74  
    75  func newTimingWheelWithClock(interval time.Duration, numSlots int, execute Execute,
    76  	ticker *time.Ticker,
    77  ) (*TimingWheel, error) {
    78  	tw := &TimingWheel{
    79  		interval:      interval,
    80  		ticker:        ticker,
    81  		slots:         make([]*list.List, numSlots),
    82  		timers:        mapp.NewSafeMap(),
    83  		tickedPos:     numSlots - 1, // at previous virtual circle
    84  		execute:       execute,
    85  		numSlots:      numSlots,
    86  		setChannel:    make(chan timingEntry),
    87  		moveChannel:   make(chan baseEntry),
    88  		removeChannel: make(chan interface{}),
    89  		drainChannel:  make(chan func(key, value interface{})),
    90  		stopChannel:   make(chan struct{}),
    91  	}
    92  
    93  	tw.initSlots()
    94  	go tw.run()
    95  
    96  	return tw, nil
    97  }
    98  
    99  // Drain drains all items and executes them.
   100  func (tw *TimingWheel) Drain(fn func(key, value interface{})) error {
   101  	select {
   102  	case tw.drainChannel <- fn:
   103  		return nil
   104  	case <-tw.stopChannel:
   105  		return ErrClosed
   106  	}
   107  }
   108  
   109  // MoveTimer moves the task with the given key to the given delay.
   110  func (tw *TimingWheel) MoveTimer(key interface{}, delay time.Duration) error {
   111  	if delay <= 0 || key == nil {
   112  		return ErrArgument
   113  	}
   114  
   115  	select {
   116  	case tw.moveChannel <- baseEntry{
   117  		delay: delay,
   118  		key:   key,
   119  	}:
   120  		return nil
   121  	case <-tw.stopChannel:
   122  		return ErrClosed
   123  	}
   124  }
   125  
   126  // RemoveTimer removes the task with the given key.
   127  func (tw *TimingWheel) RemoveTimer(key interface{}) error {
   128  	if key == nil {
   129  		return ErrArgument
   130  	}
   131  
   132  	select {
   133  	case tw.removeChannel <- key:
   134  		return nil
   135  	case <-tw.stopChannel:
   136  		return ErrClosed
   137  	}
   138  }
   139  
   140  // SetTimer sets the task value with the given key to the delay.
   141  func (tw *TimingWheel) SetTimer(key, value interface{}, delay time.Duration) error {
   142  	if delay <= 0 || key == nil {
   143  		return ErrArgument
   144  	}
   145  
   146  	select {
   147  	case tw.setChannel <- timingEntry{
   148  		baseEntry: baseEntry{
   149  			delay: delay,
   150  			key:   key,
   151  		},
   152  		value: value,
   153  	}:
   154  		return nil
   155  	case <-tw.stopChannel:
   156  		return ErrClosed
   157  	}
   158  }
   159  
   160  // Stop stops tw. No more actions after stopping a TimingWheel.
   161  func (tw *TimingWheel) Stop() {
   162  	close(tw.stopChannel)
   163  }
   164  
   165  // A TaskRunner is used to control the concurrency of goroutines.
   166  type TaskRunner struct {
   167  	limitChan chan struct{}
   168  }
   169  
   170  // NewTaskRunner returns a TaskRunner.
   171  func NewTaskRunner(concurrency int) *TaskRunner {
   172  	return &TaskRunner{
   173  		limitChan: make(chan struct{}, concurrency),
   174  	}
   175  }
   176  
   177  // Recover is used with defer to do cleanup on panics.
   178  // Use it like:
   179  //
   180  //	defer Recover(func() {})
   181  func Recover(cleanups ...func()) {
   182  	for _, cleanup := range cleanups {
   183  		cleanup()
   184  	}
   185  
   186  	if p := recover(); p != nil {
   187  		log.Printf("Panic recovered from %+v, stack: %s", p, debug.Stack())
   188  	}
   189  }
   190  
   191  // GoSafe runs the given fn using another goroutine, recovers if fn panics.
   192  func GoSafe(fn func()) {
   193  	go RunSafe(fn)
   194  }
   195  
   196  // RunSafe runs the given fn, recovers if fn panics.
   197  func RunSafe(fn func()) {
   198  	defer Recover()
   199  
   200  	fn()
   201  }
   202  
   203  // Schedule schedules a task to run under concurrency control.
   204  func (rp *TaskRunner) Schedule(task func()) {
   205  	rp.limitChan <- struct{}{}
   206  
   207  	go func() {
   208  		defer Recover(func() {
   209  			<-rp.limitChan
   210  		})
   211  
   212  		task()
   213  	}()
   214  }
   215  
   216  func (tw *TimingWheel) drainAll(fn func(key, value interface{})) {
   217  	runner := NewTaskRunner(drainWorkers)
   218  	for _, slot := range tw.slots {
   219  		for e := slot.Front(); e != nil; {
   220  			task := e.Value.(*timingEntry)
   221  			next := e.Next()
   222  			slot.Remove(e)
   223  			e = next
   224  			if !task.removed {
   225  				runner.Schedule(func() {
   226  					fn(task.key, task.value)
   227  				})
   228  			}
   229  		}
   230  	}
   231  }
   232  
   233  func (tw *TimingWheel) getPositionAndCircle(d time.Duration) (pos, circle int) {
   234  	steps := int(d / tw.interval)
   235  	pos = (tw.tickedPos + steps) % tw.numSlots
   236  	circle = (steps - 1) / tw.numSlots
   237  
   238  	return
   239  }
   240  
   241  func (tw *TimingWheel) initSlots() {
   242  	for i := 0; i < tw.numSlots; i++ {
   243  		tw.slots[i] = list.New()
   244  	}
   245  }
   246  
   247  func (tw *TimingWheel) moveTask(task baseEntry) {
   248  	val, ok := tw.timers.Get(task.key)
   249  	if !ok {
   250  		return
   251  	}
   252  
   253  	timer := val.(*positionEntry)
   254  	if task.delay < tw.interval {
   255  		GoSafe(func() {
   256  			tw.execute(timer.item.key, timer.item.value)
   257  		})
   258  		return
   259  	}
   260  
   261  	pos, circle := tw.getPositionAndCircle(task.delay)
   262  	if pos >= timer.pos {
   263  		timer.item.circle = circle
   264  		timer.item.diff = pos - timer.pos
   265  	} else if circle > 0 {
   266  		circle--
   267  		timer.item.circle = circle
   268  		timer.item.diff = tw.numSlots + pos - timer.pos
   269  	} else {
   270  		timer.item.removed = true
   271  		newItem := &timingEntry{
   272  			baseEntry: task,
   273  			value:     timer.item.value,
   274  		}
   275  		tw.slots[pos].PushBack(newItem)
   276  		tw.setTimerPosition(pos, newItem)
   277  	}
   278  }
   279  
   280  func (tw *TimingWheel) onTick() {
   281  	tw.tickedPos = (tw.tickedPos + 1) % tw.numSlots
   282  	l := tw.slots[tw.tickedPos]
   283  	tw.scanAndRunTasks(l)
   284  }
   285  
   286  func (tw *TimingWheel) removeTask(key interface{}) {
   287  	val, ok := tw.timers.Get(key)
   288  	if !ok {
   289  		return
   290  	}
   291  
   292  	timer := val.(*positionEntry)
   293  	timer.item.removed = true
   294  	tw.timers.Del(key)
   295  }
   296  
   297  func (tw *TimingWheel) run() {
   298  	for {
   299  		select {
   300  		case <-tw.ticker.C:
   301  			tw.onTick()
   302  		case task := <-tw.setChannel:
   303  			tw.setTask(&task)
   304  		case key := <-tw.removeChannel:
   305  			tw.removeTask(key)
   306  		case task := <-tw.moveChannel:
   307  			tw.moveTask(task)
   308  		case fn := <-tw.drainChannel:
   309  			tw.drainAll(fn)
   310  		case <-tw.stopChannel:
   311  			tw.ticker.Stop()
   312  			return
   313  		}
   314  	}
   315  }
   316  
   317  func (tw *TimingWheel) runTasks(tasks []timingTask) {
   318  	if len(tasks) == 0 {
   319  		return
   320  	}
   321  
   322  	go func() {
   323  		for i := range tasks {
   324  			RunSafe(func() {
   325  				tw.execute(tasks[i].key, tasks[i].value)
   326  			})
   327  		}
   328  	}()
   329  }
   330  
   331  func (tw *TimingWheel) scanAndRunTasks(l *list.List) {
   332  	var tasks []timingTask
   333  
   334  	for e := l.Front(); e != nil; {
   335  		task := e.Value.(*timingEntry)
   336  		if task.removed {
   337  			next := e.Next()
   338  			l.Remove(e)
   339  			e = next
   340  			continue
   341  		} else if task.circle > 0 {
   342  			task.circle--
   343  			e = e.Next()
   344  			continue
   345  		} else if task.diff > 0 {
   346  			next := e.Next()
   347  			l.Remove(e)
   348  			// (tw.tickedPos+task.diff)%tw.numSlots
   349  			// cannot be the same value of tw.tickedPos
   350  			pos := (tw.tickedPos + task.diff) % tw.numSlots
   351  			tw.slots[pos].PushBack(task)
   352  			tw.setTimerPosition(pos, task)
   353  			task.diff = 0
   354  			e = next
   355  			continue
   356  		}
   357  
   358  		tasks = append(tasks, timingTask{
   359  			key:   task.key,
   360  			value: task.value,
   361  		})
   362  		next := e.Next()
   363  		l.Remove(e)
   364  		tw.timers.Del(task.key)
   365  		e = next
   366  	}
   367  
   368  	tw.runTasks(tasks)
   369  }
   370  
   371  func (tw *TimingWheel) setTask(task *timingEntry) {
   372  	if task.delay < tw.interval {
   373  		task.delay = tw.interval
   374  	}
   375  
   376  	if val, ok := tw.timers.Get(task.key); ok {
   377  		entry := val.(*positionEntry)
   378  		entry.item.value = task.value
   379  		tw.moveTask(task.baseEntry)
   380  	} else {
   381  		pos, circle := tw.getPositionAndCircle(task.delay)
   382  		task.circle = circle
   383  		tw.slots[pos].PushBack(task)
   384  		tw.setTimerPosition(pos, task)
   385  	}
   386  }
   387  
   388  func (tw *TimingWheel) setTimerPosition(pos int, task *timingEntry) {
   389  	if val, ok := tw.timers.Get(task.key); ok {
   390  		timer := val.(*positionEntry)
   391  		timer.item = task
   392  		timer.pos = pos
   393  	} else {
   394  		tw.timers.Set(task.key, &positionEntry{
   395  			pos:  pos,
   396  			item: task,
   397  		})
   398  	}
   399  }