github.com/sunvim/utils@v0.1.0/workpool/workpool.go (about)

     1  package workpool
     2  
     3  import (
     4  	"context"
     5  	"runtime"
     6  	"sync"
     7  	"sync/atomic"
     8  	"time"
     9  
    10  	"github.com/sunvim/utils/queue"
    11  )
    12  
    13  // TaskHandler Define function callbacks
    14  type TaskHandler func() error
    15  
    16  // WorkPool serves incoming connections via a pool of workers
    17  type WorkPool struct {
    18  	closed       int32
    19  	isQueTask    int32         // Mark whether queue retrieval is task.
    20  	errChan      chan error    // error chan
    21  	timeout      time.Duration // max timeout
    22  	wg           sync.WaitGroup
    23  	task         chan TaskHandler
    24  	waitingQueue *queue.Queue
    25  	workerNum    int
    26  }
    27  
    28  // New new workpool and set the max number of concurrencies
    29  func New(max int) *WorkPool {
    30  	if max < 1 {
    31  		max = 1
    32  	}
    33  
    34  	p := &WorkPool{
    35  		task:         make(chan TaskHandler, 2*max),
    36  		errChan:      make(chan error, 1),
    37  		waitingQueue: queue.New(),
    38  		workerNum:    max,
    39  	}
    40  
    41  	go p.loop(max)
    42  	return p
    43  }
    44  
    45  // SetTimeout Setting timeout time
    46  func (p *WorkPool) SetTimeout(timeout time.Duration) {
    47  	p.timeout = timeout
    48  }
    49  
    50  // Do Add to the workpool and return immediately
    51  func (p *WorkPool) Do(fn TaskHandler) {
    52  	if p.IsClosed() {
    53  		return
    54  	}
    55  	p.waitingQueue.Push(fn)
    56  }
    57  
    58  // DoWait Add to the workpool and wait for execution to complete before returning
    59  func (p *WorkPool) DoWait(task TaskHandler) {
    60  	if p.IsClosed() { // closed
    61  		return
    62  	}
    63  
    64  	doneChan := make(chan struct{})
    65  	p.waitingQueue.Push(TaskHandler(func() error {
    66  		defer close(doneChan)
    67  		return task()
    68  	}))
    69  	<-doneChan
    70  }
    71  
    72  // Wait Waiting for the worker thread to finish executing
    73  func (p *WorkPool) Wait() error {
    74  	p.waitingQueue.Wait()
    75  	p.waitingQueue.Close()
    76  	p.waitTask() // wait que down
    77  	close(p.task)
    78  	p.wg.Wait() // wait all task finished
    79  	select {
    80  	case err := <-p.errChan:
    81  		p.waitingQueue = queue.New()
    82  		p.task = make(chan TaskHandler, p.workerNum*2)
    83  		return err
    84  	default:
    85  		p.waitingQueue = queue.New()
    86  		p.task = make(chan TaskHandler, p.workerNum*2)
    87  		return nil
    88  	}
    89  }
    90  
    91  // IsDone Determine whether it is complete (non-blocking)
    92  func (p *WorkPool) IsDone() bool {
    93  	if p == nil || p.task == nil {
    94  		return true
    95  	}
    96  
    97  	return p.waitingQueue.Len() == 0 && len(p.task) == 0
    98  }
    99  
   100  // IsClosed Has it been closed?
   101  func (p *WorkPool) IsClosed() bool {
   102  	if atomic.LoadInt32(&p.closed) == 1 { // closed
   103  		return true
   104  	}
   105  	return false
   106  }
   107  
   108  func (p *WorkPool) startQueue() {
   109  	p.isQueTask = 1
   110  	for {
   111  		tmp := p.waitingQueue.Pop()
   112  		if p.IsClosed() { // closed
   113  			p.waitingQueue.Close()
   114  			break
   115  		}
   116  		if tmp != nil {
   117  			fn := tmp.(TaskHandler)
   118  			if fn != nil {
   119  				p.task <- fn
   120  			}
   121  		} else {
   122  			break
   123  		}
   124  
   125  	}
   126  	atomic.StoreInt32(&p.isQueTask, 0)
   127  }
   128  
   129  func (p *WorkPool) waitTask() {
   130  	for {
   131  		runtime.Gosched()
   132  		if p.IsDone() {
   133  			if atomic.LoadInt32(&p.isQueTask) == 0 {
   134  				break
   135  			}
   136  		}
   137  	}
   138  }
   139  
   140  func (p *WorkPool) loop(maxWorkersCount int) {
   141  	go p.startQueue() // Startup queue
   142  
   143  	p.wg.Add(maxWorkersCount) // Maximum number of work cycles
   144  	// Start Max workers
   145  	for i := 0; i < maxWorkersCount; i++ {
   146  		go func() {
   147  			defer p.wg.Done()
   148  
   149  			for wt := range p.task {
   150  				if wt == nil || atomic.LoadInt32(&p.closed) == 1 { // returns immediately
   151  					continue // It needs to be consumed before returning.
   152  				}
   153  
   154  				closed := make(chan struct{}, 1)
   155  				// Set timeout, priority task timeout.
   156  				if p.timeout > 0 {
   157  					ct, cancel := context.WithTimeout(context.Background(), p.timeout)
   158  					go func() {
   159  						select {
   160  						case <-ct.Done():
   161  							p.errChan <- ct.Err()
   162  							atomic.StoreInt32(&p.closed, 1)
   163  							cancel()
   164  						case <-closed:
   165  						}
   166  					}()
   167  				}
   168  
   169  				err := wt() // Points of Execution.
   170  				close(closed)
   171  				if err != nil {
   172  					select {
   173  					case p.errChan <- err:
   174  						atomic.StoreInt32(&p.closed, 1)
   175  					default:
   176  					}
   177  				}
   178  			}
   179  		}()
   180  	}
   181  }