github.com/gopherd/gonum@v0.0.4/interp/interp.go (about)

     1  // Copyright ©2020 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package interp
     6  
     7  import "sort"
     8  
     9  const (
    10  	differentLengths        = "interp: input slices have different lengths"
    11  	tooFewPoints            = "interp: too few points for interpolation"
    12  	xsNotStrictlyIncreasing = "interp: xs values not strictly increasing"
    13  )
    14  
    15  // Predictor predicts the value of a function. It handles both
    16  // interpolation and extrapolation.
    17  type Predictor interface {
    18  	// Predict returns the predicted value at x.
    19  	Predict(x float64) float64
    20  }
    21  
    22  // Fitter fits a predictor to data.
    23  type Fitter interface {
    24  	// Fit fits a predictor to (X, Y) value pairs provided as two slices.
    25  	// It panics if len(xs) < 2, elements of xs are not strictly increasing
    26  	// or len(xs) != len(ys). Returns an error if fitting fails.
    27  	Fit(xs, ys []float64) error
    28  }
    29  
    30  // FittablePredictor is a Predictor which can fit itself to data.
    31  type FittablePredictor interface {
    32  	Fitter
    33  	Predictor
    34  }
    35  
    36  // DerivativePredictor predicts both the value and the derivative of
    37  // a function. It handles both interpolation and extrapolation.
    38  type DerivativePredictor interface {
    39  	Predictor
    40  
    41  	// PredictDerivative returns the predicted derivative at x.
    42  	PredictDerivative(x float64) float64
    43  }
    44  
    45  // Constant predicts a constant value.
    46  type Constant float64
    47  
    48  // Predict returns the predicted value at x.
    49  func (c Constant) Predict(x float64) float64 {
    50  	return float64(c)
    51  }
    52  
    53  // Function predicts by evaluating itself.
    54  type Function func(float64) float64
    55  
    56  // Predict returns the predicted value at x by evaluating fn(x).
    57  func (fn Function) Predict(x float64) float64 {
    58  	return fn(x)
    59  }
    60  
    61  // PiecewiseLinear is a piecewise linear 1-dimensional interpolator.
    62  type PiecewiseLinear struct {
    63  	// Interpolated X values.
    64  	xs []float64
    65  
    66  	// Interpolated Y data values, same len as ys.
    67  	ys []float64
    68  
    69  	// Slopes of Y between neighbouring X values. len(slopes) + 1 == len(xs) == len(ys).
    70  	slopes []float64
    71  }
    72  
    73  // Fit fits a predictor to (X, Y) value pairs provided as two slices.
    74  // It panics if len(xs) < 2, elements of xs are not strictly increasing
    75  // or len(xs) != len(ys). Always returns nil.
    76  func (pl *PiecewiseLinear) Fit(xs, ys []float64) error {
    77  	n := len(xs)
    78  	if len(ys) != n {
    79  		panic(differentLengths)
    80  	}
    81  	if n < 2 {
    82  		panic(tooFewPoints)
    83  	}
    84  	pl.slopes = calculateSlopes(xs, ys)
    85  	pl.xs = make([]float64, n)
    86  	pl.ys = make([]float64, n)
    87  	copy(pl.xs, xs)
    88  	copy(pl.ys, ys)
    89  	return nil
    90  }
    91  
    92  // Predict returns the interpolation value at x.
    93  func (pl PiecewiseLinear) Predict(x float64) float64 {
    94  	i := findSegment(pl.xs, x)
    95  	if i < 0 {
    96  		return pl.ys[0]
    97  	}
    98  	xI := pl.xs[i]
    99  	if x == xI {
   100  		return pl.ys[i]
   101  	}
   102  	n := len(pl.xs)
   103  	if i == n-1 {
   104  		return pl.ys[n-1]
   105  	}
   106  	return pl.ys[i] + pl.slopes[i]*(x-xI)
   107  }
   108  
   109  // PiecewiseConstant is a left-continous, piecewise constant
   110  // 1-dimensional interpolator.
   111  type PiecewiseConstant struct {
   112  	// Interpolated X values.
   113  	xs []float64
   114  
   115  	// Interpolated Y data values, same len as ys.
   116  	ys []float64
   117  }
   118  
   119  // Fit fits a predictor to (X, Y) value pairs provided as two slices.
   120  // It panics if len(xs) < 2, elements of xs are not strictly increasing
   121  // or len(xs) != len(ys). Always returns nil.
   122  func (pc *PiecewiseConstant) Fit(xs, ys []float64) error {
   123  	n := len(xs)
   124  	if len(ys) != n {
   125  		panic(differentLengths)
   126  	}
   127  	if n < 2 {
   128  		panic(tooFewPoints)
   129  	}
   130  	for i := 1; i < n; i++ {
   131  		if xs[i] <= xs[i-1] {
   132  			panic(xsNotStrictlyIncreasing)
   133  		}
   134  	}
   135  	pc.xs = make([]float64, n)
   136  	pc.ys = make([]float64, n)
   137  	copy(pc.xs, xs)
   138  	copy(pc.ys, ys)
   139  	return nil
   140  }
   141  
   142  // Predict returns the interpolation value at x.
   143  func (pc PiecewiseConstant) Predict(x float64) float64 {
   144  	i := findSegment(pc.xs, x)
   145  	if i < 0 {
   146  		return pc.ys[0]
   147  	}
   148  	if x == pc.xs[i] {
   149  		return pc.ys[i]
   150  	}
   151  	n := len(pc.xs)
   152  	if i == n-1 {
   153  		return pc.ys[n-1]
   154  	}
   155  	return pc.ys[i+1]
   156  }
   157  
   158  // findSegment returns 0 <= i < len(xs) such that xs[i] <= x < xs[i + 1], where xs[len(xs)]
   159  // is assumed to be +Inf. If no such i is found, it returns -1. It assumes that len(xs) >= 2
   160  // without checking.
   161  func findSegment(xs []float64, x float64) int {
   162  	return sort.Search(len(xs), func(i int) bool { return xs[i] > x }) - 1
   163  }
   164  
   165  // calculateSlopes calculates slopes (ys[i+1] - ys[i]) / (xs[i+1] - xs[i]).
   166  // It panics if len(xs) < 2, elements of xs are not strictly increasing
   167  // or len(xs) != len(ys).
   168  func calculateSlopes(xs, ys []float64) []float64 {
   169  	n := len(xs)
   170  	if n < 2 {
   171  		panic(tooFewPoints)
   172  	}
   173  	if len(ys) != n {
   174  		panic(differentLengths)
   175  	}
   176  	m := n - 1
   177  	slopes := make([]float64, m)
   178  	prevX := xs[0]
   179  	prevY := ys[0]
   180  	for i := 0; i < m; i++ {
   181  		x := xs[i+1]
   182  		y := ys[i+1]
   183  		dx := x - prevX
   184  		if dx <= 0 {
   185  			panic(xsNotStrictlyIncreasing)
   186  		}
   187  		slopes[i] = (y - prevY) / dx
   188  		prevX = x
   189  		prevY = y
   190  	}
   191  	return slopes
   192  }