github.com/MetalBlockchain/metalgo@v1.11.9/utils/sampler/weighted_best.go (about) 1 // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. 2 // See the file LICENSE for licensing terms. 3 4 package sampler 5 6 import ( 7 "errors" 8 "math" 9 "time" 10 11 "github.com/MetalBlockchain/metalgo/utils/timer/mockable" 12 13 safemath "github.com/MetalBlockchain/metalgo/utils/math" 14 ) 15 16 var ( 17 errNoValidWeightedSamplers = errors.New("no valid weighted samplers found") 18 19 _ Weighted = (*weightedBest)(nil) 20 ) 21 22 // Sampling is performed by using another implementation of the Weighted 23 // interface. 24 // 25 // Initialization attempts to find the best sampling algorithm given the dataset 26 // by performing a benchmark of the provided implementations. 27 type weightedBest struct { 28 Weighted 29 samplers []Weighted 30 benchmarkIterations int 31 clock mockable.Clock 32 } 33 34 func (s *weightedBest) Initialize(weights []uint64) error { 35 totalWeight := uint64(0) 36 for _, weight := range weights { 37 newWeight, err := safemath.Add64(totalWeight, weight) 38 if err != nil { 39 return err 40 } 41 totalWeight = newWeight 42 } 43 44 samples := []uint64(nil) 45 if totalWeight > 0 { 46 samples = make([]uint64, s.benchmarkIterations) 47 for i := range samples { 48 samples[i] = globalRNG.Uint64Inclusive(totalWeight - 1) 49 } 50 } 51 52 s.Weighted = nil 53 bestDuration := time.Duration(math.MaxInt64) 54 55 samplerLoop: 56 for _, sampler := range s.samplers { 57 if err := sampler.Initialize(weights); err != nil { 58 continue 59 } 60 61 start := s.clock.Time() 62 for _, sample := range samples { 63 if _, ok := sampler.Sample(sample); !ok { 64 continue samplerLoop 65 } 66 } 67 end := s.clock.Time() 68 duration := end.Sub(start) 69 if duration < bestDuration { 70 bestDuration = duration 71 s.Weighted = sampler 72 } 73 } 74 75 if s.Weighted == nil { 76 return errNoValidWeightedSamplers 77 } 78 return nil 79 }