github.com/fluhus/gostuff@v0.4.1-0.20240331134726-be71864f2b5d/ppln/v2/nserial.go (about) 1 package ppln 2 3 import ( 4 "fmt" 5 "iter" 6 "sync" 7 "sync/atomic" 8 ) 9 10 // NonSerial starts a multi-goroutine transformation pipeline. 11 // 12 // Input is an iterator over the input values to be transformed. 13 // It will be called in a thread-safe manner. 14 // Transform receives an input (a) and a 0-based goroutine number (g), 15 // and returns the result of processing a. 16 // Output acts on a single result, and will be called in a thread-safe manner. 17 // The order of outputs is arbitrary, but correlated with the order of 18 // inputs. 19 // 20 // If one of the functions returns a non-nil error, the process stops and the 21 // error is returned. Otherwise returns nil. 22 func NonSerial[T1 any, T2 any]( 23 ngoroutines int, 24 input iter.Seq2[T1, error], 25 transform func(a T1, g int) (T2, error), 26 output func(a T2) error) error { 27 if ngoroutines < 1 { 28 panic(fmt.Sprintf("bad number of goroutines: %d", ngoroutines)) 29 } 30 pull, pstop := iter.Pull2(input) 31 defer pstop() 32 33 // An optimization for a single thread. 34 if ngoroutines == 1 { 35 for { 36 t1, err, ok := pull() 37 38 if !ok { 39 return nil 40 } 41 if err != nil { 42 return err 43 } 44 45 t2, err := transform(t1, 0) 46 if err != nil { 47 return err 48 } 49 if err := output(t2); err != nil { 50 return err 51 } 52 } 53 } 54 55 ilock := &sync.Mutex{} 56 olock := &sync.Mutex{} 57 errs := make(chan error, ngoroutines) 58 stop := &atomic.Bool{} 59 60 for g := 0; g < ngoroutines; g++ { 61 go func(g int) { 62 for { 63 if stop.Load() { 64 errs <- nil 65 return 66 } 67 68 ilock.Lock() 69 t1, err, ok := pull() 70 ilock.Unlock() 71 72 if !ok { 73 errs <- nil 74 return 75 } 76 if err != nil { 77 stop.Store(true) 78 errs <- err 79 return 80 } 81 82 t2, err := transform(t1, g) 83 if err != nil { 84 stop.Store(true) 85 errs <- err 86 return 87 } 88 89 olock.Lock() 90 err = output(t2) 91 olock.Unlock() 92 if err != nil { 93 stop.Store(true) 94 errs <- err 95 return 96 } 97 } 98 }(g) 99 } 100 101 for g := 0; g < ngoroutines; g++ { 102 if err := <-errs; err != nil { 103 return err 104 } 105 } 106 return nil 107 }