github.com/yaling888/clash@v1.53.0/common/batch/batch.go (about)

     1  package batch
     2  
     3  import (
     4  	"context"
     5  	"sync"
     6  )
     7  
     8  type Option[T any] func(b *Batch[T])
     9  
    10  type Result[T any] struct {
    11  	Value T
    12  	Err   error
    13  }
    14  
    15  type Error struct {
    16  	Key string
    17  	Err error
    18  }
    19  
    20  func WithConcurrencyNum[T any](n int) Option[T] {
    21  	return func(b *Batch[T]) {
    22  		q := make(chan struct{}, n)
    23  		for i := 0; i < n; i++ {
    24  			q <- struct{}{}
    25  		}
    26  		b.queue = q
    27  	}
    28  }
    29  
    30  // Batch similar to errgroup, but can control the maximum number of concurrent
    31  type Batch[T any] struct {
    32  	result map[string]Result[T]
    33  	queue  chan struct{}
    34  	wg     sync.WaitGroup
    35  	mux    sync.Mutex
    36  	err    *Error
    37  	once   sync.Once
    38  	cancel func()
    39  }
    40  
    41  func (b *Batch[T]) Go(key string, fn func() (T, error)) {
    42  	b.wg.Add(1)
    43  	go func() {
    44  		defer b.wg.Done()
    45  		if b.queue != nil {
    46  			<-b.queue
    47  			defer func() {
    48  				b.queue <- struct{}{}
    49  			}()
    50  		}
    51  
    52  		value, err := fn()
    53  		if err != nil {
    54  			b.once.Do(func() {
    55  				b.err = &Error{key, err}
    56  				if b.cancel != nil {
    57  					b.cancel()
    58  				}
    59  			})
    60  		}
    61  
    62  		ret := Result[T]{value, err}
    63  		b.mux.Lock()
    64  		defer b.mux.Unlock()
    65  		b.result[key] = ret
    66  	}()
    67  }
    68  
    69  func (b *Batch[T]) Wait() *Error {
    70  	b.wg.Wait()
    71  	if b.cancel != nil {
    72  		b.cancel()
    73  	}
    74  	return b.err
    75  }
    76  
    77  func (b *Batch[T]) WaitAndGetResult() (map[string]Result[T], *Error) {
    78  	err := b.Wait()
    79  	return b.Result(), err
    80  }
    81  
    82  func (b *Batch[T]) Result() map[string]Result[T] {
    83  	b.mux.Lock()
    84  	defer b.mux.Unlock()
    85  	copyM := map[string]Result[T]{}
    86  	for k, v := range b.result {
    87  		copyM[k] = v
    88  	}
    89  	return copyM
    90  }
    91  
    92  func New[T any](ctx context.Context, opts ...Option[T]) (*Batch[T], context.Context) {
    93  	ctx, cancel := context.WithCancel(ctx)
    94  
    95  	b := &Batch[T]{
    96  		result: map[string]Result[T]{},
    97  	}
    98  
    99  	for _, o := range opts {
   100  		o(b)
   101  	}
   102  
   103  	b.cancel = cancel
   104  	return b, ctx
   105  }