github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/stat/sampleuv/weighted.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this code is governed by a BSD-style
     3  // license that can be found in the LICENSE file
     4  
     5  package sampleuv
     6  
     7  import (
     8  	"golang.org/x/exp/rand"
     9  
    10  	"github.com/jingcheng-WU/gonum/floats/scalar"
    11  )
    12  
    13  // Weighted provides sampling without replacement from a collection of items with
    14  // non-uniform probability.
    15  type Weighted struct {
    16  	weights []float64
    17  	// heap is a weight heap.
    18  	//
    19  	// It keeps a heap-organised sum of remaining
    20  	// index weights that are available to be taken
    21  	// from.
    22  	//
    23  	// Each element holds the sum of weights for
    24  	// the corresponding index, plus the sum of
    25  	// its children's weights; the children of
    26  	// an element i can be found at positions
    27  	// 2*(i+1)-1 and 2*(i+1). The root of the
    28  	// weight heap is at element 0.
    29  	//
    30  	// See comments in container/heap for an
    31  	// explanation of the layout of a heap.
    32  	heap []float64
    33  	rnd  *rand.Rand
    34  }
    35  
    36  // NewWeighted returns a Weighted for the weights w. If src is nil, rand.Rand is
    37  // used as the random number generator.
    38  //
    39  // Note that sampling from weights with a high variance or overall low absolute
    40  // value sum may result in problems with numerical stability.
    41  func NewWeighted(w []float64, src rand.Source) Weighted {
    42  	s := Weighted{
    43  		weights: make([]float64, len(w)),
    44  		heap:    make([]float64, len(w)),
    45  	}
    46  	if src != nil {
    47  		s.rnd = rand.New(src)
    48  	}
    49  	s.ReweightAll(w)
    50  	return s
    51  }
    52  
    53  // Len returns the number of items held by the Weighted, including items
    54  // already taken.
    55  func (s Weighted) Len() int { return len(s.weights) }
    56  
    57  // Take returns an index from the Weighted with probability proportional
    58  // to the weight of the item. The weight of the item is then set to zero.
    59  // Take returns false if there are no items remaining.
    60  func (s Weighted) Take() (idx int, ok bool) {
    61  	const small = 1e-12
    62  	if scalar.EqualWithinAbsOrRel(s.heap[0], 0, small, small) {
    63  		return -1, false
    64  	}
    65  
    66  	var r float64
    67  	if s.rnd == nil {
    68  		r = s.heap[0] * rand.Float64()
    69  	} else {
    70  		r = s.heap[0] * s.rnd.Float64()
    71  	}
    72  	i := 1
    73  	last := -1
    74  	left := len(s.weights)
    75  	for {
    76  		if r -= s.weights[i-1]; r <= 0 {
    77  			break // Fall within item i-1.
    78  		}
    79  		i <<= 1 // Move to left child.
    80  		if d := s.heap[i-1]; r > d {
    81  			r -= d
    82  			// If enough r to pass left child
    83  			// move to right child state will
    84  			// be caught at break above.
    85  			i++
    86  		}
    87  		if i == last || left < 0 {
    88  			// No progression.
    89  			return -1, false
    90  		}
    91  		last = i
    92  		left--
    93  	}
    94  
    95  	w, idx := s.weights[i-1], i-1
    96  
    97  	s.weights[i-1] = 0
    98  	for i > 0 {
    99  		s.heap[i-1] -= w
   100  		// The following condition is necessary to
   101  		// handle floating point error. If we see
   102  		// a heap value below zero, we know we need
   103  		// to rebuild it.
   104  		if s.heap[i-1] < 0 {
   105  			s.reset()
   106  			return idx, true
   107  		}
   108  		i >>= 1
   109  	}
   110  
   111  	return idx, true
   112  }
   113  
   114  // Reweight sets the weight of item idx to w.
   115  func (s Weighted) Reweight(idx int, w float64) {
   116  	w, s.weights[idx] = s.weights[idx]-w, w
   117  	idx++
   118  	for idx > 0 {
   119  		s.heap[idx-1] -= w
   120  		idx >>= 1
   121  	}
   122  }
   123  
   124  // ReweightAll sets the weight of all items in the Weighted. ReweightAll
   125  // panics if len(w) != s.Len.
   126  func (s Weighted) ReweightAll(w []float64) {
   127  	if len(w) != s.Len() {
   128  		panic("floats: length of the slices do not match")
   129  	}
   130  	copy(s.weights, w)
   131  	s.reset()
   132  }
   133  
   134  func (s Weighted) reset() {
   135  	copy(s.heap, s.weights)
   136  	for i := len(s.heap) - 1; i > 0; i-- {
   137  		// Sometimes 1-based counting makes sense.
   138  		s.heap[((i+1)>>1)-1] += s.heap[i]
   139  	}
   140  }