go-hep.org/x/hep@v0.38.1/fit/curve_nd_example_test.go (about)

     1  // Copyright ©2020 The go-hep Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package fit_test
     6  
     7  import (
     8  	"fmt"
     9  	"image/color"
    10  	"log"
    11  
    12  	"go-hep.org/x/hep/fit"
    13  	"go-hep.org/x/hep/hplot"
    14  	"gonum.org/v1/gonum/floats"
    15  	"gonum.org/v1/gonum/optimize"
    16  	"gonum.org/v1/plot/plotter"
    17  	"gonum.org/v1/plot/vg"
    18  )
    19  
    20  func ExampleCurveND_plane() {
    21  	var (
    22  		m1    = 0.3
    23  		m2    = 0.1
    24  		c     = 0.2
    25  		ps    = []float64{m1, m2, c}
    26  		n0    = 10
    27  		n1    = 10
    28  		x0min = -1.
    29  		x0max = 1.
    30  		x1min = -1.
    31  		x1max = 1.
    32  	)
    33  
    34  	plane := func(x, ps []float64) float64 {
    35  		return ps[0]*x[0] + ps[1]*x[1] + ps[2]
    36  	}
    37  
    38  	xData, yData := genData2D(n0, n1, plane, ps, x0min, x0max, x1min, x1max)
    39  
    40  	res, err := fit.CurveND(
    41  		fit.FuncND{
    42  			F: func(x []float64, ps []float64) float64 {
    43  				return plane(x, ps)
    44  			},
    45  			X: xData,
    46  			Y: yData,
    47  			N: 3,
    48  		},
    49  		nil, &optimize.NelderMead{},
    50  	)
    51  	if err != nil {
    52  		log.Fatal(err)
    53  	}
    54  
    55  	if err := res.Status.Err(); err != nil {
    56  		log.Fatal(err)
    57  	}
    58  	if got, want := res.X, []float64{m1, m2, c}; !floats.EqualApprox(got, want, 0.1) {
    59  		log.Fatalf("got= %v\nwant=%v\n", got, want)
    60  	}
    61  
    62  	{
    63  		// slicing for a particular x0 value to plot y as a function of x1,
    64  		// to visualise how well the fit is working for a given x0.
    65  		x0Selection := 8
    66  		if 0 > x0Selection || x0Selection > n0 {
    67  			log.Fatalf("x0 slice, %d, is not in valid range [0 - %d]", x0Selection, n0)
    68  		}
    69  		x0SlicePos := x0min + ((x0max-x0min)/float64(n0))*float64(x0Selection)
    70  
    71  		var x1Slice []float64
    72  		var ySlice []float64
    73  
    74  		for i := range xData {
    75  			if xData[i][0] == x0SlicePos {
    76  				x1Slice = append(x1Slice, xData[i][1])
    77  				ySlice = append(ySlice, yData[i])
    78  			}
    79  		}
    80  
    81  		p := hplot.New()
    82  		p.Title.Text = fmt.Sprintf("Slice of plane at x0 = %.2f", x0SlicePos)
    83  		p.X.Label.Text = "x1"
    84  		p.Y.Label.Text = "y"
    85  		p.Y.Min = x1min
    86  		p.Y.Max = x1max
    87  		p.X.Min = x0min
    88  		p.X.Max = x0max
    89  
    90  		s := hplot.NewS2D(hplot.ZipXY(x1Slice, ySlice))
    91  		s.Color = color.RGBA{B: 255, A: 255}
    92  		p.Add(s)
    93  
    94  		shiftLine := func(x, m, c, mxOtherAxis float64) float64 {
    95  			return m*x + c + mxOtherAxis
    96  		}
    97  
    98  		f := plotter.NewFunction(func(x float64) float64 {
    99  			return shiftLine(x, res.X[1], res.X[2], res.X[0]*x0SlicePos)
   100  		})
   101  		f.Color = color.RGBA{R: 255, A: 255}
   102  		f.Samples = 1000
   103  		p.Add(f)
   104  
   105  		p.Add(plotter.NewGrid())
   106  		err := p.Save(20*vg.Centimeter, -1, "testdata/2d-plane-plot.png")
   107  		if err != nil {
   108  			log.Fatal(err)
   109  		}
   110  	}
   111  }