github.com/Schaudge/grailbase@v0.0.0-20240223061707-44c758a471c0/traverse/traverse.go (about)

     1  // Copyright 2018 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache-2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package traverse provides primitives for concurrent and parallel
     6  // traversal of slices or user-defined collections.
     7  package traverse
     8  
     9  import (
    10  	"fmt"
    11  	"log"
    12  	"runtime"
    13  	"runtime/debug"
    14  	"sync"
    15  	"sync/atomic"
    16  
    17  	"github.com/Schaudge/grailbase/errors"
    18  )
    19  
    20  const cachelineSize = 64
    21  
    22  // A T is a traverser: it provides facilities for concurrently
    23  // invoking functions that traverse collections of data.
    24  type T struct {
    25  	// Limit is the traverser's concurrency limit: there will be no more
    26  	// than Limit concurrent invocations per traversal. A limit value of
    27  	// zero (the default value) denotes no limit.
    28  	Limit int
    29  	// Sequential indicates that early indexes should be handled before later
    30  	// ones.  E.g. if there are 40000 tasks and Limit == 40, the initial
    31  	// assignment is usually
    32  	//   worker 0 <- tasks 0-999
    33  	//   worker 1 <- tasks 1000-1999
    34  	//   ...
    35  	//   worker 39 <- tasks 39000-39999
    36  	// but when Sequential == true, only tasks 0-39 are initially assigned, then
    37  	// task 40 goes to the first worker to finish, etc.
    38  	// Note that this increases synchronization overhead.  It should not be used
    39  	// with e.g. > 1 billion tiny tasks; in that scenario, the caller should
    40  	// organize such tasks into e.g. 10000-task chunks and perform a
    41  	// sequential-traverse on the chunks.
    42  	// This scheduling algorithm does perform well when tasks are sorted in order
    43  	// of decreasing size.
    44  	Sequential bool
    45  	// Reporter receives status reports for each traversal. It is
    46  	// intended for users who wish to monitor the progress of large
    47  	// traversal jobs.
    48  	Reporter Reporter
    49  }
    50  
    51  // Limit returns a traverser with limit n.
    52  func Limit(n int) T {
    53  	if n <= 0 {
    54  		log.Panicf("traverse.Limit: invalid limit: %d", n)
    55  	}
    56  	return T{Limit: n}
    57  }
    58  
    59  // LimitSequential returns a sequential traverser with limit n.
    60  func LimitSequential(n int) T {
    61  	if n <= 0 {
    62  		log.Panicf("traverse.LimitSequential: invalid limit: %d", n)
    63  	}
    64  	return T{Limit: n, Sequential: true}
    65  }
    66  
    67  // Parallel is the default traverser for parallel traversal, intended
    68  // CPU-intensive parallel computing. Parallel limits the number of
    69  // concurrent invocations to a small multiple of the runtime's
    70  // available processors.
    71  var Parallel = T{Limit: 2 * runtime.GOMAXPROCS(0)}
    72  
    73  // Each performs a traversal on fn. Specifically, Each invokes fn(i)
    74  // for 0 <= i < n, managing concurrency and error propagation. Each
    75  // returns when the all invocations have completed, or after the
    76  // first invocation fails, in which case the first invocation error
    77  // is returned. Each also propagates panics from underlying invocations
    78  // to the caller. Note that if a function panics and doesn't release
    79  // shared resources that fn might need in a traverse child, this could
    80  // lead to deadlock.
    81  func (t T) Each(n int, fn func(i int) error) error {
    82  	if t.Reporter != nil {
    83  		t.Reporter.Init(n)
    84  		defer t.Reporter.Complete()
    85  	}
    86  	var err error
    87  	if t.Limit == 1 || n == 1 {
    88  		err = t.eachSerial(n, fn)
    89  	} else if t.Limit == 0 || t.Limit >= n {
    90  		err = t.each(n, fn)
    91  	} else if t.Sequential {
    92  		err = t.eachSequential(n, fn)
    93  	} else {
    94  		err = t.eachLimit(n, fn)
    95  	}
    96  	if err == nil {
    97  		return nil
    98  	}
    99  	// Propagate panics.
   100  	if err, ok := err.(panicErr); ok {
   101  		panic(fmt.Sprintf("traverse child: %v\n%s", err.v, string(err.stack)))
   102  	}
   103  	return err
   104  }
   105  
   106  func (t T) each(n int, fn func(i int) error) error {
   107  	var (
   108  		errors errors.Once
   109  		wg     sync.WaitGroup
   110  	)
   111  	wg.Add(n)
   112  	for i := 0; i < n; i++ {
   113  		go func(i int) {
   114  			if t.Reporter != nil {
   115  				t.Reporter.Begin(i)
   116  			}
   117  			if err := apply(fn, i); err != nil {
   118  				errors.Set(err)
   119  			}
   120  			if t.Reporter != nil {
   121  				t.Reporter.End(i)
   122  			}
   123  			wg.Done()
   124  		}(i)
   125  	}
   126  	wg.Wait()
   127  	return errors.Err()
   128  }
   129  
   130  // eachSerial runs on the local thread using a conventional for loop.
   131  // all invocations will be run in numerical order.
   132  func (t T) eachSerial(n int, fn func(i int) error) error {
   133  	for i := 0; i < n; i++ {
   134  		if t.Reporter != nil {
   135  			t.Reporter.Begin(i)
   136  		}
   137  		if err := apply(fn, i); err != nil {
   138  			return err
   139  		}
   140  		if t.Reporter != nil {
   141  			t.Reporter.End(i)
   142  		}
   143  	}
   144  	return nil
   145  }
   146  
   147  // eachSequential performs a concurrent run where tasks are assigned in strict
   148  // numerical order.  Unlike eachLimit(), it can be used when the traversal must
   149  // be done sequentially.
   150  func (t T) eachSequential(n int, fn func(i int) error) error {
   151  	var (
   152  		errors     errors.Once
   153  		wg         sync.WaitGroup
   154  		syncStruct struct {
   155  			_ [cachelineSize - 8]byte // cache padding
   156  			N int64
   157  			_ [cachelineSize - 8]byte // cache padding
   158  		}
   159  	)
   160  	syncStruct.N = -1
   161  	wg.Add(t.Limit)
   162  	for i := 0; i < t.Limit; i++ {
   163  		go func() {
   164  			for errors.Err() == nil {
   165  				idx := int(atomic.AddInt64(&syncStruct.N, 1))
   166  				if idx >= n {
   167  					break
   168  				}
   169  				if t.Reporter != nil {
   170  					t.Reporter.Begin(idx)
   171  				}
   172  				if err := apply(fn, idx); err != nil {
   173  					errors.Set(err)
   174  				}
   175  				if t.Reporter != nil {
   176  					t.Reporter.End(idx)
   177  				}
   178  			}
   179  			wg.Done()
   180  		}()
   181  	}
   182  	wg.Wait()
   183  	return errors.Err()
   184  }
   185  
   186  // eachLimit performs a concurrent run where tasks can be assigned in any
   187  // order.
   188  func (t T) eachLimit(n int, fn func(i int) error) error {
   189  	var (
   190  		errors errors.Once
   191  		wg     sync.WaitGroup
   192  		next   = make([]struct {
   193  			N int64
   194  			_ [cachelineSize - 8]byte // cache padding
   195  		}, t.Limit)
   196  		size = (n + t.Limit - 1) / t.Limit
   197  	)
   198  	wg.Add(t.Limit)
   199  	for i := 0; i < t.Limit; i++ {
   200  		go func(w int) {
   201  			orig := w
   202  			for errors.Err() == nil {
   203  				// Each worker traverses contiguous segments since there is
   204  				// often usable data locality associated with index locality.
   205  				idx := int(atomic.AddInt64(&next[w].N, 1) - 1)
   206  				which := w*size + idx
   207  				if idx >= size || which >= n {
   208  					w = (w + 1) % t.Limit
   209  					if w == orig {
   210  						break
   211  					}
   212  					continue
   213  				}
   214  				if t.Reporter != nil {
   215  					t.Reporter.Begin(which)
   216  				}
   217  				if err := apply(fn, which); err != nil {
   218  					errors.Set(err)
   219  				}
   220  				if t.Reporter != nil {
   221  					t.Reporter.End(which)
   222  				}
   223  			}
   224  			wg.Done()
   225  		}(i)
   226  	}
   227  	wg.Wait()
   228  	return errors.Err()
   229  }
   230  
   231  // Range performs ranged traversal on fn: n is split into
   232  // contiguous ranges, and fn is invoked for each range. The range
   233  // sizes are determined by the traverser's concurrency limits. Range
   234  // allows the caller to amortize function call costs, and is
   235  // typically used when limit is small and n is large, for example on
   236  // parallel traversal over large collections, where each item's
   237  // processing time is comparatively small.
   238  func (t T) Range(n int, fn func(start, end int) error) error {
   239  	if t.Sequential {
   240  		// interface for this should take a chunk size.
   241  		log.Panicf("traverse.Range: sequential traversal unsupported")
   242  	}
   243  	m := n
   244  	if t.Limit > 0 && t.Limit < n {
   245  		m = t.Limit
   246  	}
   247  	// TODO: consider splitting ranges into smaller chunks so that can
   248  	// take better advantage of the load balancing underneath.
   249  	return t.Each(m, func(i int) error {
   250  		var (
   251  			size  = float64(n) / float64(m)
   252  			start = int(float64(i) * size)
   253  			end   = int(float64(i+1) * size)
   254  		)
   255  		if start >= n {
   256  			return nil
   257  		}
   258  		if i == m-1 {
   259  			end = n
   260  		}
   261  		return fn(start, end)
   262  	})
   263  }
   264  
   265  var defaultT = T{}
   266  
   267  // Each performs concurrent traversal over n elements. It is a
   268  // shorthand for (T{}).Each.
   269  func Each(n int, fn func(i int) error) error {
   270  	return defaultT.Each(n, fn)
   271  }
   272  
   273  // CPU calls the function fn for each available system CPU. CPU
   274  // returns when all calls have completed or on first error.
   275  func CPU(fn func() error) error {
   276  	return Each(runtime.NumCPU(), func(int) error { return fn() })
   277  }
   278  
   279  func apply(fn func(i int) error, i int) (err error) {
   280  	defer func() {
   281  		if perr := recover(); perr != nil {
   282  			err = panicErr{perr, debug.Stack()}
   283  		}
   284  	}()
   285  	return fn(i)
   286  }
   287  
   288  type panicErr struct {
   289  	v     interface{}
   290  	stack []byte
   291  }
   292  
   293  func (p panicErr) Error() string { return fmt.Sprint(p.v) }