github.com/lyft/flytestdlib@v0.3.12-0.20210213045714-8cdd111ecda1/random/weighted_random_list.go (about) 1 package random 2 3 import ( 4 "context" 5 "fmt" 6 "math/rand" 7 "sort" 8 "time" 9 10 "github.com/lyft/flytestdlib/logger" 11 ) 12 13 //go:generate mockery -all -case=underscore 14 15 // Interface to use the Weighted Random 16 type WeightedRandomList interface { 17 Get() Comparable 18 GetWithSeed(seed rand.Source) (Comparable, error) 19 List() []Comparable 20 Len() int 21 } 22 23 // Interface for items that can be used along with WeightedRandomList 24 type Comparable interface { 25 Compare(to Comparable) bool 26 } 27 28 // Structure of each entry to select from 29 type Entry struct { 30 Item Comparable 31 Weight float32 32 } 33 34 type internalEntry struct { 35 entry Entry 36 currentTotal float32 37 } 38 39 // WeightedRandomList selects elements randomly from the list taking into account individual weights. 40 // Weight has to be assigned between 0 and 1. 41 // Support deterministic results when given a particular seed source 42 type weightedRandomListImpl struct { 43 entries []internalEntry 44 totalWeight float32 45 } 46 47 func validateEntries(entries []Entry) error { 48 if len(entries) == 0 { 49 return fmt.Errorf("entries is empty") 50 } 51 for index, entry := range entries { 52 if entry.Item == nil { 53 return fmt.Errorf("invalid entry: nil, index %d", index) 54 } 55 if entry.Weight < 0 || entry.Weight > float32(1) { 56 return fmt.Errorf("invalid weight %f, index %d", entry.Weight, index) 57 } 58 } 59 return nil 60 } 61 62 // Given a list of entries with weights, returns WeightedRandomList 63 func NewWeightedRandom(ctx context.Context, entries []Entry) (WeightedRandomList, error) { 64 err := validateEntries(entries) 65 if err != nil { 66 return nil, err 67 } 68 69 sort.Slice(entries, func(i, j int) bool { 70 return entries[i].Item.Compare(entries[j].Item) 71 }) 72 var internalEntries []internalEntry 73 numberOfEntries := len(entries) 74 totalWeight := float32(0) 75 for _, e := range entries { 76 totalWeight += e.Weight 77 } 78 79 currentTotal := float32(0) 80 for _, e := range entries { 81 if totalWeight == 0 { 82 // This indicates that none of the entries have weight assigned. 83 // We will assign equal weights to everyone 84 currentTotal += 1.0 / float32(numberOfEntries) 85 } else if e.Weight == 0 { 86 // Entries which have zero weight are ignored 87 logger.Debug(ctx, "ignoring entry due to empty weight %v", e) 88 continue 89 } 90 91 currentTotal += e.Weight 92 internalEntries = append(internalEntries, internalEntry{ 93 entry: e, 94 currentTotal: currentTotal, 95 }) 96 } 97 98 return &weightedRandomListImpl{ 99 entries: internalEntries, 100 totalWeight: currentTotal, 101 }, nil 102 } 103 104 func (w *weightedRandomListImpl) get(generator *rand.Rand) Comparable { 105 randomWeight := generator.Float32() * w.totalWeight 106 for _, e := range w.entries { 107 if e.currentTotal >= randomWeight && e.currentTotal > 0 { 108 return e.entry.Item 109 } 110 } 111 return w.entries[len(w.entries)-1].entry.Item 112 } 113 114 // Returns a random entry based on the weights 115 func (w *weightedRandomListImpl) Get() Comparable { 116 randGenerator := rand.New(rand.NewSource(time.Now().UTC().UnixNano())) 117 return w.get(randGenerator) 118 } 119 120 // For a given seed, the same entry will be returned all the time. 121 func (w *weightedRandomListImpl) GetWithSeed(seed rand.Source) (Comparable, error) { 122 randGenerator := rand.New(seed) 123 return w.get(randGenerator), nil 124 } 125 126 // Lists all the entries that are eligible for selection 127 func (w *weightedRandomListImpl) List() []Comparable { 128 entries := make([]Comparable, len(w.entries)) 129 for index, indexedItem := range w.entries { 130 entries[index] = indexedItem.entry.Item 131 } 132 return entries 133 } 134 135 // Gets the number of items that are being considered for selection. 136 func (w *weightedRandomListImpl) Len() int { 137 return len(w.entries) 138 }