github.com/reusee/pr2@v0.0.0-20230630035947-72a20ff5e864/consume.go (about)

     1  package pr2
     2  
     3  import (
     4  	"container/list"
     5  	"context"
     6  	"fmt"
     7  	"math"
     8  	"sync"
     9  
    10  	"github.com/reusee/e5"
    11  )
    12  
    13  type Put[T any] func(T) bool
    14  
    15  type Wait = func(noMorePut bool) error
    16  
    17  type ConsumeOption interface {
    18  	IsConsumeOption()
    19  }
    20  
    21  type BacklogSize int
    22  
    23  func (BacklogSize) IsConsumeOption() {}
    24  
    25  func Consume[T any](
    26  	ctx context.Context,
    27  	numThread int,
    28  	fn func(threadID int, value T) error,
    29  	options ...ConsumeOption,
    30  ) (
    31  	put Put[T],
    32  	wait Wait,
    33  ) {
    34  
    35  	backlogSize := int(math.MaxInt32)
    36  
    37  	for _, option := range options {
    38  		switch option := option.(type) {
    39  		case BacklogSize:
    40  			backlogSize = int(option)
    41  		default:
    42  			panic(fmt.Errorf("unknown option: %T", option))
    43  		}
    44  	}
    45  
    46  	inCh := make(chan T)
    47  	outCh := make(chan T)
    48  	errCh := make(chan error, 1)
    49  	valueCond := sync.NewCond(new(sync.Mutex))
    50  	numValue := 0
    51  	wg := NewWaitGroup(ctx)
    52  
    53  	wg.Go(func() {
    54  		values := list.New()
    55  		var c chan T
    56  	loop:
    57  		for {
    58  
    59  			c = inCh
    60  			if values.Len() > backlogSize {
    61  				c = nil
    62  			}
    63  
    64  			if values.Len() > 0 {
    65  				select {
    66  
    67  				case outCh <- values.Front().Value.(T):
    68  					values.Remove(values.Front())
    69  
    70  				case v, ok := <-c:
    71  					if !ok {
    72  						break loop
    73  					}
    74  					values.PushBack(v)
    75  
    76  				case <-ctx.Done():
    77  					break loop
    78  
    79  				}
    80  
    81  			} else {
    82  				select {
    83  
    84  				case v, ok := <-c:
    85  					if !ok {
    86  						break loop
    87  					}
    88  					select {
    89  					case outCh <- v:
    90  					default:
    91  						values.PushBack(v)
    92  					}
    93  
    94  				case <-ctx.Done():
    95  					break loop
    96  
    97  				}
    98  			}
    99  
   100  		}
   101  
   102  		elem := values.Front()
   103  		for elem != nil {
   104  			outCh <- elem.Value.(T)
   105  			elem = elem.Next()
   106  		}
   107  
   108  		close(outCh)
   109  
   110  	})
   111  
   112  	var putLock sync.RWMutex
   113  	putClosed := false
   114  	put = func(v T) bool {
   115  
   116  		putLock.RLock()
   117  		defer putLock.RUnlock()
   118  
   119  		if putClosed {
   120  			return false
   121  		}
   122  
   123  		if len(errCh) > 0 {
   124  			return false
   125  		}
   126  
   127  		select {
   128  
   129  		case inCh <- v:
   130  			valueCond.L.Lock()
   131  			numValue++
   132  			n := numValue
   133  			valueCond.L.Unlock()
   134  			if n == 0 {
   135  				valueCond.Signal()
   136  			}
   137  			return true
   138  
   139  		case <-ctx.Done():
   140  			return false
   141  
   142  		}
   143  	}
   144  
   145  	var closeOnce sync.Once
   146  	closePut := func() {
   147  		closeOnce.Do(func() {
   148  			putLock.Lock()
   149  			defer putLock.Unlock()
   150  			putClosed = true
   151  			close(inCh)
   152  		})
   153  	}
   154  
   155  	wait = func(noMorePut bool) error {
   156  
   157  		if noMorePut {
   158  			closePut()
   159  			wg.Wait()
   160  		}
   161  
   162  		valueCond.L.Lock()
   163  		for numValue != 0 {
   164  			valueCond.Wait()
   165  		}
   166  		valueCond.L.Unlock()
   167  
   168  		select {
   169  		case err := <-errCh:
   170  			return err
   171  		default:
   172  		}
   173  
   174  		return nil
   175  	}
   176  
   177  	for i := 0; i < numThread; i++ {
   178  		i := i
   179  
   180  		wg.Go(func() {
   181  
   182  			for v := range outCh {
   183  				err := func() (err error) {
   184  					defer e5.Handle(&err)
   185  					return fn(i, v)
   186  				}()
   187  				if err != nil {
   188  					select {
   189  					case errCh <- err:
   190  					default:
   191  					}
   192  				}
   193  				valueCond.L.Lock()
   194  				numValue--
   195  				n := numValue
   196  				valueCond.L.Unlock()
   197  				if n == 0 {
   198  					valueCond.Signal()
   199  				}
   200  			}
   201  
   202  		})
   203  
   204  	}
   205  
   206  	return
   207  
   208  }