github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/diff/fd/hessian.go (about) 1 // Copyright ©2017 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package fd 6 7 import ( 8 "math" 9 "sync" 10 11 "github.com/jingcheng-WU/gonum/mat" 12 ) 13 14 // Hessian approximates the Hessian matrix of the multivariate function f at 15 // the location x. That is 16 // H_{i,j} = ∂^2 f(x)/∂x_i ∂x_j 17 // The resulting H will be stored in dst. Finite difference formula and other 18 // options are specified by settings. If settings is nil, the Hessian will be 19 // estimated using the Forward formula and a default step size. 20 // 21 // If the dst matrix is empty it will be resized to the correct dimensions, 22 // otherwise the dimensions of dst must match the length of x or Hessian will panic. 23 // Hessian will panic if the derivative order of the formula is not 1. 24 func Hessian(dst *mat.SymDense, f func(x []float64) float64, x []float64, settings *Settings) { 25 n := len(x) 26 if dst.IsEmpty() { 27 *dst = *(dst.GrowSym(n).(*mat.SymDense)) 28 } else if dst.Symmetric() != n { 29 panic("hessian: dst size mismatch") 30 } 31 dst.Zero() 32 33 // Default settings. 34 formula := Forward 35 step := math.Sqrt(formula.Step) // Use the sqrt because taking derivatives of derivatives. 36 var originValue float64 37 var originKnown, concurrent bool 38 39 // Use user settings if provided. 40 if settings != nil { 41 if !settings.Formula.isZero() { 42 formula = settings.Formula 43 step = math.Sqrt(formula.Step) 44 checkFormula(formula) 45 if formula.Derivative != 1 { 46 panic(badDerivOrder) 47 } 48 } 49 if settings.Step != 0 { 50 if settings.Step < 0 { 51 panic(negativeStep) 52 } 53 step = settings.Step 54 } 55 originKnown = settings.OriginKnown 56 originValue = settings.OriginValue 57 concurrent = settings.Concurrent 58 } 59 60 evals := n * (n + 1) / 2 * len(formula.Stencil) * len(formula.Stencil) 61 for _, pt := range formula.Stencil { 62 if pt.Loc == 0 { 63 evals -= n * (n + 1) / 2 64 break 65 } 66 } 67 68 nWorkers := computeWorkers(concurrent, evals) 69 if nWorkers == 1 { 70 hessianSerial(dst, f, x, formula.Stencil, step, originKnown, originValue) 71 return 72 } 73 hessianConcurrent(dst, nWorkers, evals, f, x, formula.Stencil, step, originKnown, originValue) 74 } 75 76 func hessianSerial(dst *mat.SymDense, f func(x []float64) float64, x []float64, stencil []Point, step float64, originKnown bool, originValue float64) { 77 n := len(x) 78 xCopy := make([]float64, n) 79 fo := func() float64 { 80 // Copy x in case it is modified during the call. 81 copy(xCopy, x) 82 return f(x) 83 } 84 is2 := 1 / (step * step) 85 origin := getOrigin(originKnown, originValue, fo, stencil) 86 for i := 0; i < n; i++ { 87 for j := i; j < n; j++ { 88 var hess float64 89 for _, pti := range stencil { 90 for _, ptj := range stencil { 91 var v float64 92 if pti.Loc == 0 && ptj.Loc == 0 { 93 v = origin 94 } else { 95 // Copying the data anew has two benefits. First, it 96 // avoids floating point issues where adding and then 97 // subtracting the step don't return to the exact same 98 // location. Secondly, it protects against the function 99 // modifying the input data. 100 copy(xCopy, x) 101 xCopy[i] += pti.Loc * step 102 xCopy[j] += ptj.Loc * step 103 v = f(xCopy) 104 } 105 hess += v * pti.Coeff * ptj.Coeff * is2 106 } 107 } 108 dst.SetSym(i, j, hess) 109 } 110 } 111 } 112 113 func hessianConcurrent(dst *mat.SymDense, nWorkers, evals int, f func(x []float64) float64, x []float64, stencil []Point, step float64, originKnown bool, originValue float64) { 114 n := dst.Symmetric() 115 type run struct { 116 i, j int 117 iIdx, jIdx int 118 result float64 119 } 120 121 send := make(chan run, evals) 122 ans := make(chan run, evals) 123 124 var originWG sync.WaitGroup 125 hasOrigin := usesOrigin(stencil) 126 if hasOrigin { 127 originWG.Add(1) 128 // Launch worker to compute the origin. 129 go func() { 130 defer originWG.Done() 131 xCopy := make([]float64, len(x)) 132 copy(xCopy, x) 133 originValue = f(xCopy) 134 }() 135 } 136 137 var workerWG sync.WaitGroup 138 // Launch workers. 139 for i := 0; i < nWorkers; i++ { 140 workerWG.Add(1) 141 go func(send <-chan run, ans chan<- run) { 142 defer workerWG.Done() 143 xCopy := make([]float64, len(x)) 144 for r := range send { 145 if stencil[r.iIdx].Loc == 0 && stencil[r.jIdx].Loc == 0 { 146 originWG.Wait() 147 r.result = originValue 148 } else { 149 // See hessianSerial for comment on the copy. 150 copy(xCopy, x) 151 xCopy[r.i] += stencil[r.iIdx].Loc * step 152 xCopy[r.j] += stencil[r.jIdx].Loc * step 153 r.result = f(xCopy) 154 } 155 ans <- r 156 } 157 }(send, ans) 158 } 159 160 // Launch the distributor, which sends all of runs. 161 go func(send chan<- run) { 162 for i := 0; i < n; i++ { 163 for j := i; j < n; j++ { 164 for iIdx := range stencil { 165 for jIdx := range stencil { 166 send <- run{ 167 i: i, j: j, iIdx: iIdx, jIdx: jIdx, 168 } 169 } 170 } 171 } 172 } 173 close(send) 174 // Wait for all the workers to quit, then close the ans channel. 175 workerWG.Wait() 176 close(ans) 177 }(send) 178 179 is2 := 1 / (step * step) 180 // Read in the results. 181 for r := range ans { 182 v := r.result * stencil[r.iIdx].Coeff * stencil[r.jIdx].Coeff * is2 183 v += dst.At(r.i, r.j) 184 dst.SetSym(r.i, r.j, v) 185 } 186 }