github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/taskrunner/taskrunner.go (about) 1 package taskrunner 2 3 import ( 4 "context" 5 "sync" 6 ) 7 8 // TaskRunner is a helper which runs a series of scheduled tasks against a defined 9 // limit of goroutines. 10 type TaskRunner struct { 11 // ctx holds the context given to the task runner and annotated with the cancel 12 // function. 13 ctx context.Context 14 cancel func() 15 16 // sem is a chan of length `concurrencyLimit` used to ensure the task runner does 17 // not exceed the concurrencyLimit with spawned goroutines. 18 sem chan struct{} 19 20 // err holds the error returned by any task, if any. If the context is canceled, 21 // this err will hold the cancelation error. 22 err error 23 24 wg sync.WaitGroup 25 lock sync.Mutex 26 tasks []TaskFunc 27 } 28 29 // TaskFunc defines functions representing tasks. 30 type TaskFunc func(ctx context.Context) error 31 32 // NewTaskRunner creates a new task runner with the given starting context and 33 // concurrency limit. The TaskRunner will schedule no more goroutines that the 34 // specified concurrencyLimit. If the given context is canceled, then all tasks 35 // started after that point will also be canceled and the error returned. If 36 // a task returns an error, the context provided to all tasks is also canceled. 37 func NewTaskRunner(ctx context.Context, concurrencyLimit uint16) *TaskRunner { 38 if concurrencyLimit < 1 { 39 concurrencyLimit = 1 40 } 41 42 ctxWithCancel, cancel := context.WithCancel(ctx) 43 return &TaskRunner{ 44 ctx: ctxWithCancel, 45 cancel: cancel, 46 sem: make(chan struct{}, concurrencyLimit), 47 tasks: make([]TaskFunc, 0), 48 } 49 } 50 51 // Schedule schedules a task to be run. This is safe to call from within another 52 // task handler function and immediately returns. 53 func (tr *TaskRunner) Schedule(f TaskFunc) { 54 if tr.addTask(f) { 55 tr.spawnIfAvailable() 56 } 57 } 58 59 func (tr *TaskRunner) spawnIfAvailable() { 60 // To spawn a runner, write a struct{} to the sem channel. If the task runner 61 // is already at the concurrency limit, then this chan write will fail, 62 // and nothing will be spawned. This also checks if the context has already 63 // been canceled, in which case nothing needs to be done. 64 select { 65 case tr.sem <- struct{}{}: 66 go tr.runner() 67 68 case <-tr.ctx.Done(): 69 return 70 71 default: 72 return 73 } 74 } 75 76 func (tr *TaskRunner) runner() { 77 for { 78 select { 79 case <-tr.ctx.Done(): 80 // If the context was canceled, mark all the remaining tasks as "Done". 81 tr.emptyForCancel() 82 return 83 84 default: 85 // Select a task from the list, if any. 86 task := tr.selectTask() 87 if task == nil { 88 // If there are no further tasks, then "return" the struct{} by reading 89 // it from the channel (freeing a slot potentially for another worker 90 // to be spawned later). 91 <-tr.sem 92 return 93 } 94 95 // Run the task. If an error occurs, store it and cancel any further tasks. 96 err := task(tr.ctx) 97 if err != nil { 98 tr.storeErrorAndCancel(err) 99 } 100 tr.wg.Done() 101 } 102 } 103 } 104 105 func (tr *TaskRunner) addTask(f TaskFunc) bool { 106 tr.lock.Lock() 107 defer tr.lock.Unlock() 108 109 if tr.err != nil { 110 return false 111 } 112 113 tr.wg.Add(1) 114 tr.tasks = append(tr.tasks, f) 115 return true 116 } 117 118 func (tr *TaskRunner) selectTask() TaskFunc { 119 tr.lock.Lock() 120 defer tr.lock.Unlock() 121 122 if len(tr.tasks) == 0 { 123 return nil 124 } 125 126 task := tr.tasks[0] 127 tr.tasks = tr.tasks[1:] 128 return task 129 } 130 131 func (tr *TaskRunner) storeErrorAndCancel(err error) { 132 tr.lock.Lock() 133 defer tr.lock.Unlock() 134 135 if tr.err == nil { 136 tr.err = err 137 tr.cancel() 138 } 139 } 140 141 func (tr *TaskRunner) emptyForCancel() { 142 tr.lock.Lock() 143 defer tr.lock.Unlock() 144 145 if tr.err == nil { 146 tr.err = tr.ctx.Err() 147 } 148 149 for { 150 if len(tr.tasks) == 0 { 151 break 152 } 153 154 tr.tasks = tr.tasks[1:] 155 tr.wg.Done() 156 } 157 } 158 159 // Wait waits for all tasks to be completed, or a task to raise an error, 160 // or the parent context to have been canceled. 161 func (tr *TaskRunner) Wait() error { 162 tr.wg.Wait() 163 164 tr.lock.Lock() 165 defer tr.lock.Unlock() 166 return tr.err 167 }