github.com/sagernet/sing@v0.4.0-beta.19.0.20240518125136-f67a0988a636/common/batch/batch.go (about)

     1  package batch
     2  
     3  import (
     4  	"context"
     5  	"sync"
     6  
     7  	"github.com/sagernet/sing/common"
     8  	E "github.com/sagernet/sing/common/exceptions"
     9  )
    10  
    11  type Option[T any] func(b *Batch[T])
    12  
    13  type Result[T any] struct {
    14  	Value T
    15  	Err   error
    16  }
    17  
    18  type Error struct {
    19  	Key string
    20  	Err error
    21  }
    22  
    23  func (e *Error) Error() string {
    24  	return E.Cause(e.Err, e.Key).Error()
    25  }
    26  
    27  func WithConcurrencyNum[T any](n int) Option[T] {
    28  	return func(b *Batch[T]) {
    29  		q := make(chan struct{}, n)
    30  		for i := 0; i < n; i++ {
    31  			q <- struct{}{}
    32  		}
    33  		b.queue = q
    34  	}
    35  }
    36  
    37  // Batch similar to errgroup, but can control the maximum number of concurrent
    38  type Batch[T any] struct {
    39  	result map[string]Result[T]
    40  	queue  chan struct{}
    41  	wg     sync.WaitGroup
    42  	mux    sync.Mutex
    43  	err    *Error
    44  	once   sync.Once
    45  	cancel common.ContextCancelCauseFunc
    46  }
    47  
    48  func (b *Batch[T]) Go(key string, fn func() (T, error)) {
    49  	b.wg.Add(1)
    50  	go func() {
    51  		defer b.wg.Done()
    52  		if b.queue != nil {
    53  			<-b.queue
    54  			defer func() {
    55  				b.queue <- struct{}{}
    56  			}()
    57  		}
    58  
    59  		value, err := fn()
    60  		if err != nil {
    61  			b.once.Do(func() {
    62  				b.err = &Error{key, err}
    63  				if b.cancel != nil {
    64  					b.cancel(b.err)
    65  				}
    66  			})
    67  		}
    68  
    69  		ret := Result[T]{value, err}
    70  		b.mux.Lock()
    71  		defer b.mux.Unlock()
    72  		b.result[key] = ret
    73  	}()
    74  }
    75  
    76  func (b *Batch[T]) Wait() *Error {
    77  	b.wg.Wait()
    78  	if b.cancel != nil {
    79  		b.cancel(nil)
    80  	}
    81  	return b.err
    82  }
    83  
    84  func (b *Batch[T]) WaitAndGetResult() (map[string]Result[T], *Error) {
    85  	err := b.Wait()
    86  	return b.Result(), err
    87  }
    88  
    89  func (b *Batch[T]) Result() map[string]Result[T] {
    90  	b.mux.Lock()
    91  	defer b.mux.Unlock()
    92  	copy := map[string]Result[T]{}
    93  	for k, v := range b.result {
    94  		copy[k] = v
    95  	}
    96  	return copy
    97  }
    98  
    99  func New[T any](ctx context.Context, opts ...Option[T]) (*Batch[T], context.Context) {
   100  	ctx, cancel := common.ContextWithCancelCause(ctx)
   101  
   102  	b := &Batch[T]{
   103  		result: map[string]Result[T]{},
   104  	}
   105  
   106  	for _, o := range opts {
   107  		o(b)
   108  	}
   109  
   110  	b.cancel = cancel
   111  	return b, ctx
   112  }