github.com/pingcap/ticdc@v0.0.0-20220526033649-485a10ef2652/pkg/workerpool/async_pool_impl.go (about)

     1  // Copyright 2020 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package workerpool
    15  
    16  import (
    17  	"context"
    18  	"sync"
    19  	"sync/atomic"
    20  
    21  	"github.com/pingcap/errors"
    22  	cerrors "github.com/pingcap/ticdc/pkg/errors"
    23  	"github.com/pingcap/ticdc/pkg/retry"
    24  	"golang.org/x/sync/errgroup"
    25  )
    26  
    27  const (
    28  	backoffBaseDelayInMs = 1
    29  	maxTries             = 25
    30  )
    31  
    32  type defaultAsyncPoolImpl struct {
    33  	workers      []*asyncWorker
    34  	nextWorkerID int32
    35  	isRunning    int32
    36  	runningLock  sync.RWMutex
    37  }
    38  
    39  // NewDefaultAsyncPool creates a new AsyncPool that uses the default implementation
    40  func NewDefaultAsyncPool(numWorkers int) AsyncPool {
    41  	return newDefaultAsyncPoolImpl(numWorkers)
    42  }
    43  
    44  func newDefaultAsyncPoolImpl(numWorkers int) *defaultAsyncPoolImpl {
    45  	workers := make([]*asyncWorker, numWorkers)
    46  
    47  	return &defaultAsyncPoolImpl{
    48  		workers: workers,
    49  	}
    50  }
    51  
    52  func (p *defaultAsyncPoolImpl) Go(ctx context.Context, f func()) error {
    53  	if p.doGo(ctx, f) == nil {
    54  		return nil
    55  	}
    56  
    57  	err := retry.Do(ctx, func() error {
    58  		return errors.Trace(p.doGo(ctx, f))
    59  	}, retry.WithBackoffBaseDelay(backoffBaseDelayInMs), retry.WithMaxTries(maxTries), retry.WithIsRetryableErr(isRetryable))
    60  	return errors.Trace(err)
    61  }
    62  
    63  func isRetryable(err error) bool {
    64  	return cerrors.IsRetryableError(err) && cerrors.ErrAsyncPoolExited.Equal(err)
    65  }
    66  
    67  func (p *defaultAsyncPoolImpl) doGo(ctx context.Context, f func()) error {
    68  	p.runningLock.RLock()
    69  	defer p.runningLock.RUnlock()
    70  
    71  	if atomic.LoadInt32(&p.isRunning) == 0 {
    72  		return cerrors.ErrAsyncPoolExited.GenWithStackByArgs()
    73  	}
    74  
    75  	task := &asyncTask{f: f}
    76  	worker := p.workers[int(atomic.AddInt32(&p.nextWorkerID, 1))%len(p.workers)]
    77  
    78  	worker.chLock.RLock()
    79  	defer worker.chLock.RUnlock()
    80  
    81  	if atomic.LoadInt32(&worker.isClosed) == 1 {
    82  		return cerrors.ErrAsyncPoolExited.GenWithStackByArgs()
    83  	}
    84  
    85  	select {
    86  	case <-ctx.Done():
    87  		return errors.Trace(ctx.Err())
    88  	case worker.inputCh <- task:
    89  	}
    90  
    91  	return nil
    92  }
    93  
    94  func (p *defaultAsyncPoolImpl) Run(ctx context.Context) error {
    95  	p.prepare()
    96  	errg := errgroup.Group{}
    97  
    98  	p.runningLock.Lock()
    99  	atomic.StoreInt32(&p.isRunning, 1)
   100  	p.runningLock.Unlock()
   101  
   102  	defer func() {
   103  		p.runningLock.Lock()
   104  		atomic.StoreInt32(&p.isRunning, 0)
   105  		p.runningLock.Unlock()
   106  	}()
   107  
   108  	errCh := make(chan error, len(p.workers))
   109  	defer close(errCh)
   110  
   111  	for _, worker := range p.workers {
   112  		workerFinal := worker
   113  		errg.Go(func() error {
   114  			err := workerFinal.run()
   115  			if err != nil && cerrors.ErrAsyncPoolExited.NotEqual(errors.Cause(err)) {
   116  				errCh <- err
   117  			}
   118  			return nil
   119  		})
   120  	}
   121  
   122  	errg.Go(func() error {
   123  		var err error
   124  		select {
   125  		case <-ctx.Done():
   126  			err = ctx.Err()
   127  		case err = <-errCh:
   128  		}
   129  
   130  		for _, worker := range p.workers {
   131  			worker.close()
   132  		}
   133  
   134  		return err
   135  	})
   136  
   137  	return errors.Trace(errg.Wait())
   138  }
   139  
   140  func (p *defaultAsyncPoolImpl) prepare() {
   141  	for i := range p.workers {
   142  		p.workers[i] = newAsyncWorker()
   143  	}
   144  }
   145  
   146  type asyncTask struct {
   147  	f func()
   148  }
   149  
   150  type asyncWorker struct {
   151  	inputCh  chan *asyncTask
   152  	isClosed int32
   153  	chLock   sync.RWMutex
   154  }
   155  
   156  func newAsyncWorker() *asyncWorker {
   157  	return &asyncWorker{inputCh: make(chan *asyncTask, 12800)}
   158  }
   159  
   160  func (w *asyncWorker) run() error {
   161  	for {
   162  		task := <-w.inputCh
   163  		if task == nil {
   164  			return cerrors.ErrAsyncPoolExited.GenWithStackByArgs()
   165  		}
   166  		task.f()
   167  	}
   168  }
   169  
   170  func (w *asyncWorker) close() {
   171  	if atomic.SwapInt32(&w.isClosed, 1) == 1 {
   172  		return
   173  	}
   174  
   175  	w.chLock.Lock()
   176  	defer w.chLock.Unlock()
   177  
   178  	close(w.inputCh)
   179  }