go-hep.org/x/hep@v0.38.1/fit/curve1d_example_test.go (about)

     1  // Copyright ©2017 The go-hep Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package fit_test
     6  
     7  import (
     8  	"fmt"
     9  	"image/color"
    10  	"log"
    11  	"math"
    12  
    13  	"go-hep.org/x/hep/fit"
    14  	"go-hep.org/x/hep/hbook"
    15  	"go-hep.org/x/hep/hplot"
    16  	"gonum.org/v1/gonum/floats"
    17  	"gonum.org/v1/gonum/mat"
    18  	"gonum.org/v1/gonum/optimize"
    19  	"gonum.org/v1/gonum/stat"
    20  	"gonum.org/v1/gonum/stat/distuv"
    21  	"gonum.org/v1/plot/plotter"
    22  	"gonum.org/v1/plot/vg"
    23  )
    24  
    25  func ExampleCurve1D_gaussian() {
    26  	var (
    27  		cst   = 3.0
    28  		mean  = 30.0
    29  		sigma = 20.0
    30  		want  = []float64{cst, mean, sigma}
    31  	)
    32  
    33  	xdata, ydata, err := readXY("testdata/gauss-data.txt")
    34  	if err != nil {
    35  		log.Fatal(err)
    36  	}
    37  
    38  	gauss := func(x, cst, mu, sigma float64) float64 {
    39  		v := (x - mu)
    40  		return cst * math.Exp(-v*v/sigma)
    41  	}
    42  
    43  	res, err := fit.Curve1D(
    44  		fit.Func1D{
    45  			F: func(x float64, ps []float64) float64 {
    46  				return gauss(x, ps[0], ps[1], ps[2])
    47  			},
    48  			X:  xdata,
    49  			Y:  ydata,
    50  			Ps: []float64{10, 10, 10},
    51  		},
    52  		nil, &optimize.NelderMead{},
    53  	)
    54  	if err != nil {
    55  		log.Fatal(err)
    56  	}
    57  
    58  	if err := res.Status.Err(); err != nil {
    59  		log.Fatal(err)
    60  	}
    61  	if got := res.X; !floats.EqualApprox(got, want, 1e-3) {
    62  		log.Fatalf("got= %v\nwant=%v\n", got, want)
    63  	}
    64  
    65  	{
    66  		p := hplot.New()
    67  		p.X.Label.Text = "Gauss"
    68  		p.Y.Label.Text = "y-data"
    69  
    70  		s := hplot.NewS2D(hplot.ZipXY(xdata, ydata))
    71  		s.Color = color.RGBA{0, 0, 255, 255}
    72  		p.Add(s)
    73  
    74  		f := plotter.NewFunction(func(x float64) float64 {
    75  			return gauss(x, res.X[0], res.X[1], res.X[2])
    76  		})
    77  		f.Color = color.RGBA{255, 0, 0, 255}
    78  		f.Samples = 1000
    79  		p.Add(f)
    80  
    81  		p.Add(plotter.NewGrid())
    82  
    83  		err := p.Save(20*vg.Centimeter, -1, "testdata/gauss-plot.png")
    84  		if err != nil {
    85  			log.Fatal(err)
    86  		}
    87  	}
    88  }
    89  
    90  func ExampleCurve1D_exponential() {
    91  	const (
    92  		a   = 0.3
    93  		b   = 0.1
    94  		ndf = 2.0
    95  	)
    96  
    97  	xdata, ydata, err := readXY("testdata/exp-data.txt")
    98  	if err != nil {
    99  		log.Fatal(err)
   100  	}
   101  
   102  	exp := func(x, a, b float64) float64 {
   103  		return math.Exp(a*x + b)
   104  	}
   105  
   106  	res, err := fit.Curve1D(
   107  		fit.Func1D{
   108  			F: func(x float64, ps []float64) float64 {
   109  				return exp(x, ps[0], ps[1])
   110  			},
   111  			X: xdata,
   112  			Y: ydata,
   113  			N: 2,
   114  		},
   115  		nil, &optimize.NelderMead{},
   116  	)
   117  	if err != nil {
   118  		log.Fatal(err)
   119  	}
   120  
   121  	if err := res.Status.Err(); err != nil {
   122  		log.Fatal(err)
   123  	}
   124  	if got, want := res.X, []float64{a, b}; !floats.EqualApprox(got, want, 0.1) {
   125  		log.Fatalf("got= %v\nwant=%v\n", got, want)
   126  	}
   127  
   128  	{
   129  		p := hplot.New()
   130  		p.X.Label.Text = "exp(a*x+b)"
   131  		p.Y.Label.Text = "y-data"
   132  		p.Y.Min = 0
   133  		p.Y.Max = 5
   134  		p.X.Min = 0
   135  		p.X.Max = 5
   136  
   137  		s := hplot.NewS2D(hplot.ZipXY(xdata, ydata))
   138  		s.Color = color.RGBA{0, 0, 255, 255}
   139  		p.Add(s)
   140  
   141  		f := plotter.NewFunction(func(x float64) float64 {
   142  			return exp(x, res.X[0], res.X[1])
   143  		})
   144  		f.Color = color.RGBA{255, 0, 0, 255}
   145  		f.Samples = 1000
   146  		p.Add(f)
   147  
   148  		p.Add(plotter.NewGrid())
   149  
   150  		err := p.Save(20*vg.Centimeter, -1, "testdata/exp-plot.png")
   151  		if err != nil {
   152  			log.Fatal(err)
   153  		}
   154  	}
   155  }
   156  
   157  func ExampleCurve1D_poly() {
   158  	var (
   159  		a    = 1.0
   160  		b    = 2.0
   161  		ps   = []float64{a, b}
   162  		want = []float64{1.38592513, 1.98485122} // from scipy.curve_fit
   163  	)
   164  
   165  	poly := func(x float64, ps []float64) float64 {
   166  		return ps[0] + ps[1]*x*x
   167  	}
   168  
   169  	xdata, ydata := genXY(100, poly, ps, -10, 10)
   170  
   171  	res, err := fit.Curve1D(
   172  		fit.Func1D{
   173  			F:  poly,
   174  			X:  xdata,
   175  			Y:  ydata,
   176  			Ps: []float64{1, 1},
   177  		},
   178  		nil, &optimize.NelderMead{},
   179  	)
   180  	if err != nil {
   181  		log.Fatal(err)
   182  	}
   183  
   184  	if err := res.Status.Err(); err != nil {
   185  		log.Fatal(err)
   186  	}
   187  
   188  	if got := res.X; !floats.EqualApprox(got, want, 1e-6) {
   189  		log.Fatalf("got= %v\nwant=%v\n", got, want)
   190  	}
   191  
   192  	{
   193  		p := hplot.New()
   194  		p.X.Label.Text = "f(x) = a + b*x*x"
   195  		p.Y.Label.Text = "y-data"
   196  		p.X.Min = -10
   197  		p.X.Max = +10
   198  		p.Y.Min = 0
   199  		p.Y.Max = 220
   200  
   201  		s := hplot.NewS2D(hplot.ZipXY(xdata, ydata))
   202  		s.Color = color.RGBA{0, 0, 255, 255}
   203  		p.Add(s)
   204  
   205  		f := plotter.NewFunction(func(x float64) float64 {
   206  			return poly(x, res.X)
   207  		})
   208  		f.Color = color.RGBA{255, 0, 0, 255}
   209  		f.Samples = 1000
   210  		p.Add(f)
   211  
   212  		p.Add(plotter.NewGrid())
   213  
   214  		err := p.Save(20*vg.Centimeter, -1, "testdata/poly-plot.png")
   215  		if err != nil {
   216  			log.Fatal(err)
   217  		}
   218  	}
   219  }
   220  
   221  func ExampleCurve1D_powerlaw() {
   222  	var (
   223  		amp   = 11.021171432949746
   224  		index = -2.027389113217428
   225  		want  = []float64{amp, index}
   226  	)
   227  
   228  	xdata, ydata, yerrs, err := readXYerr("testdata/powerlaw-data.txt")
   229  	if err != nil {
   230  		log.Fatal(err)
   231  	}
   232  
   233  	plaw := func(x, amp, index float64) float64 {
   234  		return amp * math.Pow(x, index)
   235  	}
   236  
   237  	res, err := fit.Curve1D(
   238  		fit.Func1D{
   239  			F: func(x float64, ps []float64) float64 {
   240  				return plaw(x, ps[0], ps[1])
   241  			},
   242  			X:   xdata,
   243  			Y:   ydata,
   244  			Err: yerrs,
   245  			Ps:  []float64{1, 1},
   246  		},
   247  		nil, &optimize.NelderMead{},
   248  	)
   249  	if err != nil {
   250  		log.Fatal(err)
   251  	}
   252  
   253  	if err := res.Status.Err(); err != nil {
   254  		log.Fatal(err)
   255  	}
   256  	if got := res.X; !floats.EqualApprox(got, want, 1e-3) {
   257  		log.Fatalf("got= %v\nwant=%v\n", got, want)
   258  	}
   259  
   260  	{
   261  		p := hplot.New()
   262  		p.X.Label.Text = "f(x) = a * x^b"
   263  		p.Y.Label.Text = "y-data"
   264  		p.X.Min = 0
   265  		p.X.Max = 10
   266  		p.Y.Min = 0
   267  		p.Y.Max = 10
   268  
   269  		pts := make([]hbook.Point2D, len(xdata))
   270  		for i := range pts {
   271  			pts[i].X = xdata[i]
   272  			pts[i].Y = ydata[i]
   273  			pts[i].ErrY.Min = 0.5 * yerrs[i]
   274  			pts[i].ErrY.Max = 0.5 * yerrs[i]
   275  		}
   276  
   277  		s := hplot.NewS2D(hbook.NewS2D(pts...), hplot.WithYErrBars(true))
   278  		s.Color = color.RGBA{0, 0, 255, 255}
   279  		p.Add(s)
   280  
   281  		f := plotter.NewFunction(func(x float64) float64 {
   282  			return plaw(x, res.X[0], res.X[1])
   283  		})
   284  		f.Color = color.RGBA{255, 0, 0, 255}
   285  		f.Samples = 1000
   286  		p.Add(f)
   287  
   288  		p.Add(plotter.NewGrid())
   289  
   290  		err := p.Save(20*vg.Centimeter, -1, "testdata/powerlaw-plot.png")
   291  		if err != nil {
   292  			log.Fatal(err)
   293  		}
   294  	}
   295  }
   296  
   297  func ExampleCurve1D_hessian() {
   298  	var (
   299  		cst   = 3.0
   300  		mean  = 30.0
   301  		sigma = 20.0
   302  		want  = []float64{cst, mean, sigma}
   303  	)
   304  
   305  	xdata, ydata, err := readXY("testdata/gauss-data.txt")
   306  	if err != nil {
   307  		log.Fatal(err)
   308  	}
   309  
   310  	// use a small sample
   311  	xdata = xdata[:min(25, len(xdata))]
   312  	ydata = ydata[:min(25, len(ydata))]
   313  
   314  	gauss := func(x, cst, mu, sigma float64) float64 {
   315  		v := (x - mu)
   316  		return cst * math.Exp(-v*v/sigma)
   317  	}
   318  
   319  	f1d := fit.Func1D{
   320  		F: func(x float64, ps []float64) float64 {
   321  			return gauss(x, ps[0], ps[1], ps[2])
   322  		},
   323  		X:  xdata,
   324  		Y:  ydata,
   325  		Ps: []float64{10, 10, 10},
   326  	}
   327  	res, err := fit.Curve1D(f1d, nil, &optimize.NelderMead{})
   328  	if err != nil {
   329  		log.Fatal(err)
   330  	}
   331  
   332  	if err := res.Status.Err(); err != nil {
   333  		log.Fatal(err)
   334  	}
   335  	if got := res.X; !floats.EqualApprox(got, want, 1e-3) {
   336  		log.Fatalf("got= %v\nwant=%v\n", got, want)
   337  	}
   338  
   339  	inv := mat.NewSymDense(len(res.Location.X), nil)
   340  	f1d.Hessian(inv, res.Location.X)
   341  	// fmt.Printf("hessian: %1.2e\n", mat.Formatted(inv, mat.Prefix("         ")))
   342  
   343  	popt := res.Location.X
   344  	pcov := mat.NewDense(len(popt), len(popt), nil)
   345  	{
   346  		var chol mat.Cholesky
   347  		if ok := chol.Factorize(inv); !ok {
   348  			log.Fatalf("cov-matrix not positive semi-definite")
   349  		}
   350  
   351  		err := chol.InverseTo(inv)
   352  		if err != nil {
   353  			log.Fatalf("could not inverse matrix: %+v", err)
   354  		}
   355  		pcov.Copy(inv)
   356  	}
   357  
   358  	// compute goodness-of-fit.
   359  	gof := newGoF(f1d.X, f1d.Y, popt, func(x float64) float64 {
   360  		return f1d.F(x, popt)
   361  	})
   362  
   363  	pcov.Scale(gof.SSE/float64(len(f1d.X)-len(popt)), pcov)
   364  
   365  	// fmt.Printf("pcov: %1.2e\n", mat.Formatted(pcov, mat.Prefix("      ")))
   366  
   367  	var (
   368  		n   = float64(len(f1d.X))    // number of data points
   369  		ndf = n - float64(len(popt)) // number of degrees of freedom
   370  		t   = distuv.StudentsT{
   371  			Mu:    0,
   372  			Sigma: 1,
   373  			Nu:    ndf,
   374  		}.Quantile(0.5 * (1 + 0.95))
   375  	)
   376  
   377  	for i, p := range popt {
   378  		sigma := math.Sqrt(pcov.At(i, i))
   379  		fmt.Printf("c%d: %1.5e [%1.5e, %1.5e] -- truth: %g\n", i, p, p-sigma*t, p+sigma*t, want[i])
   380  	}
   381  	// Output:
   382  	//c0: 2.99999e+00 [2.99999e+00, 3.00000e+00] -- truth: 3
   383  	//c1: 3.00000e+01 [3.00000e+01, 3.00000e+01] -- truth: 30
   384  	//c2: 2.00000e+01 [2.00000e+01, 2.00000e+01] -- truth: 20
   385  }
   386  
   387  type GoF struct {
   388  	SSE        float64 // Sum of squares due to error
   389  	Rsquare    float64 // R-Square is the square of the correlation between the response values and the predicted response values
   390  	NdF        int     // Number of degrees of freedom
   391  	AdjRsquare float64 // Degrees of freedom adjusted R-Square
   392  	RMSE       float64 // Root mean squared error
   393  }
   394  
   395  func newGoF(xs, ys, ps []float64, f func(float64) float64) GoF {
   396  	switch {
   397  	case len(xs) != len(ys):
   398  		panic("invalid lengths")
   399  	}
   400  
   401  	var gof GoF
   402  
   403  	var (
   404  		ye = make([]float64, len(ys))
   405  		nn = float64(len(xs) - 1)
   406  		vv = float64(len(xs) - len(ps))
   407  	)
   408  
   409  	for i, x := range xs {
   410  		ye[i] = f(x)
   411  		dy := ys[i] - ye[i]
   412  		gof.SSE += dy * dy
   413  		gof.RMSE += dy * dy
   414  	}
   415  
   416  	gof.Rsquare = stat.RSquaredFrom(ye, ys, nil)
   417  	gof.AdjRsquare = 1 - ((1 - gof.Rsquare) * nn / vv)
   418  	gof.RMSE = math.Sqrt(gof.RMSE / float64(len(ys)-len(ps)))
   419  	gof.NdF = len(ys) - len(ps)
   420  
   421  	return gof
   422  }