github.com/XiaoMi/Gaea@v1.2.5/util/time_wheel.go (about)

     1  // Copyright 2019 The Gaea Authors. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package util
    16  
    17  import (
    18  	"errors"
    19  	"time"
    20  )
    21  
    22  // Task means handle unit in time wheel
    23  type Task struct {
    24  	delay    time.Duration
    25  	key      interface{}
    26  	round    int // optimize time wheel to handle delay  bigger than bucketsNum * tick
    27  	callback func()
    28  }
    29  
    30  // TimeWheel means time wheel
    31  type TimeWheel struct {
    32  	tick   time.Duration
    33  	ticker *time.Ticker
    34  
    35  	bucketsNum    int
    36  	buckets       []map[interface{}]*Task // key: added item, value: *Task
    37  	bucketIndexes map[interface{}]int     // key: added item, value: bucket position
    38  
    39  	currentIndex int
    40  
    41  	addC    chan *Task
    42  	removeC chan interface{}
    43  	stopC   chan struct{}
    44  }
    45  
    46  // NewTimeWheel create new time wheel
    47  func NewTimeWheel(tick time.Duration, bucketsNum int) (*TimeWheel, error) {
    48  	if bucketsNum <= 0 {
    49  		return nil, errors.New("bucket number must be greater than 0")
    50  	}
    51  	if int(tick.Seconds()) < 1 {
    52  		return nil, errors.New("tick cannot be less than 1s")
    53  	}
    54  
    55  	tw := &TimeWheel{
    56  		tick:          tick,
    57  		bucketsNum:    bucketsNum,
    58  		bucketIndexes: make(map[interface{}]int, 1024),
    59  		buckets:       make([]map[interface{}]*Task, bucketsNum),
    60  		currentIndex:  0,
    61  		addC:          make(chan *Task, 1024),
    62  		removeC:       make(chan interface{}, 1024),
    63  		stopC:         make(chan struct{}),
    64  	}
    65  
    66  	for i := 0; i < bucketsNum; i++ {
    67  		tw.buckets[i] = make(map[interface{}]*Task, 16)
    68  	}
    69  
    70  	return tw, nil
    71  }
    72  
    73  // Start start the time wheel
    74  func (tw *TimeWheel) Start() {
    75  	tw.ticker = time.NewTicker(tw.tick)
    76  	go tw.start()
    77  }
    78  
    79  func (tw *TimeWheel) start() {
    80  	for {
    81  		select {
    82  		case <-tw.ticker.C:
    83  			tw.handleTick()
    84  		case task := <-tw.addC:
    85  			tw.add(task)
    86  		case key := <-tw.removeC:
    87  			tw.remove(key)
    88  		case <-tw.stopC:
    89  			tw.ticker.Stop()
    90  			return
    91  		}
    92  	}
    93  }
    94  
    95  // Stop stop the time wheel
    96  func (tw *TimeWheel) Stop() {
    97  	tw.stopC <- struct{}{}
    98  }
    99  
   100  func (tw *TimeWheel) handleTick() {
   101  	bucket := tw.buckets[tw.currentIndex]
   102  	for k := range bucket {
   103  		if bucket[k].round > 0 {
   104  			bucket[k].round--
   105  			continue
   106  		}
   107  		go bucket[k].callback()
   108  		delete(bucket, k)
   109  		delete(tw.bucketIndexes, k)
   110  	}
   111  	if tw.currentIndex == tw.bucketsNum-1 {
   112  		tw.currentIndex = 0
   113  		return
   114  	}
   115  	tw.currentIndex++
   116  }
   117  
   118  // Add add an item into time wheel
   119  func (tw *TimeWheel) Add(delay time.Duration, key interface{}, callback func()) error {
   120  	if delay <= 0 || key == nil {
   121  		return errors.New("invalid params")
   122  	}
   123  	tw.addC <- &Task{delay: delay, key: key, callback: callback}
   124  	return nil
   125  }
   126  
   127  func (tw *TimeWheel) add(task *Task) {
   128  	round := tw.calculateRound(task.delay)
   129  	index := tw.calculateIndex(task.delay)
   130  	task.round = round
   131  	if originIndex, ok := tw.bucketIndexes[task.key]; ok {
   132  		delete(tw.buckets[originIndex], task.key)
   133  	}
   134  	tw.bucketIndexes[task.key] = index
   135  	tw.buckets[index][task.key] = task
   136  }
   137  
   138  func (tw *TimeWheel) calculateRound(delay time.Duration) (round int) {
   139  	delaySeconds := int(delay.Seconds())
   140  	tickSeconds := int(tw.tick.Seconds())
   141  	round = delaySeconds / tickSeconds / tw.bucketsNum
   142  	return
   143  }
   144  
   145  func (tw *TimeWheel) calculateIndex(delay time.Duration) (index int) {
   146  	delaySeconds := int(delay.Seconds())
   147  	tickSeconds := int(tw.tick.Seconds())
   148  	index = (tw.currentIndex + delaySeconds/tickSeconds) % tw.bucketsNum
   149  	return
   150  }
   151  
   152  // Remove remove an item from time wheel
   153  func (tw *TimeWheel) Remove(key interface{}) error {
   154  	if key == nil {
   155  		return errors.New("invalid params")
   156  	}
   157  	tw.removeC <- key
   158  	return nil
   159  }
   160  
   161  // don't need to call callback
   162  func (tw *TimeWheel) remove(key interface{}) {
   163  	if index, ok := tw.bucketIndexes[key]; ok {
   164  		delete(tw.bucketIndexes, key)
   165  		delete(tw.buckets[index], key)
   166  	}
   167  	return
   168  }