github.com/reusee/pr2@v0.0.0-20230630035947-72a20ff5e864/consume.go (about) 1 package pr2 2 3 import ( 4 "container/list" 5 "context" 6 "fmt" 7 "math" 8 "sync" 9 10 "github.com/reusee/e5" 11 ) 12 13 type Put[T any] func(T) bool 14 15 type Wait = func(noMorePut bool) error 16 17 type ConsumeOption interface { 18 IsConsumeOption() 19 } 20 21 type BacklogSize int 22 23 func (BacklogSize) IsConsumeOption() {} 24 25 func Consume[T any]( 26 ctx context.Context, 27 numThread int, 28 fn func(threadID int, value T) error, 29 options ...ConsumeOption, 30 ) ( 31 put Put[T], 32 wait Wait, 33 ) { 34 35 backlogSize := int(math.MaxInt32) 36 37 for _, option := range options { 38 switch option := option.(type) { 39 case BacklogSize: 40 backlogSize = int(option) 41 default: 42 panic(fmt.Errorf("unknown option: %T", option)) 43 } 44 } 45 46 inCh := make(chan T) 47 outCh := make(chan T) 48 errCh := make(chan error, 1) 49 valueCond := sync.NewCond(new(sync.Mutex)) 50 numValue := 0 51 wg := NewWaitGroup(ctx) 52 53 wg.Go(func() { 54 values := list.New() 55 var c chan T 56 loop: 57 for { 58 59 c = inCh 60 if values.Len() > backlogSize { 61 c = nil 62 } 63 64 if values.Len() > 0 { 65 select { 66 67 case outCh <- values.Front().Value.(T): 68 values.Remove(values.Front()) 69 70 case v, ok := <-c: 71 if !ok { 72 break loop 73 } 74 values.PushBack(v) 75 76 case <-ctx.Done(): 77 break loop 78 79 } 80 81 } else { 82 select { 83 84 case v, ok := <-c: 85 if !ok { 86 break loop 87 } 88 select { 89 case outCh <- v: 90 default: 91 values.PushBack(v) 92 } 93 94 case <-ctx.Done(): 95 break loop 96 97 } 98 } 99 100 } 101 102 elem := values.Front() 103 for elem != nil { 104 outCh <- elem.Value.(T) 105 elem = elem.Next() 106 } 107 108 close(outCh) 109 110 }) 111 112 var putLock sync.RWMutex 113 putClosed := false 114 put = func(v T) bool { 115 116 putLock.RLock() 117 defer putLock.RUnlock() 118 119 if putClosed { 120 return false 121 } 122 123 if len(errCh) > 0 { 124 return false 125 } 126 127 select { 128 129 case inCh <- v: 130 valueCond.L.Lock() 131 numValue++ 132 n := numValue 133 valueCond.L.Unlock() 134 if n == 0 { 135 valueCond.Signal() 136 } 137 return true 138 139 case <-ctx.Done(): 140 return false 141 142 } 143 } 144 145 var closeOnce sync.Once 146 closePut := func() { 147 closeOnce.Do(func() { 148 putLock.Lock() 149 defer putLock.Unlock() 150 putClosed = true 151 close(inCh) 152 }) 153 } 154 155 wait = func(noMorePut bool) error { 156 157 if noMorePut { 158 closePut() 159 wg.Wait() 160 } 161 162 valueCond.L.Lock() 163 for numValue != 0 { 164 valueCond.Wait() 165 } 166 valueCond.L.Unlock() 167 168 select { 169 case err := <-errCh: 170 return err 171 default: 172 } 173 174 return nil 175 } 176 177 for i := 0; i < numThread; i++ { 178 i := i 179 180 wg.Go(func() { 181 182 for v := range outCh { 183 err := func() (err error) { 184 defer e5.Handle(&err) 185 return fn(i, v) 186 }() 187 if err != nil { 188 select { 189 case errCh <- err: 190 default: 191 } 192 } 193 valueCond.L.Lock() 194 numValue-- 195 n := numValue 196 valueCond.L.Unlock() 197 if n == 0 { 198 valueCond.Signal() 199 } 200 } 201 202 }) 203 204 } 205 206 return 207 208 }