github.com/sagernet/sing@v0.2.6/common/task/task.go (about)

     1  package task
     2  
     3  import (
     4  	"context"
     5  	"sync"
     6  
     7  	"github.com/sagernet/sing/common"
     8  	E "github.com/sagernet/sing/common/exceptions"
     9  )
    10  
    11  type taskItem struct {
    12  	Name string
    13  	Run  func(ctx context.Context) error
    14  }
    15  
    16  type errTaskSucceed struct{}
    17  
    18  func (e errTaskSucceed) Error() string {
    19  	return "task succeed"
    20  }
    21  
    22  type Group struct {
    23  	tasks    []taskItem
    24  	cleanup  func()
    25  	fastFail bool
    26  }
    27  
    28  func (g *Group) Append(name string, f func(ctx context.Context) error) {
    29  	g.tasks = append(g.tasks, taskItem{
    30  		Name: name,
    31  		Run:  f,
    32  	})
    33  }
    34  
    35  func (g *Group) Append0(f func(ctx context.Context) error) {
    36  	g.tasks = append(g.tasks, taskItem{
    37  		Run: f,
    38  	})
    39  }
    40  
    41  func (g *Group) Cleanup(f func()) {
    42  	g.cleanup = f
    43  }
    44  
    45  func (g *Group) FastFail() {
    46  	g.fastFail = true
    47  }
    48  
    49  func (g *Group) Run(contextList ...context.Context) error {
    50  	return g.RunContextList(contextList)
    51  }
    52  
    53  func (g *Group) RunContextList(contextList []context.Context) error {
    54  	if len(contextList) == 0 {
    55  		contextList = append(contextList, context.Background())
    56  	}
    57  
    58  	taskContext, taskFinish := common.ContextWithCancelCause(context.Background())
    59  	taskCancelContext, taskCancel := common.ContextWithCancelCause(context.Background())
    60  
    61  	var errorAccess sync.Mutex
    62  	var returnError error
    63  	taskCount := int8(len(g.tasks))
    64  
    65  	for _, task := range g.tasks {
    66  		currentTask := task
    67  		go func() {
    68  			err := currentTask.Run(taskCancelContext)
    69  			errorAccess.Lock()
    70  			if err != nil {
    71  				if currentTask.Name != "" {
    72  					err = E.Cause(err, currentTask.Name)
    73  				}
    74  				returnError = E.Errors(returnError, err)
    75  				if g.fastFail {
    76  					taskCancel(err)
    77  				}
    78  			}
    79  			taskCount--
    80  			currentCount := taskCount
    81  			errorAccess.Unlock()
    82  			if currentCount == 0 {
    83  				taskCancel(errTaskSucceed{})
    84  				taskFinish(errTaskSucceed{})
    85  			}
    86  		}()
    87  	}
    88  
    89  	selectedContext, upstreamErr := common.SelectContext(append([]context.Context{taskCancelContext}, contextList...))
    90  	if selectedContext != 0 {
    91  		returnError = E.Append(returnError, upstreamErr, func(err error) error {
    92  			return E.Cause(err, "upstream")
    93  		})
    94  	}
    95  
    96  	if g.cleanup != nil {
    97  		g.cleanup()
    98  	}
    99  
   100  	<-taskContext.Done()
   101  	return returnError
   102  }