github.com/fluhus/gostuff@v0.4.1-0.20240331134726-be71864f2b5d/ppln/v2/serial.go (about) 1 package ppln 2 3 import ( 4 "fmt" 5 "iter" 6 "sync" 7 "sync/atomic" 8 9 "github.com/fluhus/gostuff/heaps" 10 ) 11 12 // Serial starts a multi-goroutine transformation pipeline that maintains the 13 // order of the inputs. 14 // 15 // Input is an iterator over the input values to be transformed. 16 // It will be called in a thread-safe manner. 17 // Transform receives an input (a), 0-based input serial number (i), 0-based 18 // goroutine number (g), and returns the result of processing a. 19 // Output acts on a single result, and will be called by the same 20 // order of the input, in a thread-safe manner. 21 // 22 // If one of the functions returns a non-nil error, the process stops and the 23 // error is returned. Otherwise returns nil. 24 func Serial[T1 any, T2 any]( 25 ngoroutines int, 26 input iter.Seq2[T1, error], 27 transform func(a T1, i int, g int) (T2, error), 28 output func(a T2) error) error { 29 if ngoroutines < 1 { 30 panic(fmt.Sprintf("bad number of goroutines: %d", ngoroutines)) 31 } 32 pull, pstop := iter.Pull2(input) 33 defer pstop() 34 35 // An optimization for a single thread. 36 if ngoroutines == 1 { 37 i := 0 38 for { 39 t1, err, ok := pull() 40 ii := i 41 i++ 42 43 if !ok { 44 return nil 45 } 46 if err != nil { 47 return err 48 } 49 50 t2, err := transform(t1, ii, 0) 51 if err != nil { 52 return err 53 } 54 if err := output(t2); err != nil { 55 return err 56 } 57 } 58 } 59 60 ilock := &sync.Mutex{} 61 olock := &sync.Mutex{} 62 errs := make(chan error, ngoroutines) 63 stop := &atomic.Bool{} 64 items := &serialHeap[T2]{ 65 data: heaps.New(func(a, b serialItem[T2]) bool { 66 return a.i < b.i 67 }), 68 } 69 70 i := 0 71 for g := 0; g < ngoroutines; g++ { 72 go func(g int) { 73 for { 74 if stop.Load() { 75 errs <- nil 76 return 77 } 78 79 ilock.Lock() 80 t1, err, ok := pull() 81 ii := i 82 i++ 83 ilock.Unlock() 84 85 if !ok { 86 errs <- nil 87 return 88 } 89 if err != nil { 90 stop.Store(true) 91 errs <- err 92 return 93 } 94 95 t2, err := transform(t1, ii, g) 96 if err != nil { 97 stop.Store(true) 98 errs <- err 99 return 100 } 101 102 olock.Lock() 103 items.put(serialItem[T2]{ii, t2}) 104 for items.ok() { 105 err = output(items.pop()) 106 if err != nil { 107 olock.Unlock() 108 stop.Store(true) 109 errs <- err 110 return 111 } 112 } 113 olock.Unlock() 114 } 115 }(g) 116 } 117 118 for g := 0; g < ngoroutines; g++ { 119 if err := <-errs; err != nil { 120 return err 121 } 122 } 123 return nil 124 } 125 126 // General data with a serial number. 127 type serialItem[T any] struct { 128 i int 129 data T 130 } 131 132 // A heap of serial items. Sorts by serial number. 133 type serialHeap[T any] struct { 134 next int 135 data *heaps.Heap[serialItem[T]] 136 } 137 138 // Checks whether the minimal element in the heap is the next in the series. 139 func (s *serialHeap[T]) ok() bool { 140 return s.data.Len() > 0 && s.data.Head().i == s.next 141 } 142 143 // Removes and returns the minimal element in the heap. Panics if the element 144 // is not the next in the series. 145 func (s *serialHeap[T]) pop() T { 146 if !s.ok() { 147 panic("get when not ok") 148 } 149 s.next++ 150 a := s.data.Pop() 151 return a.data 152 } 153 154 // Adds an item to the heap. 155 func (s *serialHeap[T]) put(item serialItem[T]) { 156 if item.i < s.next { 157 panic(fmt.Sprintf("put(%d) when next is %d", item.i, s.next)) 158 } 159 s.data.Push(item) 160 }