github.com/benz9527/toy-box/algo@v0.0.0-20240221120937-66c0c6bd5abd/timer/x_timing_wheels.go (about)

     1  package timer
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"github.com/benz9527/toy-box/algo/queue"
     8  	"log/slog"
     9  	"runtime/debug"
    10  	"sync/atomic"
    11  	"time"
    12  	"unsafe"
    13  )
    14  
    15  var (
    16  	_ TimingWheel  = (*timingWheel)(nil)
    17  	_ TimingWheels = (*xTimingWheels)(nil)
    18  )
    19  
    20  // 112
    21  type timingWheel struct {
    22  	slots []TimingWheelSlot // alignment 8, size 24; in kafka it is buckets
    23  	// ctx is used to shut down the timing wheel and pass
    24  	// value to control debug info.
    25  	ctx                  context.Context                   // alignment 8, size 16
    26  	globalDqRef          queue.DelayQueue[TimingWheelSlot] // alignment 8, size 16
    27  	overflowWheelRef     unsafe.Pointer                    //  alignment 8, size 8; same as kafka TimingWheel(*timingWheel)
    28  	tickMs               int64                             // alignment 8, size 8
    29  	startMs              int64                             // alignment 8, size 8; baseline startup timestamp
    30  	interval             int64                             // alignment 8, size 8
    31  	currentTimeMs        int64                             // alignment 8, size 8
    32  	slotSize             int64                             // alignment 8, size 8; in kafka it is wheelSize
    33  	globalSlotCounterRef *atomic.Int64                     // alignment 8, size 8
    34  }
    35  
    36  type TimingWheelOptions func(tw *timingWheel)
    37  
    38  func WithTimingWheelTickMs(basicTickMs time.Duration) TimingWheelOptions {
    39  	return func(tw *timingWheel) {
    40  		tw.tickMs = basicTickMs.Milliseconds()
    41  	}
    42  }
    43  
    44  func WithTimingWheelSlotSize(slotSize int64) TimingWheelOptions {
    45  	return func(tw *timingWheel) {
    46  		tw.slotSize = slotSize
    47  	}
    48  }
    49  
    50  func newTimingWheel(
    51  	ctx context.Context,
    52  	tickMs int64,
    53  	slotSize int64,
    54  	startMs int64,
    55  	slotCounter *atomic.Int64,
    56  	dq queue.DelayQueue[TimingWheelSlot],
    57  ) TimingWheel {
    58  	tw := &timingWheel{
    59  		ctx:                  ctx,
    60  		tickMs:               tickMs,
    61  		startMs:              startMs,
    62  		slotSize:             slotSize,
    63  		globalSlotCounterRef: slotCounter,
    64  		interval:             tickMs * slotSize,
    65  		currentTimeMs:        startMs - (startMs % tickMs), // truncate the remainder as startMs left boundary
    66  		slots:                make([]TimingWheelSlot, slotSize),
    67  		globalDqRef:          dq,
    68  	}
    69  	// Slot initialize by doubly linked list.
    70  	for i := int64(0); i < slotSize; i++ {
    71  		tw.slots[i] = NewXSlot()
    72  	}
    73  	tw.globalSlotCounterRef.Add(slotSize)
    74  	tw.setOverflowTimingWheel(nil)
    75  	return tw
    76  }
    77  
    78  func (tw *timingWheel) GetTickMs() int64 {
    79  	return atomic.LoadInt64(&tw.tickMs)
    80  }
    81  
    82  func (tw *timingWheel) GetStartMs() int64 {
    83  	return atomic.LoadInt64(&tw.startMs)
    84  }
    85  
    86  func (tw *timingWheel) GetCurrentTimeMs() int64 {
    87  	return atomic.LoadInt64(&tw.currentTimeMs)
    88  }
    89  
    90  func (tw *timingWheel) GetInterval() int64 {
    91  	return atomic.LoadInt64(&tw.interval)
    92  }
    93  
    94  func (tw *timingWheel) GetSlotSize() int64 {
    95  	return atomic.LoadInt64(&tw.slotSize)
    96  }
    97  
    98  func (tw *timingWheel) getOverflowTimingWheel() TimingWheel {
    99  	return *(*TimingWheel)(atomic.LoadPointer(&tw.overflowWheelRef))
   100  }
   101  
   102  func (tw *timingWheel) setOverflowTimingWheel(oftw TimingWheel) {
   103  	atomic.StorePointer(&tw.overflowWheelRef, unsafe.Pointer(&oftw))
   104  }
   105  
   106  // Here related to slot level upgrade and downgrade.
   107  func (tw *timingWheel) advanceClock(slotExpiredMs int64) {
   108  	currentTimeMs := tw.GetCurrentTimeMs()
   109  	tickMs := tw.GetTickMs()
   110  	if slotExpiredMs >= currentTimeMs+tickMs {
   111  		currentTimeMs = slotExpiredMs - (slotExpiredMs % tickMs) // truncate the remainder as slot expiredMs left boundary
   112  		atomic.StoreInt64(&tw.currentTimeMs, currentTimeMs)      // update the current time
   113  		oftw := tw.getOverflowTimingWheel()
   114  		if oftw != nil {
   115  			oftw.(*timingWheel).advanceClock(currentTimeMs)
   116  		}
   117  	}
   118  }
   119  
   120  func (tw *timingWheel) addTask(task Task, level int64) error {
   121  	if len(task.GetJobID()) <= 0 {
   122  		return ErrTimingWheelTaskEmptyJobID
   123  	}
   124  	if task.GetJob() == nil {
   125  		return ErrTimingWheelEmptyJob
   126  	}
   127  
   128  	taskExpiredMs := task.GetExpiredMs()
   129  	currentTimeMs := tw.GetCurrentTimeMs()
   130  	tickMs := tw.GetTickMs()
   131  	interval := tw.GetInterval()
   132  	slotSize := tw.GetSlotSize()
   133  
   134  	if task.Cancelled() {
   135  		return fmt.Errorf("[timing wheel] task %s is cancelled, %w",
   136  			task.GetJobID(), ErrTimingWheelTaskCancelled)
   137  	} else if taskExpiredMs < currentTimeMs+tickMs {
   138  		task.setSlot(immediateExpiredSlot)
   139  		tw.globalSlotCounterRef.Add(1)
   140  		return fmt.Errorf("[timing wheel] task task expired ms  %d is before %d, %w",
   141  			taskExpiredMs, currentTimeMs+tickMs, ErrTimingWheelTaskIsExpired)
   142  	} else if taskExpiredMs < currentTimeMs+interval {
   143  		virtualID := taskExpiredMs / tickMs
   144  		slotID := virtualID % slotSize
   145  		slot := tw.slots[slotID]
   146  		slotMs := slot.GetExpirationMs()
   147  		if slot.GetExpirationMs() != (virtualID*tickMs) && !slot.setExpirationMs(virtualID*tickMs) { // FIXME data race
   148  			err := fmt.Errorf("[timing wheel] slot (level:%d) (old:%d<->new:%d) unable update the expiration, %w",
   149  				level, slotMs, virtualID*tickMs, ErrTimingWheelTaskUnableToBeAddedToSlot)
   150  			slog.Error("[timing wheel] add task error", "error", err)
   151  			return err
   152  		}
   153  
   154  		slot.setSlotID(slotID)
   155  		slot.setLevel(level)
   156  		slot.AddTask(task)
   157  		if err := tw.globalDqRef.Offer(slot, slot.GetExpirationMs()); err != nil {
   158  			slog.Error("[timing wheel] offer slot to delay queue error", "error", err)
   159  		}
   160  		return nil
   161  	} else {
   162  		// Out of the interval. Put it into the higher interval timing wheel
   163  		oftw := tw.getOverflowTimingWheel()
   164  		if oftw == nil {
   165  			tw.setOverflowTimingWheel(newTimingWheel(
   166  				tw.ctx,
   167  				interval,
   168  				slotSize,
   169  				currentTimeMs,
   170  				tw.globalSlotCounterRef,
   171  				tw.globalDqRef,
   172  			))
   173  		}
   174  		// Tail recursive call, it will be free the previous stack frame.
   175  		return tw.getOverflowTimingWheel().(*timingWheel).addTask(task, level+1)
   176  	}
   177  }
   178  
   179  const (
   180  	disableTimingWheelsSchedulePoll        = "disableTWSPoll"
   181  	disableTimingWheelsScheduleCancelTask  = "disableTWSCancelTask"
   182  	disableTimingWheelsScheduleExpiredSlot = "disableTWSExpSlot"
   183  )
   184  
   185  // size: 112
   186  type xTimingWheels struct {
   187  	tw           TimingWheel                       // alignment 8, size 16
   188  	ctx          context.Context                   // alignment 8, size 16
   189  	dq           queue.DelayQueue[TimingWheelSlot] // alignment 8, size 16; Do not use the timer.Ticker
   190  	tasksMap     map[JobID]Task                    // alignment 8, size 8
   191  	stopC        chan struct{}                     // alignment 8, size 8
   192  	expiredSlotC chan TimingWheelSlot              // alignment 8, size 8
   193  	twEventC     chan *timingWheelEvent            // alignment 8, size 8
   194  	twEventPool  *timingWheelEventsPool            // alignment 8, size 8
   195  	taskCounter  *atomic.Int64                     // alignment 8, size 8
   196  	slotCounter  *atomic.Int64                     // alignment 8, size 8
   197  	isRunning    *atomic.Bool                      // alignment 8, size 8
   198  	// FIXME goroutine pool
   199  }
   200  
   201  // NewTimingWheels creates a new timing wheel.
   202  // @param startMs the start time in milliseconds, example value time.Now().UnixMilli().
   203  //
   204  //	Same as the kafka, Time.SYSTEM.hiResClockMs() is used.
   205  func NewTimingWheels(ctx context.Context, startMs int64, opts ...TimingWheelOptions) TimingWheels {
   206  	if ctx == nil {
   207  		return nil
   208  	}
   209  
   210  	xtw := &xTimingWheels{
   211  		ctx:          ctx,
   212  		taskCounter:  &atomic.Int64{},
   213  		slotCounter:  &atomic.Int64{},
   214  		twEventC:     make(chan *timingWheelEvent, 256),
   215  		stopC:        make(chan struct{}),
   216  		expiredSlotC: make(chan TimingWheelSlot, 128),
   217  		tasksMap:     make(map[JobID]Task),
   218  		isRunning:    &atomic.Bool{},
   219  		twEventPool:  newTimingWheelEventsPool(),
   220  	}
   221  	xtw.isRunning.Store(false)
   222  	tw := &timingWheel{
   223  		startMs: startMs,
   224  	}
   225  	for _, o := range opts {
   226  		if o != nil {
   227  			o(tw)
   228  		}
   229  	}
   230  
   231  	if tw.tickMs <= 0 {
   232  		tw.tickMs = time.Millisecond.Milliseconds()
   233  	}
   234  	if tw.slotSize <= 0 {
   235  		tw.slotSize = 20
   236  	}
   237  	xtw.dq = queue.NewArrayDelayQueue[TimingWheelSlot](128)
   238  	xtw.tw = newTimingWheel(
   239  		ctx,
   240  		tw.tickMs,
   241  		tw.slotSize,
   242  		tw.startMs,
   243  		xtw.slotCounter,
   244  		xtw.dq,
   245  	)
   246  	xtw.schedule(ctx)
   247  	return xtw
   248  }
   249  
   250  func (xtw *xTimingWheels) GetTickMs() int64 {
   251  	return xtw.tw.GetTickMs()
   252  }
   253  
   254  func (xtw *xTimingWheels) GetStartMs() int64 {
   255  	return xtw.tw.GetStartMs()
   256  }
   257  
   258  func (xtw *xTimingWheels) GetTaskCounter() int64 {
   259  	return xtw.taskCounter.Load()
   260  }
   261  
   262  func (xtw *xTimingWheels) GetSlotSize() int64 {
   263  	return xtw.slotCounter.Load()
   264  }
   265  
   266  func (xtw *xTimingWheels) Shutdown() {
   267  	if old := xtw.isRunning.Swap(false); !old {
   268  		slog.Warn("[timing wheel] timing wheel is already shutdown")
   269  		return
   270  	}
   271  	xtw.dq = nil
   272  	xtw.isRunning.Store(false)
   273  
   274  	// FIXME close on channel is no empty and will cause panic.
   275  	close(xtw.stopC)
   276  	close(xtw.expiredSlotC)
   277  	close(xtw.twEventC)
   278  
   279  	// FIXME map clear data race
   280  }
   281  
   282  func (xtw *xTimingWheels) AddTask(task Task) error {
   283  	if len(task.GetJobID()) <= 0 {
   284  		return ErrTimingWheelTaskEmptyJobID
   285  	}
   286  	if task.GetJob() == nil {
   287  		return ErrTimingWheelEmptyJob
   288  	}
   289  	if !xtw.isRunning.Load() {
   290  		return ErrTimingWheelStopped
   291  	}
   292  	event := xtw.twEventPool.Get()
   293  	event.AddTask(task)
   294  	xtw.twEventC <- event
   295  	return nil
   296  }
   297  
   298  func (xtw *xTimingWheels) AfterFunc(delayMs time.Duration, fn Job) (Task, error) {
   299  	if delayMs.Milliseconds() < xtw.GetTickMs() {
   300  		return nil, fmt.Errorf("[timing wheel] delay ms %d is less than tick ms %d, %w",
   301  			delayMs.Milliseconds(), xtw.GetTickMs(), ErrTimingWheelTaskTooShortExpiration)
   302  	}
   303  	if fn == nil {
   304  		return nil, ErrTimingWheelEmptyJob
   305  	}
   306  
   307  	now := time.Now().UTC()
   308  	task := NewOnceTask(
   309  		xtw.ctx,
   310  		JobID(fmt.Sprintf("%d", now.UnixNano())), // FIXME UUID
   311  		now.Add(delayMs).UnixMilli(),
   312  		fn,
   313  	)
   314  
   315  	if !xtw.isRunning.Load() {
   316  		return nil, ErrTimingWheelStopped
   317  	}
   318  	if err := xtw.AddTask(task); err != nil {
   319  		return nil, err
   320  	}
   321  	return task, nil
   322  }
   323  
   324  func (xtw *xTimingWheels) ScheduleFunc(schedFn func() Scheduler, fn Job) (Task, error) {
   325  	if schedFn == nil {
   326  		return nil, ErrTimingWheelUnknownScheduler
   327  	}
   328  	if fn == nil {
   329  		return nil, ErrTimingWheelEmptyJob
   330  	}
   331  
   332  	now := time.Now()
   333  	task := NewRepeatTask(
   334  		xtw.ctx,
   335  		JobID(fmt.Sprintf("%d", now.UnixNano())), // FIXME UUID
   336  		now.UnixMilli(), schedFn(),
   337  		fn,
   338  	)
   339  
   340  	if !xtw.isRunning.Load() {
   341  		return nil, ErrTimingWheelStopped
   342  	}
   343  	if err := xtw.AddTask(task); err != nil {
   344  		return nil, err
   345  	}
   346  	return task, nil
   347  }
   348  
   349  func (xtw *xTimingWheels) CancelTask(jobID JobID) error {
   350  	if len(jobID) <= 0 {
   351  		return ErrTimingWheelTaskEmptyJobID
   352  	}
   353  
   354  	if xtw.isRunning.Load() {
   355  		return ErrTimingWheelStopped
   356  	}
   357  	task, ok := xtw.tasksMap[jobID]
   358  	if !ok {
   359  		return ErrTimingWheelTaskNotFound
   360  	}
   361  
   362  	event := xtw.twEventPool.Get()
   363  	event.CancelTaskJobID(task.GetJobID())
   364  	xtw.twEventC <- event
   365  	return nil
   366  }
   367  
   368  func (xtw *xTimingWheels) schedule(ctx context.Context) {
   369  	if ctx == nil {
   370  		return
   371  	}
   372  	// FIXME Block error mainly caused by producer and consumer speed mismatch, lock data race.
   373  	//  Is there any limitation mechanism could gradually  control different interval task‘s execution timeout timestamp?
   374  	//  Tasks piling up in the same slot will cause the timing wheel to be blocked or delayed.
   375  	go func() {
   376  		defer func() {
   377  			if err := recover(); err != nil {
   378  				slog.Error("[timing wheel] event schedule panic recover", "error", err, "stack", debug.Stack())
   379  			}
   380  		}()
   381  		cancelDisabled := ctx.Value(disableTimingWheelsScheduleCancelTask)
   382  		if cancelDisabled == nil {
   383  			cancelDisabled = false
   384  		}
   385  		for {
   386  			select {
   387  			case <-ctx.Done():
   388  				xtw.Shutdown()
   389  				return
   390  			case <-xtw.stopC:
   391  				return
   392  			case event, ok := <-xtw.twEventC:
   393  				if !ok {
   394  					slog.Warn("[timing wheel] event channel has been closed")
   395  					continue
   396  				}
   397  				switch op := event.GetOperation(); op {
   398  				case addTask, reAddTask:
   399  					task, ok := event.GetTask()
   400  					if !ok {
   401  						goto recycle
   402  					}
   403  					err := xtw.addTask(task)
   404  					if errors.Is(err, ErrTimingWheelTaskIsExpired) {
   405  						xtw.handleTask(task)
   406  					}
   407  					if op == addTask {
   408  						xtw.taskCounter.Add(1)
   409  					}
   410  				case cancelTask:
   411  					jobID, ok := event.GetCancelTaskJobID()
   412  					if !ok || cancelDisabled.(bool) {
   413  						goto recycle
   414  					}
   415  					if err := xtw.cancelTask(jobID); err == nil {
   416  						xtw.taskCounter.Add(-1)
   417  					}
   418  				case unknown:
   419  					fallthrough
   420  				default:
   421  
   422  				}
   423  			recycle:
   424  				xtw.twEventPool.Put(event)
   425  			}
   426  		}
   427  	}()
   428  	go func(disabled any) {
   429  		if disabled != nil && disabled.(bool) {
   430  			return
   431  		}
   432  		defer func() {
   433  			if err := recover(); err != nil {
   434  				slog.Error("[timing wheel] expired slot schedule panic recover", "error", err, "stack", debug.Stack())
   435  			}
   436  		}()
   437  		for {
   438  			select {
   439  			case <-ctx.Done():
   440  				xtw.Shutdown()
   441  				return
   442  			case <-xtw.stopC:
   443  				return
   444  			case slot, ok := <-xtw.expiredSlotC:
   445  				if !ok {
   446  					continue
   447  				}
   448  				xtw.advanceClock(slot.GetExpirationMs())
   449  				// Here related to slot level upgrade and downgrade.
   450  				slot.Flush(xtw.handleTask)
   451  			}
   452  		}
   453  	}(ctx.Value(disableTimingWheelsScheduleExpiredSlot))
   454  	go func(disabled any) {
   455  		if disabled != nil && disabled.(bool) {
   456  			return
   457  		}
   458  		defer func() {
   459  			if err := recover(); err != nil {
   460  				slog.Error("[timing wheel] poll schedule panic recover", "error", err, "stack", debug.Stack())
   461  			}
   462  		}()
   463  		err := xtw.dq.PollToChannel(xtw.ctx, func() int64 {
   464  			return time.Now().UTC().UnixMilli()
   465  		}, xtw.expiredSlotC)
   466  		if err != nil {
   467  			slog.Error("[timing wheel] delay queue poll error", "error", err)
   468  		}
   469  		slog.Warn("[timing wheel] delay queue exit")
   470  	}(ctx.Value(disableTimingWheelsSchedulePoll))
   471  	xtw.isRunning.Store(true)
   472  }
   473  
   474  // Update all wheels' current time, in order to simulate the time is continuously incremented.
   475  // Here related to slot level upgrade and downgrade.
   476  func (xtw *xTimingWheels) advanceClock(timeoutMs int64) {
   477  	xtw.tw.(*timingWheel).advanceClock(timeoutMs)
   478  }
   479  
   480  func (xtw *xTimingWheels) addTask(task Task) error {
   481  	if task == nil || task.Cancelled() || !xtw.isRunning.Load() {
   482  		return ErrTimingWheelStopped
   483  	}
   484  	// FIXME Recursive function to addTask a task, need to measure the performance.
   485  	err := xtw.tw.(*timingWheel).addTask(task, 0)
   486  	if err == nil || errors.Is(err, ErrTimingWheelTaskIsExpired) {
   487  		// FIXME map data race
   488  		xtw.tasksMap[task.GetJobID()] = task
   489  	}
   490  	return err
   491  }
   492  
   493  // handleTask all tasks which are called by this method
   494  // will mean that the task must be in a slot ever and related slot
   495  // has been expired.
   496  func (xtw *xTimingWheels) handleTask(t Task) {
   497  	if t == nil || !xtw.isRunning.Load() {
   498  		slog.Info("[timing wheel] task is nil or timing wheel is stopped")
   499  		return
   500  	}
   501  
   502  	// FIXME goroutine pool to run this.
   503  	// [slotExpMs, slotExpMs+interval)
   504  	var (
   505  		prevSlotMetadata = t.GetPreviousSlotMetadata()
   506  		slot             = t.GetSlot()
   507  		taskLevel        int64
   508  		runNow           bool
   509  	)
   510  	if prevSlotMetadata == nil {
   511  		// Unknown task
   512  		return
   513  	} else {
   514  		taskLevel = prevSlotMetadata.GetLevel()
   515  		runNow = prevSlotMetadata.GetExpirationMs() == sentinelSlotExpiredMs
   516  		runNow = runNow || taskLevel == 0 && t.GetExpiredMs() <= prevSlotMetadata.GetExpirationMs()+xtw.GetTickMs()
   517  		runNow = runNow || t.GetExpiredMs() <= time.Now().UTC().UnixMilli()
   518  	}
   519  	if runNow && !t.Cancelled() {
   520  		go t.GetJob()(xtw.ctx, t.GetJobMetadata())
   521  	} else if t.Cancelled() {
   522  		if slot != nil {
   523  			slot.RemoveTask(t)
   524  		}
   525  		t.setSlot(nil)
   526  		t.setSlotMetadata(nil)
   527  		return
   528  	}
   529  
   530  	// Re-addTask loop job to timing wheel.
   531  	// Upgrade and downgrade (move) the t from one slot to another slot.
   532  	// Lock free.
   533  	switch t.GetJobType() {
   534  	case OnceJob:
   535  		event := xtw.twEventPool.Get()
   536  		if runNow {
   537  			event.CancelTaskJobID(t.GetJobID())
   538  			xtw.twEventC <- event
   539  		} else {
   540  			event.ReAddTask(t)
   541  			xtw.twEventC <- event
   542  		}
   543  	case RepeatedJob:
   544  		var sTask Task
   545  		if !runNow {
   546  			sTask = t
   547  		} else {
   548  			if t.GetRestLoopCount() == 0 {
   549  				event := xtw.twEventPool.Get()
   550  				event.CancelTaskJobID(t.GetJobID())
   551  				xtw.twEventC <- event
   552  				return
   553  			}
   554  			_sTask, ok := t.(ScheduledTask)
   555  			if !ok {
   556  				return
   557  			}
   558  			_sTask.UpdateNextScheduledMs()
   559  			sTask = _sTask
   560  			if sTask.GetExpiredMs() < 0 {
   561  				return
   562  			}
   563  		}
   564  		if sTask != nil {
   565  			event := xtw.twEventPool.Get()
   566  			event.ReAddTask(sTask)
   567  			xtw.twEventC <- event
   568  		}
   569  	}
   570  	return
   571  }
   572  
   573  func (xtw *xTimingWheels) cancelTask(jobID JobID) error {
   574  	if !xtw.isRunning.Load() {
   575  		return ErrTimingWheelStopped
   576  	}
   577  
   578  	task, ok := xtw.tasksMap[jobID]
   579  	if !ok {
   580  		return ErrTimingWheelTaskNotFound
   581  	}
   582  
   583  	if task.GetSlot() != nil && !task.GetSlot().RemoveTask(task) {
   584  		return ErrTimingWheelTaskUnableToBeRemoved
   585  	}
   586  	task.Cancel()
   587  
   588  	delete(xtw.tasksMap, jobID)
   589  	return nil
   590  }