github.com/MetalBlockchain/metalgo@v1.11.9/utils/sampler/weighted_array.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package sampler
     5  
     6  import (
     7  	"cmp"
     8  
     9  	"github.com/MetalBlockchain/metalgo/utils"
    10  	"github.com/MetalBlockchain/metalgo/utils/math"
    11  )
    12  
    13  var (
    14  	_ Weighted                             = (*weightedArray)(nil)
    15  	_ utils.Sortable[weightedArrayElement] = weightedArrayElement{}
    16  )
    17  
    18  type weightedArrayElement struct {
    19  	cumulativeWeight uint64
    20  	index            int
    21  }
    22  
    23  // Note that this sorts in order of decreasing weight.
    24  func (e weightedArrayElement) Compare(other weightedArrayElement) int {
    25  	return cmp.Compare(other.cumulativeWeight, e.cumulativeWeight)
    26  }
    27  
    28  // Sampling is performed by executing a modified binary search over the provided
    29  // elements. Rather than cutting the remaining dataset in half, the algorithm
    30  // attempt to just in to where it think the value will be assuming a linear
    31  // distribution of the element weights.
    32  //
    33  // Initialization takes O(n * log(n)) time, where n is the number of elements
    34  // that can be sampled.
    35  // Sampling can take up to O(n) time. If the distribution is linearly
    36  // distributed, then the runtime is constant.
    37  type weightedArray struct {
    38  	arr []weightedArrayElement
    39  }
    40  
    41  func (s *weightedArray) Initialize(weights []uint64) error {
    42  	numWeights := len(weights)
    43  	if numWeights <= cap(s.arr) {
    44  		s.arr = s.arr[:numWeights]
    45  	} else {
    46  		s.arr = make([]weightedArrayElement, numWeights)
    47  	}
    48  
    49  	for i, weight := range weights {
    50  		s.arr[i] = weightedArrayElement{
    51  			cumulativeWeight: weight,
    52  			index:            i,
    53  		}
    54  	}
    55  
    56  	// Optimize so that the array is closer to the uniform distribution
    57  	utils.Sort(s.arr)
    58  
    59  	maxIndex := len(s.arr) - 1
    60  	oneIfOdd := 1 & maxIndex
    61  	oneIfEven := 1 - oneIfOdd
    62  	end := maxIndex - oneIfEven
    63  	for i := 1; i < end; i += 2 {
    64  		s.arr[i], s.arr[end] = s.arr[end], s.arr[i]
    65  		end -= 2
    66  	}
    67  
    68  	cumulativeWeight := uint64(0)
    69  	for i := 0; i < len(s.arr); i++ {
    70  		newWeight, err := math.Add64(
    71  			cumulativeWeight,
    72  			s.arr[i].cumulativeWeight,
    73  		)
    74  		if err != nil {
    75  			return err
    76  		}
    77  		cumulativeWeight = newWeight
    78  		s.arr[i].cumulativeWeight = cumulativeWeight
    79  	}
    80  
    81  	return nil
    82  }
    83  
    84  func (s *weightedArray) Sample(value uint64) (int, bool) {
    85  	if len(s.arr) == 0 || s.arr[len(s.arr)-1].cumulativeWeight <= value {
    86  		return 0, false
    87  	}
    88  	minIndex := 0
    89  	maxIndex := len(s.arr) - 1
    90  	maxCumulativeWeight := float64(s.arr[len(s.arr)-1].cumulativeWeight)
    91  	index := int((float64(value) * float64(maxIndex+1)) / maxCumulativeWeight)
    92  
    93  	for {
    94  		previousWeight := uint64(0)
    95  		if index > 0 {
    96  			previousWeight = s.arr[index-1].cumulativeWeight
    97  		}
    98  		currentElem := s.arr[index]
    99  		currentWeight := currentElem.cumulativeWeight
   100  		if previousWeight <= value && value < currentWeight {
   101  			return currentElem.index, true
   102  		}
   103  
   104  		if value < previousWeight {
   105  			// go to the left
   106  			maxIndex = index - 1
   107  		} else {
   108  			// go to the right
   109  			minIndex = index + 1
   110  		}
   111  
   112  		minWeight := uint64(0)
   113  		if minIndex > 0 {
   114  			minWeight = s.arr[minIndex-1].cumulativeWeight
   115  		}
   116  		maxWeight := s.arr[maxIndex].cumulativeWeight
   117  
   118  		valueRange := maxWeight - minWeight
   119  		adjustedLookupValue := value - minWeight
   120  		indexRange := maxIndex - minIndex + 1
   121  		lookupMass := float64(adjustedLookupValue) * float64(indexRange)
   122  
   123  		index = int(lookupMass/float64(valueRange)) + minIndex
   124  	}
   125  }