git.frostfs.info/TrueCloudLab/frostfs-sdk-go@v0.0.0-20241022124111-5361f0ecebd3/pool/sampler.go (about) 1 package pool 2 3 import ( 4 "math/rand" 5 "sync" 6 ) 7 8 // sampler implements weighted random number generation using Vose's Alias 9 // Method (https://www.keithschwarz.com/darts-dice-coins/). 10 type sampler struct { 11 mu sync.Mutex 12 randomGenerator *rand.Rand 13 14 probabilities []float64 15 alias []int 16 } 17 18 // newSampler creates new sampler with a given set of probabilities using 19 // given source of randomness. Created sampler will produce numbers from 20 // 0 to len(probabilities). 21 func newSampler(probabilities []float64, source rand.Source) *sampler { 22 sampler := &sampler{} 23 var ( 24 small workList 25 large workList 26 ) 27 n := len(probabilities) 28 sampler.randomGenerator = rand.New(source) 29 sampler.probabilities = make([]float64, n) 30 sampler.alias = make([]int, n) 31 // Compute scaled probabilities. 32 p := make([]float64, n) 33 for i := range n { 34 p[i] = probabilities[i] * float64(n) 35 } 36 for i, pi := range p { 37 if pi < 1 { 38 small.add(i) 39 } else { 40 large.add(i) 41 } 42 } 43 for len(small) > 0 && len(large) > 0 { 44 l, g := small.remove(), large.remove() 45 sampler.probabilities[l] = p[l] 46 sampler.alias[l] = g 47 p[g] = p[g] + p[l] - 1 48 if p[g] < 1 { 49 small.add(g) 50 } else { 51 large.add(g) 52 } 53 } 54 for len(large) > 0 { 55 g := large.remove() 56 sampler.probabilities[g] = 1 57 } 58 for len(small) > 0 { 59 l := small.remove() 60 sampler.probabilities[l] = 1 61 } 62 return sampler 63 } 64 65 // Next returns the next (not so) random number from sampler. 66 // This method is safe for concurrent use by multiple goroutines. 67 func (g *sampler) Next() int { 68 n := len(g.alias) 69 70 g.mu.Lock() 71 i := g.randomGenerator.Intn(n) 72 f := g.randomGenerator.Float64() 73 g.mu.Unlock() 74 75 if f < g.probabilities[i] { 76 return i 77 } 78 return g.alias[i] 79 } 80 81 type workList []int 82 83 func (wl *workList) add(e int) { 84 *wl = append(*wl, e) 85 } 86 87 func (wl *workList) remove() int { 88 l := len(*wl) - 1 89 n := (*wl)[l] 90 *wl = (*wl)[:l] 91 return n 92 }