github.com/shuguocloud/go-zero@v1.3.0/core/collection/timingwheel.go (about) 1 package collection 2 3 import ( 4 "container/list" 5 "fmt" 6 "time" 7 8 "github.com/shuguocloud/go-zero/core/lang" 9 "github.com/shuguocloud/go-zero/core/threading" 10 "github.com/shuguocloud/go-zero/core/timex" 11 ) 12 13 const drainWorkers = 8 14 15 type ( 16 // Execute defines the method to execute the task. 17 Execute func(key, value interface{}) 18 19 // A TimingWheel is a timing wheel object to schedule tasks. 20 TimingWheel struct { 21 interval time.Duration 22 ticker timex.Ticker 23 slots []*list.List 24 timers *SafeMap 25 tickedPos int 26 numSlots int 27 execute Execute 28 setChannel chan timingEntry 29 moveChannel chan baseEntry 30 removeChannel chan interface{} 31 drainChannel chan func(key, value interface{}) 32 stopChannel chan lang.PlaceholderType 33 } 34 35 timingEntry struct { 36 baseEntry 37 value interface{} 38 circle int 39 diff int 40 removed bool 41 } 42 43 baseEntry struct { 44 delay time.Duration 45 key interface{} 46 } 47 48 positionEntry struct { 49 pos int 50 item *timingEntry 51 } 52 53 timingTask struct { 54 key interface{} 55 value interface{} 56 } 57 ) 58 59 // NewTimingWheel returns a TimingWheel. 60 func NewTimingWheel(interval time.Duration, numSlots int, execute Execute) (*TimingWheel, error) { 61 if interval <= 0 || numSlots <= 0 || execute == nil { 62 return nil, fmt.Errorf("interval: %v, slots: %d, execute: %p", interval, numSlots, execute) 63 } 64 65 return newTimingWheelWithClock(interval, numSlots, execute, timex.NewTicker(interval)) 66 } 67 68 func newTimingWheelWithClock(interval time.Duration, numSlots int, execute Execute, ticker timex.Ticker) ( 69 *TimingWheel, error) { 70 tw := &TimingWheel{ 71 interval: interval, 72 ticker: ticker, 73 slots: make([]*list.List, numSlots), 74 timers: NewSafeMap(), 75 tickedPos: numSlots - 1, // at previous virtual circle 76 execute: execute, 77 numSlots: numSlots, 78 setChannel: make(chan timingEntry), 79 moveChannel: make(chan baseEntry), 80 removeChannel: make(chan interface{}), 81 drainChannel: make(chan func(key, value interface{})), 82 stopChannel: make(chan lang.PlaceholderType), 83 } 84 85 tw.initSlots() 86 go tw.run() 87 88 return tw, nil 89 } 90 91 // Drain drains all items and executes them. 92 func (tw *TimingWheel) Drain(fn func(key, value interface{})) { 93 tw.drainChannel <- fn 94 } 95 96 // MoveTimer moves the task with the given key to the given delay. 97 func (tw *TimingWheel) MoveTimer(key interface{}, delay time.Duration) { 98 if delay <= 0 || key == nil { 99 return 100 } 101 102 tw.moveChannel <- baseEntry{ 103 delay: delay, 104 key: key, 105 } 106 } 107 108 // RemoveTimer removes the task with the given key. 109 func (tw *TimingWheel) RemoveTimer(key interface{}) { 110 if key == nil { 111 return 112 } 113 114 tw.removeChannel <- key 115 } 116 117 // SetTimer sets the task value with the given key to the delay. 118 func (tw *TimingWheel) SetTimer(key, value interface{}, delay time.Duration) { 119 if delay <= 0 || key == nil { 120 return 121 } 122 123 tw.setChannel <- timingEntry{ 124 baseEntry: baseEntry{ 125 delay: delay, 126 key: key, 127 }, 128 value: value, 129 } 130 } 131 132 // Stop stops tw. 133 func (tw *TimingWheel) Stop() { 134 close(tw.stopChannel) 135 } 136 137 func (tw *TimingWheel) drainAll(fn func(key, value interface{})) { 138 runner := threading.NewTaskRunner(drainWorkers) 139 for _, slot := range tw.slots { 140 for e := slot.Front(); e != nil; { 141 task := e.Value.(*timingEntry) 142 next := e.Next() 143 slot.Remove(e) 144 e = next 145 if !task.removed { 146 runner.Schedule(func() { 147 fn(task.key, task.value) 148 }) 149 } 150 } 151 } 152 } 153 154 func (tw *TimingWheel) getPositionAndCircle(d time.Duration) (pos, circle int) { 155 steps := int(d / tw.interval) 156 pos = (tw.tickedPos + steps) % tw.numSlots 157 circle = (steps - 1) / tw.numSlots 158 159 return 160 } 161 162 func (tw *TimingWheel) initSlots() { 163 for i := 0; i < tw.numSlots; i++ { 164 tw.slots[i] = list.New() 165 } 166 } 167 168 func (tw *TimingWheel) moveTask(task baseEntry) { 169 val, ok := tw.timers.Get(task.key) 170 if !ok { 171 return 172 } 173 174 timer := val.(*positionEntry) 175 if task.delay < tw.interval { 176 threading.GoSafe(func() { 177 tw.execute(timer.item.key, timer.item.value) 178 }) 179 return 180 } 181 182 pos, circle := tw.getPositionAndCircle(task.delay) 183 if pos >= timer.pos { 184 timer.item.circle = circle 185 timer.item.diff = pos - timer.pos 186 } else if circle > 0 { 187 circle-- 188 timer.item.circle = circle 189 timer.item.diff = tw.numSlots + pos - timer.pos 190 } else { 191 timer.item.removed = true 192 newItem := &timingEntry{ 193 baseEntry: task, 194 value: timer.item.value, 195 } 196 tw.slots[pos].PushBack(newItem) 197 tw.setTimerPosition(pos, newItem) 198 } 199 } 200 201 func (tw *TimingWheel) onTick() { 202 tw.tickedPos = (tw.tickedPos + 1) % tw.numSlots 203 l := tw.slots[tw.tickedPos] 204 tw.scanAndRunTasks(l) 205 } 206 207 func (tw *TimingWheel) removeTask(key interface{}) { 208 val, ok := tw.timers.Get(key) 209 if !ok { 210 return 211 } 212 213 timer := val.(*positionEntry) 214 timer.item.removed = true 215 tw.timers.Del(key) 216 } 217 218 func (tw *TimingWheel) run() { 219 for { 220 select { 221 case <-tw.ticker.Chan(): 222 tw.onTick() 223 case task := <-tw.setChannel: 224 tw.setTask(&task) 225 case key := <-tw.removeChannel: 226 tw.removeTask(key) 227 case task := <-tw.moveChannel: 228 tw.moveTask(task) 229 case fn := <-tw.drainChannel: 230 tw.drainAll(fn) 231 case <-tw.stopChannel: 232 tw.ticker.Stop() 233 return 234 } 235 } 236 } 237 238 func (tw *TimingWheel) runTasks(tasks []timingTask) { 239 if len(tasks) == 0 { 240 return 241 } 242 243 go func() { 244 for i := range tasks { 245 threading.RunSafe(func() { 246 tw.execute(tasks[i].key, tasks[i].value) 247 }) 248 } 249 }() 250 } 251 252 func (tw *TimingWheel) scanAndRunTasks(l *list.List) { 253 var tasks []timingTask 254 255 for e := l.Front(); e != nil; { 256 task := e.Value.(*timingEntry) 257 if task.removed { 258 next := e.Next() 259 l.Remove(e) 260 e = next 261 continue 262 } else if task.circle > 0 { 263 task.circle-- 264 e = e.Next() 265 continue 266 } else if task.diff > 0 { 267 next := e.Next() 268 l.Remove(e) 269 // (tw.tickedPos+task.diff)%tw.numSlots 270 // cannot be the same value of tw.tickedPos 271 pos := (tw.tickedPos + task.diff) % tw.numSlots 272 tw.slots[pos].PushBack(task) 273 tw.setTimerPosition(pos, task) 274 task.diff = 0 275 e = next 276 continue 277 } 278 279 tasks = append(tasks, timingTask{ 280 key: task.key, 281 value: task.value, 282 }) 283 next := e.Next() 284 l.Remove(e) 285 tw.timers.Del(task.key) 286 e = next 287 } 288 289 tw.runTasks(tasks) 290 } 291 292 func (tw *TimingWheel) setTask(task *timingEntry) { 293 if task.delay < tw.interval { 294 task.delay = tw.interval 295 } 296 297 if val, ok := tw.timers.Get(task.key); ok { 298 entry := val.(*positionEntry) 299 entry.item.value = task.value 300 tw.moveTask(task.baseEntry) 301 } else { 302 pos, circle := tw.getPositionAndCircle(task.delay) 303 task.circle = circle 304 tw.slots[pos].PushBack(task) 305 tw.setTimerPosition(pos, task) 306 } 307 } 308 309 func (tw *TimingWheel) setTimerPosition(pos int, task *timingEntry) { 310 if val, ok := tw.timers.Get(task.key); ok { 311 timer := val.(*positionEntry) 312 timer.item = task 313 timer.pos = pos 314 } else { 315 tw.timers.Set(task.key, &positionEntry{ 316 pos: pos, 317 item: task, 318 }) 319 } 320 }