github.com/fluhus/gostuff@v0.4.1-0.20240331134726-be71864f2b5d/ppln/nserial.go (about)

     1  package ppln
     2  
     3  import (
     4  	"fmt"
     5  	"sync"
     6  )
     7  
     8  // NonSerial starts a multi-goroutine transformation pipeline.
     9  //
    10  // Pusher should call push on every input value.
    11  // Stop indicates if an error was returned and pushing should stop.
    12  // Mapper receives an input (a), a push function for the results, 0-based
    13  // goroutine number (g).
    14  // It should call push zero or more times with the processing results of a.
    15  // Puller acts on a single output. The order of the outputs is arbitrary, but
    16  // correlated with the order of pusher's inputs.
    17  //
    18  // If one of the functions returns a non-nil error, the process stops and the
    19  // error is returned. Otherwise returns nil.
    20  func NonSerial[T1 any, T2 any](
    21  	ngoroutines int,
    22  	pusher func(push func(T1), stop func() bool) error,
    23  	mapper func(a T1, push func(T2), g int) error,
    24  	puller func(a T2) error) error {
    25  	if ngoroutines < 0 {
    26  		panic(fmt.Sprintf("bad number of goroutines: %d", ngoroutines))
    27  	}
    28  
    29  	var err error
    30  
    31  	// An optimization for a single thread.
    32  	if ngoroutines == 0 {
    33  		perr := pusher(func(a T1) {
    34  			if err != nil {
    35  				return
    36  			}
    37  			merr := mapper(a, func(i T2) {
    38  				if err != nil {
    39  					return
    40  				}
    41  				perr := puller(i)
    42  				if perr != nil && err == nil {
    43  					err = perr
    44  				}
    45  			}, 0)
    46  			if merr != nil && err == nil {
    47  				err = merr
    48  			}
    49  		}, func() bool { return err != nil })
    50  		if perr != nil && err == nil {
    51  			err = perr
    52  		}
    53  		return err
    54  	}
    55  
    56  	push := make(chan T1, ngoroutines*chanLenCoef)
    57  	pull := make(chan T2, ngoroutines*chanLenCoef)
    58  	wait := &sync.WaitGroup{}
    59  	wait.Add(ngoroutines)
    60  
    61  	go func() {
    62  		perr := pusher(func(a T1) {
    63  			if err != nil {
    64  				return
    65  			}
    66  			push <- a
    67  		}, func() bool { return err != nil })
    68  		if perr != nil && err == nil {
    69  			err = perr
    70  		}
    71  		close(push)
    72  	}()
    73  	for i := 0; i < ngoroutines; i++ {
    74  		i := i
    75  		go func() {
    76  			for item := range push {
    77  				if err != nil {
    78  					continue // Drain channel.
    79  				}
    80  				merr := mapper(item, func(a T2) {
    81  					if err != nil {
    82  						return
    83  					}
    84  					pull <- a
    85  				}, i)
    86  				if merr != nil && err == nil {
    87  					err = merr
    88  				}
    89  			}
    90  			wait.Done()
    91  		}()
    92  	}
    93  	go func() {
    94  		for item := range pull {
    95  			if err != nil {
    96  				continue // Drain channel.
    97  			}
    98  			perr := puller(item)
    99  			if perr != nil && err == nil {
   100  				err = perr
   101  			}
   102  		}
   103  		for range pull { // Drain channel.
   104  		}
   105  		wait.Done()
   106  	}()
   107  
   108  	wait.Wait() // Wait for workers.
   109  	wait.Add(1)
   110  	close(pull)
   111  	wait.Wait() // Wait for pull.
   112  
   113  	return err
   114  }