github.com/gopherd/gonum@v0.0.4/interp/interp.go (about) 1 // Copyright ©2020 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package interp 6 7 import "sort" 8 9 const ( 10 differentLengths = "interp: input slices have different lengths" 11 tooFewPoints = "interp: too few points for interpolation" 12 xsNotStrictlyIncreasing = "interp: xs values not strictly increasing" 13 ) 14 15 // Predictor predicts the value of a function. It handles both 16 // interpolation and extrapolation. 17 type Predictor interface { 18 // Predict returns the predicted value at x. 19 Predict(x float64) float64 20 } 21 22 // Fitter fits a predictor to data. 23 type Fitter interface { 24 // Fit fits a predictor to (X, Y) value pairs provided as two slices. 25 // It panics if len(xs) < 2, elements of xs are not strictly increasing 26 // or len(xs) != len(ys). Returns an error if fitting fails. 27 Fit(xs, ys []float64) error 28 } 29 30 // FittablePredictor is a Predictor which can fit itself to data. 31 type FittablePredictor interface { 32 Fitter 33 Predictor 34 } 35 36 // DerivativePredictor predicts both the value and the derivative of 37 // a function. It handles both interpolation and extrapolation. 38 type DerivativePredictor interface { 39 Predictor 40 41 // PredictDerivative returns the predicted derivative at x. 42 PredictDerivative(x float64) float64 43 } 44 45 // Constant predicts a constant value. 46 type Constant float64 47 48 // Predict returns the predicted value at x. 49 func (c Constant) Predict(x float64) float64 { 50 return float64(c) 51 } 52 53 // Function predicts by evaluating itself. 54 type Function func(float64) float64 55 56 // Predict returns the predicted value at x by evaluating fn(x). 57 func (fn Function) Predict(x float64) float64 { 58 return fn(x) 59 } 60 61 // PiecewiseLinear is a piecewise linear 1-dimensional interpolator. 62 type PiecewiseLinear struct { 63 // Interpolated X values. 64 xs []float64 65 66 // Interpolated Y data values, same len as ys. 67 ys []float64 68 69 // Slopes of Y between neighbouring X values. len(slopes) + 1 == len(xs) == len(ys). 70 slopes []float64 71 } 72 73 // Fit fits a predictor to (X, Y) value pairs provided as two slices. 74 // It panics if len(xs) < 2, elements of xs are not strictly increasing 75 // or len(xs) != len(ys). Always returns nil. 76 func (pl *PiecewiseLinear) Fit(xs, ys []float64) error { 77 n := len(xs) 78 if len(ys) != n { 79 panic(differentLengths) 80 } 81 if n < 2 { 82 panic(tooFewPoints) 83 } 84 pl.slopes = calculateSlopes(xs, ys) 85 pl.xs = make([]float64, n) 86 pl.ys = make([]float64, n) 87 copy(pl.xs, xs) 88 copy(pl.ys, ys) 89 return nil 90 } 91 92 // Predict returns the interpolation value at x. 93 func (pl PiecewiseLinear) Predict(x float64) float64 { 94 i := findSegment(pl.xs, x) 95 if i < 0 { 96 return pl.ys[0] 97 } 98 xI := pl.xs[i] 99 if x == xI { 100 return pl.ys[i] 101 } 102 n := len(pl.xs) 103 if i == n-1 { 104 return pl.ys[n-1] 105 } 106 return pl.ys[i] + pl.slopes[i]*(x-xI) 107 } 108 109 // PiecewiseConstant is a left-continous, piecewise constant 110 // 1-dimensional interpolator. 111 type PiecewiseConstant struct { 112 // Interpolated X values. 113 xs []float64 114 115 // Interpolated Y data values, same len as ys. 116 ys []float64 117 } 118 119 // Fit fits a predictor to (X, Y) value pairs provided as two slices. 120 // It panics if len(xs) < 2, elements of xs are not strictly increasing 121 // or len(xs) != len(ys). Always returns nil. 122 func (pc *PiecewiseConstant) Fit(xs, ys []float64) error { 123 n := len(xs) 124 if len(ys) != n { 125 panic(differentLengths) 126 } 127 if n < 2 { 128 panic(tooFewPoints) 129 } 130 for i := 1; i < n; i++ { 131 if xs[i] <= xs[i-1] { 132 panic(xsNotStrictlyIncreasing) 133 } 134 } 135 pc.xs = make([]float64, n) 136 pc.ys = make([]float64, n) 137 copy(pc.xs, xs) 138 copy(pc.ys, ys) 139 return nil 140 } 141 142 // Predict returns the interpolation value at x. 143 func (pc PiecewiseConstant) Predict(x float64) float64 { 144 i := findSegment(pc.xs, x) 145 if i < 0 { 146 return pc.ys[0] 147 } 148 if x == pc.xs[i] { 149 return pc.ys[i] 150 } 151 n := len(pc.xs) 152 if i == n-1 { 153 return pc.ys[n-1] 154 } 155 return pc.ys[i+1] 156 } 157 158 // findSegment returns 0 <= i < len(xs) such that xs[i] <= x < xs[i + 1], where xs[len(xs)] 159 // is assumed to be +Inf. If no such i is found, it returns -1. It assumes that len(xs) >= 2 160 // without checking. 161 func findSegment(xs []float64, x float64) int { 162 return sort.Search(len(xs), func(i int) bool { return xs[i] > x }) - 1 163 } 164 165 // calculateSlopes calculates slopes (ys[i+1] - ys[i]) / (xs[i+1] - xs[i]). 166 // It panics if len(xs) < 2, elements of xs are not strictly increasing 167 // or len(xs) != len(ys). 168 func calculateSlopes(xs, ys []float64) []float64 { 169 n := len(xs) 170 if n < 2 { 171 panic(tooFewPoints) 172 } 173 if len(ys) != n { 174 panic(differentLengths) 175 } 176 m := n - 1 177 slopes := make([]float64, m) 178 prevX := xs[0] 179 prevY := ys[0] 180 for i := 0; i < m; i++ { 181 x := xs[i+1] 182 y := ys[i+1] 183 dx := x - prevX 184 if dx <= 0 { 185 panic(xsNotStrictlyIncreasing) 186 } 187 slopes[i] = (y - prevY) / dx 188 prevX = x 189 prevY = y 190 } 191 return slopes 192 }