github.com/prysmaticlabs/prysm@v1.4.4/shared/mputil/scatter.go (about)

     1  // Package mputil contains useful helpers for converting
     2  // multi-processor computation.
     3  package mputil
     4  
     5  import (
     6  	"errors"
     7  	"runtime"
     8  	"sync"
     9  )
    10  
    11  // WorkerResults are the results of a scatter worker.
    12  type WorkerResults struct {
    13  	Offset int
    14  	Extent interface{}
    15  }
    16  
    17  // Scatter scatters a computation across multiple goroutines.
    18  // This breaks the task in to a number of chunks and executes those chunks in parallel with the function provided.
    19  // Results returned are collected and presented a a set of WorkerResults, which can be reassembled by the calling function.
    20  // Any error that occurs in the workers will be passed back to the calling function.
    21  func Scatter(inputLen int, sFunc func(int, int, *sync.RWMutex) (interface{}, error)) ([]*WorkerResults, error) {
    22  	if inputLen <= 0 {
    23  		return nil, errors.New("input length must be greater than 0")
    24  	}
    25  
    26  	chunkSize := calculateChunkSize(inputLen)
    27  	workers := inputLen / chunkSize
    28  	if inputLen%chunkSize != 0 {
    29  		workers++
    30  	}
    31  	resultCh := make(chan *WorkerResults, workers)
    32  	defer close(resultCh)
    33  	errorCh := make(chan error, workers)
    34  	defer close(errorCh)
    35  	mutex := new(sync.RWMutex)
    36  	for worker := 0; worker < workers; worker++ {
    37  		offset := worker * chunkSize
    38  		entries := chunkSize
    39  		if offset+entries > inputLen {
    40  			entries = inputLen - offset
    41  		}
    42  		go func(offset int, entries int) {
    43  			extent, err := sFunc(offset, entries, mutex)
    44  			if err != nil {
    45  				errorCh <- err
    46  			} else {
    47  				resultCh <- &WorkerResults{
    48  					Offset: offset,
    49  					Extent: extent,
    50  				}
    51  			}
    52  		}(offset, entries)
    53  	}
    54  
    55  	// Collect results from workers
    56  	results := make([]*WorkerResults, workers)
    57  	for i := 0; i < workers; i++ {
    58  		select {
    59  		case result := <-resultCh:
    60  			results[i] = result
    61  		case err := <-errorCh:
    62  			return nil, err
    63  		}
    64  	}
    65  	return results, nil
    66  }
    67  
    68  // calculateChunkSize calculates a suitable chunk size for the purposes of parallelisation.
    69  func calculateChunkSize(items int) int {
    70  	// Start with a simple even split
    71  	chunkSize := items / runtime.GOMAXPROCS(0)
    72  
    73  	// Add 1 if we have leftovers (or if we have fewer items than processors).
    74  	if chunkSize == 0 || items%chunkSize != 0 {
    75  		chunkSize++
    76  	}
    77  
    78  	return chunkSize
    79  }