github.com/fluhus/gostuff@v0.4.1-0.20240331134726-be71864f2b5d/ppln/v2/nserial.go (about)

     1  package ppln
     2  
     3  import (
     4  	"fmt"
     5  	"iter"
     6  	"sync"
     7  	"sync/atomic"
     8  )
     9  
    10  // NonSerial starts a multi-goroutine transformation pipeline.
    11  //
    12  // Input is an iterator over the input values to be transformed.
    13  // It will be called in a thread-safe manner.
    14  // Transform receives an input (a) and a 0-based goroutine number (g),
    15  // and returns the result of processing a.
    16  // Output acts on a single result, and will be called in a thread-safe manner.
    17  // The order of outputs is arbitrary, but correlated with the order of
    18  // inputs.
    19  //
    20  // If one of the functions returns a non-nil error, the process stops and the
    21  // error is returned. Otherwise returns nil.
    22  func NonSerial[T1 any, T2 any](
    23  	ngoroutines int,
    24  	input iter.Seq2[T1, error],
    25  	transform func(a T1, g int) (T2, error),
    26  	output func(a T2) error) error {
    27  	if ngoroutines < 1 {
    28  		panic(fmt.Sprintf("bad number of goroutines: %d", ngoroutines))
    29  	}
    30  	pull, pstop := iter.Pull2(input)
    31  	defer pstop()
    32  
    33  	// An optimization for a single thread.
    34  	if ngoroutines == 1 {
    35  		for {
    36  			t1, err, ok := pull()
    37  
    38  			if !ok {
    39  				return nil
    40  			}
    41  			if err != nil {
    42  				return err
    43  			}
    44  
    45  			t2, err := transform(t1, 0)
    46  			if err != nil {
    47  				return err
    48  			}
    49  			if err := output(t2); err != nil {
    50  				return err
    51  			}
    52  		}
    53  	}
    54  
    55  	ilock := &sync.Mutex{}
    56  	olock := &sync.Mutex{}
    57  	errs := make(chan error, ngoroutines)
    58  	stop := &atomic.Bool{}
    59  
    60  	for g := 0; g < ngoroutines; g++ {
    61  		go func(g int) {
    62  			for {
    63  				if stop.Load() {
    64  					errs <- nil
    65  					return
    66  				}
    67  
    68  				ilock.Lock()
    69  				t1, err, ok := pull()
    70  				ilock.Unlock()
    71  
    72  				if !ok {
    73  					errs <- nil
    74  					return
    75  				}
    76  				if err != nil {
    77  					stop.Store(true)
    78  					errs <- err
    79  					return
    80  				}
    81  
    82  				t2, err := transform(t1, g)
    83  				if err != nil {
    84  					stop.Store(true)
    85  					errs <- err
    86  					return
    87  				}
    88  
    89  				olock.Lock()
    90  				err = output(t2)
    91  				olock.Unlock()
    92  				if err != nil {
    93  					stop.Store(true)
    94  					errs <- err
    95  					return
    96  				}
    97  			}
    98  		}(g)
    99  	}
   100  
   101  	for g := 0; g < ngoroutines; g++ {
   102  		if err := <-errs; err != nil {
   103  			return err
   104  		}
   105  	}
   106  	return nil
   107  }