github.com/sagernet/sing@v0.4.0-beta.19.0.20240518125136-f67a0988a636/common/task/task.go (about) 1 package task 2 3 import ( 4 "context" 5 "sync" 6 7 "github.com/sagernet/sing/common" 8 E "github.com/sagernet/sing/common/exceptions" 9 ) 10 11 type taskItem struct { 12 Name string 13 Run func(ctx context.Context) error 14 } 15 16 type errTaskSucceed struct{} 17 18 func (e errTaskSucceed) Error() string { 19 return "task succeed" 20 } 21 22 type Group struct { 23 tasks []taskItem 24 cleanup func() 25 fastFail bool 26 queue chan struct{} 27 } 28 29 func (g *Group) Append(name string, f func(ctx context.Context) error) { 30 g.tasks = append(g.tasks, taskItem{ 31 Name: name, 32 Run: f, 33 }) 34 } 35 36 func (g *Group) Append0(f func(ctx context.Context) error) { 37 g.tasks = append(g.tasks, taskItem{ 38 Run: f, 39 }) 40 } 41 42 func (g *Group) Cleanup(f func()) { 43 g.cleanup = f 44 } 45 46 func (g *Group) FastFail() { 47 g.fastFail = true 48 } 49 50 func (g *Group) Concurrency(n int) { 51 g.queue = make(chan struct{}, n) 52 for i := 0; i < n; i++ { 53 g.queue <- struct{}{} 54 } 55 } 56 57 func (g *Group) Run(contextList ...context.Context) error { 58 return g.RunContextList(contextList) 59 } 60 61 func (g *Group) RunContextList(contextList []context.Context) error { 62 if len(contextList) == 0 { 63 contextList = append(contextList, context.Background()) 64 } 65 66 taskContext, taskFinish := common.ContextWithCancelCause(context.Background()) 67 taskCancelContext, taskCancel := common.ContextWithCancelCause(context.Background()) 68 69 var errorAccess sync.Mutex 70 var returnError error 71 taskCount := len(g.tasks) 72 73 for _, task := range g.tasks { 74 currentTask := task 75 go func() { 76 if g.queue != nil { 77 select { 78 case <-taskCancelContext.Done(): 79 errorAccess.Lock() 80 taskCount-- 81 currentCount := taskCount 82 if currentCount == 0 { 83 taskCancel(errTaskSucceed{}) 84 taskFinish(errTaskSucceed{}) 85 } 86 errorAccess.Unlock() 87 return 88 case <-g.queue: 89 } 90 } 91 err := currentTask.Run(taskCancelContext) 92 errorAccess.Lock() 93 if err != nil { 94 if currentTask.Name != "" { 95 err = E.Cause(err, currentTask.Name) 96 } 97 returnError = E.Errors(returnError, err) 98 if g.fastFail { 99 taskCancel(err) 100 } 101 } 102 taskCount-- 103 currentCount := taskCount 104 errorAccess.Unlock() 105 if currentCount == 0 { 106 taskCancel(errTaskSucceed{}) 107 taskFinish(errTaskSucceed{}) 108 } 109 if g.queue != nil { 110 g.queue <- struct{}{} 111 } 112 }() 113 } 114 115 selectedContext, upstreamErr := common.SelectContext(append([]context.Context{taskCancelContext}, contextList...)) 116 117 if selectedContext != 0 { 118 taskCancel(upstreamErr) 119 } 120 121 if g.cleanup != nil { 122 g.cleanup() 123 } 124 125 <-taskContext.Done() 126 127 if selectedContext != 0 { 128 returnError = E.Append(returnError, upstreamErr, func(err error) error { 129 return E.Cause(err, "upstream") 130 }) 131 } 132 133 return returnError 134 }