github.com/unicornultrafoundation/go-u2u@v1.0.0-rc1.0.20240205080301-e74a83d3fadc/utils/weighted_shuffle_test.go (about)

     1  package utils
     2  
     3  import (
     4  	"crypto/sha256"
     5  	"testing"
     6  
     7  	"github.com/stretchr/testify/assert"
     8  	"github.com/unicornultrafoundation/go-helios/common/littleendian"
     9  	"github.com/unicornultrafoundation/go-helios/hash"
    10  	"github.com/unicornultrafoundation/go-helios/native/pos"
    11  )
    12  
    13  func getTestWeightsIncreasing(num int) []pos.Weight {
    14  	weights := make([]pos.Weight, num)
    15  	for i := 0; i < num; i++ {
    16  		weights[i] = pos.Weight(i+1) * 1000
    17  	}
    18  	return weights
    19  }
    20  
    21  func getTestWeightsEqual(num int) []pos.Weight {
    22  	weights := make([]pos.Weight, num)
    23  	for i := 0; i < num; i++ {
    24  		weights[i] = 1000
    25  	}
    26  	return weights
    27  }
    28  
    29  // Test average distribution of the shuffle
    30  func Test_Permutation_distribution(t *testing.T) {
    31  	weightsArr := getTestWeightsIncreasing(30)
    32  
    33  	weightHits := make(map[int]int) // weight -> number of occurrences
    34  	for roundSeed := 0; roundSeed < 3000; roundSeed++ {
    35  		seed := hashOf(hash.Hash{}, uint32(roundSeed))
    36  		perm := WeightedPermutation(len(weightsArr)/10, weightsArr, seed)
    37  		for _, p := range perm {
    38  			weight := weightsArr[p]
    39  			weightFactor := int(weight / 1000)
    40  
    41  			_, ok := weightHits[weightFactor]
    42  			if !ok {
    43  				weightHits[weightFactor] = 0
    44  			}
    45  			weightHits[weightFactor]++
    46  		}
    47  	}
    48  
    49  	assertar := assert.New(t)
    50  	for weightFactor, hits := range weightHits {
    51  		//fmt.Printf("Test_RandomElection_distribution: %d \n", hits/weightFactor)
    52  		assertar.Equal((hits/weightFactor) > 20-8, true)
    53  		assertar.Equal((hits/weightFactor) < 20+8, true)
    54  		if t.Failed() {
    55  			return
    56  		}
    57  	}
    58  }
    59  
    60  // test that WeightedPermutation provides a correct permaition
    61  func testCorrectPermutation(t *testing.T, weightsArr []pos.Weight) {
    62  	assertar := assert.New(t)
    63  
    64  	perm := WeightedPermutation(len(weightsArr), weightsArr, hash.Hash{})
    65  	assertar.Equal(len(weightsArr), len(perm))
    66  
    67  	met := make(map[int]bool)
    68  	for _, p := range perm {
    69  		assertar.True(p >= 0)
    70  		assertar.True(p < len(weightsArr))
    71  		assertar.False(met[p])
    72  		met[p] = true
    73  	}
    74  }
    75  
    76  func Test_Permutation_correctness(t *testing.T) {
    77  	testCorrectPermutation(t, getTestWeightsIncreasing(1))
    78  	testCorrectPermutation(t, getTestWeightsIncreasing(30))
    79  	testCorrectPermutation(t, getTestWeightsEqual(1000))
    80  }
    81  
    82  func hashOf(a hash.Hash, b uint32) hash.Hash {
    83  	hasher := sha256.New()
    84  	hasher.Write(a.Bytes())
    85  	hasher.Write(littleendian.Uint32ToBytes(uint32(b)))
    86  	return hash.FromBytes(hasher.Sum(nil))
    87  }
    88  
    89  func Test_Permutation_determinism(t *testing.T) {
    90  	weightsArr := getTestWeightsIncreasing(5)
    91  
    92  	assertar := assert.New(t)
    93  
    94  	assertar.Equal([]int{4, 0, 1, 2, 3}, WeightedPermutation(len(weightsArr), weightsArr, hashOf(hash.Hash{}, 0)))
    95  	assertar.Equal([]int{2, 4, 3, 1, 0}, WeightedPermutation(len(weightsArr), weightsArr, hashOf(hash.Hash{}, 1)))
    96  	assertar.Equal([]int{4, 2, 3, 1, 0}, WeightedPermutation(len(weightsArr), weightsArr, hashOf(hash.Hash{}, 2)))
    97  	assertar.Equal([]int{0, 2, 1, 3, 4}, WeightedPermutation(len(weightsArr), weightsArr, hashOf(hash.Hash{}, 3)))
    98  	assertar.Equal([]int{1, 2}, WeightedPermutation(len(weightsArr)/2, weightsArr, hashOf(hash.Hash{}, 4)))
    99  }