github.com/ngicks/gokugen@v0.0.5/scheduler/worker_pool.go (about)

     1  package scheduler
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"log"
     7  	"runtime/debug"
     8  	"sync"
     9  	"sync/atomic"
    10  )
    11  
    12  // WorkerConstructor is aliased type of constructor.
    13  // id must be, as its name says, unique value.
    14  // onTaskReceived, onTaskDone can be nil.
    15  type WorkerConstructor = func(id int, onTaskReceived func(), onTaskDone func()) *Worker[int]
    16  
    17  // BuildWorkerConstructor is helper function for WorkerConstructor.
    18  // taskCh must not be nil. onTaskReceived_, onTaskDone_ can be nil.
    19  func BuildWorkerConstructor(taskCh <-chan *Task, onTaskReceived_ func(), onTaskDone_ func()) WorkerConstructor {
    20  	return func(id int, onTaskReceived__ func(), onTaskDone__ func()) *Worker[int] {
    21  		onTaskReceived := func() {
    22  			if onTaskReceived_ != nil {
    23  				onTaskReceived_()
    24  			}
    25  			if onTaskReceived__ != nil {
    26  				onTaskReceived__()
    27  			}
    28  		}
    29  		onTaskDone := func() {
    30  			if onTaskDone_ != nil {
    31  				onTaskDone_()
    32  			}
    33  			if onTaskDone__ != nil {
    34  				onTaskDone__()
    35  			}
    36  		}
    37  		w, err := NewWorker(id, taskCh, onTaskReceived, onTaskDone)
    38  		if err != nil {
    39  			panic(err)
    40  		}
    41  		return w
    42  	}
    43  }
    44  
    45  // WorkerPool is container for workers.
    46  type WorkerPool struct {
    47  	mu     sync.RWMutex
    48  	status workingState
    49  	wg     sync.WaitGroup
    50  
    51  	activeWorkerNum int64
    52  
    53  	workerConstructor WorkerConstructor
    54  	workerIdx         int
    55  	workers           map[int]*Worker[int]
    56  	sleepingWorkers   map[int]*Worker[int]
    57  }
    58  
    59  func NewWorkerPool(
    60  	workerConstructor WorkerConstructor,
    61  ) *WorkerPool {
    62  	w := WorkerPool{
    63  		workerConstructor: workerConstructor,
    64  		workers:           make(map[int]*Worker[int], 0),
    65  		sleepingWorkers:   make(map[int]*Worker[int], 0),
    66  	}
    67  	return &w
    68  }
    69  
    70  func (p *WorkerPool) Add(delta uint32) (newAliveLen int) {
    71  	p.mu.Lock()
    72  	for i := uint32(0); i < delta; i++ {
    73  		workerId := p.workerIdx
    74  		p.workerIdx++
    75  		worker := p.workerConstructor(
    76  			workerId,
    77  			func() { atomic.AddInt64(&p.activeWorkerNum, 1) },
    78  			func() { atomic.AddInt64(&p.activeWorkerNum, -1) },
    79  		)
    80  		// callWorkerStart calls wg.Done().
    81  		p.wg.Add(1)
    82  		go p.callWorkerStart(worker, true, func(err error) { log.Println(err) })
    83  
    84  		p.workers[worker.Id()] = worker
    85  	}
    86  	p.mu.Unlock()
    87  	alive, _ := p.Len()
    88  	return alive
    89  }
    90  
    91  var (
    92  	errGoexit = errors.New("runtime.Goexit was called")
    93  )
    94  
    95  type panicErr struct {
    96  	err   interface{}
    97  	stack []byte
    98  }
    99  
   100  // Error implements error interface.
   101  func (p *panicErr) Error() string {
   102  	return fmt.Sprintf("%v\n\n%s", p.err, p.stack)
   103  }
   104  
   105  func (p *WorkerPool) callWorkerStart(worker *Worker[int], shouldRecover bool, abnormalReturnCb func(error)) (workerErr error) {
   106  	var normalReturn, recovered bool
   107  	var abnormalReturnErr error
   108  	// see https://cs.opensource.google/go/x/sync/+/0de741cf:singleflight/singleflight.go;l=138-200;drc=0de741cfad7ff3874b219dfbc1b9195b58c7c490
   109  	defer func() {
   110  		// Done will be done right before the exit.
   111  		defer p.wg.Done()
   112  		p.mu.Lock()
   113  		delete(p.workers, worker.Id())
   114  		delete(p.sleepingWorkers, worker.Id())
   115  		p.mu.Unlock()
   116  
   117  		if !normalReturn && !recovered {
   118  			abnormalReturnErr = errGoexit
   119  		}
   120  		if !normalReturn {
   121  			abnormalReturnCb(abnormalReturnErr)
   122  		}
   123  		if recovered && !shouldRecover {
   124  			panic(abnormalReturnErr)
   125  		}
   126  	}()
   127  
   128  	func() {
   129  		defer func() {
   130  			if err := recover(); err != nil {
   131  				abnormalReturnErr = &panicErr{
   132  					err:   err,
   133  					stack: debug.Stack(),
   134  				}
   135  			}
   136  		}()
   137  		workerErr = worker.Start()
   138  		normalReturn = true
   139  	}()
   140  	if !normalReturn {
   141  		recovered = true
   142  	}
   143  	return
   144  }
   145  
   146  func (p *WorkerPool) Remove(delta uint32) (alive int, sleeping int) {
   147  	p.mu.Lock()
   148  	var count uint32
   149  	for _, worker := range p.workers {
   150  		if count < delta {
   151  			worker.Stop()
   152  			delete(p.workers, worker.Id())
   153  			p.sleepingWorkers[worker.Id()] = worker
   154  			count++
   155  		} else {
   156  			break
   157  		}
   158  	}
   159  	p.mu.Unlock()
   160  	return p.Len()
   161  }
   162  
   163  func (p *WorkerPool) Len() (alive int, sleeping int) {
   164  	p.mu.Lock()
   165  	defer p.mu.Unlock()
   166  	return len(p.workers), len(p.sleepingWorkers)
   167  }
   168  
   169  func (p *WorkerPool) ActiveWorkerNum() int64 {
   170  	return atomic.LoadInt64(&p.activeWorkerNum)
   171  }
   172  
   173  // Kill kills all worker.
   174  func (p *WorkerPool) Kill() {
   175  	p.mu.Lock()
   176  	defer p.mu.Unlock()
   177  	for _, w := range p.workers {
   178  		w.Kill()
   179  	}
   180  }
   181  
   182  // Wait waits for all workers to stop.
   183  // Calling this without sleeping or removing all worker may block forever.
   184  func (p *WorkerPool) Wait() {
   185  	p.wg.Wait()
   186  }