github.com/lyft/flytestdlib@v0.3.12-0.20210213045714-8cdd111ecda1/random/weighted_random_list.go (about)

     1  package random
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"math/rand"
     7  	"sort"
     8  	"time"
     9  
    10  	"github.com/lyft/flytestdlib/logger"
    11  )
    12  
    13  //go:generate mockery -all -case=underscore
    14  
    15  // Interface to use the Weighted Random
    16  type WeightedRandomList interface {
    17  	Get() Comparable
    18  	GetWithSeed(seed rand.Source) (Comparable, error)
    19  	List() []Comparable
    20  	Len() int
    21  }
    22  
    23  // Interface for items that can be used along with WeightedRandomList
    24  type Comparable interface {
    25  	Compare(to Comparable) bool
    26  }
    27  
    28  // Structure of each entry to select from
    29  type Entry struct {
    30  	Item   Comparable
    31  	Weight float32
    32  }
    33  
    34  type internalEntry struct {
    35  	entry        Entry
    36  	currentTotal float32
    37  }
    38  
    39  // WeightedRandomList selects elements randomly from the list taking into account individual weights.
    40  // Weight has to be assigned between 0 and 1.
    41  // Support deterministic results when given a particular seed source
    42  type weightedRandomListImpl struct {
    43  	entries     []internalEntry
    44  	totalWeight float32
    45  }
    46  
    47  func validateEntries(entries []Entry) error {
    48  	if len(entries) == 0 {
    49  		return fmt.Errorf("entries is empty")
    50  	}
    51  	for index, entry := range entries {
    52  		if entry.Item == nil {
    53  			return fmt.Errorf("invalid entry: nil, index %d", index)
    54  		}
    55  		if entry.Weight < 0 || entry.Weight > float32(1) {
    56  			return fmt.Errorf("invalid weight %f, index %d", entry.Weight, index)
    57  		}
    58  	}
    59  	return nil
    60  }
    61  
    62  // Given a list of entries with weights, returns WeightedRandomList
    63  func NewWeightedRandom(ctx context.Context, entries []Entry) (WeightedRandomList, error) {
    64  	err := validateEntries(entries)
    65  	if err != nil {
    66  		return nil, err
    67  	}
    68  
    69  	sort.Slice(entries, func(i, j int) bool {
    70  		return entries[i].Item.Compare(entries[j].Item)
    71  	})
    72  	var internalEntries []internalEntry
    73  	numberOfEntries := len(entries)
    74  	totalWeight := float32(0)
    75  	for _, e := range entries {
    76  		totalWeight += e.Weight
    77  	}
    78  
    79  	currentTotal := float32(0)
    80  	for _, e := range entries {
    81  		if totalWeight == 0 {
    82  			// This indicates that none of the entries have weight assigned.
    83  			// We will assign equal weights to everyone
    84  			currentTotal += 1.0 / float32(numberOfEntries)
    85  		} else if e.Weight == 0 {
    86  			// Entries which have zero weight are ignored
    87  			logger.Debug(ctx, "ignoring entry due to empty weight %v", e)
    88  			continue
    89  		}
    90  
    91  		currentTotal += e.Weight
    92  		internalEntries = append(internalEntries, internalEntry{
    93  			entry:        e,
    94  			currentTotal: currentTotal,
    95  		})
    96  	}
    97  
    98  	return &weightedRandomListImpl{
    99  		entries:     internalEntries,
   100  		totalWeight: currentTotal,
   101  	}, nil
   102  }
   103  
   104  func (w *weightedRandomListImpl) get(generator *rand.Rand) Comparable {
   105  	randomWeight := generator.Float32() * w.totalWeight
   106  	for _, e := range w.entries {
   107  		if e.currentTotal >= randomWeight && e.currentTotal > 0 {
   108  			return e.entry.Item
   109  		}
   110  	}
   111  	return w.entries[len(w.entries)-1].entry.Item
   112  }
   113  
   114  // Returns a random entry based on the weights
   115  func (w *weightedRandomListImpl) Get() Comparable {
   116  	randGenerator := rand.New(rand.NewSource(time.Now().UTC().UnixNano()))
   117  	return w.get(randGenerator)
   118  }
   119  
   120  // For a given seed, the same entry will be returned all the time.
   121  func (w *weightedRandomListImpl) GetWithSeed(seed rand.Source) (Comparable, error) {
   122  	randGenerator := rand.New(seed)
   123  	return w.get(randGenerator), nil
   124  }
   125  
   126  // Lists all the entries that are eligible for selection
   127  func (w *weightedRandomListImpl) List() []Comparable {
   128  	entries := make([]Comparable, len(w.entries))
   129  	for index, indexedItem := range w.entries {
   130  		entries[index] = indexedItem.entry.Item
   131  	}
   132  	return entries
   133  }
   134  
   135  // Gets the number of items that are being considered for selection.
   136  func (w *weightedRandomListImpl) Len() int {
   137  	return len(w.entries)
   138  }