github.com/sagernet/sing@v0.2.6/common/task/task.go (about) 1 package task 2 3 import ( 4 "context" 5 "sync" 6 7 "github.com/sagernet/sing/common" 8 E "github.com/sagernet/sing/common/exceptions" 9 ) 10 11 type taskItem struct { 12 Name string 13 Run func(ctx context.Context) error 14 } 15 16 type errTaskSucceed struct{} 17 18 func (e errTaskSucceed) Error() string { 19 return "task succeed" 20 } 21 22 type Group struct { 23 tasks []taskItem 24 cleanup func() 25 fastFail bool 26 } 27 28 func (g *Group) Append(name string, f func(ctx context.Context) error) { 29 g.tasks = append(g.tasks, taskItem{ 30 Name: name, 31 Run: f, 32 }) 33 } 34 35 func (g *Group) Append0(f func(ctx context.Context) error) { 36 g.tasks = append(g.tasks, taskItem{ 37 Run: f, 38 }) 39 } 40 41 func (g *Group) Cleanup(f func()) { 42 g.cleanup = f 43 } 44 45 func (g *Group) FastFail() { 46 g.fastFail = true 47 } 48 49 func (g *Group) Run(contextList ...context.Context) error { 50 return g.RunContextList(contextList) 51 } 52 53 func (g *Group) RunContextList(contextList []context.Context) error { 54 if len(contextList) == 0 { 55 contextList = append(contextList, context.Background()) 56 } 57 58 taskContext, taskFinish := common.ContextWithCancelCause(context.Background()) 59 taskCancelContext, taskCancel := common.ContextWithCancelCause(context.Background()) 60 61 var errorAccess sync.Mutex 62 var returnError error 63 taskCount := int8(len(g.tasks)) 64 65 for _, task := range g.tasks { 66 currentTask := task 67 go func() { 68 err := currentTask.Run(taskCancelContext) 69 errorAccess.Lock() 70 if err != nil { 71 if currentTask.Name != "" { 72 err = E.Cause(err, currentTask.Name) 73 } 74 returnError = E.Errors(returnError, err) 75 if g.fastFail { 76 taskCancel(err) 77 } 78 } 79 taskCount-- 80 currentCount := taskCount 81 errorAccess.Unlock() 82 if currentCount == 0 { 83 taskCancel(errTaskSucceed{}) 84 taskFinish(errTaskSucceed{}) 85 } 86 }() 87 } 88 89 selectedContext, upstreamErr := common.SelectContext(append([]context.Context{taskCancelContext}, contextList...)) 90 if selectedContext != 0 { 91 returnError = E.Append(returnError, upstreamErr, func(err error) error { 92 return E.Cause(err, "upstream") 93 }) 94 } 95 96 if g.cleanup != nil { 97 g.cleanup() 98 } 99 100 <-taskContext.Done() 101 return returnError 102 }