github.com/songzhibin97/gkit@v1.2.13/delayed/do.go (about)

     1  package delayed
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"github.com/songzhibin97/gkit/goroutine"
     7  	"github.com/songzhibin97/gkit/options"
     8  	"os"
     9  	"os/signal"
    10  	"sync"
    11  	"sync/atomic"
    12  	"syscall"
    13  	"time"
    14  )
    15  
    16  var BadDelayed = &badDelayed{}
    17  var ErrorRepeatShutdown = errors.New("重复关闭")
    18  
    19  type badDelayed struct{}
    20  
    21  func (b badDelayed) Do() {
    22  	return
    23  }
    24  
    25  func (b badDelayed) ExecTime() int64 {
    26  	return 0
    27  }
    28  
    29  func (b badDelayed) Identify() string {
    30  	return "badDelayed"
    31  }
    32  
    33  // DispatchingDelayed 调度延时任务.
    34  // Concurrency safety
    35  type DispatchingDelayed struct {
    36  	sync.RWMutex
    37  	delays         []Delayed
    38  	checkTime      time.Duration                                 // 检查时间
    39  	Worker         int64                                         // 并发数(实际执行任务)
    40  	signal         []os.Signal                                   // 接受注册的信号
    41  	signalCallback func(signal os.Signal, d *DispatchingDelayed) // 接受到信号的回调
    42  	close          chan struct{}                                 // 内部使用(标记是否已经关闭)
    43  	isClose        int32
    44  	pool           goroutine.GGroup // 并发内部使用的pool // min 1
    45  	refresh        chan struct{}    // 强行刷新
    46  }
    47  
    48  // AddDelayed 添加延时任务
    49  func (d *DispatchingDelayed) AddDelayed(delayed Delayed) {
    50  	if delayed.ExecTime() <= 0 {
    51  		// 无效任务
    52  		return
    53  	}
    54  	if atomic.LoadInt32(&d.isClose) == 1 {
    55  		return
    56  	}
    57  
    58  	d.Lock()
    59  	defer d.Unlock()
    60  
    61  	i := len(d.delays)
    62  	d.delays = append(d.delays, delayed)
    63  	siftupDelayed(d.delays, i)
    64  }
    65  
    66  func (d *DispatchingDelayed) delDelayed(i int) Delayed {
    67  	d.Lock()
    68  	defer d.Unlock()
    69  
    70  	if i >= len(d.delays) {
    71  		return BadDelayed
    72  	}
    73  
    74  	last := len(d.delays) - 1
    75  	if i != last {
    76  		d.delays[i] = d.delays[last]
    77  	}
    78  	d.delays[last] = nil
    79  	d.delays = d.delays[:last]
    80  	smallestChanged := i
    81  	if i != last {
    82  		// Moving to i may have moved the last timer to a new parent,
    83  		// so sift up to preserve the heap guarantee.
    84  		smallestChanged = siftupDelayed(d.delays, i)
    85  	}
    86  
    87  	return d.delays[smallestChanged]
    88  }
    89  
    90  // delDelayedTop pop最小时间
    91  func (d *DispatchingDelayed) delDelayedTop() Delayed {
    92  	d.Lock()
    93  	defer d.Unlock()
    94  
    95  	if len(d.delays) == 0 {
    96  		return BadDelayed
    97  	}
    98  
    99  	ret := d.delays[0]
   100  	last := len(d.delays) - 1
   101  	if last > 0 {
   102  		d.delays[0] = d.delays[last]
   103  	}
   104  
   105  	d.delays[last] = nil
   106  	d.delays = d.delays[:last]
   107  	if last > 0 {
   108  		siftdownDelayed(d.delays, 0)
   109  	}
   110  	return ret
   111  }
   112  
   113  // getTopDelayed 获取下一个需要执行的任务
   114  func (d *DispatchingDelayed) getTopDelayed() Delayed {
   115  	d.RLock()
   116  	d.RUnlock()
   117  	if len(d.delays) == 0 {
   118  		return BadDelayed
   119  	}
   120  	return d.delays[0]
   121  }
   122  
   123  // IsInvalid 判断任务是否有效
   124  func (d *DispatchingDelayed) IsInvalid(delayed Delayed) bool {
   125  	return delayed == badDelayed{}
   126  }
   127  
   128  // Close 关闭
   129  func (d *DispatchingDelayed) Close() error {
   130  	if !atomic.CompareAndSwapInt32(&d.isClose, 0, 1) {
   131  		return ErrorRepeatShutdown
   132  	}
   133  	close(d.close)
   134  	return d.pool.Shutdown()
   135  }
   136  
   137  // Refresh 刷新
   138  func (d *DispatchingDelayed) Refresh() {
   139  	select {
   140  	case d.refresh <- struct{}{}:
   141  	}
   142  }
   143  
   144  // sentinel 启动
   145  func (d *DispatchingDelayed) sentinel() {
   146  	go func() {
   147  		timer := time.NewTicker(d.checkTime)
   148  		for {
   149  			select {
   150  			case <-timer.C:
   151  			case <-d.refresh:
   152  			case <-d.close:
   153  				// 关闭流程
   154  				ln := len(d.delays)
   155  				for i := 0; i < ln; i++ {
   156  					pop := d.delDelayedTop()
   157  					if d.IsInvalid(pop) {
   158  						continue
   159  					}
   160  					d.pool.AddTask(pop.Do)
   161  				}
   162  				return
   163  			}
   164  			now := time.Now().Unix()
   165  			for i := 0; i < len(d.delays); i++ {
   166  				top := d.getTopDelayed()
   167  
   168  				// 还没到达执行时间
   169  				if top.ExecTime() > now {
   170  					break
   171  				}
   172  
   173  				d.delDelayedTop()
   174  				d.pool.AddTask(top.Do)
   175  			}
   176  		}
   177  	}()
   178  }
   179  
   180  // NewDispatchingDelayed 初始化调度实例
   181  func NewDispatchingDelayed(o ...options.Option) *DispatchingDelayed {
   182  	dispatchingDelayed := &DispatchingDelayed{
   183  		checkTime: time.Second,
   184  		Worker:    1,
   185  		signal:    []os.Signal{syscall.SIGINT},
   186  		signalCallback: func(signal os.Signal, d *DispatchingDelayed) {
   187  			_ = d.Close()
   188  		},
   189  		close:   make(chan struct{}, 1),
   190  		pool:    goroutine.NewGoroutine(context.Background(), goroutine.SetMax(1), goroutine.SetIdle(1)),
   191  		refresh: make(chan struct{}, 1),
   192  	}
   193  	for _, option := range o {
   194  		option(dispatchingDelayed)
   195  	}
   196  	if dispatchingDelayed.checkTime <= 0 {
   197  		dispatchingDelayed.checkTime = time.Second
   198  	}
   199  	if dispatchingDelayed.Worker <= 0 {
   200  		dispatchingDelayed.Worker = 1
   201  	}
   202  	if dispatchingDelayed.Worker != 1 {
   203  		dispatchingDelayed.pool = goroutine.NewGoroutine(context.Background(), goroutine.SetMax(dispatchingDelayed.Worker), goroutine.SetIdle(dispatchingDelayed.Worker))
   204  	}
   205  	if len(dispatchingDelayed.signal) != 0 && dispatchingDelayed.signalCallback != nil {
   206  		sign := make(chan os.Signal, 1)
   207  		signal.Notify(sign, dispatchingDelayed.signal...)
   208  		go func() {
   209  			for {
   210  				select {
   211  				case <-dispatchingDelayed.close:
   212  					return
   213  				case v := <-sign:
   214  					dispatchingDelayed.signalCallback(v, dispatchingDelayed)
   215  				}
   216  			}
   217  		}()
   218  	}
   219  	dispatchingDelayed.sentinel()
   220  	return dispatchingDelayed
   221  }