github.com/songzhibin97/gkit@v1.2.13/goroutine/goroutine.go (about)

     1  package goroutine
     2  
     3  // package goroutine: 管理goroutine并发量托管任务以及兜底
     4  
     5  import (
     6  	"context"
     7  	"errors"
     8  	"fmt"
     9  	"runtime"
    10  	"sync"
    11  	"sync/atomic"
    12  	"time"
    13  
    14  	"github.com/songzhibin97/gkit/cache/buffer"
    15  
    16  	"github.com/songzhibin97/gkit/log"
    17  	"github.com/songzhibin97/gkit/options"
    18  	"github.com/songzhibin97/gkit/timeout"
    19  )
    20  
    21  var ErrRepeatClose = errors.New("goroutine/goroutine :重复关闭")
    22  
    23  type Goroutine struct {
    24  	close int32
    25  
    26  	// n: 当前goroutine的数量
    27  	n int64
    28  	// 参数选项
    29  	config
    30  	// wait
    31  	wait sync.WaitGroup
    32  	// ctx context
    33  	ctx context.Context
    34  	// cancel
    35  	cancel context.CancelFunc
    36  	// task
    37  	task chan func()
    38  }
    39  
    40  // _go 封装goroutine 使其安全执行
    41  func (g *Goroutine) _go() {
    42  	atomic.AddInt64(&g.n, 1)
    43  	g.wait.Add(1)
    44  	go func() {
    45  		// recover 避免野生goroutine panic后主程退出
    46  		defer func() {
    47  			if err := recover(); err != nil {
    48  				buf := buffer.GetBytes(64 << 10)
    49  				n := runtime.Stack(*buf, false)
    50  				defer buffer.PutBytes(buf)
    51  				// recover panic
    52  				if g.logger == nil {
    53  					fmt.Println("\nrecover go func,", "panic:", err, "\n\npanic stack:\n", string((*buf)[:n]))
    54  					return
    55  				}
    56  				g.logger.Log(log.LevelError, "panic err:", err, "panic stack:", string((*buf)[:n]))
    57  				return
    58  			}
    59  		}()
    60  		defer atomic.AddInt64(&g.n, -1)
    61  		defer g.wait.Done()
    62  		t := time.NewTicker(g.checkTime)
    63  		defer t.Stop()
    64  		for {
    65  			select {
    66  			case <-t.C:
    67  				// 当前的g的个数大于设置的闲置值,则退出
    68  				if atomic.LoadInt64(&g.n) > atomic.LoadInt64(&g.idle) {
    69  					// 闲置数超过预期
    70  					return
    71  				}
    72  			case f, ok := <-g.task:
    73  				// channel已经被关闭
    74  				if !ok {
    75  					return
    76  				}
    77  				// 执行函数
    78  				f()
    79  				if atomic.LoadInt64(&g.n) > atomic.LoadInt64(&g.max) {
    80  					// 如果已经超出预定值,则该goroutine退出
    81  					return
    82  				}
    83  				t.Reset(g.checkTime)
    84  			case <-g.ctx.Done():
    85  				// 触发ctx退出
    86  				return
    87  			}
    88  		}
    89  	}()
    90  }
    91  
    92  // AddTask 添加任务
    93  // 直到添加成功为止
    94  func (g *Goroutine) AddTask(f func()) bool {
    95  	// 判断channel是否关闭
    96  	if atomic.LoadInt32(&g.close) != 0 {
    97  		return false
    98  	}
    99  	// 尝试直接塞入
   100  	// 如果阻塞尝试进行扩容
   101  	select {
   102  	case g.task <- f:
   103  	default:
   104  		if atomic.LoadInt64(&g.n) < atomic.LoadInt64(&g.max) {
   105  			g._go()
   106  		}
   107  		g.task <- f
   108  	}
   109  	return true
   110  }
   111  
   112  // AddTaskN 添加任务 有超时时间
   113  func (g *Goroutine) AddTaskN(ctx context.Context, f func()) bool {
   114  	// 判断channel是否关闭
   115  	if atomic.LoadInt32(&g.close) != 0 {
   116  		return false
   117  	}
   118  	if atomic.LoadInt64(&g.max) > atomic.LoadInt64(&g.n) {
   119  		g._go()
   120  	}
   121  	select {
   122  	case <-ctx.Done():
   123  		return false
   124  	case g.task <- f:
   125  		return true
   126  	}
   127  }
   128  
   129  // ChangeMax 修改pool上限值
   130  func (g *Goroutine) ChangeMax(m int64) {
   131  	atomic.StoreInt64(&g.max, m)
   132  }
   133  
   134  // Shutdown 优雅关闭
   135  // 符合幂等性
   136  func (g *Goroutine) Shutdown() error {
   137  	if atomic.SwapInt32(&g.close, 1) == 1 {
   138  		return ErrRepeatClose
   139  	}
   140  	g.cancel()
   141  	close(g.task)
   142  	err := Delegate(context.TODO(), g.stopTimeout, func(context.Context) error {
   143  		g.wait.Wait()
   144  		return nil
   145  	})
   146  	if g.logger != nil {
   147  		g.logger.Log(log.LevelDebug, err)
   148  	}
   149  	return err
   150  }
   151  
   152  // Trick Debug使用
   153  func (g *Goroutine) Trick() string {
   154  	if g.logger != nil {
   155  		g.logger.Log(log.LevelDebug, "max:", atomic.LoadInt64(&g.max), "idle:", atomic.LoadInt64(&g.idle), "now goroutine", atomic.LoadInt64(&g.n), "task len:", len(g.task))
   156  	}
   157  	return fmt.Sprintln("max:", atomic.LoadInt64(&g.max), "idle:", atomic.LoadInt64(&g.idle), "now goroutine:", atomic.LoadInt64(&g.n), "task len:", len(g.task))
   158  }
   159  
   160  // Delegate 委托执行 一般用于回收函数超时控制
   161  func Delegate(c context.Context, t time.Duration, f func(ctx context.Context) error) error {
   162  	ch := make(chan error, 1)
   163  	go func() {
   164  		defer func() {
   165  			if err := recover(); err != nil {
   166  				// panic兜底
   167  				switch e := err.(type) {
   168  				case string:
   169  					ch <- errors.New(e)
   170  				case error:
   171  					ch <- e
   172  				default:
   173  					ch <- errors.New(fmt.Sprintf("%+v\n", err))
   174  				}
   175  				return
   176  			}
   177  		}()
   178  		ch <- f(c)
   179  	}()
   180  	// 增加优雅退出超时控制
   181  	var (
   182  		cancel context.CancelFunc
   183  	)
   184  	if t > 0 {
   185  		_, c, cancel = timeout.Shrink(c, t)
   186  	} else {
   187  		c, cancel = context.WithCancel(c)
   188  	}
   189  	defer cancel()
   190  	select {
   191  	case <-c.Done():
   192  		return c.Err()
   193  	case err := <-ch:
   194  		return err
   195  	}
   196  }
   197  
   198  // NewGoroutine 实例化方法
   199  func NewGoroutine(ctx context.Context, opts ...options.Option) GGroup {
   200  	ctx, cancel := context.WithCancel(ctx)
   201  	o := config{
   202  		stopTimeout: 10 * time.Second,
   203  		max:         1000,
   204  		idle:        1000,
   205  		checkTime:   10 * time.Minute,
   206  	}
   207  	for _, opt := range opts {
   208  		opt(&o)
   209  	}
   210  	g := &Goroutine{
   211  		ctx:    ctx,
   212  		cancel: cancel,
   213  		// 为什么设置0
   214  		// task buffer 理论上如果比较大,调度可能会延迟
   215  		task:   make(chan func(), 0),
   216  		config: o,
   217  	}
   218  	if o.idle > o.max {
   219  		o.idle = o.max
   220  	}
   221  	// 预加载出idle池,避免阻塞在buffer中
   222  	for i := int64(0); i < o.idle; i++ {
   223  		g._go()
   224  	}
   225  	return g
   226  }