github.com/pachyderm/pachyderm@v1.13.4/src/server/pkg/storage/chunk/chain.go (about)

     1  package chunk
     2  
     3  import (
     4  	"context"
     5  
     6  	"golang.org/x/sync/errgroup"
     7  )
     8  
     9  // TaskChain manages a chain of tasks that have a parallel and serial part.
    10  // The parallel part should be executed by the user specified callback and the
    11  // serial part should be executed within the callback passed into the user
    12  // specified callback.
    13  type TaskChain struct {
    14  	eg       *errgroup.Group
    15  	ctx      context.Context
    16  	prevChan chan struct{}
    17  }
    18  
    19  // NewTaskChain creates a new task chain.
    20  func NewTaskChain(ctx context.Context) *TaskChain {
    21  	eg, errCtx := errgroup.WithContext(ctx)
    22  	prevChan := make(chan struct{})
    23  	close(prevChan)
    24  	return &TaskChain{
    25  		eg:       eg,
    26  		ctx:      errCtx,
    27  		prevChan: prevChan,
    28  	}
    29  }
    30  
    31  // CreateTask creates a new task in the task chain.
    32  func (c *TaskChain) CreateTask(cb func(context.Context, func(func() error) error) error) error {
    33  	select {
    34  	case <-c.ctx.Done():
    35  		return c.ctx.Err()
    36  	default:
    37  	}
    38  	scb := c.serialCallback()
    39  	c.eg.Go(func() error {
    40  		return cb(c.ctx, scb)
    41  	})
    42  	return nil
    43  }
    44  
    45  // Wait waits on the currently executing tasks to finish.
    46  func (c *TaskChain) Wait() error {
    47  	select {
    48  	case <-c.ctx.Done():
    49  		return c.ctx.Err()
    50  	case <-c.prevChan:
    51  		return nil
    52  	}
    53  }
    54  
    55  func (c *TaskChain) serialCallback() func(func() error) error {
    56  	prevChan := c.prevChan
    57  	nextChan := make(chan struct{})
    58  	c.prevChan = nextChan
    59  	return func(cb func() error) error {
    60  		defer close(nextChan)
    61  		select {
    62  		case <-prevChan:
    63  			return cb()
    64  		case <-c.ctx.Done():
    65  			return c.ctx.Err()
    66  		}
    67  	}
    68  }