github.com/MetalBlockchain/metalgo@v1.11.9/utils/sampler/rand.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package sampler
     5  
     6  import (
     7  	"math"
     8  	"sync"
     9  	"time"
    10  
    11  	"gonum.org/v1/gonum/mathext/prng"
    12  )
    13  
    14  var globalRNG = newRNG()
    15  
    16  func newRNG() *rng {
    17  	// We don't use a cryptographically secure source of randomness here, as
    18  	// there's no need to ensure a truly random sampling.
    19  	source := prng.NewMT19937()
    20  	source.Seed(uint64(time.Now().UnixNano()))
    21  	return &rng{rng: source}
    22  }
    23  
    24  type rng struct {
    25  	lock sync.Mutex
    26  	rng  Source
    27  }
    28  
    29  type Source interface {
    30  	// Uint64 returns a random number in [0, MaxUint64] and advances the
    31  	// generator's state.
    32  	Uint64() uint64
    33  }
    34  
    35  // Uint64Inclusive returns a pseudo-random number in [0,n].
    36  //
    37  // Invariant: The result of this function is stored in chain state, so any
    38  // modifications are considered breaking.
    39  func (r *rng) Uint64Inclusive(n uint64) uint64 {
    40  	switch {
    41  	// n+1 is power of two, so we can just mask
    42  	//
    43  	// Note: This does work for MaxUint64 as overflow is explicitly part of the
    44  	// compiler specification: https://go.dev/ref/spec#Integer_overflow
    45  	case n&(n+1) == 0:
    46  		return r.uint64() & n
    47  
    48  	// n is greater than MaxUint64/2 so we need to just iterate until we get a
    49  	// number in the requested range.
    50  	case n > math.MaxInt64:
    51  		v := r.uint64()
    52  		for v > n {
    53  			v = r.uint64()
    54  		}
    55  		return v
    56  
    57  	// n is less than MaxUint64/2 so we generate a number in the range
    58  	// [0, k*(n+1)) where k is the largest integer such that k*(n+1) is less
    59  	// than or equal to MaxUint64/2. We can't easily find k such that k*(n+1) is
    60  	// less than or equal to MaxUint64 because the calculation would overflow.
    61  	//
    62  	// ref: https://github.com/golang/go/blob/ce10e9d84574112b224eae88dc4e0f43710808de/src/math/rand/rand.go#L127-L132
    63  	default:
    64  		max := (1 << 63) - 1 - (1<<63)%(n+1)
    65  		v := r.uint63()
    66  		for v > max {
    67  			v = r.uint63()
    68  		}
    69  		return v % (n + 1)
    70  	}
    71  }
    72  
    73  // uint63 returns a random number in [0, MaxInt64]
    74  func (r *rng) uint63() uint64 {
    75  	return r.uint64() & math.MaxInt64
    76  }
    77  
    78  // uint64 returns a random number in [0, MaxUint64]
    79  func (r *rng) uint64() uint64 {
    80  	// Note: We must grab a write lock here because rng.Uint64 internally
    81  	// modifies state.
    82  	r.lock.Lock()
    83  	n := r.rng.Uint64()
    84  	r.lock.Unlock()
    85  	return n
    86  }