go-hep.org/x/hep@v0.38.1/fit/curve1d_test.go (about)

     1  // Copyright ©2017 The go-hep Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package fit_test
     6  
     7  import (
     8  	"bufio"
     9  	"image/color"
    10  	"math"
    11  	"math/rand"
    12  	"os"
    13  	"strconv"
    14  	"strings"
    15  	"testing"
    16  
    17  	"go-hep.org/x/hep/fit"
    18  	"go-hep.org/x/hep/hplot"
    19  	"gonum.org/v1/gonum/floats"
    20  	"gonum.org/v1/plot/plotter"
    21  	"gonum.org/v1/plot/vg"
    22  )
    23  
    24  func TestCurve1D(t *testing.T) {
    25  	checkPlot(ExampleCurve1D_gaussian, t, "gauss-plot.png")
    26  	checkPlotApprox(ExampleCurve1D_exponential, t, 0.1, "exp-plot.png")
    27  	checkPlot(ExampleCurve1D_poly, t, "poly-plot.png")
    28  	checkPlot(ExampleCurve1D_powerlaw, t, "powerlaw-plot.png")
    29  }
    30  
    31  func TestCurve1DGaussianDefaultOpt(t *testing.T) {
    32  	checkPlot(func() {
    33  		var (
    34  			cst   = 3.0
    35  			mean  = 30.0
    36  			sigma = 20.0
    37  			want  = []float64{cst, mean, sigma}
    38  		)
    39  
    40  		xdata, ydata, err := readXY("testdata/gauss-data.txt")
    41  		if err != nil {
    42  			t.Fatal(err)
    43  		}
    44  
    45  		gauss := func(x, cst, mu, sigma float64) float64 {
    46  			v := (x - mu)
    47  			return cst * math.Exp(-v*v/sigma)
    48  		}
    49  
    50  		res, err := fit.Curve1D(
    51  			fit.Func1D{
    52  				F: func(x float64, ps []float64) float64 {
    53  					return gauss(x, ps[0], ps[1], ps[2])
    54  				},
    55  				X:  xdata,
    56  				Y:  ydata,
    57  				Ps: []float64{10, 10, 10},
    58  			},
    59  			nil, nil,
    60  		)
    61  		if err != nil {
    62  			t.Fatal(err)
    63  		}
    64  
    65  		if err := res.Status.Err(); err != nil {
    66  			t.Fatal(err)
    67  		}
    68  		if got := res.X; !floats.EqualApprox(got, want, 1e-3) {
    69  			t.Fatalf("got= %v\nwant=%v\n", got, want)
    70  		}
    71  
    72  		{
    73  			p := hplot.New()
    74  			p.X.Label.Text = "Gauss"
    75  			p.Y.Label.Text = "y-data"
    76  
    77  			s := hplot.NewS2D(hplot.ZipXY(xdata, ydata))
    78  			s.Color = color.RGBA{0, 0, 255, 255}
    79  			p.Add(s)
    80  
    81  			f := plotter.NewFunction(func(x float64) float64 {
    82  				return gauss(x, res.X[0], res.X[1], res.X[2])
    83  			})
    84  			f.Color = color.RGBA{255, 0, 0, 255}
    85  			f.Samples = 1000
    86  			p.Add(f)
    87  
    88  			p.Add(plotter.NewGrid())
    89  
    90  			err := p.Save(20*vg.Centimeter, -1, "testdata/gauss-plot-default-opt.png")
    91  			if err != nil {
    92  				t.Fatal(err)
    93  			}
    94  		}
    95  	}, t, "gauss-plot-default-opt.png")
    96  }
    97  
    98  func genXY(n int, f func(x float64, ps []float64) float64, ps []float64, xmin, xmax float64) ([]float64, []float64) {
    99  	xdata := make([]float64, n)
   100  	ydata := make([]float64, n)
   101  	rnd := rand.New(rand.NewSource(1234))
   102  	xstep := (xmax - xmin) / float64(n)
   103  	p := make([]float64, len(ps))
   104  	for i := range n {
   105  		x := xmin + xstep*float64(i)
   106  		for j := range p {
   107  			v := rnd.NormFloat64()
   108  			p[j] = ps[j] + v*0.2
   109  		}
   110  		xdata[i] = x
   111  		ydata[i] = f(x, p)
   112  	}
   113  	return xdata, ydata
   114  }
   115  
   116  func readXY(fname string) (xs, ys []float64, err error) {
   117  	f, err := os.Open(fname)
   118  	if err != nil {
   119  		return xs, ys, err
   120  	}
   121  	defer f.Close()
   122  
   123  	scan := bufio.NewScanner(f)
   124  	for scan.Scan() {
   125  		line := scan.Text()
   126  		toks := strings.Split(line, " ")
   127  		x, err := strconv.ParseFloat(toks[0], 64)
   128  		if err != nil {
   129  			return xs, ys, err
   130  		}
   131  		xs = append(xs, x)
   132  
   133  		y, err := strconv.ParseFloat(toks[1], 64)
   134  		if err != nil {
   135  			return xs, ys, err
   136  		}
   137  		ys = append(ys, y)
   138  	}
   139  
   140  	return
   141  }
   142  
   143  func readXYerr(fname string) (xs, ys, yerrs []float64, err error) {
   144  	f, err := os.Open(fname)
   145  	if err != nil {
   146  		return xs, ys, yerrs, err
   147  	}
   148  	defer f.Close()
   149  
   150  	scan := bufio.NewScanner(f)
   151  	for scan.Scan() {
   152  		line := scan.Text()
   153  		toks := strings.Split(line, " ")
   154  		x, err := strconv.ParseFloat(toks[0], 64)
   155  		if err != nil {
   156  			return xs, ys, yerrs, err
   157  		}
   158  		xs = append(xs, x)
   159  
   160  		y, err := strconv.ParseFloat(toks[1], 64)
   161  		if err != nil {
   162  			return xs, ys, yerrs, err
   163  		}
   164  		ys = append(ys, y)
   165  
   166  		yerr, err := strconv.ParseFloat(toks[2], 64)
   167  		if err != nil {
   168  			return xs, ys, yerrs, err
   169  		}
   170  		yerrs = append(yerrs, yerr)
   171  	}
   172  
   173  	return
   174  }