github.com/grailbio/base@v0.0.11/traverse/traverse.go (about) 1 // Copyright 2018 GRAIL, Inc. All rights reserved. 2 // Use of this source code is governed by the Apache-2.0 3 // license that can be found in the LICENSE file. 4 5 // Package traverse provides primitives for concurrent and parallel 6 // traversal of slices or user-defined collections. 7 package traverse 8 9 import ( 10 "fmt" 11 "log" 12 "runtime" 13 "runtime/debug" 14 "sync" 15 "sync/atomic" 16 17 "github.com/grailbio/base/errors" 18 ) 19 20 const cachelineSize = 64 21 22 // A T is a traverser: it provides facilities for concurrently 23 // invoking functions that traverse collections of data. 24 type T struct { 25 // Limit is the traverser's concurrency limit: there will be no more 26 // than Limit concurrent invocations per traversal. A limit value of 27 // zero (the default value) denotes no limit. 28 Limit int 29 // Sequential indicates that early indexes should be handled before later 30 // ones. E.g. if there are 40000 tasks and Limit == 40, the initial 31 // assignment is usually 32 // worker 0 <- tasks 0-999 33 // worker 1 <- tasks 1000-1999 34 // ... 35 // worker 39 <- tasks 39000-39999 36 // but when Sequential == true, only tasks 0-39 are initially assigned, then 37 // task 40 goes to the first worker to finish, etc. 38 // Note that this increases synchronization overhead. It should not be used 39 // with e.g. > 1 billion tiny tasks; in that scenario, the caller should 40 // organize such tasks into e.g. 10000-task chunks and perform a 41 // sequential-traverse on the chunks. 42 // This scheduling algorithm does perform well when tasks are sorted in order 43 // of decreasing size. 44 Sequential bool 45 // Reporter receives status reports for each traversal. It is 46 // intended for users who wish to monitor the progress of large 47 // traversal jobs. 48 Reporter Reporter 49 } 50 51 // Limit returns a traverser with limit n. 52 func Limit(n int) T { 53 if n <= 0 { 54 log.Panicf("traverse.Limit: invalid limit: %d", n) 55 } 56 return T{Limit: n} 57 } 58 59 // LimitSequential returns a sequential traverser with limit n. 60 func LimitSequential(n int) T { 61 if n <= 0 { 62 log.Panicf("traverse.LimitSequential: invalid limit: %d", n) 63 } 64 return T{Limit: n, Sequential: true} 65 } 66 67 // Parallel is the default traverser for parallel traversal, intended 68 // CPU-intensive parallel computing. Parallel limits the number of 69 // concurrent invocations to a small multiple of the runtime's 70 // available processors. 71 var Parallel = T{Limit: 2 * runtime.GOMAXPROCS(0)} 72 73 // Each performs a traversal on fn. Specifically, Each invokes fn(i) 74 // for 0 <= i < n, managing concurrency and error propagation. Each 75 // returns when the all invocations have completed, or after the 76 // first invocation fails, in which case the first invocation error 77 // is returned. Each also propagates panics from underlying invocations 78 // to the caller. Note that if a function panics and doesn't release 79 // shared resources that fn might need in a traverse child, this could 80 // lead to deadlock. 81 func (t T) Each(n int, fn func(i int) error) error { 82 if t.Reporter != nil { 83 t.Reporter.Init(n) 84 defer t.Reporter.Complete() 85 } 86 var err error 87 if t.Limit == 1 || n == 1 { 88 err = t.eachSerial(n, fn) 89 } else if t.Limit == 0 || t.Limit >= n { 90 err = t.each(n, fn) 91 } else if t.Sequential { 92 err = t.eachSequential(n, fn) 93 } else { 94 err = t.eachLimit(n, fn) 95 } 96 if err == nil { 97 return nil 98 } 99 // Propagate panics. 100 if err, ok := err.(panicErr); ok { 101 panic(fmt.Sprintf("traverse child: %v\n%s", err.v, string(err.stack))) 102 } 103 return err 104 } 105 106 func (t T) each(n int, fn func(i int) error) error { 107 var ( 108 errors errors.Once 109 wg sync.WaitGroup 110 ) 111 wg.Add(n) 112 for i := 0; i < n; i++ { 113 go func(i int) { 114 if t.Reporter != nil { 115 t.Reporter.Begin(i) 116 } 117 if err := apply(fn, i); err != nil { 118 errors.Set(err) 119 } 120 if t.Reporter != nil { 121 t.Reporter.End(i) 122 } 123 wg.Done() 124 }(i) 125 } 126 wg.Wait() 127 return errors.Err() 128 } 129 130 // eachSerial runs on the local thread using a conventional for loop. 131 // all invocations will be run in numerical order. 132 func (t T) eachSerial(n int, fn func(i int) error) error { 133 for i := 0; i < n; i++ { 134 if t.Reporter != nil { 135 t.Reporter.Begin(i) 136 } 137 if err := apply(fn, i); err != nil { 138 return err 139 } 140 if t.Reporter != nil { 141 t.Reporter.End(i) 142 } 143 } 144 return nil 145 } 146 147 // eachSequential performs a concurrent run where tasks are assigned in strict 148 // numerical order. Unlike eachLimit(), it can be used when the traversal must 149 // be done sequentially. 150 func (t T) eachSequential(n int, fn func(i int) error) error { 151 var ( 152 errors errors.Once 153 wg sync.WaitGroup 154 syncStruct struct { 155 _ [cachelineSize - 8]byte // cache padding 156 N int64 157 _ [cachelineSize - 8]byte // cache padding 158 } 159 ) 160 syncStruct.N = -1 161 wg.Add(t.Limit) 162 for i := 0; i < t.Limit; i++ { 163 go func() { 164 for errors.Err() == nil { 165 idx := int(atomic.AddInt64(&syncStruct.N, 1)) 166 if idx >= n { 167 break 168 } 169 if t.Reporter != nil { 170 t.Reporter.Begin(idx) 171 } 172 if err := apply(fn, idx); err != nil { 173 errors.Set(err) 174 } 175 if t.Reporter != nil { 176 t.Reporter.End(idx) 177 } 178 } 179 wg.Done() 180 }() 181 } 182 wg.Wait() 183 return errors.Err() 184 } 185 186 // eachLimit performs a concurrent run where tasks can be assigned in any 187 // order. 188 func (t T) eachLimit(n int, fn func(i int) error) error { 189 var ( 190 errors errors.Once 191 wg sync.WaitGroup 192 next = make([]struct { 193 N int64 194 _ [cachelineSize - 8]byte // cache padding 195 }, t.Limit) 196 size = (n + t.Limit - 1) / t.Limit 197 ) 198 wg.Add(t.Limit) 199 for i := 0; i < t.Limit; i++ { 200 go func(w int) { 201 orig := w 202 for errors.Err() == nil { 203 // Each worker traverses contiguous segments since there is 204 // often usable data locality associated with index locality. 205 idx := int(atomic.AddInt64(&next[w].N, 1) - 1) 206 which := w*size + idx 207 if idx >= size || which >= n { 208 w = (w + 1) % t.Limit 209 if w == orig { 210 break 211 } 212 continue 213 } 214 if t.Reporter != nil { 215 t.Reporter.Begin(which) 216 } 217 if err := apply(fn, which); err != nil { 218 errors.Set(err) 219 } 220 if t.Reporter != nil { 221 t.Reporter.End(which) 222 } 223 } 224 wg.Done() 225 }(i) 226 } 227 wg.Wait() 228 return errors.Err() 229 } 230 231 // Range performs ranged traversal on fn: n is split into 232 // contiguous ranges, and fn is invoked for each range. The range 233 // sizes are determined by the traverser's concurrency limits. Range 234 // allows the caller to amortize function call costs, and is 235 // typically used when limit is small and n is large, for example on 236 // parallel traversal over large collections, where each item's 237 // processing time is comparatively small. 238 func (t T) Range(n int, fn func(start, end int) error) error { 239 if t.Sequential { 240 // interface for this should take a chunk size. 241 log.Panicf("traverse.Range: sequential traversal unsupported") 242 } 243 m := n 244 if t.Limit > 0 && t.Limit < n { 245 m = t.Limit 246 } 247 // TODO: consider splitting ranges into smaller chunks so that can 248 // take better advantage of the load balancing underneath. 249 return t.Each(m, func(i int) error { 250 var ( 251 size = float64(n) / float64(m) 252 start = int(float64(i) * size) 253 end = int(float64(i+1) * size) 254 ) 255 if start >= n { 256 return nil 257 } 258 if i == m-1 { 259 end = n 260 } 261 return fn(start, end) 262 }) 263 } 264 265 var defaultT = T{} 266 267 // Each performs concurrent traversal over n elements. It is a 268 // shorthand for (T{}).Each. 269 func Each(n int, fn func(i int) error) error { 270 return defaultT.Each(n, fn) 271 } 272 273 // CPU calls the function fn for each available system CPU. CPU 274 // returns when all calls have completed or on first error. 275 func CPU(fn func() error) error { 276 return Each(runtime.NumCPU(), func(int) error { return fn() }) 277 } 278 279 func apply(fn func(i int) error, i int) (err error) { 280 defer func() { 281 if perr := recover(); perr != nil { 282 err = panicErr{perr, debug.Stack()} 283 } 284 }() 285 return fn(i) 286 } 287 288 type panicErr struct { 289 v interface{} 290 stack []byte 291 } 292 293 func (p panicErr) Error() string { return fmt.Sprint(p.v) }