github.com/MetalBlockchain/metalgo@v1.11.9/utils/sampler/weighted_heap.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package sampler
     5  
     6  import (
     7  	"cmp"
     8  
     9  	"github.com/MetalBlockchain/metalgo/utils"
    10  	"github.com/MetalBlockchain/metalgo/utils/math"
    11  )
    12  
    13  var (
    14  	_ Weighted                            = (*weightedHeap)(nil)
    15  	_ utils.Sortable[weightedHeapElement] = weightedHeapElement{}
    16  )
    17  
    18  type weightedHeapElement struct {
    19  	weight           uint64
    20  	cumulativeWeight uint64
    21  	index            int
    22  }
    23  
    24  // Compare the elements. Weight is in decreasing order. Index is in increasing
    25  // order.
    26  func (e weightedHeapElement) Compare(other weightedHeapElement) int {
    27  	// By accounting for the initial index of the weights, this results in a
    28  	// stable sort. We do this rather than using `sort.Stable` because of the
    29  	// reported change in performance of the sort used.
    30  	if weightCmp := cmp.Compare(other.weight, e.weight); weightCmp != 0 {
    31  		return weightCmp
    32  	}
    33  	return cmp.Compare(e.index, other.index)
    34  }
    35  
    36  // Sampling is performed by executing a search over a tree of elements in the
    37  // order of their probabilistic occurrence.
    38  //
    39  // Initialization takes O(n * log(n)) time, where n is the number of elements
    40  // that can be sampled.
    41  // Sampling can take up to O(log(n)) time. As the distribution becomes more
    42  // biased, sampling will become faster in expectation.
    43  type weightedHeap struct {
    44  	heap []weightedHeapElement
    45  }
    46  
    47  func (s *weightedHeap) Initialize(weights []uint64) error {
    48  	numWeights := len(weights)
    49  	if numWeights <= cap(s.heap) {
    50  		s.heap = s.heap[:numWeights]
    51  	} else {
    52  		s.heap = make([]weightedHeapElement, numWeights)
    53  	}
    54  	for i, weight := range weights {
    55  		s.heap[i] = weightedHeapElement{
    56  			weight:           weight,
    57  			cumulativeWeight: weight,
    58  			index:            i,
    59  		}
    60  	}
    61  
    62  	// Optimize so that the most probable values are at the top of the heap
    63  	utils.Sort(s.heap)
    64  
    65  	// Initialize the heap
    66  	for i := len(s.heap) - 1; i > 0; i-- {
    67  		// Explicitly performing a shift here allows the compiler to avoid
    68  		// checking for negative numbers, which saves a couple cycles
    69  		parentIndex := (i - 1) >> 1
    70  		newWeight, err := math.Add64(
    71  			s.heap[parentIndex].cumulativeWeight,
    72  			s.heap[i].cumulativeWeight,
    73  		)
    74  		if err != nil {
    75  			return err
    76  		}
    77  		s.heap[parentIndex].cumulativeWeight = newWeight
    78  	}
    79  
    80  	return nil
    81  }
    82  
    83  func (s *weightedHeap) Sample(value uint64) (int, bool) {
    84  	if len(s.heap) == 0 || s.heap[0].cumulativeWeight <= value {
    85  		return 0, false
    86  	}
    87  
    88  	index := 0
    89  	for {
    90  		currentElement := s.heap[index]
    91  		currentWeight := currentElement.weight
    92  		if value < currentWeight {
    93  			return currentElement.index, true
    94  		}
    95  		value -= currentWeight
    96  
    97  		// We shouldn't return the root, so check the left child
    98  		index = index*2 + 1
    99  
   100  		if leftWeight := s.heap[index].cumulativeWeight; leftWeight <= value {
   101  			// If the weight is greater than the left weight, you should move to
   102  			// the right child
   103  			value -= leftWeight
   104  			index++
   105  		}
   106  	}
   107  }