github.com/MetalBlockchain/metalgo@v1.11.9/utils/sampler/weighted_test.go (about) 1 // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. 2 // See the file LICENSE for licensing terms. 3 4 package sampler 5 6 import ( 7 "fmt" 8 "math" 9 "testing" 10 11 "github.com/stretchr/testify/require" 12 13 safemath "github.com/MetalBlockchain/metalgo/utils/math" 14 ) 15 16 var ( 17 weightedSamplers = []struct { 18 name string 19 sampler Weighted 20 }{ 21 { 22 name: "inverse uniform cdf", 23 sampler: &weightedArray{}, 24 }, 25 { 26 name: "heap division", 27 sampler: &weightedHeap{}, 28 }, 29 { 30 name: "linear scan", 31 sampler: &weightedLinear{}, 32 }, 33 { 34 name: "lookup", 35 sampler: &weightedUniform{ 36 maxWeight: 1024, 37 }, 38 }, 39 { 40 name: "best with k=30", 41 sampler: &weightedBest{ 42 samplers: []Weighted{ 43 &weightedArray{}, 44 &weightedHeap{}, 45 &weightedUniform{ 46 maxWeight: 1024, 47 }, 48 }, 49 benchmarkIterations: 30, 50 }, 51 }, 52 } 53 weightedTests = []struct { 54 name string 55 test func(*testing.T, Weighted) 56 }{ 57 { 58 name: "initialize overflow", 59 test: WeightedInitializeOverflowTest, 60 }, 61 { 62 name: "out of range", 63 test: WeightedOutOfRangeTest, 64 }, 65 { 66 name: "singleton", 67 test: WeightedSingletonTest, 68 }, 69 { 70 name: "with zero", 71 test: WeightedWithZeroTest, 72 }, 73 { 74 name: "distribution", 75 test: WeightedDistributionTest, 76 }, 77 } 78 ) 79 80 func TestAllWeighted(t *testing.T) { 81 for _, s := range weightedSamplers { 82 for _, test := range weightedTests { 83 t.Run(fmt.Sprintf("sampler %s test %s", s.name, test.name), func(t *testing.T) { 84 test.test(t, s.sampler) 85 }) 86 } 87 } 88 } 89 90 func WeightedInitializeOverflowTest(t *testing.T, s Weighted) { 91 err := s.Initialize([]uint64{1, math.MaxUint64}) 92 require.ErrorIs(t, err, safemath.ErrOverflow) 93 } 94 95 func WeightedOutOfRangeTest(t *testing.T, s Weighted) { 96 require := require.New(t) 97 98 require.NoError(s.Initialize([]uint64{1})) 99 100 _, ok := s.Sample(1) 101 require.False(ok) 102 } 103 104 func WeightedSingletonTest(t *testing.T, s Weighted) { 105 require := require.New(t) 106 107 require.NoError(s.Initialize([]uint64{1})) 108 109 index, ok := s.Sample(0) 110 require.True(ok) 111 require.Zero(index) 112 } 113 114 func WeightedWithZeroTest(t *testing.T, s Weighted) { 115 require := require.New(t) 116 117 require.NoError(s.Initialize([]uint64{0, 1})) 118 119 index, ok := s.Sample(0) 120 require.True(ok) 121 require.Equal(1, index) 122 } 123 124 func WeightedDistributionTest(t *testing.T, s Weighted) { 125 require := require.New(t) 126 127 require.NoError(s.Initialize([]uint64{1, 1, 2, 3, 4})) 128 129 counts := make([]int, 5) 130 for i := uint64(0); i < 11; i++ { 131 index, ok := s.Sample(i) 132 require.True(ok) 133 counts[index]++ 134 } 135 require.Equal([]int{1, 1, 2, 3, 4}, counts) 136 }