go-hep.org/x/hep@v0.38.1/fit/curve1d_test.go (about) 1 // Copyright ©2017 The go-hep Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package fit_test 6 7 import ( 8 "bufio" 9 "image/color" 10 "math" 11 "math/rand" 12 "os" 13 "strconv" 14 "strings" 15 "testing" 16 17 "go-hep.org/x/hep/fit" 18 "go-hep.org/x/hep/hplot" 19 "gonum.org/v1/gonum/floats" 20 "gonum.org/v1/plot/plotter" 21 "gonum.org/v1/plot/vg" 22 ) 23 24 func TestCurve1D(t *testing.T) { 25 checkPlot(ExampleCurve1D_gaussian, t, "gauss-plot.png") 26 checkPlotApprox(ExampleCurve1D_exponential, t, 0.1, "exp-plot.png") 27 checkPlot(ExampleCurve1D_poly, t, "poly-plot.png") 28 checkPlot(ExampleCurve1D_powerlaw, t, "powerlaw-plot.png") 29 } 30 31 func TestCurve1DGaussianDefaultOpt(t *testing.T) { 32 checkPlot(func() { 33 var ( 34 cst = 3.0 35 mean = 30.0 36 sigma = 20.0 37 want = []float64{cst, mean, sigma} 38 ) 39 40 xdata, ydata, err := readXY("testdata/gauss-data.txt") 41 if err != nil { 42 t.Fatal(err) 43 } 44 45 gauss := func(x, cst, mu, sigma float64) float64 { 46 v := (x - mu) 47 return cst * math.Exp(-v*v/sigma) 48 } 49 50 res, err := fit.Curve1D( 51 fit.Func1D{ 52 F: func(x float64, ps []float64) float64 { 53 return gauss(x, ps[0], ps[1], ps[2]) 54 }, 55 X: xdata, 56 Y: ydata, 57 Ps: []float64{10, 10, 10}, 58 }, 59 nil, nil, 60 ) 61 if err != nil { 62 t.Fatal(err) 63 } 64 65 if err := res.Status.Err(); err != nil { 66 t.Fatal(err) 67 } 68 if got := res.X; !floats.EqualApprox(got, want, 1e-3) { 69 t.Fatalf("got= %v\nwant=%v\n", got, want) 70 } 71 72 { 73 p := hplot.New() 74 p.X.Label.Text = "Gauss" 75 p.Y.Label.Text = "y-data" 76 77 s := hplot.NewS2D(hplot.ZipXY(xdata, ydata)) 78 s.Color = color.RGBA{0, 0, 255, 255} 79 p.Add(s) 80 81 f := plotter.NewFunction(func(x float64) float64 { 82 return gauss(x, res.X[0], res.X[1], res.X[2]) 83 }) 84 f.Color = color.RGBA{255, 0, 0, 255} 85 f.Samples = 1000 86 p.Add(f) 87 88 p.Add(plotter.NewGrid()) 89 90 err := p.Save(20*vg.Centimeter, -1, "testdata/gauss-plot-default-opt.png") 91 if err != nil { 92 t.Fatal(err) 93 } 94 } 95 }, t, "gauss-plot-default-opt.png") 96 } 97 98 func genXY(n int, f func(x float64, ps []float64) float64, ps []float64, xmin, xmax float64) ([]float64, []float64) { 99 xdata := make([]float64, n) 100 ydata := make([]float64, n) 101 rnd := rand.New(rand.NewSource(1234)) 102 xstep := (xmax - xmin) / float64(n) 103 p := make([]float64, len(ps)) 104 for i := range n { 105 x := xmin + xstep*float64(i) 106 for j := range p { 107 v := rnd.NormFloat64() 108 p[j] = ps[j] + v*0.2 109 } 110 xdata[i] = x 111 ydata[i] = f(x, p) 112 } 113 return xdata, ydata 114 } 115 116 func readXY(fname string) (xs, ys []float64, err error) { 117 f, err := os.Open(fname) 118 if err != nil { 119 return xs, ys, err 120 } 121 defer f.Close() 122 123 scan := bufio.NewScanner(f) 124 for scan.Scan() { 125 line := scan.Text() 126 toks := strings.Split(line, " ") 127 x, err := strconv.ParseFloat(toks[0], 64) 128 if err != nil { 129 return xs, ys, err 130 } 131 xs = append(xs, x) 132 133 y, err := strconv.ParseFloat(toks[1], 64) 134 if err != nil { 135 return xs, ys, err 136 } 137 ys = append(ys, y) 138 } 139 140 return 141 } 142 143 func readXYerr(fname string) (xs, ys, yerrs []float64, err error) { 144 f, err := os.Open(fname) 145 if err != nil { 146 return xs, ys, yerrs, err 147 } 148 defer f.Close() 149 150 scan := bufio.NewScanner(f) 151 for scan.Scan() { 152 line := scan.Text() 153 toks := strings.Split(line, " ") 154 x, err := strconv.ParseFloat(toks[0], 64) 155 if err != nil { 156 return xs, ys, yerrs, err 157 } 158 xs = append(xs, x) 159 160 y, err := strconv.ParseFloat(toks[1], 64) 161 if err != nil { 162 return xs, ys, yerrs, err 163 } 164 ys = append(ys, y) 165 166 yerr, err := strconv.ParseFloat(toks[2], 64) 167 if err != nil { 168 return xs, ys, yerrs, err 169 } 170 yerrs = append(yerrs, yerr) 171 } 172 173 return 174 }