go-hep.org/x/hep@v0.38.1/fit/curve1d_example_test.go (about) 1 // Copyright ©2017 The go-hep Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package fit_test 6 7 import ( 8 "fmt" 9 "image/color" 10 "log" 11 "math" 12 13 "go-hep.org/x/hep/fit" 14 "go-hep.org/x/hep/hbook" 15 "go-hep.org/x/hep/hplot" 16 "gonum.org/v1/gonum/floats" 17 "gonum.org/v1/gonum/mat" 18 "gonum.org/v1/gonum/optimize" 19 "gonum.org/v1/gonum/stat" 20 "gonum.org/v1/gonum/stat/distuv" 21 "gonum.org/v1/plot/plotter" 22 "gonum.org/v1/plot/vg" 23 ) 24 25 func ExampleCurve1D_gaussian() { 26 var ( 27 cst = 3.0 28 mean = 30.0 29 sigma = 20.0 30 want = []float64{cst, mean, sigma} 31 ) 32 33 xdata, ydata, err := readXY("testdata/gauss-data.txt") 34 if err != nil { 35 log.Fatal(err) 36 } 37 38 gauss := func(x, cst, mu, sigma float64) float64 { 39 v := (x - mu) 40 return cst * math.Exp(-v*v/sigma) 41 } 42 43 res, err := fit.Curve1D( 44 fit.Func1D{ 45 F: func(x float64, ps []float64) float64 { 46 return gauss(x, ps[0], ps[1], ps[2]) 47 }, 48 X: xdata, 49 Y: ydata, 50 Ps: []float64{10, 10, 10}, 51 }, 52 nil, &optimize.NelderMead{}, 53 ) 54 if err != nil { 55 log.Fatal(err) 56 } 57 58 if err := res.Status.Err(); err != nil { 59 log.Fatal(err) 60 } 61 if got := res.X; !floats.EqualApprox(got, want, 1e-3) { 62 log.Fatalf("got= %v\nwant=%v\n", got, want) 63 } 64 65 { 66 p := hplot.New() 67 p.X.Label.Text = "Gauss" 68 p.Y.Label.Text = "y-data" 69 70 s := hplot.NewS2D(hplot.ZipXY(xdata, ydata)) 71 s.Color = color.RGBA{0, 0, 255, 255} 72 p.Add(s) 73 74 f := plotter.NewFunction(func(x float64) float64 { 75 return gauss(x, res.X[0], res.X[1], res.X[2]) 76 }) 77 f.Color = color.RGBA{255, 0, 0, 255} 78 f.Samples = 1000 79 p.Add(f) 80 81 p.Add(plotter.NewGrid()) 82 83 err := p.Save(20*vg.Centimeter, -1, "testdata/gauss-plot.png") 84 if err != nil { 85 log.Fatal(err) 86 } 87 } 88 } 89 90 func ExampleCurve1D_exponential() { 91 const ( 92 a = 0.3 93 b = 0.1 94 ndf = 2.0 95 ) 96 97 xdata, ydata, err := readXY("testdata/exp-data.txt") 98 if err != nil { 99 log.Fatal(err) 100 } 101 102 exp := func(x, a, b float64) float64 { 103 return math.Exp(a*x + b) 104 } 105 106 res, err := fit.Curve1D( 107 fit.Func1D{ 108 F: func(x float64, ps []float64) float64 { 109 return exp(x, ps[0], ps[1]) 110 }, 111 X: xdata, 112 Y: ydata, 113 N: 2, 114 }, 115 nil, &optimize.NelderMead{}, 116 ) 117 if err != nil { 118 log.Fatal(err) 119 } 120 121 if err := res.Status.Err(); err != nil { 122 log.Fatal(err) 123 } 124 if got, want := res.X, []float64{a, b}; !floats.EqualApprox(got, want, 0.1) { 125 log.Fatalf("got= %v\nwant=%v\n", got, want) 126 } 127 128 { 129 p := hplot.New() 130 p.X.Label.Text = "exp(a*x+b)" 131 p.Y.Label.Text = "y-data" 132 p.Y.Min = 0 133 p.Y.Max = 5 134 p.X.Min = 0 135 p.X.Max = 5 136 137 s := hplot.NewS2D(hplot.ZipXY(xdata, ydata)) 138 s.Color = color.RGBA{0, 0, 255, 255} 139 p.Add(s) 140 141 f := plotter.NewFunction(func(x float64) float64 { 142 return exp(x, res.X[0], res.X[1]) 143 }) 144 f.Color = color.RGBA{255, 0, 0, 255} 145 f.Samples = 1000 146 p.Add(f) 147 148 p.Add(plotter.NewGrid()) 149 150 err := p.Save(20*vg.Centimeter, -1, "testdata/exp-plot.png") 151 if err != nil { 152 log.Fatal(err) 153 } 154 } 155 } 156 157 func ExampleCurve1D_poly() { 158 var ( 159 a = 1.0 160 b = 2.0 161 ps = []float64{a, b} 162 want = []float64{1.38592513, 1.98485122} // from scipy.curve_fit 163 ) 164 165 poly := func(x float64, ps []float64) float64 { 166 return ps[0] + ps[1]*x*x 167 } 168 169 xdata, ydata := genXY(100, poly, ps, -10, 10) 170 171 res, err := fit.Curve1D( 172 fit.Func1D{ 173 F: poly, 174 X: xdata, 175 Y: ydata, 176 Ps: []float64{1, 1}, 177 }, 178 nil, &optimize.NelderMead{}, 179 ) 180 if err != nil { 181 log.Fatal(err) 182 } 183 184 if err := res.Status.Err(); err != nil { 185 log.Fatal(err) 186 } 187 188 if got := res.X; !floats.EqualApprox(got, want, 1e-6) { 189 log.Fatalf("got= %v\nwant=%v\n", got, want) 190 } 191 192 { 193 p := hplot.New() 194 p.X.Label.Text = "f(x) = a + b*x*x" 195 p.Y.Label.Text = "y-data" 196 p.X.Min = -10 197 p.X.Max = +10 198 p.Y.Min = 0 199 p.Y.Max = 220 200 201 s := hplot.NewS2D(hplot.ZipXY(xdata, ydata)) 202 s.Color = color.RGBA{0, 0, 255, 255} 203 p.Add(s) 204 205 f := plotter.NewFunction(func(x float64) float64 { 206 return poly(x, res.X) 207 }) 208 f.Color = color.RGBA{255, 0, 0, 255} 209 f.Samples = 1000 210 p.Add(f) 211 212 p.Add(plotter.NewGrid()) 213 214 err := p.Save(20*vg.Centimeter, -1, "testdata/poly-plot.png") 215 if err != nil { 216 log.Fatal(err) 217 } 218 } 219 } 220 221 func ExampleCurve1D_powerlaw() { 222 var ( 223 amp = 11.021171432949746 224 index = -2.027389113217428 225 want = []float64{amp, index} 226 ) 227 228 xdata, ydata, yerrs, err := readXYerr("testdata/powerlaw-data.txt") 229 if err != nil { 230 log.Fatal(err) 231 } 232 233 plaw := func(x, amp, index float64) float64 { 234 return amp * math.Pow(x, index) 235 } 236 237 res, err := fit.Curve1D( 238 fit.Func1D{ 239 F: func(x float64, ps []float64) float64 { 240 return plaw(x, ps[0], ps[1]) 241 }, 242 X: xdata, 243 Y: ydata, 244 Err: yerrs, 245 Ps: []float64{1, 1}, 246 }, 247 nil, &optimize.NelderMead{}, 248 ) 249 if err != nil { 250 log.Fatal(err) 251 } 252 253 if err := res.Status.Err(); err != nil { 254 log.Fatal(err) 255 } 256 if got := res.X; !floats.EqualApprox(got, want, 1e-3) { 257 log.Fatalf("got= %v\nwant=%v\n", got, want) 258 } 259 260 { 261 p := hplot.New() 262 p.X.Label.Text = "f(x) = a * x^b" 263 p.Y.Label.Text = "y-data" 264 p.X.Min = 0 265 p.X.Max = 10 266 p.Y.Min = 0 267 p.Y.Max = 10 268 269 pts := make([]hbook.Point2D, len(xdata)) 270 for i := range pts { 271 pts[i].X = xdata[i] 272 pts[i].Y = ydata[i] 273 pts[i].ErrY.Min = 0.5 * yerrs[i] 274 pts[i].ErrY.Max = 0.5 * yerrs[i] 275 } 276 277 s := hplot.NewS2D(hbook.NewS2D(pts...), hplot.WithYErrBars(true)) 278 s.Color = color.RGBA{0, 0, 255, 255} 279 p.Add(s) 280 281 f := plotter.NewFunction(func(x float64) float64 { 282 return plaw(x, res.X[0], res.X[1]) 283 }) 284 f.Color = color.RGBA{255, 0, 0, 255} 285 f.Samples = 1000 286 p.Add(f) 287 288 p.Add(plotter.NewGrid()) 289 290 err := p.Save(20*vg.Centimeter, -1, "testdata/powerlaw-plot.png") 291 if err != nil { 292 log.Fatal(err) 293 } 294 } 295 } 296 297 func ExampleCurve1D_hessian() { 298 var ( 299 cst = 3.0 300 mean = 30.0 301 sigma = 20.0 302 want = []float64{cst, mean, sigma} 303 ) 304 305 xdata, ydata, err := readXY("testdata/gauss-data.txt") 306 if err != nil { 307 log.Fatal(err) 308 } 309 310 // use a small sample 311 xdata = xdata[:min(25, len(xdata))] 312 ydata = ydata[:min(25, len(ydata))] 313 314 gauss := func(x, cst, mu, sigma float64) float64 { 315 v := (x - mu) 316 return cst * math.Exp(-v*v/sigma) 317 } 318 319 f1d := fit.Func1D{ 320 F: func(x float64, ps []float64) float64 { 321 return gauss(x, ps[0], ps[1], ps[2]) 322 }, 323 X: xdata, 324 Y: ydata, 325 Ps: []float64{10, 10, 10}, 326 } 327 res, err := fit.Curve1D(f1d, nil, &optimize.NelderMead{}) 328 if err != nil { 329 log.Fatal(err) 330 } 331 332 if err := res.Status.Err(); err != nil { 333 log.Fatal(err) 334 } 335 if got := res.X; !floats.EqualApprox(got, want, 1e-3) { 336 log.Fatalf("got= %v\nwant=%v\n", got, want) 337 } 338 339 inv := mat.NewSymDense(len(res.Location.X), nil) 340 f1d.Hessian(inv, res.Location.X) 341 // fmt.Printf("hessian: %1.2e\n", mat.Formatted(inv, mat.Prefix(" "))) 342 343 popt := res.Location.X 344 pcov := mat.NewDense(len(popt), len(popt), nil) 345 { 346 var chol mat.Cholesky 347 if ok := chol.Factorize(inv); !ok { 348 log.Fatalf("cov-matrix not positive semi-definite") 349 } 350 351 err := chol.InverseTo(inv) 352 if err != nil { 353 log.Fatalf("could not inverse matrix: %+v", err) 354 } 355 pcov.Copy(inv) 356 } 357 358 // compute goodness-of-fit. 359 gof := newGoF(f1d.X, f1d.Y, popt, func(x float64) float64 { 360 return f1d.F(x, popt) 361 }) 362 363 pcov.Scale(gof.SSE/float64(len(f1d.X)-len(popt)), pcov) 364 365 // fmt.Printf("pcov: %1.2e\n", mat.Formatted(pcov, mat.Prefix(" "))) 366 367 var ( 368 n = float64(len(f1d.X)) // number of data points 369 ndf = n - float64(len(popt)) // number of degrees of freedom 370 t = distuv.StudentsT{ 371 Mu: 0, 372 Sigma: 1, 373 Nu: ndf, 374 }.Quantile(0.5 * (1 + 0.95)) 375 ) 376 377 for i, p := range popt { 378 sigma := math.Sqrt(pcov.At(i, i)) 379 fmt.Printf("c%d: %1.5e [%1.5e, %1.5e] -- truth: %g\n", i, p, p-sigma*t, p+sigma*t, want[i]) 380 } 381 // Output: 382 //c0: 2.99999e+00 [2.99999e+00, 3.00000e+00] -- truth: 3 383 //c1: 3.00000e+01 [3.00000e+01, 3.00000e+01] -- truth: 30 384 //c2: 2.00000e+01 [2.00000e+01, 2.00000e+01] -- truth: 20 385 } 386 387 type GoF struct { 388 SSE float64 // Sum of squares due to error 389 Rsquare float64 // R-Square is the square of the correlation between the response values and the predicted response values 390 NdF int // Number of degrees of freedom 391 AdjRsquare float64 // Degrees of freedom adjusted R-Square 392 RMSE float64 // Root mean squared error 393 } 394 395 func newGoF(xs, ys, ps []float64, f func(float64) float64) GoF { 396 switch { 397 case len(xs) != len(ys): 398 panic("invalid lengths") 399 } 400 401 var gof GoF 402 403 var ( 404 ye = make([]float64, len(ys)) 405 nn = float64(len(xs) - 1) 406 vv = float64(len(xs) - len(ps)) 407 ) 408 409 for i, x := range xs { 410 ye[i] = f(x) 411 dy := ys[i] - ye[i] 412 gof.SSE += dy * dy 413 gof.RMSE += dy * dy 414 } 415 416 gof.Rsquare = stat.RSquaredFrom(ye, ys, nil) 417 gof.AdjRsquare = 1 - ((1 - gof.Rsquare) * nn / vv) 418 gof.RMSE = math.Sqrt(gof.RMSE / float64(len(ys)-len(ps))) 419 gof.NdF = len(ys) - len(ps) 420 421 return gof 422 }