github.com/MetalBlockchain/metalgo@v1.11.9/utils/sampler/weighted_best.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package sampler
     5  
     6  import (
     7  	"errors"
     8  	"math"
     9  	"time"
    10  
    11  	"github.com/MetalBlockchain/metalgo/utils/timer/mockable"
    12  
    13  	safemath "github.com/MetalBlockchain/metalgo/utils/math"
    14  )
    15  
    16  var (
    17  	errNoValidWeightedSamplers = errors.New("no valid weighted samplers found")
    18  
    19  	_ Weighted = (*weightedBest)(nil)
    20  )
    21  
    22  // Sampling is performed by using another implementation of the Weighted
    23  // interface.
    24  //
    25  // Initialization attempts to find the best sampling algorithm given the dataset
    26  // by performing a benchmark of the provided implementations.
    27  type weightedBest struct {
    28  	Weighted
    29  	samplers            []Weighted
    30  	benchmarkIterations int
    31  	clock               mockable.Clock
    32  }
    33  
    34  func (s *weightedBest) Initialize(weights []uint64) error {
    35  	totalWeight := uint64(0)
    36  	for _, weight := range weights {
    37  		newWeight, err := safemath.Add64(totalWeight, weight)
    38  		if err != nil {
    39  			return err
    40  		}
    41  		totalWeight = newWeight
    42  	}
    43  
    44  	samples := []uint64(nil)
    45  	if totalWeight > 0 {
    46  		samples = make([]uint64, s.benchmarkIterations)
    47  		for i := range samples {
    48  			samples[i] = globalRNG.Uint64Inclusive(totalWeight - 1)
    49  		}
    50  	}
    51  
    52  	s.Weighted = nil
    53  	bestDuration := time.Duration(math.MaxInt64)
    54  
    55  samplerLoop:
    56  	for _, sampler := range s.samplers {
    57  		if err := sampler.Initialize(weights); err != nil {
    58  			continue
    59  		}
    60  
    61  		start := s.clock.Time()
    62  		for _, sample := range samples {
    63  			if _, ok := sampler.Sample(sample); !ok {
    64  				continue samplerLoop
    65  			}
    66  		}
    67  		end := s.clock.Time()
    68  		duration := end.Sub(start)
    69  		if duration < bestDuration {
    70  			bestDuration = duration
    71  			s.Weighted = sampler
    72  		}
    73  	}
    74  
    75  	if s.Weighted == nil {
    76  		return errNoValidWeightedSamplers
    77  	}
    78  	return nil
    79  }