github.com/fluhus/gostuff@v0.4.1-0.20240331134726-be71864f2b5d/ppln/serial.go (about) 1 package ppln 2 3 import ( 4 "fmt" 5 "sync" 6 7 "github.com/fluhus/gostuff/heaps" 8 ) 9 10 const ( 11 // Size of pipeline channel buffers per goroutine. 12 chanLenCoef = 1000 13 ) 14 15 // Serial starts a multi-goroutine transformation pipeline that maintains the 16 // order of the inputs. 17 // 18 // Pusher should call push on every input value. Stop indicates if an error was 19 // returned and pushing should stop. 20 // Mapper receives an input (a), 0-based input serial number (i), 0-based 21 // goroutine number (g), and returns the result of processing a. 22 // Puller acts on a single result, and will be called by the same 23 // order of pusher's input. 24 // 25 // If one of the functions returns a non-nil error, the process stops and the 26 // error is returned. Otherwise returns nil. 27 func Serial[T1 any, T2 any]( 28 ngoroutines int, 29 pusher func(push func(T1), stop func() bool) error, 30 mapper func(a T1, i int, g int) (T2, error), 31 puller func(a T2) error) error { 32 if ngoroutines < 0 { 33 panic(fmt.Sprintf("bad number of goroutines: %d", ngoroutines)) 34 } 35 36 var err error 37 38 // An optimization for a single thread. 39 if ngoroutines == 0 { 40 i := 0 41 pusher(func(a T1) { 42 if err != nil { 43 return 44 } 45 t2, merr := mapper(a, i, 0) 46 if merr != nil { 47 err = merr 48 return 49 } 50 perr := puller(t2) 51 if perr != nil { 52 err = perr 53 return 54 } 55 i++ 56 }, func() bool { return err != nil }) 57 return err 58 } 59 60 push := make(chan serialItem[T1], ngoroutines*chanLenCoef) 61 pull := make(chan serialItem[T2], ngoroutines*chanLenCoef) 62 wait := &sync.WaitGroup{} 63 wait.Add(ngoroutines) 64 65 go func() { 66 i := 0 67 perr := pusher(func(a T1) { 68 if err != nil { 69 return 70 } 71 push <- serialItem[T1]{i, a} 72 i++ 73 }, func() bool { return err != nil }) 74 if perr != nil && err == nil { 75 err = perr 76 } 77 close(push) 78 }() 79 for i := 0; i < ngoroutines; i++ { 80 i := i 81 go func() { 82 for item := range push { 83 if err != nil { 84 continue // Drain channel. 85 } 86 t2, merr := mapper(item.data, item.i, i) 87 if merr != nil { 88 if err == nil { 89 err = merr 90 } 91 continue 92 } 93 pull <- serialItem[T2]{item.i, t2} 94 } 95 wait.Done() 96 }() 97 } 98 go func() { 99 items := &serialHeap[T2]{ 100 data: heaps.New(func(a, b serialItem[T2]) bool { 101 return a.i < b.i 102 }), 103 } 104 for item := range pull { 105 if err != nil { 106 continue // Drain channel. 107 } 108 items.put(item) 109 for items.ok() { 110 perr := puller(items.pop()) 111 if perr != nil && err == nil { 112 err = perr 113 } 114 } 115 } 116 wait.Done() 117 }() 118 119 wait.Wait() // Wait for workers. 120 wait.Add(1) 121 close(pull) 122 wait.Wait() // Wait for pull. 123 124 return err 125 } 126 127 // General data with a serial number. 128 type serialItem[T any] struct { 129 i int 130 data T 131 } 132 133 // A heap of serial items. Sorts by serial number. 134 type serialHeap[T any] struct { 135 next int 136 data *heaps.Heap[serialItem[T]] 137 } 138 139 // Checks whether the minimal element in the heap is the next in the series. 140 func (s *serialHeap[T]) ok() bool { 141 return s.data.Len() > 0 && s.data.Head().i == s.next 142 } 143 144 // Removes and returns the minimal element in the heap. Panics if the element 145 // is not the next in the series. 146 func (s *serialHeap[T]) pop() T { 147 if !s.ok() { 148 panic("get when not ok") 149 } 150 s.next++ 151 a := s.data.Pop() 152 return a.data 153 } 154 155 // Adds an item to the heap. 156 func (s *serialHeap[T]) put(item serialItem[T]) { 157 if item.i < s.next { 158 panic(fmt.Sprintf("put(%d) when next is %d", item.i, s.next)) 159 } 160 s.data.Push(item) 161 }