github.com/MetalBlockchain/metalgo@v1.11.9/utils/sampler/weighted_without_replacement_generic.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package sampler
     5  
     6  import (
     7  	safemath "github.com/MetalBlockchain/metalgo/utils/math"
     8  )
     9  
    10  type weightedWithoutReplacementGeneric struct {
    11  	u Uniform
    12  	w Weighted
    13  }
    14  
    15  func (s *weightedWithoutReplacementGeneric) Initialize(weights []uint64) error {
    16  	totalWeight := uint64(0)
    17  	for _, weight := range weights {
    18  		newWeight, err := safemath.Add64(totalWeight, weight)
    19  		if err != nil {
    20  			return err
    21  		}
    22  		totalWeight = newWeight
    23  	}
    24  	s.u.Initialize(totalWeight)
    25  	return s.w.Initialize(weights)
    26  }
    27  
    28  func (s *weightedWithoutReplacementGeneric) Sample(count int) ([]int, bool) {
    29  	s.u.Reset()
    30  
    31  	indices := make([]int, count)
    32  	for i := 0; i < count; i++ {
    33  		weight, ok := s.u.Next()
    34  		if !ok {
    35  			return nil, false
    36  		}
    37  
    38  		indices[i], ok = s.w.Sample(weight)
    39  		if !ok {
    40  			return nil, false
    41  		}
    42  	}
    43  	return indices, true
    44  }