github.com/whoyao/protocol@v0.0.0-20230519045905-2d8ace718ca5/rpc/race.go (about)

     1  package rpc
     2  
     3  import (
     4  	"context"
     5  	"sync"
     6  )
     7  
     8  type raceResult[T any] struct {
     9  	i   int
    10  	val *T
    11  	err error
    12  }
    13  
    14  type Race[T any] struct {
    15  	ctx       context.Context
    16  	cancel    context.CancelFunc
    17  	nextIndex int
    18  
    19  	resultLock sync.Mutex
    20  	result     *raceResult[T]
    21  }
    22  
    23  // NewRace creates a race to yield the result from one or more candidate
    24  // functions
    25  func NewRace[T any](ctx context.Context) *Race[T] {
    26  	ctx, cancel := context.WithCancel(ctx)
    27  	return &Race[T]{
    28  		ctx:    ctx,
    29  		cancel: cancel,
    30  	}
    31  }
    32  
    33  // Go adds a candidate function to the race by running it in a new goroutine
    34  func (r *Race[T]) Go(fn func(ctx context.Context) (*T, error)) {
    35  	i := r.nextIndex
    36  	r.nextIndex++
    37  
    38  	go func() {
    39  		val, err := fn(r.ctx)
    40  
    41  		r.resultLock.Lock()
    42  		if r.result == nil {
    43  			r.result = &raceResult[T]{i, val, err}
    44  		}
    45  		r.resultLock.Unlock()
    46  
    47  		r.cancel()
    48  	}()
    49  }
    50  
    51  // Wait awaits the first complete function and returns the index and results
    52  // or -1 if the context is cancelled before any candidate finishes.
    53  func (r *Race[T]) Wait() (int, *T, error) {
    54  	<-r.ctx.Done()
    55  
    56  	r.resultLock.Lock()
    57  	res := r.result
    58  	r.resultLock.Unlock()
    59  	if res != nil {
    60  		return res.i, res.val, res.err
    61  	}
    62  	return -1, nil, r.ctx.Err()
    63  }