github.com/lingyao2333/mo-zero@v1.4.1/core/collection/timingwheel.go (about)

     1  package collection
     2  
     3  import (
     4  	"container/list"
     5  	"errors"
     6  	"fmt"
     7  	"time"
     8  
     9  	"github.com/lingyao2333/mo-zero/core/lang"
    10  	"github.com/lingyao2333/mo-zero/core/threading"
    11  	"github.com/lingyao2333/mo-zero/core/timex"
    12  )
    13  
    14  const drainWorkers = 8
    15  
    16  var (
    17  	ErrClosed   = errors.New("TimingWheel is closed already")
    18  	ErrArgument = errors.New("incorrect task argument")
    19  )
    20  
    21  type (
    22  	// Execute defines the method to execute the task.
    23  	Execute func(key, value interface{})
    24  
    25  	// A TimingWheel is a timing wheel object to schedule tasks.
    26  	TimingWheel struct {
    27  		interval      time.Duration
    28  		ticker        timex.Ticker
    29  		slots         []*list.List
    30  		timers        *SafeMap
    31  		tickedPos     int
    32  		numSlots      int
    33  		execute       Execute
    34  		setChannel    chan timingEntry
    35  		moveChannel   chan baseEntry
    36  		removeChannel chan interface{}
    37  		drainChannel  chan func(key, value interface{})
    38  		stopChannel   chan lang.PlaceholderType
    39  	}
    40  
    41  	timingEntry struct {
    42  		baseEntry
    43  		value   interface{}
    44  		circle  int
    45  		diff    int
    46  		removed bool
    47  	}
    48  
    49  	baseEntry struct {
    50  		delay time.Duration
    51  		key   interface{}
    52  	}
    53  
    54  	positionEntry struct {
    55  		pos  int
    56  		item *timingEntry
    57  	}
    58  
    59  	timingTask struct {
    60  		key   interface{}
    61  		value interface{}
    62  	}
    63  )
    64  
    65  // NewTimingWheel returns a TimingWheel.
    66  func NewTimingWheel(interval time.Duration, numSlots int, execute Execute) (*TimingWheel, error) {
    67  	if interval <= 0 || numSlots <= 0 || execute == nil {
    68  		return nil, fmt.Errorf("interval: %v, slots: %d, execute: %p",
    69  			interval, numSlots, execute)
    70  	}
    71  
    72  	return newTimingWheelWithClock(interval, numSlots, execute, timex.NewTicker(interval))
    73  }
    74  
    75  func newTimingWheelWithClock(interval time.Duration, numSlots int, execute Execute,
    76  	ticker timex.Ticker) (*TimingWheel, error) {
    77  	tw := &TimingWheel{
    78  		interval:      interval,
    79  		ticker:        ticker,
    80  		slots:         make([]*list.List, numSlots),
    81  		timers:        NewSafeMap(),
    82  		tickedPos:     numSlots - 1, // at previous virtual circle
    83  		execute:       execute,
    84  		numSlots:      numSlots,
    85  		setChannel:    make(chan timingEntry),
    86  		moveChannel:   make(chan baseEntry),
    87  		removeChannel: make(chan interface{}),
    88  		drainChannel:  make(chan func(key, value interface{})),
    89  		stopChannel:   make(chan lang.PlaceholderType),
    90  	}
    91  
    92  	tw.initSlots()
    93  	go tw.run()
    94  
    95  	return tw, nil
    96  }
    97  
    98  // Drain drains all items and executes them.
    99  func (tw *TimingWheel) Drain(fn func(key, value interface{})) error {
   100  	select {
   101  	case tw.drainChannel <- fn:
   102  		return nil
   103  	case <-tw.stopChannel:
   104  		return ErrClosed
   105  	}
   106  }
   107  
   108  // MoveTimer moves the task with the given key to the given delay.
   109  func (tw *TimingWheel) MoveTimer(key interface{}, delay time.Duration) error {
   110  	if delay <= 0 || key == nil {
   111  		return ErrArgument
   112  	}
   113  
   114  	select {
   115  	case tw.moveChannel <- baseEntry{
   116  		delay: delay,
   117  		key:   key,
   118  	}:
   119  		return nil
   120  	case <-tw.stopChannel:
   121  		return ErrClosed
   122  	}
   123  }
   124  
   125  // RemoveTimer removes the task with the given key.
   126  func (tw *TimingWheel) RemoveTimer(key interface{}) error {
   127  	if key == nil {
   128  		return ErrArgument
   129  	}
   130  
   131  	select {
   132  	case tw.removeChannel <- key:
   133  		return nil
   134  	case <-tw.stopChannel:
   135  		return ErrClosed
   136  	}
   137  }
   138  
   139  // SetTimer sets the task value with the given key to the delay.
   140  func (tw *TimingWheel) SetTimer(key, value interface{}, delay time.Duration) error {
   141  	if delay <= 0 || key == nil {
   142  		return ErrArgument
   143  	}
   144  
   145  	select {
   146  	case tw.setChannel <- timingEntry{
   147  		baseEntry: baseEntry{
   148  			delay: delay,
   149  			key:   key,
   150  		},
   151  		value: value,
   152  	}:
   153  		return nil
   154  	case <-tw.stopChannel:
   155  		return ErrClosed
   156  	}
   157  }
   158  
   159  // Stop stops tw. No more actions after stopping a TimingWheel.
   160  func (tw *TimingWheel) Stop() {
   161  	close(tw.stopChannel)
   162  }
   163  
   164  func (tw *TimingWheel) drainAll(fn func(key, value interface{})) {
   165  	runner := threading.NewTaskRunner(drainWorkers)
   166  	for _, slot := range tw.slots {
   167  		for e := slot.Front(); e != nil; {
   168  			task := e.Value.(*timingEntry)
   169  			next := e.Next()
   170  			slot.Remove(e)
   171  			e = next
   172  			if !task.removed {
   173  				runner.Schedule(func() {
   174  					fn(task.key, task.value)
   175  				})
   176  			}
   177  		}
   178  	}
   179  }
   180  
   181  func (tw *TimingWheel) getPositionAndCircle(d time.Duration) (pos, circle int) {
   182  	steps := int(d / tw.interval)
   183  	pos = (tw.tickedPos + steps) % tw.numSlots
   184  	circle = (steps - 1) / tw.numSlots
   185  
   186  	return
   187  }
   188  
   189  func (tw *TimingWheel) initSlots() {
   190  	for i := 0; i < tw.numSlots; i++ {
   191  		tw.slots[i] = list.New()
   192  	}
   193  }
   194  
   195  func (tw *TimingWheel) moveTask(task baseEntry) {
   196  	val, ok := tw.timers.Get(task.key)
   197  	if !ok {
   198  		return
   199  	}
   200  
   201  	timer := val.(*positionEntry)
   202  	if task.delay < tw.interval {
   203  		threading.GoSafe(func() {
   204  			tw.execute(timer.item.key, timer.item.value)
   205  		})
   206  		return
   207  	}
   208  
   209  	pos, circle := tw.getPositionAndCircle(task.delay)
   210  	if pos >= timer.pos {
   211  		timer.item.circle = circle
   212  		timer.item.diff = pos - timer.pos
   213  	} else if circle > 0 {
   214  		circle--
   215  		timer.item.circle = circle
   216  		timer.item.diff = tw.numSlots + pos - timer.pos
   217  	} else {
   218  		timer.item.removed = true
   219  		newItem := &timingEntry{
   220  			baseEntry: task,
   221  			value:     timer.item.value,
   222  		}
   223  		tw.slots[pos].PushBack(newItem)
   224  		tw.setTimerPosition(pos, newItem)
   225  	}
   226  }
   227  
   228  func (tw *TimingWheel) onTick() {
   229  	tw.tickedPos = (tw.tickedPos + 1) % tw.numSlots
   230  	l := tw.slots[tw.tickedPos]
   231  	tw.scanAndRunTasks(l)
   232  }
   233  
   234  func (tw *TimingWheel) removeTask(key interface{}) {
   235  	val, ok := tw.timers.Get(key)
   236  	if !ok {
   237  		return
   238  	}
   239  
   240  	timer := val.(*positionEntry)
   241  	timer.item.removed = true
   242  	tw.timers.Del(key)
   243  }
   244  
   245  func (tw *TimingWheel) run() {
   246  	for {
   247  		select {
   248  		case <-tw.ticker.Chan():
   249  			tw.onTick()
   250  		case task := <-tw.setChannel:
   251  			tw.setTask(&task)
   252  		case key := <-tw.removeChannel:
   253  			tw.removeTask(key)
   254  		case task := <-tw.moveChannel:
   255  			tw.moveTask(task)
   256  		case fn := <-tw.drainChannel:
   257  			tw.drainAll(fn)
   258  		case <-tw.stopChannel:
   259  			tw.ticker.Stop()
   260  			return
   261  		}
   262  	}
   263  }
   264  
   265  func (tw *TimingWheel) runTasks(tasks []timingTask) {
   266  	if len(tasks) == 0 {
   267  		return
   268  	}
   269  
   270  	go func() {
   271  		for i := range tasks {
   272  			threading.RunSafe(func() {
   273  				tw.execute(tasks[i].key, tasks[i].value)
   274  			})
   275  		}
   276  	}()
   277  }
   278  
   279  func (tw *TimingWheel) scanAndRunTasks(l *list.List) {
   280  	var tasks []timingTask
   281  
   282  	for e := l.Front(); e != nil; {
   283  		task := e.Value.(*timingEntry)
   284  		if task.removed {
   285  			next := e.Next()
   286  			l.Remove(e)
   287  			e = next
   288  			continue
   289  		} else if task.circle > 0 {
   290  			task.circle--
   291  			e = e.Next()
   292  			continue
   293  		} else if task.diff > 0 {
   294  			next := e.Next()
   295  			l.Remove(e)
   296  			// (tw.tickedPos+task.diff)%tw.numSlots
   297  			// cannot be the same value of tw.tickedPos
   298  			pos := (tw.tickedPos + task.diff) % tw.numSlots
   299  			tw.slots[pos].PushBack(task)
   300  			tw.setTimerPosition(pos, task)
   301  			task.diff = 0
   302  			e = next
   303  			continue
   304  		}
   305  
   306  		tasks = append(tasks, timingTask{
   307  			key:   task.key,
   308  			value: task.value,
   309  		})
   310  		next := e.Next()
   311  		l.Remove(e)
   312  		tw.timers.Del(task.key)
   313  		e = next
   314  	}
   315  
   316  	tw.runTasks(tasks)
   317  }
   318  
   319  func (tw *TimingWheel) setTask(task *timingEntry) {
   320  	if task.delay < tw.interval {
   321  		task.delay = tw.interval
   322  	}
   323  
   324  	if val, ok := tw.timers.Get(task.key); ok {
   325  		entry := val.(*positionEntry)
   326  		entry.item.value = task.value
   327  		tw.moveTask(task.baseEntry)
   328  	} else {
   329  		pos, circle := tw.getPositionAndCircle(task.delay)
   330  		task.circle = circle
   331  		tw.slots[pos].PushBack(task)
   332  		tw.setTimerPosition(pos, task)
   333  	}
   334  }
   335  
   336  func (tw *TimingWheel) setTimerPosition(pos int, task *timingEntry) {
   337  	if val, ok := tw.timers.Get(task.key); ok {
   338  		timer := val.(*positionEntry)
   339  		timer.item = task
   340  		timer.pos = pos
   341  	} else {
   342  		tw.timers.Set(task.key, &positionEntry{
   343  			pos:  pos,
   344  			item: task,
   345  		})
   346  	}
   347  }