github.com/MetalBlockchain/metalgo@v1.11.9/utils/sampler/weighted_test.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package sampler
     5  
     6  import (
     7  	"fmt"
     8  	"math"
     9  	"testing"
    10  
    11  	"github.com/stretchr/testify/require"
    12  
    13  	safemath "github.com/MetalBlockchain/metalgo/utils/math"
    14  )
    15  
    16  var (
    17  	weightedSamplers = []struct {
    18  		name    string
    19  		sampler Weighted
    20  	}{
    21  		{
    22  			name:    "inverse uniform cdf",
    23  			sampler: &weightedArray{},
    24  		},
    25  		{
    26  			name:    "heap division",
    27  			sampler: &weightedHeap{},
    28  		},
    29  		{
    30  			name:    "linear scan",
    31  			sampler: &weightedLinear{},
    32  		},
    33  		{
    34  			name: "lookup",
    35  			sampler: &weightedUniform{
    36  				maxWeight: 1024,
    37  			},
    38  		},
    39  		{
    40  			name: "best with k=30",
    41  			sampler: &weightedBest{
    42  				samplers: []Weighted{
    43  					&weightedArray{},
    44  					&weightedHeap{},
    45  					&weightedUniform{
    46  						maxWeight: 1024,
    47  					},
    48  				},
    49  				benchmarkIterations: 30,
    50  			},
    51  		},
    52  	}
    53  	weightedTests = []struct {
    54  		name string
    55  		test func(*testing.T, Weighted)
    56  	}{
    57  		{
    58  			name: "initialize overflow",
    59  			test: WeightedInitializeOverflowTest,
    60  		},
    61  		{
    62  			name: "out of range",
    63  			test: WeightedOutOfRangeTest,
    64  		},
    65  		{
    66  			name: "singleton",
    67  			test: WeightedSingletonTest,
    68  		},
    69  		{
    70  			name: "with zero",
    71  			test: WeightedWithZeroTest,
    72  		},
    73  		{
    74  			name: "distribution",
    75  			test: WeightedDistributionTest,
    76  		},
    77  	}
    78  )
    79  
    80  func TestAllWeighted(t *testing.T) {
    81  	for _, s := range weightedSamplers {
    82  		for _, test := range weightedTests {
    83  			t.Run(fmt.Sprintf("sampler %s test %s", s.name, test.name), func(t *testing.T) {
    84  				test.test(t, s.sampler)
    85  			})
    86  		}
    87  	}
    88  }
    89  
    90  func WeightedInitializeOverflowTest(t *testing.T, s Weighted) {
    91  	err := s.Initialize([]uint64{1, math.MaxUint64})
    92  	require.ErrorIs(t, err, safemath.ErrOverflow)
    93  }
    94  
    95  func WeightedOutOfRangeTest(t *testing.T, s Weighted) {
    96  	require := require.New(t)
    97  
    98  	require.NoError(s.Initialize([]uint64{1}))
    99  
   100  	_, ok := s.Sample(1)
   101  	require.False(ok)
   102  }
   103  
   104  func WeightedSingletonTest(t *testing.T, s Weighted) {
   105  	require := require.New(t)
   106  
   107  	require.NoError(s.Initialize([]uint64{1}))
   108  
   109  	index, ok := s.Sample(0)
   110  	require.True(ok)
   111  	require.Zero(index)
   112  }
   113  
   114  func WeightedWithZeroTest(t *testing.T, s Weighted) {
   115  	require := require.New(t)
   116  
   117  	require.NoError(s.Initialize([]uint64{0, 1}))
   118  
   119  	index, ok := s.Sample(0)
   120  	require.True(ok)
   121  	require.Equal(1, index)
   122  }
   123  
   124  func WeightedDistributionTest(t *testing.T, s Weighted) {
   125  	require := require.New(t)
   126  
   127  	require.NoError(s.Initialize([]uint64{1, 1, 2, 3, 4}))
   128  
   129  	counts := make([]int, 5)
   130  	for i := uint64(0); i < 11; i++ {
   131  		index, ok := s.Sample(i)
   132  		require.True(ok)
   133  		counts[index]++
   134  	}
   135  	require.Equal([]int{1, 1, 2, 3, 4}, counts)
   136  }