github.com/MetalBlockchain/metalgo@v1.11.9/utils/sampler/uniform_replacer.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package sampler
     5  
     6  type defaultMap map[uint64]uint64
     7  
     8  func (m defaultMap) get(key uint64, defaultVal uint64) uint64 {
     9  	if val, ok := m[key]; ok {
    10  		return val
    11  	}
    12  	return defaultVal
    13  }
    14  
    15  // uniformReplacer allows for sampling over a uniform distribution without
    16  // replacement.
    17  //
    18  // Sampling is performed by lazily performing an array shuffle of the array
    19  // [0, 1, ..., length - 1]. By performing the first count swaps of this shuffle,
    20  // we can create an array of length count with elements sampled with uniform
    21  // probability.
    22  //
    23  // Initialization takes O(1) time.
    24  //
    25  // Sampling is performed in O(count) time and O(count) space.
    26  type uniformReplacer struct {
    27  	rng        *rng
    28  	length     uint64
    29  	drawn      defaultMap
    30  	drawsCount uint64
    31  }
    32  
    33  func (s *uniformReplacer) Initialize(length uint64) {
    34  	s.length = length
    35  	s.drawn = make(defaultMap)
    36  	s.drawsCount = 0
    37  }
    38  
    39  func (s *uniformReplacer) Sample(count int) ([]uint64, bool) {
    40  	s.Reset()
    41  
    42  	results := make([]uint64, count)
    43  	for i := 0; i < count; i++ {
    44  		ret, hasNext := s.Next()
    45  		if !hasNext {
    46  			return nil, false
    47  		}
    48  		results[i] = ret
    49  	}
    50  	return results, true
    51  }
    52  
    53  func (s *uniformReplacer) Reset() {
    54  	clear(s.drawn)
    55  	s.drawsCount = 0
    56  }
    57  
    58  func (s *uniformReplacer) Next() (uint64, bool) {
    59  	if s.drawsCount >= s.length {
    60  		return 0, false
    61  	}
    62  
    63  	draw := s.rng.Uint64Inclusive(s.length-1-s.drawsCount) + s.drawsCount
    64  	ret := s.drawn.get(draw, draw)
    65  	s.drawn[draw] = s.drawn.get(s.drawsCount, s.drawsCount)
    66  	s.drawsCount++
    67  
    68  	return ret, true
    69  }