github.com/zuoyebang/bitalostable@v1.0.1-0.20240229032404-e3b99a834294/internal/randvar/deck.go (about)

     1  // Copyright 2019 The LevelDB-Go and Pebble Authors. All rights reserved. Use
     2  // of this source code is governed by a BSD-style license that can be found in
     3  // the LICENSE file.
     4  
     5  package randvar
     6  
     7  import (
     8  	"sync"
     9  
    10  	"golang.org/x/exp/rand"
    11  )
    12  
    13  // Deck is a random number generator that generates numbers in the range
    14  // [0,len(weights)-1] where the probability of i is
    15  // weights(i)/sum(weights). Unlike Weighted, the weights are specified as
    16  // integers and used in a deck-of-cards style random number selection which
    17  // ensures that each element is returned with a desired frequency within the
    18  // size of the deck.
    19  type Deck struct {
    20  	rng *rand.Rand
    21  	mu  struct {
    22  		sync.Mutex
    23  		index int
    24  		deck  []int
    25  	}
    26  }
    27  
    28  // NewDeck returns a new deck random number generator.
    29  func NewDeck(rng *rand.Rand, weights ...int) *Deck {
    30  	var sum int
    31  	for i := range weights {
    32  		sum += weights[i]
    33  	}
    34  	deck := make([]int, 0, sum)
    35  	for i := range weights {
    36  		for j := 0; j < weights[i]; j++ {
    37  			deck = append(deck, i)
    38  		}
    39  	}
    40  	d := &Deck{
    41  		rng: ensureRand(rng),
    42  	}
    43  	d.mu.index = len(deck)
    44  	d.mu.deck = deck
    45  	return d
    46  }
    47  
    48  // Int returns a random number in the range [0,len(weights)-1] where the
    49  // probability of i is weights(i)/sum(weights).
    50  func (d *Deck) Int() int {
    51  	d.mu.Lock()
    52  	if d.mu.index == len(d.mu.deck) {
    53  		d.rng.Shuffle(len(d.mu.deck), func(i, j int) {
    54  			d.mu.deck[i], d.mu.deck[j] = d.mu.deck[j], d.mu.deck[i]
    55  		})
    56  		d.mu.index = 0
    57  	}
    58  	result := d.mu.deck[d.mu.index]
    59  	d.mu.index++
    60  	d.mu.Unlock()
    61  	return result
    62  }