github.com/sohaha/zlsgo@v1.7.13-0.20240501141223-10dd1a906f76/zsync/promise_extension.go (about)

     1  package zsync
     2  
     3  import (
     4  	"context"
     5  	"strings"
     6  )
     7  
     8  func PromiseAllContext[T any](ctx context.Context, promises ...*Promise[T]) *Promise[[]T] {
     9  	return NewPromiseContext(ctx, func() (res []T, err error) {
    10  		if len(promises) == 0 {
    11  			return
    12  		}
    13  
    14  		res = make([]T, len(promises))
    15  		for index := range promises {
    16  			value, err := promises[index].Done()
    17  			if err != nil {
    18  				return nil, err
    19  			}
    20  			res[index] = value
    21  		}
    22  
    23  		return
    24  	})
    25  }
    26  
    27  func PromiseAll[T any](promises ...*Promise[T]) *Promise[[]T] {
    28  	return PromiseAllContext(context.Background(), promises...)
    29  }
    30  
    31  func PromiseRaceContext[T any](ctx context.Context, promises ...*Promise[T]) *Promise[T] {
    32  	return NewPromiseContext(ctx, func() (res T, err error) {
    33  		if len(promises) == 0 {
    34  			return
    35  		}
    36  
    37  		valC := make(chan T, len(promises))
    38  		errC := make(chan error, len(promises))
    39  		for index := range promises {
    40  			go func(index int) {
    41  				value, err := promises[index].Done()
    42  				if err != nil {
    43  					errC <- err
    44  					return
    45  				}
    46  				valC <- value
    47  			}(index)
    48  		}
    49  
    50  		select {
    51  		case res = <-valC:
    52  		case err = <-errC:
    53  		case <-ctx.Done():
    54  			err = ctx.Err()
    55  		}
    56  
    57  		return
    58  	})
    59  }
    60  
    61  func PromiseRace[T any](promises ...*Promise[T]) *Promise[T] {
    62  	return PromiseRaceContext(context.Background(), promises...)
    63  }
    64  
    65  type AggregateError struct {
    66  	Errors []error
    67  }
    68  
    69  func (ae *AggregateError) Error() string {
    70  	errStrings := make([]string, len(ae.Errors))
    71  
    72  	for i, err := range ae.Errors {
    73  		errStrings[i] = err.Error()
    74  	}
    75  
    76  	return "All promises were rejected: " + strings.Join(errStrings, ", ")
    77  }
    78  
    79  func PromiseAnyContext[T any](ctx context.Context, promises ...*Promise[T]) *Promise[T] {
    80  	return NewPromiseContext(ctx, func() (res T, err error) {
    81  		if len(promises) == 0 {
    82  			return
    83  		}
    84  
    85  		valC := make(chan T, len(promises))
    86  		errC := make(chan error, len(promises))
    87  		for index := range promises {
    88  			go func(index int) {
    89  				value, err := promises[index].Done()
    90  				if err != nil {
    91  					errC <- err
    92  					return
    93  				}
    94  				valC <- value
    95  			}(index)
    96  		}
    97  
    98  		errs := make([]error, 0, len(promises))
    99  	hander:
   100  		select {
   101  		case res = <-valC:
   102  			return
   103  		case e := <-errC:
   104  			errs = append(errs, e)
   105  			if len(errs) == len(promises) {
   106  				err = &AggregateError{
   107  					Errors: errs,
   108  				}
   109  				return
   110  			}
   111  			goto hander
   112  		case <-ctx.Done():
   113  			err = ctx.Err()
   114  			return
   115  		}
   116  	})
   117  }
   118  
   119  func PromiseAny[T any](promises ...*Promise[T]) *Promise[T] {
   120  	return PromiseAnyContext(context.Background(), promises...)
   121  }