github.com/xmplusdev/xmcore@v1.8.11-0.20240412132628-5518b55526af/common/task/task.go (about)

     1  package task
     2  
     3  import (
     4  	"context"
     5  
     6  	"github.com/xmplusdev/xmcore/common/signal/semaphore"
     7  )
     8  
     9  // OnSuccess executes g() after f() returns nil.
    10  func OnSuccess(f func() error, g func() error) func() error {
    11  	return func() error {
    12  		if err := f(); err != nil {
    13  			return err
    14  		}
    15  		return g()
    16  	}
    17  }
    18  
    19  // Run executes a list of tasks in parallel, returns the first error encountered or nil if all tasks pass.
    20  func Run(ctx context.Context, tasks ...func() error) error {
    21  	n := len(tasks)
    22  	s := semaphore.New(n)
    23  	done := make(chan error, 1)
    24  
    25  	for _, task := range tasks {
    26  		<-s.Wait()
    27  		go func(f func() error) {
    28  			err := f()
    29  			if err == nil {
    30  				s.Signal()
    31  				return
    32  			}
    33  
    34  			select {
    35  			case done <- err:
    36  			default:
    37  			}
    38  		}(task)
    39  	}
    40  
    41  	/*
    42  		if altctx := ctx.Value("altctx"); altctx != nil {
    43  			ctx = altctx.(context.Context)
    44  		}
    45  	*/
    46  
    47  	for i := 0; i < n; i++ {
    48  		select {
    49  		case err := <-done:
    50  			return err
    51  		case <-ctx.Done():
    52  			return ctx.Err()
    53  		case <-s.Wait():
    54  		}
    55  	}
    56  
    57  	/*
    58  		if cancel := ctx.Value("cancel"); cancel != nil {
    59  			cancel.(context.CancelFunc)()
    60  		}
    61  	*/
    62  
    63  	return nil
    64  }