github.com/amp-space/amp-sdk-go@v0.7.6/stdlib/utils/channels.go (about)

     1  package utils
     2  
     3  import (
     4  	"context"
     5  	"reflect"
     6  	"sync/atomic"
     7  	"time"
     8  )
     9  
    10  // ContextFromChan creates a context that finishes when the provided channel
    11  // receives or is closed.
    12  func ContextFromChan(chStop <-chan struct{}) (context.Context, context.CancelFunc) {
    13  	ctx, cancel := context.WithCancel(context.Background())
    14  	go func() {
    15  		select {
    16  		case <-chStop:
    17  			cancel()
    18  		case <-ctx.Done():
    19  		}
    20  	}()
    21  	return ctx, cancel
    22  }
    23  
    24  // WaitGroupChan creates a channel that closes when the provided sync.WaitGroup is done.
    25  type WaitGroupChan struct {
    26  	i         int
    27  	x         int
    28  	chAdd     chan wgAdd
    29  	chWait    chan struct{}
    30  	chCtxDone <-chan struct{}
    31  	chStop    chan struct{}
    32  	waitCalls uint32
    33  }
    34  
    35  type wgAdd struct {
    36  	i   int
    37  	err chan string
    38  }
    39  
    40  func NewWaitGroupChan(ctx context.Context) *WaitGroupChan {
    41  	wg := &WaitGroupChan{
    42  		chAdd:  make(chan wgAdd),
    43  		chWait: make(chan struct{}),
    44  		chStop: make(chan struct{}),
    45  	}
    46  	if ctx != nil {
    47  		wg.chCtxDone = ctx.Done()
    48  	}
    49  
    50  	go func() {
    51  		var done bool
    52  		for {
    53  			select {
    54  			case <-wg.chCtxDone:
    55  				if !done {
    56  					close(wg.chWait)
    57  				}
    58  				return
    59  			case <-wg.chStop:
    60  				if !done {
    61  					close(wg.chWait)
    62  				}
    63  				return
    64  			case wgAdd := <-wg.chAdd:
    65  				if done {
    66  					wgAdd.err <- "WaitGroupChan already finished. Do you need to add a bounding wg.Add(1) and wg.Done()?"
    67  					return
    68  				}
    69  				wg.i += wgAdd.i
    70  				if wg.i < 0 {
    71  					wgAdd.err <- "called Done() too many times"
    72  					close(wg.chWait)
    73  					return
    74  				} else if wg.i == 0 {
    75  					done = true
    76  					close(wg.chWait)
    77  				}
    78  				wgAdd.err <- ""
    79  			}
    80  		}
    81  	}()
    82  
    83  	return wg
    84  }
    85  
    86  func (wg *WaitGroupChan) Close() {
    87  	close(wg.chStop)
    88  }
    89  
    90  func (wg *WaitGroupChan) Add(i int) {
    91  	if atomic.LoadUint32(&wg.waitCalls) > 0 {
    92  		panic("cannot call Add() after Wait()")
    93  	}
    94  	ch := make(chan string)
    95  	select {
    96  	case <-wg.chCtxDone:
    97  	case <-wg.chStop:
    98  	case wg.chAdd <- wgAdd{i, ch}:
    99  		err := <-ch
   100  		if err != "" {
   101  			panic(err)
   102  		}
   103  	}
   104  }
   105  
   106  func (wg *WaitGroupChan) Done() {
   107  	ch := make(chan string)
   108  	select {
   109  	case <-wg.chCtxDone:
   110  	case <-wg.chStop:
   111  	case <-wg.chWait:
   112  	case wg.chAdd <- wgAdd{-1, ch}:
   113  		err := <-ch
   114  		if err != "" {
   115  			panic(err)
   116  		}
   117  	}
   118  }
   119  
   120  func (wg *WaitGroupChan) Wait() <-chan struct{} {
   121  	atomic.StoreUint32(&wg.waitCalls, 1)
   122  	return wg.chWait
   123  }
   124  
   125  // CombinedContext creates a context that finishes when any of the provided
   126  // signals finish.  A signal can be a `context.Context`, a `chan struct{}`, or
   127  // a `time.Duration` (which is transformed into a `context.WithTimeout`).
   128  func CombinedContext(signals ...interface{}) (context.Context, context.CancelFunc) {
   129  	ctx, cancel := context.WithCancel(context.Background())
   130  	if len(signals) == 0 {
   131  		return ctx, cancel
   132  	}
   133  	signals = append(signals, ctx)
   134  
   135  	var cases []reflect.SelectCase
   136  	var cancel2 context.CancelFunc
   137  	for _, signal := range signals {
   138  		var ch reflect.Value
   139  
   140  		switch sig := signal.(type) {
   141  		case context.Context:
   142  			ch = reflect.ValueOf(sig.Done())
   143  		case <-chan struct{}:
   144  			ch = reflect.ValueOf(sig)
   145  		case chan struct{}:
   146  			ch = reflect.ValueOf(sig)
   147  		case time.Duration:
   148  			var ctxTimeout context.Context
   149  			ctxTimeout, cancel2 = context.WithTimeout(ctx, sig)
   150  			ch = reflect.ValueOf(ctxTimeout.Done())
   151  		default:
   152  			continue
   153  		}
   154  		cases = append(cases, reflect.SelectCase{Chan: ch, Dir: reflect.SelectRecv})
   155  	}
   156  
   157  	go func() {
   158  		defer cancel()
   159  		if cancel2 != nil {
   160  			defer cancel2()
   161  		}
   162  		_, _, _ = reflect.Select(cases)
   163  	}()
   164  
   165  	return context.WithCancel(ctx)
   166  }
   167  
   168  type ChanContext chan struct{}
   169  
   170  var _ context.Context = ChanContext(nil)
   171  
   172  func (ch ChanContext) Deadline() (deadline time.Time, ok bool) {
   173  	return time.Time{}, false
   174  }
   175  
   176  func (ch ChanContext) Done() <-chan struct{} {
   177  	return ch
   178  }
   179  
   180  func (ch ChanContext) Err() error {
   181  	select {
   182  	case <-ch:
   183  		return context.Canceled
   184  	default:
   185  		return nil
   186  	}
   187  }
   188  
   189  func (ch ChanContext) Value(key interface{}) interface{} {
   190  	return nil
   191  }