github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/taskrunner/taskrunner.go (about)

     1  package taskrunner
     2  
     3  import (
     4  	"context"
     5  	"sync"
     6  )
     7  
     8  // TaskRunner is a helper which runs a series of scheduled tasks against a defined
     9  // limit of goroutines.
    10  type TaskRunner struct {
    11  	// ctx holds the context given to the task runner and annotated with the cancel
    12  	// function.
    13  	ctx    context.Context
    14  	cancel func()
    15  
    16  	// sem is a chan of length `concurrencyLimit` used to ensure the task runner does
    17  	// not exceed the concurrencyLimit with spawned goroutines.
    18  	sem chan struct{}
    19  
    20  	// err holds the error returned by any task, if any. If the context is canceled,
    21  	// this err will hold the cancelation error.
    22  	err error
    23  
    24  	wg    sync.WaitGroup
    25  	lock  sync.Mutex
    26  	tasks []TaskFunc
    27  }
    28  
    29  // TaskFunc defines functions representing tasks.
    30  type TaskFunc func(ctx context.Context) error
    31  
    32  // NewTaskRunner creates a new task runner with the given starting context and
    33  // concurrency limit. The TaskRunner will schedule no more goroutines that the
    34  // specified concurrencyLimit. If the given context is canceled, then all tasks
    35  // started after that point will also be canceled and the error returned. If
    36  // a task returns an error, the context provided to all tasks is also canceled.
    37  func NewTaskRunner(ctx context.Context, concurrencyLimit uint16) *TaskRunner {
    38  	if concurrencyLimit < 1 {
    39  		concurrencyLimit = 1
    40  	}
    41  
    42  	ctxWithCancel, cancel := context.WithCancel(ctx)
    43  	return &TaskRunner{
    44  		ctx:    ctxWithCancel,
    45  		cancel: cancel,
    46  		sem:    make(chan struct{}, concurrencyLimit),
    47  		tasks:  make([]TaskFunc, 0),
    48  	}
    49  }
    50  
    51  // Schedule schedules a task to be run. This is safe to call from within another
    52  // task handler function and immediately returns.
    53  func (tr *TaskRunner) Schedule(f TaskFunc) {
    54  	if tr.addTask(f) {
    55  		tr.spawnIfAvailable()
    56  	}
    57  }
    58  
    59  func (tr *TaskRunner) spawnIfAvailable() {
    60  	// To spawn a runner, write a struct{} to the sem channel. If the task runner
    61  	// is already at the concurrency limit, then this chan write will fail,
    62  	// and nothing will be spawned. This also checks if the context has already
    63  	// been canceled, in which case nothing needs to be done.
    64  	select {
    65  	case tr.sem <- struct{}{}:
    66  		go tr.runner()
    67  
    68  	case <-tr.ctx.Done():
    69  		return
    70  
    71  	default:
    72  		return
    73  	}
    74  }
    75  
    76  func (tr *TaskRunner) runner() {
    77  	for {
    78  		select {
    79  		case <-tr.ctx.Done():
    80  			// If the context was canceled, mark all the remaining tasks as "Done".
    81  			tr.emptyForCancel()
    82  			return
    83  
    84  		default:
    85  			// Select a task from the list, if any.
    86  			task := tr.selectTask()
    87  			if task == nil {
    88  				// If there are no further tasks, then "return" the struct{} by reading
    89  				// it from the channel (freeing a slot potentially for another worker
    90  				// to be spawned later).
    91  				<-tr.sem
    92  				return
    93  			}
    94  
    95  			// Run the task. If an error occurs, store it and cancel any further tasks.
    96  			err := task(tr.ctx)
    97  			if err != nil {
    98  				tr.storeErrorAndCancel(err)
    99  			}
   100  			tr.wg.Done()
   101  		}
   102  	}
   103  }
   104  
   105  func (tr *TaskRunner) addTask(f TaskFunc) bool {
   106  	tr.lock.Lock()
   107  	defer tr.lock.Unlock()
   108  
   109  	if tr.err != nil {
   110  		return false
   111  	}
   112  
   113  	tr.wg.Add(1)
   114  	tr.tasks = append(tr.tasks, f)
   115  	return true
   116  }
   117  
   118  func (tr *TaskRunner) selectTask() TaskFunc {
   119  	tr.lock.Lock()
   120  	defer tr.lock.Unlock()
   121  
   122  	if len(tr.tasks) == 0 {
   123  		return nil
   124  	}
   125  
   126  	task := tr.tasks[0]
   127  	tr.tasks = tr.tasks[1:]
   128  	return task
   129  }
   130  
   131  func (tr *TaskRunner) storeErrorAndCancel(err error) {
   132  	tr.lock.Lock()
   133  	defer tr.lock.Unlock()
   134  
   135  	if tr.err == nil {
   136  		tr.err = err
   137  		tr.cancel()
   138  	}
   139  }
   140  
   141  func (tr *TaskRunner) emptyForCancel() {
   142  	tr.lock.Lock()
   143  	defer tr.lock.Unlock()
   144  
   145  	if tr.err == nil {
   146  		tr.err = tr.ctx.Err()
   147  	}
   148  
   149  	for {
   150  		if len(tr.tasks) == 0 {
   151  			break
   152  		}
   153  
   154  		tr.tasks = tr.tasks[1:]
   155  		tr.wg.Done()
   156  	}
   157  }
   158  
   159  // Wait waits for all tasks to be completed, or a task to raise an error,
   160  // or the parent context to have been canceled.
   161  func (tr *TaskRunner) Wait() error {
   162  	tr.wg.Wait()
   163  
   164  	tr.lock.Lock()
   165  	defer tr.lock.Unlock()
   166  	return tr.err
   167  }