github.com/sandwich-go/boost@v1.3.29/xpool/goroutine.go (about)

     1  package xpool
     2  
     3  import (
     4  	"errors"
     5  	"github.com/sandwich-go/boost/xerror"
     6  	"github.com/sandwich-go/boost/xsync"
     7  	"github.com/sandwich-go/boost/xtime"
     8  	"time"
     9  )
    10  
    11  type worker struct {
    12  	jobChan chan Job
    13  
    14  	// closeChan can be closed in order to cleanly shut down this worker.
    15  	closeChan chan struct{}
    16  	// closedChan is closed by the run() goroutine when it exits.
    17  	closedChan chan struct{}
    18  }
    19  
    20  func (w *worker) Start(jobQueue chan Job) {
    21  	defer func() {
    22  		close(w.closedChan)
    23  	}()
    24  
    25  	go func() {
    26  		var job Job
    27  		for {
    28  			select {
    29  			case job = <-jobQueue:
    30  				if job == nil {
    31  					return
    32  				}
    33  				job()
    34  			case <-w.closeChan:
    35  				return
    36  			}
    37  		}
    38  	}()
    39  }
    40  
    41  func (w *worker) stop() { close(w.closeChan) }
    42  func (w *worker) join() { <-w.closedChan }
    43  
    44  func newWorker() *worker {
    45  	return &worker{
    46  		jobChan:    make(chan Job),
    47  		closeChan:  make(chan struct{}),
    48  		closedChan: make(chan struct{}),
    49  	}
    50  }
    51  
    52  // Job 被 worker 竞争的工作
    53  type Job func()
    54  
    55  // GoroutinePool 线程池,numWorkers 数量的 worker 竞争 Job
    56  type GoroutinePool struct {
    57  	queuedJobs xsync.AtomicInt64
    58  	jobQueue   chan Job
    59  	workers    []*worker
    60  	closeFlag  xsync.AtomicInt32
    61  	timeout    time.Duration
    62  }
    63  
    64  // NewGoroutinePool 创建新的协程竞争池
    65  // numWorkers 数量的 worker 竞争 Job
    66  // jobQueueLen 设置 job 队列长度
    67  // timeout 若 job 队列满,Push job 的超时时间
    68  func NewGoroutinePool(numWorkers int, jobQueueLen int, timeout time.Duration) *GoroutinePool {
    69  	pool := &GoroutinePool{jobQueue: make(chan Job, jobQueueLen), timeout: timeout}
    70  	pool.SetSize(numWorkers)
    71  	return pool
    72  }
    73  
    74  var poolTimeWheel = xtime.NewWheel(time.Second, 20)
    75  
    76  // Push 放入 job 至job 队列
    77  // 若设置了 timeout,当 job 队列满,Push 阻塞 timeout 会报错
    78  func (p *GoroutinePool) Push(job Job) error {
    79  	if p.IsClosed() {
    80  		return errors.New("pool closed")
    81  	}
    82  	if p.timeout == 0 {
    83  		p.jobQueue <- job
    84  	} else {
    85  		select {
    86  		case <-poolTimeWheel.After(p.timeout):
    87  			return xerror.NewText("goroutine pool job queue blocked with %s", p.timeout)
    88  		case p.jobQueue <- job:
    89  		}
    90  	}
    91  	return nil
    92  }
    93  
    94  func (p *GoroutinePool) SetSize(n int) {
    95  	lWorkers := len(p.workers)
    96  	if lWorkers == n {
    97  		return
    98  	}
    99  
   100  	// Add extra workers if N > len(workers)
   101  	for i := lWorkers; i < n; i++ {
   102  		w := newWorker()
   103  		w.Start(p.jobQueue)
   104  		p.workers = append(p.workers, w)
   105  	}
   106  
   107  	// Asynchronously stop all workers > N
   108  	for i := n; i < lWorkers; i++ {
   109  		p.workers[i].stop()
   110  	}
   111  
   112  	// Synchronously wait for all workers > N to stop
   113  	for i := n; i < lWorkers; i++ {
   114  		p.workers[i].join()
   115  	}
   116  
   117  	// Remove stopped workers from slice
   118  	p.workers = p.workers[:n]
   119  }
   120  
   121  // IsClosed 协程竞争池是否已关闭
   122  func (p *GoroutinePool) IsClosed() bool {
   123  	return p.closeFlag.Get() == 1
   124  }
   125  
   126  // Close 关闭协程竞争池
   127  func (p *GoroutinePool) Close() {
   128  	if p.closeFlag.CompareAndSwap(0, 1) {
   129  		p.SetSize(0)
   130  		close(p.jobQueue)
   131  	}
   132  }