github.com/sagernet/sing@v0.4.0-beta.19.0.20240518125136-f67a0988a636/common/task/task.go (about)

     1  package task
     2  
     3  import (
     4  	"context"
     5  	"sync"
     6  
     7  	"github.com/sagernet/sing/common"
     8  	E "github.com/sagernet/sing/common/exceptions"
     9  )
    10  
    11  type taskItem struct {
    12  	Name string
    13  	Run  func(ctx context.Context) error
    14  }
    15  
    16  type errTaskSucceed struct{}
    17  
    18  func (e errTaskSucceed) Error() string {
    19  	return "task succeed"
    20  }
    21  
    22  type Group struct {
    23  	tasks    []taskItem
    24  	cleanup  func()
    25  	fastFail bool
    26  	queue    chan struct{}
    27  }
    28  
    29  func (g *Group) Append(name string, f func(ctx context.Context) error) {
    30  	g.tasks = append(g.tasks, taskItem{
    31  		Name: name,
    32  		Run:  f,
    33  	})
    34  }
    35  
    36  func (g *Group) Append0(f func(ctx context.Context) error) {
    37  	g.tasks = append(g.tasks, taskItem{
    38  		Run: f,
    39  	})
    40  }
    41  
    42  func (g *Group) Cleanup(f func()) {
    43  	g.cleanup = f
    44  }
    45  
    46  func (g *Group) FastFail() {
    47  	g.fastFail = true
    48  }
    49  
    50  func (g *Group) Concurrency(n int) {
    51  	g.queue = make(chan struct{}, n)
    52  	for i := 0; i < n; i++ {
    53  		g.queue <- struct{}{}
    54  	}
    55  }
    56  
    57  func (g *Group) Run(contextList ...context.Context) error {
    58  	return g.RunContextList(contextList)
    59  }
    60  
    61  func (g *Group) RunContextList(contextList []context.Context) error {
    62  	if len(contextList) == 0 {
    63  		contextList = append(contextList, context.Background())
    64  	}
    65  
    66  	taskContext, taskFinish := common.ContextWithCancelCause(context.Background())
    67  	taskCancelContext, taskCancel := common.ContextWithCancelCause(context.Background())
    68  
    69  	var errorAccess sync.Mutex
    70  	var returnError error
    71  	taskCount := len(g.tasks)
    72  
    73  	for _, task := range g.tasks {
    74  		currentTask := task
    75  		go func() {
    76  			if g.queue != nil {
    77  				select {
    78  				case <-taskCancelContext.Done():
    79  					errorAccess.Lock()
    80  					taskCount--
    81  					currentCount := taskCount
    82  					if currentCount == 0 {
    83  						taskCancel(errTaskSucceed{})
    84  						taskFinish(errTaskSucceed{})
    85  					}
    86  					errorAccess.Unlock()
    87  					return
    88  				case <-g.queue:
    89  				}
    90  			}
    91  			err := currentTask.Run(taskCancelContext)
    92  			errorAccess.Lock()
    93  			if err != nil {
    94  				if currentTask.Name != "" {
    95  					err = E.Cause(err, currentTask.Name)
    96  				}
    97  				returnError = E.Errors(returnError, err)
    98  				if g.fastFail {
    99  					taskCancel(err)
   100  				}
   101  			}
   102  			taskCount--
   103  			currentCount := taskCount
   104  			errorAccess.Unlock()
   105  			if currentCount == 0 {
   106  				taskCancel(errTaskSucceed{})
   107  				taskFinish(errTaskSucceed{})
   108  			}
   109  			if g.queue != nil {
   110  				g.queue <- struct{}{}
   111  			}
   112  		}()
   113  	}
   114  
   115  	selectedContext, upstreamErr := common.SelectContext(append([]context.Context{taskCancelContext}, contextList...))
   116  
   117  	if selectedContext != 0 {
   118  		taskCancel(upstreamErr)
   119  	}
   120  
   121  	if g.cleanup != nil {
   122  		g.cleanup()
   123  	}
   124  
   125  	<-taskContext.Done()
   126  
   127  	if selectedContext != 0 {
   128  		returnError = E.Append(returnError, upstreamErr, func(err error) error {
   129  			return E.Cause(err, "upstream")
   130  		})
   131  	}
   132  
   133  	return returnError
   134  }