github.com/gopherd/gonum@v0.0.4/diff/fd/gradient.go (about) 1 // Copyright ©2017 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package fd 6 7 import "github.com/gopherd/gonum/floats" 8 9 // Gradient estimates the gradient of the multivariate function f at the 10 // location x. If dst is not nil, the result will be stored in-place into dst 11 // and returned, otherwise a new slice will be allocated first. Finite 12 // difference formula and other options are specified by settings. If settings is 13 // nil, the gradient will be estimated using the Forward formula and a default 14 // step size. 15 // 16 // Gradient panics if the length of dst and x is not equal, or if the derivative 17 // order of the formula is not 1. 18 func Gradient(dst []float64, f func([]float64) float64, x []float64, settings *Settings) []float64 { 19 if dst == nil { 20 dst = make([]float64, len(x)) 21 } 22 if len(dst) != len(x) { 23 panic("fd: slice length mismatch") 24 } 25 26 // Default settings. 27 formula := Forward 28 step := formula.Step 29 var originValue float64 30 var originKnown, concurrent bool 31 32 // Use user settings if provided. 33 if settings != nil { 34 if !settings.Formula.isZero() { 35 formula = settings.Formula 36 step = formula.Step 37 checkFormula(formula) 38 if formula.Derivative != 1 { 39 panic(badDerivOrder) 40 } 41 } 42 if settings.Step != 0 { 43 step = settings.Step 44 } 45 originKnown = settings.OriginKnown 46 originValue = settings.OriginValue 47 concurrent = settings.Concurrent 48 } 49 50 evals := len(formula.Stencil) * len(x) 51 nWorkers := computeWorkers(concurrent, evals) 52 53 hasOrigin := usesOrigin(formula.Stencil) 54 // Copy x in case it is modified during the call. 55 xcopy := make([]float64, len(x)) 56 if hasOrigin && !originKnown { 57 copy(xcopy, x) 58 originValue = f(xcopy) 59 } 60 61 if nWorkers == 1 { 62 for i := range xcopy { 63 var deriv float64 64 for _, pt := range formula.Stencil { 65 if pt.Loc == 0 { 66 deriv += pt.Coeff * originValue 67 continue 68 } 69 // Copying the data anew has two benefits. First, it 70 // avoids floating point issues where adding and then 71 // subtracting the step don't return to the exact same 72 // location. Secondly, it protects against the function 73 // modifying the input data. 74 copy(xcopy, x) 75 xcopy[i] += pt.Loc * step 76 deriv += pt.Coeff * f(xcopy) 77 } 78 dst[i] = deriv / step 79 } 80 return dst 81 } 82 83 sendChan := make(chan fdrun, evals) 84 ansChan := make(chan fdrun, evals) 85 quit := make(chan struct{}) 86 defer close(quit) 87 88 // Launch workers. Workers receive an index and a step, and compute the answer. 89 for i := 0; i < nWorkers; i++ { 90 go func(sendChan <-chan fdrun, ansChan chan<- fdrun, quit <-chan struct{}) { 91 xcopy := make([]float64, len(x)) 92 for { 93 select { 94 case <-quit: 95 return 96 case run := <-sendChan: 97 // See above comment on the copy. 98 copy(xcopy, x) 99 xcopy[run.idx] += run.pt.Loc * step 100 run.result = f(xcopy) 101 ansChan <- run 102 } 103 } 104 }(sendChan, ansChan, quit) 105 } 106 107 // Launch the distributor. Distributor sends the cases to be computed. 108 go func(sendChan chan<- fdrun, ansChan chan<- fdrun) { 109 for i := range x { 110 for _, pt := range formula.Stencil { 111 if pt.Loc == 0 { 112 // Answer already known. Send the answer on the answer channel. 113 ansChan <- fdrun{ 114 idx: i, 115 pt: pt, 116 result: originValue, 117 } 118 continue 119 } 120 // Answer not known, send the answer to be computed. 121 sendChan <- fdrun{ 122 idx: i, 123 pt: pt, 124 } 125 } 126 } 127 }(sendChan, ansChan) 128 129 for i := range dst { 130 dst[i] = 0 131 } 132 // Read in all of the results. 133 for i := 0; i < evals; i++ { 134 run := <-ansChan 135 dst[run.idx] += run.pt.Coeff * run.result 136 } 137 floats.Scale(1/step, dst) 138 return dst 139 } 140 141 type fdrun struct { 142 idx int 143 pt Point 144 result float64 145 }