github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/pkg/workerpool/pool_impl.go (about)

     1  // Copyright 2020 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package workerpool
    15  
    16  import (
    17  	"context"
    18  	"sync"
    19  	"sync/atomic"
    20  	"time"
    21  
    22  	"github.com/pingcap/errors"
    23  	"github.com/pingcap/failpoint"
    24  	"github.com/pingcap/log"
    25  	cerrors "github.com/pingcap/tiflow/pkg/errors"
    26  	"github.com/pingcap/tiflow/pkg/notify"
    27  	"go.uber.org/zap"
    28  	"golang.org/x/sync/errgroup"
    29  	"golang.org/x/time/rate"
    30  )
    31  
    32  const (
    33  	workerPoolDefaultClockSourceInterval = time.Millisecond * 100
    34  )
    35  
    36  type defaultPoolImpl struct {
    37  	// assume the hasher to be the trivial hasher for now
    38  	hasher Hasher
    39  	// do not resize this slice after creating the pool
    40  	workers []*worker
    41  	// used to generate handler IDs, must be accessed atomically
    42  	nextHandlerID int64
    43  }
    44  
    45  // NewDefaultWorkerPool creates a new WorkerPool that uses the default implementation
    46  func NewDefaultWorkerPool(numWorkers int) WorkerPool {
    47  	return newDefaultPoolImpl(&defaultHasher{}, numWorkers)
    48  }
    49  
    50  func newDefaultPoolImpl(hasher Hasher, numWorkers int) *defaultPoolImpl {
    51  	workers := make([]*worker, numWorkers)
    52  	for i := 0; i < numWorkers; i++ {
    53  		workers[i] = newWorker()
    54  	}
    55  	return &defaultPoolImpl{
    56  		hasher:  hasher,
    57  		workers: workers,
    58  	}
    59  }
    60  
    61  func (p *defaultPoolImpl) Run(ctx context.Context) error {
    62  	errg, ctx := errgroup.WithContext(ctx)
    63  
    64  	for _, worker := range p.workers {
    65  		workerFinal := worker
    66  		errg.Go(func() error {
    67  			err := workerFinal.run(ctx)
    68  			if err != nil {
    69  				return errors.Trace(err)
    70  			}
    71  			return nil
    72  		})
    73  	}
    74  
    75  	return errg.Wait()
    76  }
    77  
    78  func (p *defaultPoolImpl) RegisterEvent(f func(ctx context.Context, event interface{}) error) EventHandle {
    79  	handler := &defaultEventHandle{
    80  		f:     f,
    81  		errCh: make(chan error, 1),
    82  		id:    atomic.AddInt64(&p.nextHandlerID, 1) - 1,
    83  	}
    84  
    85  	workerID := p.hasher.Hash(handler) % int64(len(p.workers))
    86  	p.workers[workerID].addHandle(handler)
    87  	handler.worker = p.workers[workerID]
    88  
    89  	return handler
    90  }
    91  
    92  type handleStatus = int32
    93  
    94  const (
    95  	handleRunning = handleStatus(iota)
    96  	handleCancelling
    97  	handleCancelled
    98  )
    99  
   100  type defaultEventHandle struct {
   101  	// the function to be run each time the event is triggered
   102  	f func(ctx context.Context, event interface{}) error
   103  	// must be accessed atomically
   104  	status handleStatus
   105  	// channel for the error returned by f
   106  	errCh chan error
   107  	// the worker that the handle is associated with
   108  	worker *worker
   109  	// identifier for this handle. No significant usage for now.
   110  	// Might be used to support consistent hashing in the future,
   111  	// so that the pool can be resized efficiently.
   112  	id int64
   113  
   114  	// whether there is a valid timer handler, must be accessed atomically
   115  	hasTimer int32
   116  	// the time when timer was triggered the last time
   117  	lastTimer time.Time
   118  	// minimum interval between two timer calls
   119  	timerInterval time.Duration
   120  	// the handler for the timer
   121  	timerHandler func(ctx context.Context) error
   122  
   123  	// whether this is a valid errorHandler, must be accessed atomically
   124  	hasErrorHandler int32
   125  	// the error handler, called when the handle meets an error (which is returned by f)
   126  	errorHandler func(err error)
   127  }
   128  
   129  func (h *defaultEventHandle) AddEvent(ctx context.Context, event interface{}) error {
   130  	status := atomic.LoadInt32(&h.status)
   131  	if status != handleRunning {
   132  		return cerrors.ErrWorkerPoolHandleCancelled.GenWithStackByArgs()
   133  	}
   134  
   135  	failpoint.Inject("addEventDelayPoint", func() {})
   136  
   137  	task := task{
   138  		handle: h,
   139  		f: func(ctx1 context.Context) error {
   140  			return h.f(ctx, event)
   141  		},
   142  	}
   143  
   144  	select {
   145  	case <-ctx.Done():
   146  		return errors.Trace(ctx.Err())
   147  	case h.worker.taskCh <- task:
   148  	}
   149  	return nil
   150  }
   151  
   152  func (h *defaultEventHandle) AddEvents(ctx context.Context, events []interface{}) error {
   153  	status := atomic.LoadInt32(&h.status)
   154  	if status != handleRunning {
   155  		return cerrors.ErrWorkerPoolHandleCancelled.GenWithStackByArgs()
   156  	}
   157  
   158  	failpoint.Inject("addEventDelayPoint", func() {})
   159  
   160  	task := task{
   161  		handle: h,
   162  		f: func(ctx1 context.Context) error {
   163  			for _, event := range events {
   164  				err := h.f(ctx, event)
   165  				if err != nil {
   166  					return err
   167  				}
   168  			}
   169  			return nil
   170  		},
   171  	}
   172  
   173  	select {
   174  	case <-ctx.Done():
   175  		return errors.Trace(ctx.Err())
   176  	case h.worker.taskCh <- task:
   177  	}
   178  	return nil
   179  }
   180  
   181  func (h *defaultEventHandle) SetTimer(ctx context.Context, interval time.Duration, f func(ctx context.Context) error) EventHandle {
   182  	// mark the timer handler function as invalid
   183  	atomic.StoreInt32(&h.hasTimer, 0)
   184  	// wait for `hasTimer` to take effect, otherwise we might have a data race, if there was a previous handler.
   185  	h.worker.synchronize()
   186  
   187  	h.timerInterval = interval
   188  	h.timerHandler = func(ctx1 context.Context) error {
   189  		return f(ctx)
   190  	}
   191  	// mark the timer handler function as valid
   192  	atomic.StoreInt32(&h.hasTimer, 1)
   193  
   194  	return h
   195  }
   196  
   197  func (h *defaultEventHandle) Unregister() {
   198  	if !atomic.CompareAndSwapInt32(&h.status, handleRunning, handleCancelled) {
   199  		// call synchronize so that the returning of Unregister cannot race
   200  		// with the calling of the errorHandler, if an error is already being processed.
   201  		h.worker.synchronize()
   202  		// already cancelled
   203  		return
   204  	}
   205  
   206  	failpoint.Inject("unregisterDelayPoint", func() {})
   207  
   208  	// call synchronize so that all function executions related to this handle will be
   209  	// linearized BEFORE Unregister.
   210  	h.worker.synchronize()
   211  
   212  	h.doCancel(cerrors.ErrWorkerPoolHandleCancelled.GenWithStackByArgs())
   213  }
   214  
   215  func (h *defaultEventHandle) GracefulUnregister(ctx context.Context, timeout time.Duration) error {
   216  	if !atomic.CompareAndSwapInt32(&h.status, handleRunning, handleCancelling) {
   217  		// already cancelling or cancelled
   218  		return nil
   219  	}
   220  
   221  	defer func() {
   222  		if !atomic.CompareAndSwapInt32(&h.status, handleCancelling, handleCancelled) {
   223  			// already cancelled
   224  			return
   225  		}
   226  
   227  		// call synchronize so that all function executions related to this handle will be
   228  		// linearized BEFORE Unregister.
   229  		h.worker.synchronize()
   230  		h.doCancel(cerrors.ErrWorkerPoolHandleCancelled.GenWithStackByArgs())
   231  	}()
   232  
   233  	ctx, cancel := context.WithTimeout(ctx, timeout)
   234  	defer cancel()
   235  
   236  	doneCh := make(chan struct{})
   237  	select {
   238  	case <-ctx.Done():
   239  		return cerrors.ErrWorkerPoolGracefulUnregisterTimedOut.GenWithStackByArgs()
   240  	case h.worker.taskCh <- task{
   241  		handle: h,
   242  		doneCh: doneCh,
   243  	}:
   244  	}
   245  
   246  	select {
   247  	case <-ctx.Done():
   248  		return cerrors.ErrWorkerPoolGracefulUnregisterTimedOut.GenWithStackByArgs()
   249  	case <-doneCh:
   250  	}
   251  
   252  	return nil
   253  }
   254  
   255  // callers of doCancel need to check h.isCancelled first.
   256  // DO NOT call doCancel multiple times on the same handle.
   257  func (h *defaultEventHandle) doCancel(err error) {
   258  	h.worker.removeHandle(h)
   259  
   260  	if atomic.LoadInt32(&h.hasErrorHandler) == 1 {
   261  		h.errorHandler(err)
   262  	}
   263  
   264  	h.errCh <- err
   265  	close(h.errCh)
   266  }
   267  
   268  func (h *defaultEventHandle) ErrCh() <-chan error {
   269  	return h.errCh
   270  }
   271  
   272  func (h *defaultEventHandle) OnExit(f func(err error)) EventHandle {
   273  	atomic.StoreInt32(&h.hasErrorHandler, 0)
   274  	h.worker.synchronize()
   275  	h.errorHandler = f
   276  	atomic.StoreInt32(&h.hasErrorHandler, 1)
   277  	return h
   278  }
   279  
   280  func (h *defaultEventHandle) HashCode() int64 {
   281  	return h.id
   282  }
   283  
   284  func (h *defaultEventHandle) cancelWithErr(err error) {
   285  	if !atomic.CompareAndSwapInt32(&h.status, handleRunning, handleCancelled) {
   286  		// already cancelled
   287  		return
   288  	}
   289  
   290  	h.doCancel(err)
   291  }
   292  
   293  func (h *defaultEventHandle) durationSinceLastTimer() time.Duration {
   294  	return time.Since(h.lastTimer)
   295  }
   296  
   297  func (h *defaultEventHandle) doTimer(ctx context.Context) error {
   298  	if atomic.LoadInt32(&h.hasTimer) == 0 {
   299  		return nil
   300  	}
   301  
   302  	if h.durationSinceLastTimer() < h.timerInterval {
   303  		return nil
   304  	}
   305  
   306  	err := h.timerHandler(ctx)
   307  	if err != nil {
   308  		return errors.Trace(err)
   309  	}
   310  
   311  	h.lastTimer = time.Now()
   312  
   313  	return nil
   314  }
   315  
   316  type task struct {
   317  	handle *defaultEventHandle
   318  	f      func(ctx context.Context) error
   319  
   320  	doneCh chan struct{} // only used in implementing GracefulUnregister
   321  }
   322  
   323  type worker struct {
   324  	taskCh       chan task
   325  	handles      map[*defaultEventHandle]struct{}
   326  	handleRWLock sync.RWMutex
   327  	// A message is passed to handleCancelCh when we need to wait for the
   328  	// current execution of handler to finish. Should be BLOCKING.
   329  	handleCancelCh chan struct{}
   330  	// must be accessed atomically
   331  	isRunning int32
   332  	// notifies exits of run()
   333  	stopNotifier notify.Notifier
   334  
   335  	slowSynchronizeThreshold time.Duration
   336  	slowSynchronizeLimiter   *rate.Limiter
   337  }
   338  
   339  func newWorker() *worker {
   340  	return &worker{
   341  		taskCh:         make(chan task, 128),
   342  		handles:        make(map[*defaultEventHandle]struct{}),
   343  		handleCancelCh: make(chan struct{}), // this channel must be unbuffered, i.e. blocking
   344  
   345  		slowSynchronizeThreshold: 10 * time.Second,
   346  		slowSynchronizeLimiter:   rate.NewLimiter(rate.Every(time.Second*5), 1),
   347  	}
   348  }
   349  
   350  func (w *worker) run(ctx context.Context) error {
   351  	ticker := time.NewTicker(workerPoolDefaultClockSourceInterval)
   352  	atomic.StoreInt32(&w.isRunning, 1)
   353  	defer func() {
   354  		ticker.Stop()
   355  		atomic.StoreInt32(&w.isRunning, 0)
   356  		w.stopNotifier.Notify()
   357  	}()
   358  
   359  	for {
   360  		select {
   361  		case <-ctx.Done():
   362  			return errors.Trace(ctx.Err())
   363  		case task := <-w.taskCh:
   364  			if atomic.LoadInt32(&task.handle.status) == handleCancelled {
   365  				// ignored cancelled handle
   366  				continue
   367  			}
   368  
   369  			if task.doneCh != nil {
   370  				close(task.doneCh)
   371  				if task.f != nil {
   372  					log.L().DPanic("unexpected message handler func in cancellation task", zap.Stack("stack"))
   373  				}
   374  				continue
   375  			}
   376  
   377  			err := task.f(ctx)
   378  			if err != nil {
   379  				task.handle.cancelWithErr(err)
   380  			}
   381  		case <-ticker.C:
   382  			var handleErrs []struct {
   383  				h *defaultEventHandle
   384  				e error
   385  			}
   386  
   387  			w.handleRWLock.RLock()
   388  			for handle := range w.handles {
   389  				if atomic.LoadInt32(&handle.status) == handleCancelled {
   390  					// ignored cancelled handle
   391  					continue
   392  				}
   393  				err := handle.doTimer(ctx)
   394  				if err != nil {
   395  					handleErrs = append(handleErrs, struct {
   396  						h *defaultEventHandle
   397  						e error
   398  					}{handle, err})
   399  				}
   400  			}
   401  			w.handleRWLock.RUnlock()
   402  
   403  			// cancelWithErr must be called out side of the loop above,
   404  			// to avoid deadlock.
   405  			for _, handleErr := range handleErrs {
   406  				handleErr.h.cancelWithErr(handleErr.e)
   407  			}
   408  		case <-w.handleCancelCh:
   409  		}
   410  	}
   411  }
   412  
   413  // synchronize waits for the worker to loop at least once, or to exit.
   414  func (w *worker) synchronize() {
   415  	if atomic.LoadInt32(&w.isRunning) == 0 {
   416  		return
   417  	}
   418  
   419  	receiver, err := w.stopNotifier.NewReceiver(time.Millisecond * 100)
   420  	if err != nil {
   421  		if cerrors.ErrOperateOnClosedNotifier.Equal(errors.Cause(err)) {
   422  			return
   423  		}
   424  		log.Panic("unexpected error", zap.Error(err))
   425  	}
   426  	defer receiver.Stop()
   427  
   428  	startTime := time.Now()
   429  	for {
   430  		workerHasFinishedLoop := false
   431  		select {
   432  		case w.handleCancelCh <- struct{}{}:
   433  			workerHasFinishedLoop = true
   434  		case <-receiver.C:
   435  		}
   436  		if workerHasFinishedLoop || atomic.LoadInt32(&w.isRunning) == 0 {
   437  			break
   438  		}
   439  
   440  		if time.Since(startTime) > w.slowSynchronizeThreshold &&
   441  			w.slowSynchronizeLimiter.Allow() {
   442  			// likely the workerpool has deadlocked, or there is a bug
   443  			// in the event handlers.
   444  			logWarn("synchronize is taking too long, report a bug",
   445  				zap.Duration("duration", time.Since(startTime)),
   446  				zap.Stack("stacktrace"))
   447  		}
   448  	}
   449  }
   450  
   451  // A delegate to log.Warn. It exists only for testing.
   452  var logWarn = log.Warn
   453  
   454  func (w *worker) addHandle(handle *defaultEventHandle) {
   455  	w.handleRWLock.Lock()
   456  	defer w.handleRWLock.Unlock()
   457  
   458  	w.handles[handle] = struct{}{}
   459  }
   460  
   461  func (w *worker) removeHandle(handle *defaultEventHandle) {
   462  	w.handleRWLock.Lock()
   463  	defer w.handleRWLock.Unlock()
   464  
   465  	delete(w.handles, handle)
   466  }