go-hep.org/x/hep@v0.38.1/fit/fit.go (about) 1 // Copyright ©2017 The go-hep Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // Package fit provides functions to fit data. 6 package fit // import "go-hep.org/x/hep/fit" 7 8 import ( 9 "gonum.org/v1/gonum/diff/fd" 10 "gonum.org/v1/gonum/mat" 11 ) 12 13 //go:generate go tool github.com/campoy/embedmd -w README.md 14 15 // Func1D describes a 1D function to fit some data. 16 type Func1D struct { 17 // F is the function to minimize. 18 // ps is the slice of parameters to optimize during the fit. 19 F func(x float64, ps []float64) float64 20 21 // N is the number of parameters to optimize during the fit. 22 // If N is 0, Ps must not be nil. 23 N int 24 25 // Ps is the initial values for the parameters. 26 // If Ps is nil, the set of initial parameters values is a slice of 27 // length N filled with zeros. 28 Ps []float64 29 30 X []float64 31 Y []float64 32 Err []float64 33 34 sig2 []float64 // inverse of squares of measurement errors along Y. 35 36 fct func(ps []float64) float64 // cost function (objective function) 37 grad func(grad, ps []float64) 38 hess func(hess *mat.SymDense, x []float64) 39 } 40 41 func (f *Func1D) init() { 42 43 f.sig2 = make([]float64, len(f.Y)) 44 switch { 45 default: 46 for i := range f.Y { 47 f.sig2[i] = 1 48 } 49 case f.Err != nil: 50 for i, v := range f.Err { 51 f.sig2[i] = 1 / (v * v) 52 } 53 } 54 55 if f.Ps == nil { 56 f.Ps = make([]float64, f.N) 57 } 58 59 if len(f.Ps) == 0 { 60 panic("fit: invalid number of initial parameters") 61 } 62 63 if len(f.X) != len(f.Y) { 64 panic("fit: mismatch length") 65 } 66 67 if len(f.sig2) != len(f.Y) { 68 panic("fit: mismatch length") 69 } 70 71 f.fct = func(ps []float64) float64 { 72 var chi2 float64 73 for i := range f.X { 74 res := f.F(f.X[i], ps) - f.Y[i] 75 chi2 += res * res * f.sig2[i] 76 } 77 return 0.5 * chi2 78 } 79 80 f.grad = func(grad, ps []float64) { 81 fd.Gradient(grad, f.fct, ps, nil) 82 } 83 84 f.hess = func(hess *mat.SymDense, x []float64) { 85 fd.Hessian(hess, f.fct, x, nil) 86 } 87 } 88 89 // Hessian computes the hessian matrix at the provided x point. 90 func (f *Func1D) Hessian(hess *mat.SymDense, x []float64) { 91 if f.hess == nil { 92 f.init() 93 } 94 f.hess(hess, x) 95 } 96 97 // FuncND describes a multivariate function F(x0, x1... xn; p0, p1... pn) 98 // for which the parameters ps can be found with a fit. 99 type FuncND struct { 100 // F is the function to minimize. 101 // ps is the slice of parameters to optimize during the fit. 102 // x is the slice of independent variables. 103 F func(x []float64, ps []float64) float64 104 105 // N is the number of parameters to optimize during the fit. 106 // If N is 0, Ps must not be nil. 107 N int 108 109 // Ps is the initial values for the parameters. 110 // If Ps is nil, the set of initial parameters values is a slice of 111 // length N filled with zeros. 112 Ps []float64 113 114 // X is the multidimensional slice of the independent variables, 115 // it must be structured so that the X[i] is a list of values for the 116 // independent variables that corresponds to a single Y value. 117 // In other words, the sequence of rows must correspond to the sequence 118 // of independent variable values. 119 X [][]float64 120 Y []float64 121 Err []float64 122 123 sig2 []float64 // inverse of squares of measurement errors along Y. 124 125 fct func(ps []float64) float64 // cost function (objective function) 126 grad func(grad, ps []float64) 127 hess func(hess *mat.SymDense, x []float64) // hessian matrix 128 } 129 130 func (f *FuncND) init() { 131 132 f.sig2 = make([]float64, len(f.Y)) 133 switch { 134 default: 135 for i := range f.Y { 136 f.sig2[i] = 1 137 } 138 case f.Err != nil: 139 for i, v := range f.Err { 140 f.sig2[i] = 1 / (v * v) 141 } 142 } 143 144 if f.Ps == nil { 145 f.Ps = make([]float64, f.N) 146 } 147 148 if len(f.Ps) == 0 { 149 panic("fit: invalid number of initial parameters") 150 } 151 152 if len(f.X) != len(f.Y) { 153 panic("fit: mismatch length") 154 } 155 156 if len(f.sig2) != len(f.Y) { 157 panic("fit: mismatch length") 158 } 159 160 f.fct = func(ps []float64) float64 { 161 var chi2 float64 162 for i := range f.X { 163 res := f.F(f.X[i], ps) - f.Y[i] 164 chi2 += res * res * f.sig2[i] 165 } 166 return 0.5 * chi2 167 } 168 169 f.grad = func(grad []float64, ps []float64) { 170 fd.Gradient(grad, f.fct, ps, nil) 171 } 172 173 f.hess = func(hess *mat.SymDense, x []float64) { 174 fd.Hessian(hess, f.fct, x, nil) 175 } 176 }