github.com/MetalBlockchain/metalgo@v1.11.9/utils/sampler/weighted_uniform.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package sampler
     5  
     6  import (
     7  	"errors"
     8  	"math"
     9  
    10  	safemath "github.com/MetalBlockchain/metalgo/utils/math"
    11  )
    12  
    13  var (
    14  	errWeightsTooLarge = errors.New("total weight is too large")
    15  
    16  	_ Weighted = (*weightedUniform)(nil)
    17  )
    18  
    19  // Sampling is performed by indexing into the array to find the correct index.
    20  //
    21  // Initialization takes O(Sum(weights)) time. This results in an exponential
    22  // initialization time. Therefore, the time to execute this operation can be
    23  // extremely long. Initialization takes O(Sum(weights)) space, causing this
    24  // algorithm to be unable to handle large inputs.
    25  //
    26  // Sampling is performed in O(1) time. However, if the Sum(weights) is large,
    27  // this operation can still be relatively slow due to poor cache locality.
    28  type weightedUniform struct {
    29  	indices   []int
    30  	maxWeight uint64
    31  }
    32  
    33  func (s *weightedUniform) Initialize(weights []uint64) error {
    34  	totalWeight := uint64(0)
    35  	for _, weight := range weights {
    36  		newWeight, err := safemath.Add64(totalWeight, weight)
    37  		if err != nil {
    38  			return err
    39  		}
    40  		totalWeight = newWeight
    41  	}
    42  	if totalWeight > s.maxWeight || totalWeight > math.MaxInt32 {
    43  		return errWeightsTooLarge
    44  	}
    45  	size := int(totalWeight)
    46  
    47  	if size > cap(s.indices) {
    48  		s.indices = make([]int, size)
    49  	} else {
    50  		s.indices = s.indices[:size]
    51  	}
    52  
    53  	offset := 0
    54  	for i, weight := range weights {
    55  		for j := uint64(0); j < weight; j++ {
    56  			s.indices[offset] = i
    57  			offset++
    58  		}
    59  	}
    60  
    61  	return nil
    62  }
    63  
    64  func (s *weightedUniform) Sample(value uint64) (int, bool) {
    65  	if uint64(len(s.indices)) <= value {
    66  		return 0, false
    67  	}
    68  	return s.indices[int(value)], true
    69  }