github.com/sagernet/sing@v0.4.0-beta.19.0.20240518125136-f67a0988a636/common/batch/batch.go (about) 1 package batch 2 3 import ( 4 "context" 5 "sync" 6 7 "github.com/sagernet/sing/common" 8 E "github.com/sagernet/sing/common/exceptions" 9 ) 10 11 type Option[T any] func(b *Batch[T]) 12 13 type Result[T any] struct { 14 Value T 15 Err error 16 } 17 18 type Error struct { 19 Key string 20 Err error 21 } 22 23 func (e *Error) Error() string { 24 return E.Cause(e.Err, e.Key).Error() 25 } 26 27 func WithConcurrencyNum[T any](n int) Option[T] { 28 return func(b *Batch[T]) { 29 q := make(chan struct{}, n) 30 for i := 0; i < n; i++ { 31 q <- struct{}{} 32 } 33 b.queue = q 34 } 35 } 36 37 // Batch similar to errgroup, but can control the maximum number of concurrent 38 type Batch[T any] struct { 39 result map[string]Result[T] 40 queue chan struct{} 41 wg sync.WaitGroup 42 mux sync.Mutex 43 err *Error 44 once sync.Once 45 cancel common.ContextCancelCauseFunc 46 } 47 48 func (b *Batch[T]) Go(key string, fn func() (T, error)) { 49 b.wg.Add(1) 50 go func() { 51 defer b.wg.Done() 52 if b.queue != nil { 53 <-b.queue 54 defer func() { 55 b.queue <- struct{}{} 56 }() 57 } 58 59 value, err := fn() 60 if err != nil { 61 b.once.Do(func() { 62 b.err = &Error{key, err} 63 if b.cancel != nil { 64 b.cancel(b.err) 65 } 66 }) 67 } 68 69 ret := Result[T]{value, err} 70 b.mux.Lock() 71 defer b.mux.Unlock() 72 b.result[key] = ret 73 }() 74 } 75 76 func (b *Batch[T]) Wait() *Error { 77 b.wg.Wait() 78 if b.cancel != nil { 79 b.cancel(nil) 80 } 81 return b.err 82 } 83 84 func (b *Batch[T]) WaitAndGetResult() (map[string]Result[T], *Error) { 85 err := b.Wait() 86 return b.Result(), err 87 } 88 89 func (b *Batch[T]) Result() map[string]Result[T] { 90 b.mux.Lock() 91 defer b.mux.Unlock() 92 copy := map[string]Result[T]{} 93 for k, v := range b.result { 94 copy[k] = v 95 } 96 return copy 97 } 98 99 func New[T any](ctx context.Context, opts ...Option[T]) (*Batch[T], context.Context) { 100 ctx, cancel := common.ContextWithCancelCause(ctx) 101 102 b := &Batch[T]{ 103 result: map[string]Result[T]{}, 104 } 105 106 for _, o := range opts { 107 o(b) 108 } 109 110 b.cancel = cancel 111 return b, ctx 112 }