github.com/MetalBlockchain/metalgo@v1.11.9/utils/sampler/weighted_linear.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package sampler
     5  
     6  import (
     7  	"cmp"
     8  
     9  	"github.com/MetalBlockchain/metalgo/utils"
    10  	"github.com/MetalBlockchain/metalgo/utils/math"
    11  )
    12  
    13  var (
    14  	_ Weighted                              = (*weightedLinear)(nil)
    15  	_ utils.Sortable[weightedLinearElement] = weightedLinearElement{}
    16  )
    17  
    18  type weightedLinearElement struct {
    19  	cumulativeWeight uint64
    20  	index            int
    21  }
    22  
    23  // Note that this sorts in order of decreasing cumulative weight.
    24  func (e weightedLinearElement) Compare(other weightedLinearElement) int {
    25  	return cmp.Compare(other.cumulativeWeight, e.cumulativeWeight)
    26  }
    27  
    28  // Sampling is performed by executing a linear search over the provided elements
    29  // in the order of their probabilistic occurrence.
    30  //
    31  // Initialization takes O(n * log(n)) time, where n is the number of elements
    32  // that can be sampled.
    33  // Sampling can take up to O(n) time. As the distribution becomes more biased,
    34  // sampling will become faster in expectation.
    35  type weightedLinear struct {
    36  	arr []weightedLinearElement
    37  }
    38  
    39  func (s *weightedLinear) Initialize(weights []uint64) error {
    40  	numWeights := len(weights)
    41  	if numWeights <= cap(s.arr) {
    42  		s.arr = s.arr[:numWeights]
    43  	} else {
    44  		s.arr = make([]weightedLinearElement, numWeights)
    45  	}
    46  
    47  	for i, weight := range weights {
    48  		s.arr[i] = weightedLinearElement{
    49  			cumulativeWeight: weight,
    50  			index:            i,
    51  		}
    52  	}
    53  
    54  	// Optimize so that the most probable values are at the front of the array
    55  	utils.Sort(s.arr)
    56  
    57  	for i := 1; i < len(s.arr); i++ {
    58  		newWeight, err := math.Add64(
    59  			s.arr[i-1].cumulativeWeight,
    60  			s.arr[i].cumulativeWeight,
    61  		)
    62  		if err != nil {
    63  			return err
    64  		}
    65  		s.arr[i].cumulativeWeight = newWeight
    66  	}
    67  
    68  	return nil
    69  }
    70  
    71  func (s *weightedLinear) Sample(value uint64) (int, bool) {
    72  	if len(s.arr) == 0 || s.arr[len(s.arr)-1].cumulativeWeight <= value {
    73  		return 0, false
    74  	}
    75  
    76  	index := 0
    77  	for {
    78  		if elem := s.arr[index]; value < elem.cumulativeWeight {
    79  			return elem.index, true
    80  		}
    81  		index++
    82  	}
    83  }