github.com/MetalBlockchain/metalgo@v1.11.9/utils/sampler/rand_test.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package sampler
     5  
     6  import (
     7  	"math"
     8  	"math/rand"
     9  	"strconv"
    10  	"testing"
    11  
    12  	"github.com/stretchr/testify/require"
    13  	"github.com/thepudds/fzgen/fuzzer"
    14  	"gonum.org/v1/gonum/mathext/prng"
    15  )
    16  
    17  type testSource struct {
    18  	onInvalid func()
    19  	nums      []uint64
    20  }
    21  
    22  func (s *testSource) Seed(uint64) {
    23  	s.onInvalid()
    24  }
    25  
    26  func (s *testSource) Uint64() uint64 {
    27  	if len(s.nums) == 0 {
    28  		s.onInvalid()
    29  	}
    30  	num := s.nums[0]
    31  	s.nums = s.nums[1:]
    32  	return num
    33  }
    34  
    35  type testSTDSource struct {
    36  	onInvalid func()
    37  	nums      []uint64
    38  }
    39  
    40  func (s *testSTDSource) Seed(int64) {
    41  	s.onInvalid()
    42  }
    43  
    44  func (s *testSTDSource) Int63() int64 {
    45  	return int64(s.Uint64() & (1<<63 - 1))
    46  }
    47  
    48  func (s *testSTDSource) Uint64() uint64 {
    49  	if len(s.nums) == 0 {
    50  		s.onInvalid()
    51  	}
    52  	num := s.nums[0]
    53  	s.nums = s.nums[1:]
    54  	return num
    55  }
    56  
    57  func TestRNG(t *testing.T) {
    58  	tests := []struct {
    59  		max      uint64
    60  		nums     []uint64
    61  		expected uint64
    62  	}{
    63  		{
    64  			max: math.MaxUint64,
    65  			nums: []uint64{
    66  				0x01,
    67  			},
    68  			expected: 0x01,
    69  		},
    70  		{
    71  			max: math.MaxUint64,
    72  			nums: []uint64{
    73  				0x0102030405060708,
    74  			},
    75  			expected: 0x0102030405060708,
    76  		},
    77  		{
    78  			max: math.MaxUint64,
    79  			nums: []uint64{
    80  				0xF102030405060708,
    81  			},
    82  			expected: 0xF102030405060708,
    83  		},
    84  		{
    85  			max: math.MaxInt64,
    86  			nums: []uint64{
    87  				0x01,
    88  			},
    89  			expected: 0x01,
    90  		},
    91  		{
    92  			max: math.MaxInt64,
    93  			nums: []uint64{
    94  				0x0102030405060708,
    95  			},
    96  			expected: 0x0102030405060708,
    97  		},
    98  		{
    99  			max: math.MaxInt64,
   100  			nums: []uint64{
   101  				0x8102030405060708,
   102  			},
   103  			expected: 0x0102030405060708,
   104  		},
   105  		{
   106  			max: 15,
   107  			nums: []uint64{
   108  				0x810203040506071a,
   109  			},
   110  			expected: 0x0a,
   111  		},
   112  		{
   113  			max: math.MaxInt64 + 1,
   114  			nums: []uint64{
   115  				math.MaxInt64 + 1,
   116  			},
   117  			expected: math.MaxInt64 + 1,
   118  		},
   119  		{
   120  			max: math.MaxInt64 + 1,
   121  			nums: []uint64{
   122  				math.MaxInt64 + 2,
   123  				0,
   124  			},
   125  			expected: 0,
   126  		},
   127  		{
   128  			max: math.MaxInt64 + 1,
   129  			nums: []uint64{
   130  				math.MaxInt64 + 2,
   131  				0x0102030405060708,
   132  			},
   133  			expected: 0x0102030405060708,
   134  		},
   135  		{
   136  			max: 2,
   137  			nums: []uint64{
   138  				math.MaxInt64 - 2,
   139  			},
   140  			expected: 0x02,
   141  		},
   142  		{
   143  			max: 2,
   144  			nums: []uint64{
   145  				math.MaxInt64 - 1,
   146  				0x01,
   147  			},
   148  			expected: 0x01,
   149  		},
   150  	}
   151  	for i, test := range tests {
   152  		t.Run(strconv.Itoa(i), func(t *testing.T) {
   153  			require := require.New(t)
   154  
   155  			source := &testSource{
   156  				onInvalid: t.FailNow,
   157  				nums:      test.nums,
   158  			}
   159  			r := &rng{rng: source}
   160  			val := r.Uint64Inclusive(test.max)
   161  			require.Equal(test.expected, val)
   162  			require.Empty(source.nums)
   163  
   164  			if test.max >= math.MaxInt64 {
   165  				return
   166  			}
   167  
   168  			stdSource := &testSTDSource{
   169  				onInvalid: t.FailNow,
   170  				nums:      test.nums,
   171  			}
   172  			mathRNG := rand.New(stdSource) //#nosec G404
   173  			stdVal := mathRNG.Int63n(int64(test.max + 1))
   174  			require.Equal(test.expected, uint64(stdVal))
   175  			require.Empty(source.nums)
   176  		})
   177  	}
   178  }
   179  
   180  func FuzzRNG(f *testing.F) {
   181  	f.Fuzz(func(t *testing.T, data []byte) {
   182  		require := require.New(t)
   183  
   184  		var (
   185  			max        uint64
   186  			sourceNums []uint64
   187  		)
   188  		fz := fuzzer.NewFuzzer(data)
   189  		fz.Fill(&max, &sourceNums)
   190  		if max >= math.MaxInt64 {
   191  			t.SkipNow()
   192  		}
   193  
   194  		source := &testSource{
   195  			onInvalid: t.SkipNow,
   196  			nums:      sourceNums,
   197  		}
   198  		r := &rng{rng: source}
   199  		val := r.Uint64Inclusive(max)
   200  
   201  		stdSource := &testSTDSource{
   202  			onInvalid: t.SkipNow,
   203  			nums:      sourceNums,
   204  		}
   205  		mathRNG := rand.New(stdSource) //#nosec G404
   206  		stdVal := mathRNG.Int63n(int64(max + 1))
   207  		require.Equal(val, uint64(stdVal))
   208  		require.Len(stdSource.nums, len(source.nums))
   209  	})
   210  }
   211  
   212  func BenchmarkSeed32(b *testing.B) {
   213  	source := prng.NewMT19937()
   214  
   215  	b.ResetTimer()
   216  	for i := 0; i < b.N; i++ {
   217  		source.Seed(0)
   218  	}
   219  }
   220  
   221  func BenchmarkSeed64(b *testing.B) {
   222  	source := prng.NewMT19937_64()
   223  
   224  	b.ResetTimer()
   225  	for i := 0; i < b.N; i++ {
   226  		source.Seed(0)
   227  	}
   228  }