github.com/fluhus/gostuff@v0.4.1-0.20240331134726-be71864f2b5d/ppln/nserial.go (about) 1 package ppln 2 3 import ( 4 "fmt" 5 "sync" 6 ) 7 8 // NonSerial starts a multi-goroutine transformation pipeline. 9 // 10 // Pusher should call push on every input value. 11 // Stop indicates if an error was returned and pushing should stop. 12 // Mapper receives an input (a), a push function for the results, 0-based 13 // goroutine number (g). 14 // It should call push zero or more times with the processing results of a. 15 // Puller acts on a single output. The order of the outputs is arbitrary, but 16 // correlated with the order of pusher's inputs. 17 // 18 // If one of the functions returns a non-nil error, the process stops and the 19 // error is returned. Otherwise returns nil. 20 func NonSerial[T1 any, T2 any]( 21 ngoroutines int, 22 pusher func(push func(T1), stop func() bool) error, 23 mapper func(a T1, push func(T2), g int) error, 24 puller func(a T2) error) error { 25 if ngoroutines < 0 { 26 panic(fmt.Sprintf("bad number of goroutines: %d", ngoroutines)) 27 } 28 29 var err error 30 31 // An optimization for a single thread. 32 if ngoroutines == 0 { 33 perr := pusher(func(a T1) { 34 if err != nil { 35 return 36 } 37 merr := mapper(a, func(i T2) { 38 if err != nil { 39 return 40 } 41 perr := puller(i) 42 if perr != nil && err == nil { 43 err = perr 44 } 45 }, 0) 46 if merr != nil && err == nil { 47 err = merr 48 } 49 }, func() bool { return err != nil }) 50 if perr != nil && err == nil { 51 err = perr 52 } 53 return err 54 } 55 56 push := make(chan T1, ngoroutines*chanLenCoef) 57 pull := make(chan T2, ngoroutines*chanLenCoef) 58 wait := &sync.WaitGroup{} 59 wait.Add(ngoroutines) 60 61 go func() { 62 perr := pusher(func(a T1) { 63 if err != nil { 64 return 65 } 66 push <- a 67 }, func() bool { return err != nil }) 68 if perr != nil && err == nil { 69 err = perr 70 } 71 close(push) 72 }() 73 for i := 0; i < ngoroutines; i++ { 74 i := i 75 go func() { 76 for item := range push { 77 if err != nil { 78 continue // Drain channel. 79 } 80 merr := mapper(item, func(a T2) { 81 if err != nil { 82 return 83 } 84 pull <- a 85 }, i) 86 if merr != nil && err == nil { 87 err = merr 88 } 89 } 90 wait.Done() 91 }() 92 } 93 go func() { 94 for item := range pull { 95 if err != nil { 96 continue // Drain channel. 97 } 98 perr := puller(item) 99 if perr != nil && err == nil { 100 err = perr 101 } 102 } 103 for range pull { // Drain channel. 104 } 105 wait.Done() 106 }() 107 108 wait.Wait() // Wait for workers. 109 wait.Add(1) 110 close(pull) 111 wait.Wait() // Wait for pull. 112 113 return err 114 }