gonum.org/v1/gonum@v0.15.1-0.20240517103525-f853624cb1bb/stat/sampleuv/weighted.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this code is governed by a BSD-style
     3  // license that can be found in the LICENSE file
     4  
     5  package sampleuv
     6  
     7  import "golang.org/x/exp/rand"
     8  
     9  // Weighted provides sampling without replacement from a collection of items with
    10  // non-uniform probability.
    11  type Weighted struct {
    12  	weights []float64
    13  	// heap is a weight heap.
    14  	//
    15  	// It keeps a heap-organised sum of remaining
    16  	// index weights that are available to be taken
    17  	// from.
    18  	//
    19  	// Each element holds the sum of weights for
    20  	// the corresponding index, plus the sum of
    21  	// its children's weights; the children of
    22  	// an element i can be found at positions
    23  	// 2*(i+1)-1 and 2*(i+1). The root of the
    24  	// weight heap is at element 0.
    25  	//
    26  	// See comments in container/heap for an
    27  	// explanation of the layout of a heap.
    28  	heap []float64
    29  	rnd  *rand.Rand
    30  }
    31  
    32  // NewWeighted returns a Weighted for the weights w. If src is nil, rand.Rand is
    33  // used as the random number generator.
    34  //
    35  // Note that sampling from weights with a high variance or overall low absolute
    36  // value sum may result in problems with numerical stability.
    37  func NewWeighted(w []float64, src rand.Source) Weighted {
    38  	s := Weighted{
    39  		weights: make([]float64, len(w)),
    40  		heap:    make([]float64, len(w)),
    41  	}
    42  	if src != nil {
    43  		s.rnd = rand.New(src)
    44  	}
    45  	s.ReweightAll(w)
    46  	return s
    47  }
    48  
    49  // Len returns the number of items held by the Weighted, including items
    50  // already taken.
    51  func (s Weighted) Len() int { return len(s.weights) }
    52  
    53  // Take returns an index from the Weighted with probability proportional
    54  // to the weight of the item. The weight of the item is then set to zero.
    55  // Take returns false if there are no items remaining.
    56  func (s Weighted) Take() (idx int, ok bool) {
    57  	if s.heap[0] == 0 {
    58  		return -1, false
    59  	}
    60  
    61  	var r float64
    62  	if s.rnd == nil {
    63  		r = rand.Float64()
    64  	} else {
    65  		r = s.rnd.Float64()
    66  	}
    67  
    68  	r *= s.heap[0]
    69  	i := 0
    70  	for {
    71  		r -= s.weights[i]
    72  		if r < 0 {
    73  			break // Fall within item i.
    74  		}
    75  
    76  		li := i*2 + 1 // Move to left child.
    77  		// Left node should exist, because r is non-negative,
    78  		// but there could be floating point errors, so we
    79  		// check index explicitly.
    80  		if li >= len(s.heap) {
    81  			break
    82  		}
    83  
    84  		i = li
    85  
    86  		d := s.heap[i]
    87  		if r >= d {
    88  			// If there is enough r to pass left child try to
    89  			// move to the right child.
    90  			r -= d
    91  			ri := i + 1
    92  
    93  			if ri >= len(s.heap) {
    94  				break
    95  			}
    96  
    97  			i = ri
    98  		}
    99  	}
   100  
   101  	s.Reweight(i, 0)
   102  
   103  	return i, true
   104  }
   105  
   106  // Reweight sets the weight of item idx to w.
   107  func (s Weighted) Reweight(idx int, w float64) {
   108  	s.weights[idx] = w
   109  
   110  	// We want to keep the heap state here consistent
   111  	// with the result of a reset call. So we sum
   112  	// weights in the same order, since floating point
   113  	// addition is not associative.
   114  	for {
   115  		w = s.weights[idx]
   116  
   117  		ri := idx*2 + 2
   118  		if ri < len(s.heap) {
   119  			w += s.heap[ri]
   120  		}
   121  
   122  		li := ri - 1
   123  		if li < len(s.heap) {
   124  			w += s.heap[li]
   125  		}
   126  
   127  		s.heap[idx] = w
   128  
   129  		if idx == 0 {
   130  			break
   131  		}
   132  
   133  		idx = (idx - 1) / 2
   134  	}
   135  }
   136  
   137  // ReweightAll sets the weight of all items in the Weighted. ReweightAll
   138  // panics if len(w) != s.Len.
   139  func (s Weighted) ReweightAll(w []float64) {
   140  	if len(w) != s.Len() {
   141  		panic("floats: length of the slices do not match")
   142  	}
   143  	copy(s.weights, w)
   144  	s.reset()
   145  }
   146  
   147  func (s Weighted) reset() {
   148  	copy(s.heap, s.weights)
   149  	for i := len(s.heap) - 1; i > 0; i-- {
   150  		// Sometimes 1-based counting makes sense.
   151  		s.heap[((i+1)>>1)-1] += s.heap[i]
   152  	}
   153  }