git.frostfs.info/TrueCloudLab/frostfs-sdk-go@v0.0.0-20241022124111-5361f0ecebd3/pool/sampler.go (about)

     1  package pool
     2  
     3  import (
     4  	"math/rand"
     5  	"sync"
     6  )
     7  
     8  // sampler implements weighted random number generation using Vose's Alias
     9  // Method (https://www.keithschwarz.com/darts-dice-coins/).
    10  type sampler struct {
    11  	mu              sync.Mutex
    12  	randomGenerator *rand.Rand
    13  
    14  	probabilities []float64
    15  	alias         []int
    16  }
    17  
    18  // newSampler creates new sampler with a given set of probabilities using
    19  // given source of randomness. Created sampler will produce numbers from
    20  // 0 to len(probabilities).
    21  func newSampler(probabilities []float64, source rand.Source) *sampler {
    22  	sampler := &sampler{}
    23  	var (
    24  		small workList
    25  		large workList
    26  	)
    27  	n := len(probabilities)
    28  	sampler.randomGenerator = rand.New(source)
    29  	sampler.probabilities = make([]float64, n)
    30  	sampler.alias = make([]int, n)
    31  	// Compute scaled probabilities.
    32  	p := make([]float64, n)
    33  	for i := range n {
    34  		p[i] = probabilities[i] * float64(n)
    35  	}
    36  	for i, pi := range p {
    37  		if pi < 1 {
    38  			small.add(i)
    39  		} else {
    40  			large.add(i)
    41  		}
    42  	}
    43  	for len(small) > 0 && len(large) > 0 {
    44  		l, g := small.remove(), large.remove()
    45  		sampler.probabilities[l] = p[l]
    46  		sampler.alias[l] = g
    47  		p[g] = p[g] + p[l] - 1
    48  		if p[g] < 1 {
    49  			small.add(g)
    50  		} else {
    51  			large.add(g)
    52  		}
    53  	}
    54  	for len(large) > 0 {
    55  		g := large.remove()
    56  		sampler.probabilities[g] = 1
    57  	}
    58  	for len(small) > 0 {
    59  		l := small.remove()
    60  		sampler.probabilities[l] = 1
    61  	}
    62  	return sampler
    63  }
    64  
    65  // Next returns the next (not so) random number from sampler.
    66  // This method is safe for concurrent use by multiple goroutines.
    67  func (g *sampler) Next() int {
    68  	n := len(g.alias)
    69  
    70  	g.mu.Lock()
    71  	i := g.randomGenerator.Intn(n)
    72  	f := g.randomGenerator.Float64()
    73  	g.mu.Unlock()
    74  
    75  	if f < g.probabilities[i] {
    76  		return i
    77  	}
    78  	return g.alias[i]
    79  }
    80  
    81  type workList []int
    82  
    83  func (wl *workList) add(e int) {
    84  	*wl = append(*wl, e)
    85  }
    86  
    87  func (wl *workList) remove() int {
    88  	l := len(*wl) - 1
    89  	n := (*wl)[l]
    90  	*wl = (*wl)[:l]
    91  	return n
    92  }