github.com/prysmaticlabs/prysm@v1.4.4/shared/mputil/scatter_test.go (about)

     1  package mputil_test
     2  
     3  import (
     4  	"errors"
     5  	"sync"
     6  	"testing"
     7  
     8  	"github.com/prysmaticlabs/prysm/shared/mputil"
     9  	"github.com/prysmaticlabs/prysm/shared/testutil/assert"
    10  	"github.com/prysmaticlabs/prysm/shared/testutil/require"
    11  )
    12  
    13  func TestDouble(t *testing.T) {
    14  	tests := []struct {
    15  		name     string
    16  		inValues int
    17  		err      error
    18  	}{
    19  		{
    20  			name:     "0",
    21  			inValues: 0,
    22  			err:      errors.New("input length must be greater than 0"),
    23  		},
    24  		{
    25  			name:     "1",
    26  			inValues: 1,
    27  		},
    28  		{
    29  			name:     "1023",
    30  			inValues: 1023,
    31  		},
    32  		{
    33  			name:     "1024",
    34  			inValues: 1024,
    35  		},
    36  		{
    37  			name:     "1025",
    38  			inValues: 1025,
    39  		},
    40  	}
    41  
    42  	for _, test := range tests {
    43  		t.Run(test.name, func(t *testing.T) {
    44  			inValues := make([]int, test.inValues)
    45  			for i := 0; i < test.inValues; i++ {
    46  				inValues[i] = i
    47  			}
    48  			outValues := make([]int, test.inValues)
    49  			workerResults, err := mputil.Scatter(len(inValues), func(offset int, entries int, _ *sync.RWMutex) (interface{}, error) {
    50  				extent := make([]int, entries)
    51  				for i := 0; i < entries; i++ {
    52  					extent[i] = inValues[offset+i] * 2
    53  				}
    54  				return extent, nil
    55  			})
    56  			if test.err != nil {
    57  				assert.ErrorContains(t, test.err.Error(), err)
    58  			} else {
    59  				require.NoError(t, err)
    60  				for _, result := range workerResults {
    61  					copy(outValues[result.Offset:], result.Extent.([]int))
    62  				}
    63  
    64  				for i := 0; i < test.inValues; i++ {
    65  					require.Equal(t, inValues[i]*2, outValues[i], "Outvalue at %d mismatch", i)
    66  				}
    67  			}
    68  		})
    69  	}
    70  }
    71  
    72  func TestMutex(t *testing.T) {
    73  	totalRuns := 1048576
    74  	val := 0
    75  	_, err := mputil.Scatter(totalRuns, func(offset int, entries int, mu *sync.RWMutex) (interface{}, error) {
    76  		for i := 0; i < entries; i++ {
    77  			mu.Lock()
    78  			val++
    79  			mu.Unlock()
    80  		}
    81  		return nil, nil
    82  	})
    83  	require.NoError(t, err)
    84  
    85  	if val != totalRuns {
    86  		t.Fatalf("Unexpected value: expected \"%v\", found \"%v\"", totalRuns, val)
    87  	}
    88  }
    89  
    90  func TestError(t *testing.T) {
    91  	totalRuns := 1024
    92  	val := 0
    93  	_, err := mputil.Scatter(totalRuns, func(offset int, entries int, mu *sync.RWMutex) (interface{}, error) {
    94  		for i := 0; i < entries; i++ {
    95  			mu.Lock()
    96  			val++
    97  			if val == 1011 {
    98  				mu.Unlock()
    99  				return nil, errors.New("bad number")
   100  			}
   101  			mu.Unlock()
   102  		}
   103  		return nil, nil
   104  	})
   105  	if err == nil {
   106  		t.Fatalf("Missing expected error")
   107  	}
   108  }