github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/stat/distuv/categorical.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package distuv
     6  
     7  import (
     8  	"math"
     9  
    10  	"golang.org/x/exp/rand"
    11  )
    12  
    13  // Categorical is an extension of the Bernoulli distribution where x takes
    14  // values {0, 1, ..., len(w)-1} where w is the weight vector. Categorical must
    15  // be initialized with NewCategorical.
    16  type Categorical struct {
    17  	weights []float64
    18  
    19  	// heap is a weight heap.
    20  	//
    21  	// It keeps a heap-organised sum of remaining
    22  	// index weights that are available to be taken
    23  	// from.
    24  	//
    25  	// Each element holds the sum of weights for
    26  	// the corresponding index, plus the sum of
    27  	// its children's weights; the children of
    28  	// an element i can be found at positions
    29  	// 2*(i+1)-1 and 2*(i+1). The root of the
    30  	// weight heap is at element 0.
    31  	//
    32  	// See comments in container/heap for an
    33  	// explanation of the layout of a heap.
    34  	heap []float64
    35  
    36  	src rand.Source
    37  }
    38  
    39  // NewCategorical constructs a new categorical distribution where the probability
    40  // that x equals i is proportional to w[i]. All of the weights must be
    41  // nonnegative, and at least one of the weights must be positive.
    42  func NewCategorical(w []float64, src rand.Source) Categorical {
    43  	c := Categorical{
    44  		weights: make([]float64, len(w)),
    45  		heap:    make([]float64, len(w)),
    46  		src:     src,
    47  	}
    48  	c.ReweightAll(w)
    49  	return c
    50  }
    51  
    52  // CDF computes the value of the cumulative density function at x.
    53  func (c Categorical) CDF(x float64) float64 {
    54  	var cdf float64
    55  	for i, w := range c.weights {
    56  		if x < float64(i) {
    57  			break
    58  		}
    59  		cdf += w
    60  	}
    61  	return cdf / c.heap[0]
    62  }
    63  
    64  // Entropy returns the entropy of the distribution.
    65  func (c Categorical) Entropy() float64 {
    66  	var ent float64
    67  	for _, w := range c.weights {
    68  		if w == 0 {
    69  			continue
    70  		}
    71  		p := w / c.heap[0]
    72  		ent += p * math.Log(p)
    73  	}
    74  	return -ent
    75  }
    76  
    77  // Len returns the number of values x could possibly take (the length of the
    78  // initial supplied weight vector).
    79  func (c Categorical) Len() int {
    80  	return len(c.weights)
    81  }
    82  
    83  // Mean returns the mean of the probability distribution.
    84  func (c Categorical) Mean() float64 {
    85  	var mean float64
    86  	for i, v := range c.weights {
    87  		mean += float64(i) * v
    88  	}
    89  	return mean / c.heap[0]
    90  }
    91  
    92  // Prob computes the value of the probability density function at x.
    93  func (c Categorical) Prob(x float64) float64 {
    94  	xi := int(x)
    95  	if float64(xi) != x {
    96  		return 0
    97  	}
    98  	if xi < 0 || xi > len(c.weights)-1 {
    99  		return 0
   100  	}
   101  	return c.weights[xi] / c.heap[0]
   102  }
   103  
   104  // LogProb computes the natural logarithm of the value of the probability density function at x.
   105  func (c Categorical) LogProb(x float64) float64 {
   106  	return math.Log(c.Prob(x))
   107  }
   108  
   109  // Rand returns a random draw from the categorical distribution.
   110  func (c Categorical) Rand() float64 {
   111  	var r float64
   112  	if c.src == nil {
   113  		r = c.heap[0] * rand.Float64()
   114  	} else {
   115  		r = c.heap[0] * rand.New(c.src).Float64()
   116  	}
   117  	i := 1
   118  	last := -1
   119  	left := len(c.weights)
   120  	for {
   121  		if r -= c.weights[i-1]; r <= 0 {
   122  			break // Fall within item i-1.
   123  		}
   124  		i <<= 1 // Move to left child.
   125  		if d := c.heap[i-1]; r > d {
   126  			r -= d
   127  			// If enough r to pass left child,
   128  			// move to right child state will
   129  			// be caught at break above.
   130  			i++
   131  		}
   132  		if i == last || left < 0 {
   133  			panic("categorical: bad sample")
   134  		}
   135  		last = i
   136  		left--
   137  	}
   138  	return float64(i - 1)
   139  }
   140  
   141  // Reweight sets the weight of item idx to w. The input weight must be
   142  // non-negative, and after reweighting at least one of the weights must be
   143  // positive.
   144  func (c Categorical) Reweight(idx int, w float64) {
   145  	if w < 0 {
   146  		panic("categorical: negative weight")
   147  	}
   148  	w, c.weights[idx] = c.weights[idx]-w, w
   149  	idx++
   150  	for idx > 0 {
   151  		c.heap[idx-1] -= w
   152  		idx >>= 1
   153  	}
   154  	if c.heap[0] <= 0 {
   155  		panic("categorical: sum of the weights non-positive")
   156  	}
   157  }
   158  
   159  // ReweightAll resets the weights of the distribution. ReweightAll panics if
   160  // len(w) != c.Len. All of the weights must be nonnegative, and at least one of
   161  // the weights must be positive.
   162  func (c Categorical) ReweightAll(w []float64) {
   163  	if len(w) != c.Len() {
   164  		panic("categorical: length of the slices do not match")
   165  	}
   166  	for _, v := range w {
   167  		if v < 0 {
   168  			panic("categorical: negative weight")
   169  		}
   170  	}
   171  	copy(c.weights, w)
   172  	c.reset()
   173  }
   174  
   175  func (c Categorical) reset() {
   176  	copy(c.heap, c.weights)
   177  	for i := len(c.heap) - 1; i > 0; i-- {
   178  		// Sometimes 1-based counting makes sense.
   179  		c.heap[((i+1)>>1)-1] += c.heap[i]
   180  	}
   181  	// TODO(btracey): Renormalization for weird weights?
   182  	if c.heap[0] <= 0 {
   183  		panic("categorical: sum of the weights non-positive")
   184  	}
   185  }