github.com/shuguocloud/go-zero@v1.3.0/core/collection/timingwheel.go (about)

     1  package collection
     2  
     3  import (
     4  	"container/list"
     5  	"fmt"
     6  	"time"
     7  
     8  	"github.com/shuguocloud/go-zero/core/lang"
     9  	"github.com/shuguocloud/go-zero/core/threading"
    10  	"github.com/shuguocloud/go-zero/core/timex"
    11  )
    12  
    13  const drainWorkers = 8
    14  
    15  type (
    16  	// Execute defines the method to execute the task.
    17  	Execute func(key, value interface{})
    18  
    19  	// A TimingWheel is a timing wheel object to schedule tasks.
    20  	TimingWheel struct {
    21  		interval      time.Duration
    22  		ticker        timex.Ticker
    23  		slots         []*list.List
    24  		timers        *SafeMap
    25  		tickedPos     int
    26  		numSlots      int
    27  		execute       Execute
    28  		setChannel    chan timingEntry
    29  		moveChannel   chan baseEntry
    30  		removeChannel chan interface{}
    31  		drainChannel  chan func(key, value interface{})
    32  		stopChannel   chan lang.PlaceholderType
    33  	}
    34  
    35  	timingEntry struct {
    36  		baseEntry
    37  		value   interface{}
    38  		circle  int
    39  		diff    int
    40  		removed bool
    41  	}
    42  
    43  	baseEntry struct {
    44  		delay time.Duration
    45  		key   interface{}
    46  	}
    47  
    48  	positionEntry struct {
    49  		pos  int
    50  		item *timingEntry
    51  	}
    52  
    53  	timingTask struct {
    54  		key   interface{}
    55  		value interface{}
    56  	}
    57  )
    58  
    59  // NewTimingWheel returns a TimingWheel.
    60  func NewTimingWheel(interval time.Duration, numSlots int, execute Execute) (*TimingWheel, error) {
    61  	if interval <= 0 || numSlots <= 0 || execute == nil {
    62  		return nil, fmt.Errorf("interval: %v, slots: %d, execute: %p", interval, numSlots, execute)
    63  	}
    64  
    65  	return newTimingWheelWithClock(interval, numSlots, execute, timex.NewTicker(interval))
    66  }
    67  
    68  func newTimingWheelWithClock(interval time.Duration, numSlots int, execute Execute, ticker timex.Ticker) (
    69  	*TimingWheel, error) {
    70  	tw := &TimingWheel{
    71  		interval:      interval,
    72  		ticker:        ticker,
    73  		slots:         make([]*list.List, numSlots),
    74  		timers:        NewSafeMap(),
    75  		tickedPos:     numSlots - 1, // at previous virtual circle
    76  		execute:       execute,
    77  		numSlots:      numSlots,
    78  		setChannel:    make(chan timingEntry),
    79  		moveChannel:   make(chan baseEntry),
    80  		removeChannel: make(chan interface{}),
    81  		drainChannel:  make(chan func(key, value interface{})),
    82  		stopChannel:   make(chan lang.PlaceholderType),
    83  	}
    84  
    85  	tw.initSlots()
    86  	go tw.run()
    87  
    88  	return tw, nil
    89  }
    90  
    91  // Drain drains all items and executes them.
    92  func (tw *TimingWheel) Drain(fn func(key, value interface{})) {
    93  	tw.drainChannel <- fn
    94  }
    95  
    96  // MoveTimer moves the task with the given key to the given delay.
    97  func (tw *TimingWheel) MoveTimer(key interface{}, delay time.Duration) {
    98  	if delay <= 0 || key == nil {
    99  		return
   100  	}
   101  
   102  	tw.moveChannel <- baseEntry{
   103  		delay: delay,
   104  		key:   key,
   105  	}
   106  }
   107  
   108  // RemoveTimer removes the task with the given key.
   109  func (tw *TimingWheel) RemoveTimer(key interface{}) {
   110  	if key == nil {
   111  		return
   112  	}
   113  
   114  	tw.removeChannel <- key
   115  }
   116  
   117  // SetTimer sets the task value with the given key to the delay.
   118  func (tw *TimingWheel) SetTimer(key, value interface{}, delay time.Duration) {
   119  	if delay <= 0 || key == nil {
   120  		return
   121  	}
   122  
   123  	tw.setChannel <- timingEntry{
   124  		baseEntry: baseEntry{
   125  			delay: delay,
   126  			key:   key,
   127  		},
   128  		value: value,
   129  	}
   130  }
   131  
   132  // Stop stops tw.
   133  func (tw *TimingWheel) Stop() {
   134  	close(tw.stopChannel)
   135  }
   136  
   137  func (tw *TimingWheel) drainAll(fn func(key, value interface{})) {
   138  	runner := threading.NewTaskRunner(drainWorkers)
   139  	for _, slot := range tw.slots {
   140  		for e := slot.Front(); e != nil; {
   141  			task := e.Value.(*timingEntry)
   142  			next := e.Next()
   143  			slot.Remove(e)
   144  			e = next
   145  			if !task.removed {
   146  				runner.Schedule(func() {
   147  					fn(task.key, task.value)
   148  				})
   149  			}
   150  		}
   151  	}
   152  }
   153  
   154  func (tw *TimingWheel) getPositionAndCircle(d time.Duration) (pos, circle int) {
   155  	steps := int(d / tw.interval)
   156  	pos = (tw.tickedPos + steps) % tw.numSlots
   157  	circle = (steps - 1) / tw.numSlots
   158  
   159  	return
   160  }
   161  
   162  func (tw *TimingWheel) initSlots() {
   163  	for i := 0; i < tw.numSlots; i++ {
   164  		tw.slots[i] = list.New()
   165  	}
   166  }
   167  
   168  func (tw *TimingWheel) moveTask(task baseEntry) {
   169  	val, ok := tw.timers.Get(task.key)
   170  	if !ok {
   171  		return
   172  	}
   173  
   174  	timer := val.(*positionEntry)
   175  	if task.delay < tw.interval {
   176  		threading.GoSafe(func() {
   177  			tw.execute(timer.item.key, timer.item.value)
   178  		})
   179  		return
   180  	}
   181  
   182  	pos, circle := tw.getPositionAndCircle(task.delay)
   183  	if pos >= timer.pos {
   184  		timer.item.circle = circle
   185  		timer.item.diff = pos - timer.pos
   186  	} else if circle > 0 {
   187  		circle--
   188  		timer.item.circle = circle
   189  		timer.item.diff = tw.numSlots + pos - timer.pos
   190  	} else {
   191  		timer.item.removed = true
   192  		newItem := &timingEntry{
   193  			baseEntry: task,
   194  			value:     timer.item.value,
   195  		}
   196  		tw.slots[pos].PushBack(newItem)
   197  		tw.setTimerPosition(pos, newItem)
   198  	}
   199  }
   200  
   201  func (tw *TimingWheel) onTick() {
   202  	tw.tickedPos = (tw.tickedPos + 1) % tw.numSlots
   203  	l := tw.slots[tw.tickedPos]
   204  	tw.scanAndRunTasks(l)
   205  }
   206  
   207  func (tw *TimingWheel) removeTask(key interface{}) {
   208  	val, ok := tw.timers.Get(key)
   209  	if !ok {
   210  		return
   211  	}
   212  
   213  	timer := val.(*positionEntry)
   214  	timer.item.removed = true
   215  	tw.timers.Del(key)
   216  }
   217  
   218  func (tw *TimingWheel) run() {
   219  	for {
   220  		select {
   221  		case <-tw.ticker.Chan():
   222  			tw.onTick()
   223  		case task := <-tw.setChannel:
   224  			tw.setTask(&task)
   225  		case key := <-tw.removeChannel:
   226  			tw.removeTask(key)
   227  		case task := <-tw.moveChannel:
   228  			tw.moveTask(task)
   229  		case fn := <-tw.drainChannel:
   230  			tw.drainAll(fn)
   231  		case <-tw.stopChannel:
   232  			tw.ticker.Stop()
   233  			return
   234  		}
   235  	}
   236  }
   237  
   238  func (tw *TimingWheel) runTasks(tasks []timingTask) {
   239  	if len(tasks) == 0 {
   240  		return
   241  	}
   242  
   243  	go func() {
   244  		for i := range tasks {
   245  			threading.RunSafe(func() {
   246  				tw.execute(tasks[i].key, tasks[i].value)
   247  			})
   248  		}
   249  	}()
   250  }
   251  
   252  func (tw *TimingWheel) scanAndRunTasks(l *list.List) {
   253  	var tasks []timingTask
   254  
   255  	for e := l.Front(); e != nil; {
   256  		task := e.Value.(*timingEntry)
   257  		if task.removed {
   258  			next := e.Next()
   259  			l.Remove(e)
   260  			e = next
   261  			continue
   262  		} else if task.circle > 0 {
   263  			task.circle--
   264  			e = e.Next()
   265  			continue
   266  		} else if task.diff > 0 {
   267  			next := e.Next()
   268  			l.Remove(e)
   269  			// (tw.tickedPos+task.diff)%tw.numSlots
   270  			// cannot be the same value of tw.tickedPos
   271  			pos := (tw.tickedPos + task.diff) % tw.numSlots
   272  			tw.slots[pos].PushBack(task)
   273  			tw.setTimerPosition(pos, task)
   274  			task.diff = 0
   275  			e = next
   276  			continue
   277  		}
   278  
   279  		tasks = append(tasks, timingTask{
   280  			key:   task.key,
   281  			value: task.value,
   282  		})
   283  		next := e.Next()
   284  		l.Remove(e)
   285  		tw.timers.Del(task.key)
   286  		e = next
   287  	}
   288  
   289  	tw.runTasks(tasks)
   290  }
   291  
   292  func (tw *TimingWheel) setTask(task *timingEntry) {
   293  	if task.delay < tw.interval {
   294  		task.delay = tw.interval
   295  	}
   296  
   297  	if val, ok := tw.timers.Get(task.key); ok {
   298  		entry := val.(*positionEntry)
   299  		entry.item.value = task.value
   300  		tw.moveTask(task.baseEntry)
   301  	} else {
   302  		pos, circle := tw.getPositionAndCircle(task.delay)
   303  		task.circle = circle
   304  		tw.slots[pos].PushBack(task)
   305  		tw.setTimerPosition(pos, task)
   306  	}
   307  }
   308  
   309  func (tw *TimingWheel) setTimerPosition(pos int, task *timingEntry) {
   310  	if val, ok := tw.timers.Get(task.key); ok {
   311  		timer := val.(*positionEntry)
   312  		timer.item = task
   313  		timer.pos = pos
   314  	} else {
   315  		tw.timers.Set(task.key, &positionEntry{
   316  			pos:  pos,
   317  			item: task,
   318  		})
   319  	}
   320  }