github.com/dubbogo/gost@v1.14.0/sync/base_worker_pool.go (about)

     1  /*
     2   * Licensed to the Apache Software Foundation (ASF) under one or more
     3   * contributor license agreements.  See the NOTICE file distributed with
     4   * this work for additional information regarding copyright ownership.
     5   * The ASF licenses this file to You under the Apache License, Version 2.0
     6   * (the "License"); you may not use this file except in compliance with
     7   * the License.  You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   */
    17  
    18  package gxsync
    19  
    20  import (
    21  	"fmt"
    22  	"runtime/debug"
    23  	"sync"
    24  )
    25  
    26  import (
    27  	"go.uber.org/atomic"
    28  )
    29  
    30  import (
    31  	gxlog "github.com/dubbogo/gost/log"
    32  )
    33  
    34  type WorkerPoolConfig struct {
    35  	NumWorkers int
    36  	NumQueues  int
    37  	QueueSize  int
    38  	Logger     gxlog.Logger
    39  	Enable     bool
    40  }
    41  
    42  // baseWorkerPool is a worker pool with multiple queues.
    43  //
    44  // The below picture shows baseWorkerPool architecture.
    45  // Note that:
    46  // - TaskQueueX is a channel with buffer, please refer to taskQueues.
    47  // - Workers consume tasks in the dispatched queue only, please refer to dispatch(numWorkers).
    48  // - taskId will be incremented by 1 after a task is enqueued.
    49  // ┌───────┐  ┌───────┐  ┌───────┐                 ┌─────────────────────────┐
    50  // │worker0│  │worker2│  │worker4│               ┌─┤ taskId % NumQueues == 0 │
    51  // └───────┘  └───────┘  └───────┘               │ └─────────────────────────┘
    52  //     │          │          │                   │
    53  //     └───────consume───────┘                enqueue
    54  //                ▼                             task    ╔══════════════════╗
    55  //              ┌──┬──┬──┬──┬──┬──┬──┬──┬──┬──┐  │      ║ baseWorkerPool:  ║
    56  //  TaskQueue0  │t0│t1│t2│t3│t4│t5│t6│t7│t8│t9│◀─┘      ║                  ║
    57  //              ├──┼──┼──┼──┼──┼──┼──┼──┼──┼──┤         ║ *NumWorkers=6    ║
    58  //  TaskQueue1  │t0│t1│t2│t3│t4│t5│t6│t7│t8│t9│◀┐       ║ *NumQueues=2     ║
    59  //              └──┴──┴──┴──┴──┴──┴──┴──┴──┴──┘ │       ║ *QueueSize=10    ║
    60  //                ▲                          enqueue    ╚══════════════════╝
    61  //     ┌───────consume───────┐                 task
    62  //     │          │          │                  │
    63  // ┌───────┐  ┌───────┐  ┌───────┐              │  ┌─────────────────────────┐
    64  // │worker1│  │worker3│  │worker5│              └──│ taskId % NumQueues == 1 │
    65  // └───────┘  └───────┘  └───────┘                 └─────────────────────────┘
    66  type baseWorkerPool struct {
    67  	logger gxlog.Logger
    68  
    69  	taskId     uint32
    70  	taskQueues []chan task
    71  
    72  	numWorkers *atomic.Int32
    73  	enable     bool
    74  
    75  	wg *sync.WaitGroup
    76  }
    77  
    78  func newBaseWorkerPool(config WorkerPoolConfig) *baseWorkerPool {
    79  	if config.NumWorkers < 1 {
    80  		config.NumWorkers = 1
    81  	}
    82  	if config.NumQueues < 1 {
    83  		config.NumQueues = 1
    84  	}
    85  	if config.QueueSize < 0 {
    86  		config.QueueSize = 0
    87  	}
    88  
    89  	taskQueues := make([]chan task, config.NumQueues)
    90  	for i := range taskQueues {
    91  		taskQueues[i] = make(chan task, config.QueueSize)
    92  	}
    93  
    94  	p := &baseWorkerPool{
    95  		logger:     config.Logger,
    96  		taskQueues: taskQueues,
    97  		numWorkers: new(atomic.Int32),
    98  		wg:         new(sync.WaitGroup),
    99  		enable:     config.Enable,
   100  	}
   101  
   102  	if !config.Enable {
   103  		return p
   104  	}
   105  
   106  	initWg := new(sync.WaitGroup)
   107  	initWg.Add(config.NumWorkers)
   108  
   109  	p.dispatch(config.NumWorkers, initWg)
   110  
   111  	initWg.Wait()
   112  	if p.logger != nil {
   113  		p.logger.Infof("all %d workers are started", p.NumWorkers())
   114  	}
   115  
   116  	return p
   117  }
   118  
   119  func (p *baseWorkerPool) dispatch(numWorkers int, wg *sync.WaitGroup) {
   120  	for i := 0; i < numWorkers; i++ {
   121  		p.newWorker(i, wg)
   122  	}
   123  }
   124  
   125  func (p *baseWorkerPool) Submit(_ task) error {
   126  	panic("implement me")
   127  }
   128  
   129  func (p *baseWorkerPool) SubmitSync(_ task) error {
   130  	panic("implement me")
   131  }
   132  
   133  func (p *baseWorkerPool) Close() {
   134  	if p.IsClosed() {
   135  		return
   136  	}
   137  
   138  	for _, q := range p.taskQueues {
   139  		close(q)
   140  	}
   141  	p.wg.Wait()
   142  	if p.logger != nil {
   143  		p.logger.Infof("there are %d workers remained, all workers are closed", p.NumWorkers())
   144  	}
   145  }
   146  
   147  func (p *baseWorkerPool) IsClosed() bool {
   148  	return p.NumWorkers() == 0
   149  }
   150  
   151  func (p *baseWorkerPool) NumWorkers() int32 {
   152  	return p.numWorkers.Load()
   153  }
   154  
   155  func (p *baseWorkerPool) newWorker(workerId int, wg *sync.WaitGroup) {
   156  	p.wg.Add(1)
   157  	p.numWorkers.Add(1)
   158  	go p.worker(workerId, wg)
   159  }
   160  
   161  func (p *baseWorkerPool) worker(workerId int, wg *sync.WaitGroup) {
   162  	defer func() {
   163  		if n := p.numWorkers.Add(-1); n < 0 {
   164  			panic(fmt.Sprintf("numWorkers should be greater or equal to 0, but the value is %d", n))
   165  		}
   166  		p.wg.Done()
   167  	}()
   168  
   169  	chanId := workerId % len(p.taskQueues)
   170  
   171  	wg.Done()
   172  	for {
   173  		select {
   174  		case t, ok := <-p.taskQueues[chanId]:
   175  			if !ok {
   176  				return
   177  			}
   178  			if t != nil {
   179  				func() {
   180  					// prevent from goroutine panic
   181  					defer func() {
   182  						if r := recover(); r != nil {
   183  							if p.logger != nil {
   184  								p.logger.Errorf("goroutine panic: %v\n%s", r, string(debug.Stack()))
   185  							}
   186  						}
   187  					}()
   188  					// execute task
   189  					t()
   190  				}()
   191  			}
   192  		}
   193  	}
   194  }