github.com/bingoohuang/gg@v0.0.0-20240325092523-45da7dee9335/pkg/timex/timingwheel.go (about) 1 package timex 2 3 import ( 4 "container/list" 5 "errors" 6 "fmt" 7 "log" 8 "runtime/debug" 9 "time" 10 11 "github.com/bingoohuang/gg/pkg/mapp" 12 ) 13 14 const drainWorkers = 8 15 16 var ( 17 ErrClosed = errors.New("TimingWheel is closed already") 18 ErrArgument = errors.New("incorrect task argument") 19 ) 20 21 type ( 22 // Execute defines the method to execute the task. 23 Execute func(key, value interface{}) 24 25 // A TimingWheel is a timing wheel object to schedule tasks. 26 TimingWheel struct { 27 interval time.Duration 28 ticker *time.Ticker 29 slots []*list.List 30 timers *mapp.SafeMap 31 tickedPos int 32 numSlots int 33 execute Execute 34 setChannel chan timingEntry 35 moveChannel chan baseEntry 36 removeChannel chan interface{} 37 drainChannel chan func(key, value interface{}) 38 stopChannel chan struct{} 39 } 40 41 timingEntry struct { 42 baseEntry 43 value interface{} 44 circle int 45 diff int 46 removed bool 47 } 48 49 baseEntry struct { 50 delay time.Duration 51 key interface{} 52 } 53 54 positionEntry struct { 55 pos int 56 item *timingEntry 57 } 58 59 timingTask struct { 60 key interface{} 61 value interface{} 62 } 63 ) 64 65 // NewTimingWheel returns a TimingWheel. 66 func NewTimingWheel(interval time.Duration, numSlots int, execute Execute) (*TimingWheel, error) { 67 if interval <= 0 || numSlots <= 0 || execute == nil { 68 return nil, fmt.Errorf("interval: %v, slots: %d, execute: %p", 69 interval, numSlots, execute) 70 } 71 72 return newTimingWheelWithClock(interval, numSlots, execute, time.NewTicker(interval)) 73 } 74 75 func newTimingWheelWithClock(interval time.Duration, numSlots int, execute Execute, 76 ticker *time.Ticker, 77 ) (*TimingWheel, error) { 78 tw := &TimingWheel{ 79 interval: interval, 80 ticker: ticker, 81 slots: make([]*list.List, numSlots), 82 timers: mapp.NewSafeMap(), 83 tickedPos: numSlots - 1, // at previous virtual circle 84 execute: execute, 85 numSlots: numSlots, 86 setChannel: make(chan timingEntry), 87 moveChannel: make(chan baseEntry), 88 removeChannel: make(chan interface{}), 89 drainChannel: make(chan func(key, value interface{})), 90 stopChannel: make(chan struct{}), 91 } 92 93 tw.initSlots() 94 go tw.run() 95 96 return tw, nil 97 } 98 99 // Drain drains all items and executes them. 100 func (tw *TimingWheel) Drain(fn func(key, value interface{})) error { 101 select { 102 case tw.drainChannel <- fn: 103 return nil 104 case <-tw.stopChannel: 105 return ErrClosed 106 } 107 } 108 109 // MoveTimer moves the task with the given key to the given delay. 110 func (tw *TimingWheel) MoveTimer(key interface{}, delay time.Duration) error { 111 if delay <= 0 || key == nil { 112 return ErrArgument 113 } 114 115 select { 116 case tw.moveChannel <- baseEntry{ 117 delay: delay, 118 key: key, 119 }: 120 return nil 121 case <-tw.stopChannel: 122 return ErrClosed 123 } 124 } 125 126 // RemoveTimer removes the task with the given key. 127 func (tw *TimingWheel) RemoveTimer(key interface{}) error { 128 if key == nil { 129 return ErrArgument 130 } 131 132 select { 133 case tw.removeChannel <- key: 134 return nil 135 case <-tw.stopChannel: 136 return ErrClosed 137 } 138 } 139 140 // SetTimer sets the task value with the given key to the delay. 141 func (tw *TimingWheel) SetTimer(key, value interface{}, delay time.Duration) error { 142 if delay <= 0 || key == nil { 143 return ErrArgument 144 } 145 146 select { 147 case tw.setChannel <- timingEntry{ 148 baseEntry: baseEntry{ 149 delay: delay, 150 key: key, 151 }, 152 value: value, 153 }: 154 return nil 155 case <-tw.stopChannel: 156 return ErrClosed 157 } 158 } 159 160 // Stop stops tw. No more actions after stopping a TimingWheel. 161 func (tw *TimingWheel) Stop() { 162 close(tw.stopChannel) 163 } 164 165 // A TaskRunner is used to control the concurrency of goroutines. 166 type TaskRunner struct { 167 limitChan chan struct{} 168 } 169 170 // NewTaskRunner returns a TaskRunner. 171 func NewTaskRunner(concurrency int) *TaskRunner { 172 return &TaskRunner{ 173 limitChan: make(chan struct{}, concurrency), 174 } 175 } 176 177 // Recover is used with defer to do cleanup on panics. 178 // Use it like: 179 // 180 // defer Recover(func() {}) 181 func Recover(cleanups ...func()) { 182 for _, cleanup := range cleanups { 183 cleanup() 184 } 185 186 if p := recover(); p != nil { 187 log.Printf("Panic recovered from %+v, stack: %s", p, debug.Stack()) 188 } 189 } 190 191 // GoSafe runs the given fn using another goroutine, recovers if fn panics. 192 func GoSafe(fn func()) { 193 go RunSafe(fn) 194 } 195 196 // RunSafe runs the given fn, recovers if fn panics. 197 func RunSafe(fn func()) { 198 defer Recover() 199 200 fn() 201 } 202 203 // Schedule schedules a task to run under concurrency control. 204 func (rp *TaskRunner) Schedule(task func()) { 205 rp.limitChan <- struct{}{} 206 207 go func() { 208 defer Recover(func() { 209 <-rp.limitChan 210 }) 211 212 task() 213 }() 214 } 215 216 func (tw *TimingWheel) drainAll(fn func(key, value interface{})) { 217 runner := NewTaskRunner(drainWorkers) 218 for _, slot := range tw.slots { 219 for e := slot.Front(); e != nil; { 220 task := e.Value.(*timingEntry) 221 next := e.Next() 222 slot.Remove(e) 223 e = next 224 if !task.removed { 225 runner.Schedule(func() { 226 fn(task.key, task.value) 227 }) 228 } 229 } 230 } 231 } 232 233 func (tw *TimingWheel) getPositionAndCircle(d time.Duration) (pos, circle int) { 234 steps := int(d / tw.interval) 235 pos = (tw.tickedPos + steps) % tw.numSlots 236 circle = (steps - 1) / tw.numSlots 237 238 return 239 } 240 241 func (tw *TimingWheel) initSlots() { 242 for i := 0; i < tw.numSlots; i++ { 243 tw.slots[i] = list.New() 244 } 245 } 246 247 func (tw *TimingWheel) moveTask(task baseEntry) { 248 val, ok := tw.timers.Get(task.key) 249 if !ok { 250 return 251 } 252 253 timer := val.(*positionEntry) 254 if task.delay < tw.interval { 255 GoSafe(func() { 256 tw.execute(timer.item.key, timer.item.value) 257 }) 258 return 259 } 260 261 pos, circle := tw.getPositionAndCircle(task.delay) 262 if pos >= timer.pos { 263 timer.item.circle = circle 264 timer.item.diff = pos - timer.pos 265 } else if circle > 0 { 266 circle-- 267 timer.item.circle = circle 268 timer.item.diff = tw.numSlots + pos - timer.pos 269 } else { 270 timer.item.removed = true 271 newItem := &timingEntry{ 272 baseEntry: task, 273 value: timer.item.value, 274 } 275 tw.slots[pos].PushBack(newItem) 276 tw.setTimerPosition(pos, newItem) 277 } 278 } 279 280 func (tw *TimingWheel) onTick() { 281 tw.tickedPos = (tw.tickedPos + 1) % tw.numSlots 282 l := tw.slots[tw.tickedPos] 283 tw.scanAndRunTasks(l) 284 } 285 286 func (tw *TimingWheel) removeTask(key interface{}) { 287 val, ok := tw.timers.Get(key) 288 if !ok { 289 return 290 } 291 292 timer := val.(*positionEntry) 293 timer.item.removed = true 294 tw.timers.Del(key) 295 } 296 297 func (tw *TimingWheel) run() { 298 for { 299 select { 300 case <-tw.ticker.C: 301 tw.onTick() 302 case task := <-tw.setChannel: 303 tw.setTask(&task) 304 case key := <-tw.removeChannel: 305 tw.removeTask(key) 306 case task := <-tw.moveChannel: 307 tw.moveTask(task) 308 case fn := <-tw.drainChannel: 309 tw.drainAll(fn) 310 case <-tw.stopChannel: 311 tw.ticker.Stop() 312 return 313 } 314 } 315 } 316 317 func (tw *TimingWheel) runTasks(tasks []timingTask) { 318 if len(tasks) == 0 { 319 return 320 } 321 322 go func() { 323 for i := range tasks { 324 RunSafe(func() { 325 tw.execute(tasks[i].key, tasks[i].value) 326 }) 327 } 328 }() 329 } 330 331 func (tw *TimingWheel) scanAndRunTasks(l *list.List) { 332 var tasks []timingTask 333 334 for e := l.Front(); e != nil; { 335 task := e.Value.(*timingEntry) 336 if task.removed { 337 next := e.Next() 338 l.Remove(e) 339 e = next 340 continue 341 } else if task.circle > 0 { 342 task.circle-- 343 e = e.Next() 344 continue 345 } else if task.diff > 0 { 346 next := e.Next() 347 l.Remove(e) 348 // (tw.tickedPos+task.diff)%tw.numSlots 349 // cannot be the same value of tw.tickedPos 350 pos := (tw.tickedPos + task.diff) % tw.numSlots 351 tw.slots[pos].PushBack(task) 352 tw.setTimerPosition(pos, task) 353 task.diff = 0 354 e = next 355 continue 356 } 357 358 tasks = append(tasks, timingTask{ 359 key: task.key, 360 value: task.value, 361 }) 362 next := e.Next() 363 l.Remove(e) 364 tw.timers.Del(task.key) 365 e = next 366 } 367 368 tw.runTasks(tasks) 369 } 370 371 func (tw *TimingWheel) setTask(task *timingEntry) { 372 if task.delay < tw.interval { 373 task.delay = tw.interval 374 } 375 376 if val, ok := tw.timers.Get(task.key); ok { 377 entry := val.(*positionEntry) 378 entry.item.value = task.value 379 tw.moveTask(task.baseEntry) 380 } else { 381 pos, circle := tw.getPositionAndCircle(task.delay) 382 task.circle = circle 383 tw.slots[pos].PushBack(task) 384 tw.setTimerPosition(pos, task) 385 } 386 } 387 388 func (tw *TimingWheel) setTimerPosition(pos int, task *timingEntry) { 389 if val, ok := tw.timers.Get(task.key); ok { 390 timer := val.(*positionEntry) 391 timer.item = task 392 timer.pos = pos 393 } else { 394 tw.timers.Set(task.key, &positionEntry{ 395 pos: pos, 396 item: task, 397 }) 398 } 399 }