github.com/prysmaticlabs/prysm@v1.4.4/beacon-chain/core/helpers/shuffle_test.go (about)

     1  package helpers
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  	"testing"
     7  
     8  	types "github.com/prysmaticlabs/eth2-types"
     9  	"github.com/prysmaticlabs/prysm/shared/params"
    10  	"github.com/prysmaticlabs/prysm/shared/sliceutil"
    11  	"github.com/prysmaticlabs/prysm/shared/testutil/assert"
    12  	"github.com/prysmaticlabs/prysm/shared/testutil/require"
    13  )
    14  
    15  func TestShuffleList_InvalidValidatorCount(t *testing.T) {
    16  	maxShuffleListSize = 20
    17  	list := make([]types.ValidatorIndex, 21)
    18  	if _, err := ShuffleList(list, [32]byte{123, 125}); err == nil {
    19  		t.Error("Shuffle should have failed when validator count exceeds ModuloBias")
    20  		maxShuffleListSize = 1 << 40
    21  	}
    22  	maxShuffleListSize = 1 << 40
    23  }
    24  
    25  func TestShuffleList_OK(t *testing.T) {
    26  	var list1 []types.ValidatorIndex
    27  	seed1 := [32]byte{1, 128, 12}
    28  	seed2 := [32]byte{2, 128, 12}
    29  	for i := 0; i < 10; i++ {
    30  		list1 = append(list1, types.ValidatorIndex(i))
    31  	}
    32  
    33  	list2 := make([]types.ValidatorIndex, len(list1))
    34  	copy(list2, list1)
    35  
    36  	list1, err := ShuffleList(list1, seed1)
    37  	assert.NoError(t, err, "Shuffle failed with")
    38  
    39  	list2, err = ShuffleList(list2, seed2)
    40  	assert.NoError(t, err, "Shuffle failed with")
    41  
    42  	if reflect.DeepEqual(list1, list2) {
    43  		t.Errorf("2 shuffled lists shouldn't be equal")
    44  	}
    45  	assert.DeepEqual(t, []types.ValidatorIndex{0, 7, 8, 6, 3, 9, 4, 5, 2, 1}, list1, "List 1 was incorrectly shuffled got")
    46  	assert.DeepEqual(t, []types.ValidatorIndex{0, 5, 2, 1, 6, 8, 7, 3, 4, 9}, list2, "List 2 was incorrectly shuffled got")
    47  }
    48  
    49  func TestSplitIndices_OK(t *testing.T) {
    50  	var l []uint64
    51  	numValidators := uint64(64000)
    52  	for i := uint64(0); i < numValidators; i++ {
    53  		l = append(l, i)
    54  	}
    55  	split := SplitIndices(l, uint64(params.BeaconConfig().SlotsPerEpoch))
    56  	assert.Equal(t, uint64(params.BeaconConfig().SlotsPerEpoch), uint64(len(split)), "Split list failed due to incorrect length")
    57  
    58  	for _, s := range split {
    59  		assert.Equal(t, numValidators/uint64(params.BeaconConfig().SlotsPerEpoch), uint64(len(s)), "Split list failed due to incorrect length")
    60  	}
    61  }
    62  
    63  func TestShuffleList_Vs_ShuffleIndex(t *testing.T) {
    64  	var list []types.ValidatorIndex
    65  	listSize := uint64(1000)
    66  	seed := [32]byte{123, 42}
    67  	for i := types.ValidatorIndex(0); uint64(i) < listSize; i++ {
    68  		list = append(list, i)
    69  	}
    70  	shuffledListByIndex := make([]types.ValidatorIndex, listSize)
    71  	for i := types.ValidatorIndex(0); uint64(i) < listSize; i++ {
    72  		si, err := ShuffledIndex(i, listSize, seed)
    73  		assert.NoError(t, err)
    74  		shuffledListByIndex[si] = i
    75  	}
    76  	shuffledList, err := ShuffleList(list, seed)
    77  	require.NoError(t, err, "Shuffled list error")
    78  	assert.DeepEqual(t, shuffledListByIndex, shuffledList, "Shuffled lists ar not equal")
    79  }
    80  
    81  func BenchmarkShuffledIndex(b *testing.B) {
    82  	listSizes := []uint64{4000000, 40000, 400}
    83  	seed := [32]byte{123, 42}
    84  	for _, listSize := range listSizes {
    85  		b.Run(fmt.Sprintf("ShuffledIndex_%d", listSize), func(ib *testing.B) {
    86  			for i := uint64(0); i < uint64(ib.N); i++ {
    87  				_, err := ShuffledIndex(types.ValidatorIndex(i%listSize), listSize, seed)
    88  				assert.NoError(b, err)
    89  			}
    90  		})
    91  	}
    92  }
    93  
    94  func BenchmarkIndexComparison(b *testing.B) {
    95  	listSizes := []uint64{400000, 40000, 400}
    96  	seed := [32]byte{123, 42}
    97  	for _, listSize := range listSizes {
    98  		b.Run(fmt.Sprintf("Indexwise_ShuffleList_%d", listSize), func(ib *testing.B) {
    99  			for i := 0; i < ib.N; i++ {
   100  				// Simulate a list-shuffle by running shuffle-index listSize times.
   101  				for j := types.ValidatorIndex(0); uint64(j) < listSize; j++ {
   102  					_, err := ShuffledIndex(j, listSize, seed)
   103  					assert.NoError(b, err)
   104  				}
   105  			}
   106  		})
   107  	}
   108  }
   109  
   110  func BenchmarkShuffleList(b *testing.B) {
   111  	listSizes := []uint64{400000, 40000, 400}
   112  	seed := [32]byte{123, 42}
   113  	for _, listSize := range listSizes {
   114  		testIndices := make([]types.ValidatorIndex, listSize)
   115  		for i := uint64(0); i < listSize; i++ {
   116  			testIndices[i] = types.ValidatorIndex(i)
   117  		}
   118  		b.Run(fmt.Sprintf("ShuffleList_%d", listSize), func(ib *testing.B) {
   119  			for i := 0; i < ib.N; i++ {
   120  				_, err := ShuffleList(testIndices, seed)
   121  				assert.NoError(b, err)
   122  			}
   123  		})
   124  	}
   125  }
   126  
   127  func TestShuffledIndex(t *testing.T) {
   128  	var list []types.ValidatorIndex
   129  	listSize := uint64(399)
   130  	for i := types.ValidatorIndex(0); uint64(i) < listSize; i++ {
   131  		list = append(list, i)
   132  	}
   133  	shuffledList := make([]types.ValidatorIndex, listSize)
   134  	unshuffledlist := make([]types.ValidatorIndex, listSize)
   135  	seed := [32]byte{123, 42}
   136  	for i := types.ValidatorIndex(0); uint64(i) < listSize; i++ {
   137  		si, err := ShuffledIndex(i, listSize, seed)
   138  		assert.NoError(t, err)
   139  		shuffledList[si] = i
   140  	}
   141  	for i := types.ValidatorIndex(0); uint64(i) < listSize; i++ {
   142  		ui, err := UnShuffledIndex(i, listSize, seed)
   143  		assert.NoError(t, err)
   144  		unshuffledlist[ui] = shuffledList[i]
   145  	}
   146  	assert.DeepEqual(t, list, unshuffledlist)
   147  }
   148  
   149  func TestSplitIndicesAndOffset_OK(t *testing.T) {
   150  	var l []uint64
   151  	validators := uint64(64000)
   152  	for i := uint64(0); i < validators; i++ {
   153  		l = append(l, i)
   154  	}
   155  	chunks := uint64(6)
   156  	split := SplitIndices(l, chunks)
   157  	for i := uint64(0); i < chunks; i++ {
   158  		if !reflect.DeepEqual(split[i], l[sliceutil.SplitOffset(uint64(len(l)), chunks, i):sliceutil.SplitOffset(uint64(len(l)), chunks, i+1)]) {
   159  			t.Errorf("Want: %v got: %v", l[sliceutil.SplitOffset(uint64(len(l)), chunks, i):sliceutil.SplitOffset(uint64(len(l)), chunks, i+1)], split[i])
   160  			break
   161  		}
   162  	}
   163  }