github.com/status-im/status-go@v1.1.0/services/wallet/async/scheduler.go (about)

     1  package async
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"sync"
     8  
     9  	orderedmap "github.com/wk8/go-ordered-map/v2"
    10  )
    11  
    12  var ErrTaskOverwritten = errors.New("task overwritten")
    13  
    14  // Scheduler ensures that only one task of a type is running at a time.
    15  type Scheduler struct {
    16  	queue      *orderedmap.OrderedMap[TaskType, *taskContext]
    17  	queueMutex sync.Mutex
    18  
    19  	context                context.Context
    20  	cancelFn               context.CancelFunc
    21  	doNotDeleteCurrentTask bool
    22  }
    23  
    24  type ReplacementPolicy = int
    25  
    26  const (
    27  	// ReplacementPolicyCancelOld for when the task arguments might change the result
    28  	ReplacementPolicyCancelOld ReplacementPolicy = iota
    29  	// ReplacementPolicyIgnoreNew for when the task arguments doesn't change the result
    30  	ReplacementPolicyIgnoreNew
    31  )
    32  
    33  type TaskType struct {
    34  	ID     int64
    35  	Policy ReplacementPolicy
    36  }
    37  
    38  type taskFunction func(context.Context) (interface{}, error)
    39  type resultFunction func(interface{}, TaskType, error)
    40  
    41  type taskContext struct {
    42  	taskType TaskType
    43  	policy   ReplacementPolicy
    44  
    45  	taskFn taskFunction
    46  	resFn  resultFunction
    47  }
    48  
    49  func NewScheduler() *Scheduler {
    50  	return &Scheduler{
    51  		queue: orderedmap.New[TaskType, *taskContext](),
    52  	}
    53  }
    54  
    55  // Enqueue provides a queue of task types allowing only one task at a time of the corresponding type. The running task is the first one in the queue (s.queue.Oldest())
    56  //
    57  // Schedule policy for new tasks
    58  //   - pushed at the back of the queue (s.queue.PushBack()) if none of the same time already scheduled
    59  //   - overwrite the queued one of the same type, depending on the policy
    60  //   - In case of ReplacementPolicyIgnoreNew, the new task will be ignored
    61  //   - In case of ReplacementPolicyCancelOld, the old running task will be canceled or if not yet run overwritten and the new one will be executed when its turn comes.
    62  //
    63  // The task function (taskFn) might not be executed if
    64  //   - the task is ignored
    65  //   - the task is overwritten. The result function (resFn) will be called with ErrTaskOverwritten
    66  //
    67  // The result function (resFn) will always be called if the task is not ignored
    68  func (s *Scheduler) Enqueue(taskType TaskType, taskFn taskFunction, resFn resultFunction) (ignored bool) {
    69  	s.queueMutex.Lock()
    70  	defer s.queueMutex.Unlock()
    71  
    72  	taskRunning := s.queue.Len() > 0
    73  	existingTask, typeInQueue := s.queue.Get(taskType)
    74  
    75  	// we need wrap the original resFn to ensure it is called only once
    76  	// otherwise, there's a chance that it will be called twice if we
    77  	// call Stop() quickly after Enqueue while task is running
    78  	var invokeResFnOnce sync.Once
    79  	onceResFn := func(res interface{}, taskType TaskType, err error) {
    80  		invokeResFnOnce.Do(func() {
    81  			resFn(res, taskType, err)
    82  		})
    83  	}
    84  	newTask := &taskContext{
    85  		taskType: taskType,
    86  		policy:   taskType.Policy,
    87  		taskFn:   taskFn,
    88  		resFn:    onceResFn,
    89  	}
    90  
    91  	if taskRunning {
    92  		if typeInQueue {
    93  			if s.queue.Oldest().Value.taskType == taskType {
    94  				// If same task type is running
    95  				if existingTask.policy == ReplacementPolicyCancelOld {
    96  					// If a previous task is running, cancel it
    97  					if s.cancelFn != nil {
    98  						s.cancelFn()
    99  						s.cancelFn = nil
   100  					} else {
   101  						// In case of multiple tasks of the same type, the previous one is overwritten
   102  						go func() {
   103  							existingTask.resFn(nil, existingTask.taskType, ErrTaskOverwritten)
   104  						}()
   105  					}
   106  
   107  					s.doNotDeleteCurrentTask = true
   108  
   109  					// Add it again to refresh the order of the task
   110  					s.queue.Delete(taskType)
   111  					s.queue.Set(taskType, newTask)
   112  				} else {
   113  					ignored = true
   114  				}
   115  			} else {
   116  				// if other task type is running
   117  				// notify the queued one that it is overwritten or ignored
   118  				if existingTask.policy == ReplacementPolicyCancelOld {
   119  					oldResFn := existingTask.resFn
   120  					go func() {
   121  						oldResFn(nil, existingTask.taskType, ErrTaskOverwritten)
   122  					}()
   123  					// Overwrite the queued one of the same type
   124  					existingTask.taskFn = taskFn
   125  					existingTask.resFn = onceResFn
   126  				} else {
   127  					ignored = true
   128  				}
   129  			}
   130  		} else {
   131  			// Policy does not matter for the fist enqueued task of a type
   132  			s.queue.Set(taskType, newTask)
   133  		}
   134  	} else {
   135  		// If no task is running add and run it. The worker will take care of scheduling new tasks added while running
   136  		s.queue.Set(taskType, newTask)
   137  		existingTask = newTask
   138  		s.runTask(existingTask, taskFn, func(res interface{}, runningTask *taskContext, err error) {
   139  			s.finishedTask(res, runningTask, onceResFn, err)
   140  		})
   141  	}
   142  
   143  	return ignored
   144  }
   145  
   146  func (s *Scheduler) runTask(tc *taskContext, taskFn taskFunction, resFn func(interface{}, *taskContext, error)) {
   147  	thisContext, thisCancelFn := context.WithCancel(context.Background())
   148  	s.cancelFn = thisCancelFn
   149  	s.context = thisContext
   150  
   151  	go func() {
   152  		res, err := taskFn(thisContext)
   153  
   154  		// Release context resources
   155  		thisCancelFn()
   156  
   157  		if errors.Is(err, context.Canceled) {
   158  			resFn(res, tc, fmt.Errorf("task canceled: %w", err))
   159  		} else {
   160  			resFn(res, tc, err)
   161  		}
   162  	}()
   163  }
   164  
   165  // finishedTask is the only one that can remove a task from the queue
   166  // if the current running task completed (doNotDeleteCurrentTask is true)
   167  func (s *Scheduler) finishedTask(finishedRes interface{}, doneTask *taskContext, finishedResFn resultFunction, finishedErr error) {
   168  	s.queueMutex.Lock()
   169  
   170  	current := s.queue.Oldest()
   171  	// Delete current task if not overwritten
   172  	if s.doNotDeleteCurrentTask {
   173  		s.doNotDeleteCurrentTask = false
   174  	} else {
   175  		// current maybe nil if Stop() is called
   176  		if current != nil {
   177  			s.queue.Delete(current.Value.taskType)
   178  		}
   179  	}
   180  
   181  	// Run next task
   182  	if pair := s.queue.Oldest(); pair != nil {
   183  		nextTask := pair.Value
   184  		s.runTask(nextTask, nextTask.taskFn, func(res interface{}, runningTask *taskContext, err error) {
   185  			s.finishedTask(res, runningTask, runningTask.resFn, err)
   186  		})
   187  	} else {
   188  		s.cancelFn = nil
   189  	}
   190  	s.queueMutex.Unlock()
   191  
   192  	// Report result
   193  	finishedResFn(finishedRes, doneTask.taskType, finishedErr)
   194  }
   195  
   196  func (s *Scheduler) Stop() {
   197  	s.queueMutex.Lock()
   198  	defer s.queueMutex.Unlock()
   199  
   200  	if s.cancelFn != nil {
   201  		s.cancelFn()
   202  		s.cancelFn = nil
   203  	}
   204  
   205  	// Empty the queue so the running task will not be restarted
   206  	for pair := s.queue.Oldest(); pair != nil; pair = pair.Next() {
   207  		// Notify the queued one that they are canceled
   208  		if pair.Value.policy == ReplacementPolicyCancelOld {
   209  			go func(val *taskContext) {
   210  				val.resFn(nil, val.taskType, context.Canceled)
   211  			}(pair.Value)
   212  		}
   213  		s.queue.Delete(pair.Value.taskType)
   214  	}
   215  }