gonum.org/v1/gonum@v0.15.1-0.20240517103525-f853624cb1bb/diff/fd/hessian.go (about) 1 // Copyright ©2017 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package fd 6 7 import ( 8 "math" 9 "sync" 10 11 "gonum.org/v1/gonum/mat" 12 ) 13 14 // Hessian approximates the Hessian matrix of the multivariate function f at 15 // the location x. That is 16 // 17 // H_{i,j} = ∂^2 f(x)/∂x_i ∂x_j 18 // 19 // The resulting H will be stored in dst. Finite difference formula and other 20 // options are specified by settings. If settings is nil, the Hessian will be 21 // estimated using the Forward formula and a default step size. 22 // 23 // If the dst matrix is empty it will be resized to the correct dimensions, 24 // otherwise the dimensions of dst must match the length of x or Hessian will panic. 25 // Hessian will panic if the derivative order of the formula is not 1. 26 func Hessian(dst *mat.SymDense, f func(x []float64) float64, x []float64, settings *Settings) { 27 n := len(x) 28 if dst.IsEmpty() { 29 *dst = *(dst.GrowSym(n).(*mat.SymDense)) 30 } else if dst.SymmetricDim() != n { 31 panic("hessian: dst size mismatch") 32 } 33 dst.Zero() 34 35 // Default settings. 36 formula := Forward 37 step := math.Sqrt(formula.Step) // Use the sqrt because taking derivatives of derivatives. 38 var originValue float64 39 var originKnown, concurrent bool 40 41 // Use user settings if provided. 42 if settings != nil { 43 if !settings.Formula.isZero() { 44 formula = settings.Formula 45 step = math.Sqrt(formula.Step) 46 checkFormula(formula) 47 if formula.Derivative != 1 { 48 panic(badDerivOrder) 49 } 50 } 51 if settings.Step != 0 { 52 if settings.Step < 0 { 53 panic(negativeStep) 54 } 55 step = settings.Step 56 } 57 originKnown = settings.OriginKnown 58 originValue = settings.OriginValue 59 concurrent = settings.Concurrent 60 } 61 62 evals := n * (n + 1) / 2 * len(formula.Stencil) * len(formula.Stencil) 63 for _, pt := range formula.Stencil { 64 if pt.Loc == 0 { 65 evals -= n * (n + 1) / 2 66 break 67 } 68 } 69 70 nWorkers := computeWorkers(concurrent, evals) 71 if nWorkers == 1 { 72 hessianSerial(dst, f, x, formula.Stencil, step, originKnown, originValue) 73 return 74 } 75 hessianConcurrent(dst, nWorkers, evals, f, x, formula.Stencil, step, originKnown, originValue) 76 } 77 78 func hessianSerial(dst *mat.SymDense, f func(x []float64) float64, x []float64, stencil []Point, step float64, originKnown bool, originValue float64) { 79 n := len(x) 80 xCopy := make([]float64, n) 81 fo := func() float64 { 82 // Copy x in case it is modified during the call. 83 copy(xCopy, x) 84 return f(x) 85 } 86 is2 := 1 / (step * step) 87 origin := getOrigin(originKnown, originValue, fo, stencil) 88 for i := 0; i < n; i++ { 89 for j := i; j < n; j++ { 90 var hess float64 91 for _, pti := range stencil { 92 for _, ptj := range stencil { 93 var v float64 94 if pti.Loc == 0 && ptj.Loc == 0 { 95 v = origin 96 } else { 97 // Copying the data anew has two benefits. First, it 98 // avoids floating point issues where adding and then 99 // subtracting the step don't return to the exact same 100 // location. Secondly, it protects against the function 101 // modifying the input data. 102 copy(xCopy, x) 103 xCopy[i] += pti.Loc * step 104 xCopy[j] += ptj.Loc * step 105 v = f(xCopy) 106 } 107 hess += v * pti.Coeff * ptj.Coeff * is2 108 } 109 } 110 dst.SetSym(i, j, hess) 111 } 112 } 113 } 114 115 func hessianConcurrent(dst *mat.SymDense, nWorkers, evals int, f func(x []float64) float64, x []float64, stencil []Point, step float64, originKnown bool, originValue float64) { 116 n := dst.SymmetricDim() 117 type run struct { 118 i, j int 119 iIdx, jIdx int 120 result float64 121 } 122 123 send := make(chan run, evals) 124 ans := make(chan run, evals) 125 126 var originWG sync.WaitGroup 127 hasOrigin := usesOrigin(stencil) 128 if hasOrigin { 129 originWG.Add(1) 130 // Launch worker to compute the origin. 131 go func() { 132 defer originWG.Done() 133 xCopy := make([]float64, len(x)) 134 copy(xCopy, x) 135 originValue = f(xCopy) 136 }() 137 } 138 139 var workerWG sync.WaitGroup 140 // Launch workers. 141 for i := 0; i < nWorkers; i++ { 142 workerWG.Add(1) 143 go func(send <-chan run, ans chan<- run) { 144 defer workerWG.Done() 145 xCopy := make([]float64, len(x)) 146 for r := range send { 147 if stencil[r.iIdx].Loc == 0 && stencil[r.jIdx].Loc == 0 { 148 originWG.Wait() 149 r.result = originValue 150 } else { 151 // See hessianSerial for comment on the copy. 152 copy(xCopy, x) 153 xCopy[r.i] += stencil[r.iIdx].Loc * step 154 xCopy[r.j] += stencil[r.jIdx].Loc * step 155 r.result = f(xCopy) 156 } 157 ans <- r 158 } 159 }(send, ans) 160 } 161 162 // Launch the distributor, which sends all of runs. 163 go func(send chan<- run) { 164 for i := 0; i < n; i++ { 165 for j := i; j < n; j++ { 166 for iIdx := range stencil { 167 for jIdx := range stencil { 168 send <- run{ 169 i: i, j: j, iIdx: iIdx, jIdx: jIdx, 170 } 171 } 172 } 173 } 174 } 175 close(send) 176 // Wait for all the workers to quit, then close the ans channel. 177 workerWG.Wait() 178 close(ans) 179 }(send) 180 181 is2 := 1 / (step * step) 182 // Read in the results. 183 for r := range ans { 184 v := r.result * stencil[r.iIdx].Coeff * stencil[r.jIdx].Coeff * is2 185 v += dst.At(r.i, r.j) 186 dst.SetSym(r.i, r.j, v) 187 } 188 }