github.com/ngicks/gokugen@v0.0.5/scheduler/worker_pool.go (about) 1 package scheduler 2 3 import ( 4 "errors" 5 "fmt" 6 "log" 7 "runtime/debug" 8 "sync" 9 "sync/atomic" 10 ) 11 12 // WorkerConstructor is aliased type of constructor. 13 // id must be, as its name says, unique value. 14 // onTaskReceived, onTaskDone can be nil. 15 type WorkerConstructor = func(id int, onTaskReceived func(), onTaskDone func()) *Worker[int] 16 17 // BuildWorkerConstructor is helper function for WorkerConstructor. 18 // taskCh must not be nil. onTaskReceived_, onTaskDone_ can be nil. 19 func BuildWorkerConstructor(taskCh <-chan *Task, onTaskReceived_ func(), onTaskDone_ func()) WorkerConstructor { 20 return func(id int, onTaskReceived__ func(), onTaskDone__ func()) *Worker[int] { 21 onTaskReceived := func() { 22 if onTaskReceived_ != nil { 23 onTaskReceived_() 24 } 25 if onTaskReceived__ != nil { 26 onTaskReceived__() 27 } 28 } 29 onTaskDone := func() { 30 if onTaskDone_ != nil { 31 onTaskDone_() 32 } 33 if onTaskDone__ != nil { 34 onTaskDone__() 35 } 36 } 37 w, err := NewWorker(id, taskCh, onTaskReceived, onTaskDone) 38 if err != nil { 39 panic(err) 40 } 41 return w 42 } 43 } 44 45 // WorkerPool is container for workers. 46 type WorkerPool struct { 47 mu sync.RWMutex 48 status workingState 49 wg sync.WaitGroup 50 51 activeWorkerNum int64 52 53 workerConstructor WorkerConstructor 54 workerIdx int 55 workers map[int]*Worker[int] 56 sleepingWorkers map[int]*Worker[int] 57 } 58 59 func NewWorkerPool( 60 workerConstructor WorkerConstructor, 61 ) *WorkerPool { 62 w := WorkerPool{ 63 workerConstructor: workerConstructor, 64 workers: make(map[int]*Worker[int], 0), 65 sleepingWorkers: make(map[int]*Worker[int], 0), 66 } 67 return &w 68 } 69 70 func (p *WorkerPool) Add(delta uint32) (newAliveLen int) { 71 p.mu.Lock() 72 for i := uint32(0); i < delta; i++ { 73 workerId := p.workerIdx 74 p.workerIdx++ 75 worker := p.workerConstructor( 76 workerId, 77 func() { atomic.AddInt64(&p.activeWorkerNum, 1) }, 78 func() { atomic.AddInt64(&p.activeWorkerNum, -1) }, 79 ) 80 // callWorkerStart calls wg.Done(). 81 p.wg.Add(1) 82 go p.callWorkerStart(worker, true, func(err error) { log.Println(err) }) 83 84 p.workers[worker.Id()] = worker 85 } 86 p.mu.Unlock() 87 alive, _ := p.Len() 88 return alive 89 } 90 91 var ( 92 errGoexit = errors.New("runtime.Goexit was called") 93 ) 94 95 type panicErr struct { 96 err interface{} 97 stack []byte 98 } 99 100 // Error implements error interface. 101 func (p *panicErr) Error() string { 102 return fmt.Sprintf("%v\n\n%s", p.err, p.stack) 103 } 104 105 func (p *WorkerPool) callWorkerStart(worker *Worker[int], shouldRecover bool, abnormalReturnCb func(error)) (workerErr error) { 106 var normalReturn, recovered bool 107 var abnormalReturnErr error 108 // see https://cs.opensource.google/go/x/sync/+/0de741cf:singleflight/singleflight.go;l=138-200;drc=0de741cfad7ff3874b219dfbc1b9195b58c7c490 109 defer func() { 110 // Done will be done right before the exit. 111 defer p.wg.Done() 112 p.mu.Lock() 113 delete(p.workers, worker.Id()) 114 delete(p.sleepingWorkers, worker.Id()) 115 p.mu.Unlock() 116 117 if !normalReturn && !recovered { 118 abnormalReturnErr = errGoexit 119 } 120 if !normalReturn { 121 abnormalReturnCb(abnormalReturnErr) 122 } 123 if recovered && !shouldRecover { 124 panic(abnormalReturnErr) 125 } 126 }() 127 128 func() { 129 defer func() { 130 if err := recover(); err != nil { 131 abnormalReturnErr = &panicErr{ 132 err: err, 133 stack: debug.Stack(), 134 } 135 } 136 }() 137 workerErr = worker.Start() 138 normalReturn = true 139 }() 140 if !normalReturn { 141 recovered = true 142 } 143 return 144 } 145 146 func (p *WorkerPool) Remove(delta uint32) (alive int, sleeping int) { 147 p.mu.Lock() 148 var count uint32 149 for _, worker := range p.workers { 150 if count < delta { 151 worker.Stop() 152 delete(p.workers, worker.Id()) 153 p.sleepingWorkers[worker.Id()] = worker 154 count++ 155 } else { 156 break 157 } 158 } 159 p.mu.Unlock() 160 return p.Len() 161 } 162 163 func (p *WorkerPool) Len() (alive int, sleeping int) { 164 p.mu.Lock() 165 defer p.mu.Unlock() 166 return len(p.workers), len(p.sleepingWorkers) 167 } 168 169 func (p *WorkerPool) ActiveWorkerNum() int64 { 170 return atomic.LoadInt64(&p.activeWorkerNum) 171 } 172 173 // Kill kills all worker. 174 func (p *WorkerPool) Kill() { 175 p.mu.Lock() 176 defer p.mu.Unlock() 177 for _, w := range p.workers { 178 w.Kill() 179 } 180 } 181 182 // Wait waits for all workers to stop. 183 // Calling this without sleeping or removing all worker may block forever. 184 func (p *WorkerPool) Wait() { 185 p.wg.Wait() 186 }