github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/taskrunner/preloadedtaskrunner.go (about)

     1  package taskrunner
     2  
     3  import (
     4  	"context"
     5  	"sync"
     6  )
     7  
     8  // PreloadedTaskRunner is a task runner that invokes a series of preloaded tasks,
     9  // running until the tasks are completed, the context is canceled or an error is
    10  // returned by one of the tasks (which cancels the context).
    11  type PreloadedTaskRunner struct {
    12  	// ctx holds the context given to the task runner and annotated with the cancel
    13  	// function.
    14  	ctx    context.Context
    15  	cancel func()
    16  
    17  	// sem is a chan of length `concurrencyLimit` used to ensure the task runner does
    18  	// not exceed the concurrencyLimit with spawned goroutines.
    19  	sem chan struct{}
    20  
    21  	wg    sync.WaitGroup
    22  	err   error
    23  	lock  sync.Mutex
    24  	tasks []TaskFunc
    25  }
    26  
    27  func NewPreloadedTaskRunner(ctx context.Context, concurrencyLimit uint16, initialCapacity int) *PreloadedTaskRunner {
    28  	// Ensure a concurrency level of at least 1.
    29  	if concurrencyLimit <= 0 {
    30  		concurrencyLimit = 1
    31  	}
    32  
    33  	ctxWithCancel, cancel := context.WithCancel(ctx)
    34  	return &PreloadedTaskRunner{
    35  		ctx:    ctxWithCancel,
    36  		cancel: cancel,
    37  		sem:    make(chan struct{}, concurrencyLimit),
    38  		tasks:  make([]TaskFunc, 0, initialCapacity),
    39  	}
    40  }
    41  
    42  // Add adds the given task function to be run.
    43  func (tr *PreloadedTaskRunner) Add(f TaskFunc) {
    44  	tr.tasks = append(tr.tasks, f)
    45  	tr.wg.Add(1)
    46  }
    47  
    48  // Start starts running the tasks in the task runner. This does *not* wait for the tasks
    49  // to complete, but rather returns immediately.
    50  func (tr *PreloadedTaskRunner) Start() {
    51  	for range tr.tasks {
    52  		tr.spawnIfAvailable()
    53  	}
    54  }
    55  
    56  // StartAndWait starts running the tasks in the task runner and waits for them to complete.
    57  func (tr *PreloadedTaskRunner) StartAndWait() error {
    58  	tr.Start()
    59  	tr.wg.Wait()
    60  
    61  	tr.lock.Lock()
    62  	defer tr.lock.Unlock()
    63  
    64  	return tr.err
    65  }
    66  
    67  func (tr *PreloadedTaskRunner) spawnIfAvailable() {
    68  	// To spawn a runner, write a struct{} to the sem channel. If the task runner
    69  	// is already at the concurrency limit, then this chan write will fail,
    70  	// and nothing will be spawned. This also checks if the context has already
    71  	// been canceled, in which case nothing needs to be done.
    72  	select {
    73  	case tr.sem <- struct{}{}:
    74  		go tr.runner()
    75  
    76  	case <-tr.ctx.Done():
    77  		// If the context was canceled, nothing more to do.
    78  		tr.emptyForCancel()
    79  		return
    80  
    81  	default:
    82  		return
    83  	}
    84  }
    85  
    86  func (tr *PreloadedTaskRunner) runner() {
    87  	for {
    88  		select {
    89  		case <-tr.ctx.Done():
    90  			// If the context was canceled, nothing more to do.
    91  			tr.emptyForCancel()
    92  			return
    93  
    94  		default:
    95  			// Select a task from the list, if any.
    96  			task := tr.selectTask()
    97  			if task == nil {
    98  				return
    99  			}
   100  
   101  			// Run the task. If an error occurs, store it and cancel any further tasks.
   102  			err := task(tr.ctx)
   103  			if err != nil {
   104  				tr.storeErrorAndCancel(err)
   105  			}
   106  			tr.wg.Done()
   107  		}
   108  	}
   109  }
   110  
   111  func (tr *PreloadedTaskRunner) selectTask() TaskFunc {
   112  	tr.lock.Lock()
   113  	defer tr.lock.Unlock()
   114  
   115  	if len(tr.tasks) == 0 {
   116  		return nil
   117  	}
   118  
   119  	task := tr.tasks[0]
   120  	tr.tasks[0] = nil // to free the reference once the task completes.
   121  	tr.tasks = tr.tasks[1:]
   122  	return task
   123  }
   124  
   125  func (tr *PreloadedTaskRunner) storeErrorAndCancel(err error) {
   126  	tr.lock.Lock()
   127  	defer tr.lock.Unlock()
   128  
   129  	if tr.err == nil {
   130  		tr.err = err
   131  		tr.cancel()
   132  	}
   133  }
   134  
   135  func (tr *PreloadedTaskRunner) emptyForCancel() {
   136  	tr.lock.Lock()
   137  	defer tr.lock.Unlock()
   138  
   139  	if tr.err == nil {
   140  		tr.err = tr.ctx.Err()
   141  	}
   142  
   143  	for {
   144  		if len(tr.tasks) == 0 {
   145  			break
   146  		}
   147  
   148  		tr.tasks[0] = nil // to free the reference
   149  		tr.tasks = tr.tasks[1:]
   150  		tr.wg.Done()
   151  	}
   152  }