github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/taskrunner/preloadedtaskrunner.go (about) 1 package taskrunner 2 3 import ( 4 "context" 5 "sync" 6 ) 7 8 // PreloadedTaskRunner is a task runner that invokes a series of preloaded tasks, 9 // running until the tasks are completed, the context is canceled or an error is 10 // returned by one of the tasks (which cancels the context). 11 type PreloadedTaskRunner struct { 12 // ctx holds the context given to the task runner and annotated with the cancel 13 // function. 14 ctx context.Context 15 cancel func() 16 17 // sem is a chan of length `concurrencyLimit` used to ensure the task runner does 18 // not exceed the concurrencyLimit with spawned goroutines. 19 sem chan struct{} 20 21 wg sync.WaitGroup 22 err error 23 lock sync.Mutex 24 tasks []TaskFunc 25 } 26 27 func NewPreloadedTaskRunner(ctx context.Context, concurrencyLimit uint16, initialCapacity int) *PreloadedTaskRunner { 28 // Ensure a concurrency level of at least 1. 29 if concurrencyLimit <= 0 { 30 concurrencyLimit = 1 31 } 32 33 ctxWithCancel, cancel := context.WithCancel(ctx) 34 return &PreloadedTaskRunner{ 35 ctx: ctxWithCancel, 36 cancel: cancel, 37 sem: make(chan struct{}, concurrencyLimit), 38 tasks: make([]TaskFunc, 0, initialCapacity), 39 } 40 } 41 42 // Add adds the given task function to be run. 43 func (tr *PreloadedTaskRunner) Add(f TaskFunc) { 44 tr.tasks = append(tr.tasks, f) 45 tr.wg.Add(1) 46 } 47 48 // Start starts running the tasks in the task runner. This does *not* wait for the tasks 49 // to complete, but rather returns immediately. 50 func (tr *PreloadedTaskRunner) Start() { 51 for range tr.tasks { 52 tr.spawnIfAvailable() 53 } 54 } 55 56 // StartAndWait starts running the tasks in the task runner and waits for them to complete. 57 func (tr *PreloadedTaskRunner) StartAndWait() error { 58 tr.Start() 59 tr.wg.Wait() 60 61 tr.lock.Lock() 62 defer tr.lock.Unlock() 63 64 return tr.err 65 } 66 67 func (tr *PreloadedTaskRunner) spawnIfAvailable() { 68 // To spawn a runner, write a struct{} to the sem channel. If the task runner 69 // is already at the concurrency limit, then this chan write will fail, 70 // and nothing will be spawned. This also checks if the context has already 71 // been canceled, in which case nothing needs to be done. 72 select { 73 case tr.sem <- struct{}{}: 74 go tr.runner() 75 76 case <-tr.ctx.Done(): 77 // If the context was canceled, nothing more to do. 78 tr.emptyForCancel() 79 return 80 81 default: 82 return 83 } 84 } 85 86 func (tr *PreloadedTaskRunner) runner() { 87 for { 88 select { 89 case <-tr.ctx.Done(): 90 // If the context was canceled, nothing more to do. 91 tr.emptyForCancel() 92 return 93 94 default: 95 // Select a task from the list, if any. 96 task := tr.selectTask() 97 if task == nil { 98 return 99 } 100 101 // Run the task. If an error occurs, store it and cancel any further tasks. 102 err := task(tr.ctx) 103 if err != nil { 104 tr.storeErrorAndCancel(err) 105 } 106 tr.wg.Done() 107 } 108 } 109 } 110 111 func (tr *PreloadedTaskRunner) selectTask() TaskFunc { 112 tr.lock.Lock() 113 defer tr.lock.Unlock() 114 115 if len(tr.tasks) == 0 { 116 return nil 117 } 118 119 task := tr.tasks[0] 120 tr.tasks[0] = nil // to free the reference once the task completes. 121 tr.tasks = tr.tasks[1:] 122 return task 123 } 124 125 func (tr *PreloadedTaskRunner) storeErrorAndCancel(err error) { 126 tr.lock.Lock() 127 defer tr.lock.Unlock() 128 129 if tr.err == nil { 130 tr.err = err 131 tr.cancel() 132 } 133 } 134 135 func (tr *PreloadedTaskRunner) emptyForCancel() { 136 tr.lock.Lock() 137 defer tr.lock.Unlock() 138 139 if tr.err == nil { 140 tr.err = tr.ctx.Err() 141 } 142 143 for { 144 if len(tr.tasks) == 0 { 145 break 146 } 147 148 tr.tasks[0] = nil // to free the reference 149 tr.tasks = tr.tasks[1:] 150 tr.wg.Done() 151 } 152 }