github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/diff/fd/gradient.go (about)

     1  // Copyright ©2017 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package fd
     6  
     7  import "github.com/jingcheng-WU/gonum/floats"
     8  
     9  // Gradient estimates the gradient of the multivariate function f at the
    10  // location x. If dst is not nil, the result will be stored in-place into dst
    11  // and returned, otherwise a new slice will be allocated first. Finite
    12  // difference formula and other options are specified by settings. If settings is
    13  // nil, the gradient will be estimated using the Forward formula and a default
    14  // step size.
    15  //
    16  // Gradient panics if the length of dst and x is not equal, or if the derivative
    17  // order of the formula is not 1.
    18  func Gradient(dst []float64, f func([]float64) float64, x []float64, settings *Settings) []float64 {
    19  	if dst == nil {
    20  		dst = make([]float64, len(x))
    21  	}
    22  	if len(dst) != len(x) {
    23  		panic("fd: slice length mismatch")
    24  	}
    25  
    26  	// Default settings.
    27  	formula := Forward
    28  	step := formula.Step
    29  	var originValue float64
    30  	var originKnown, concurrent bool
    31  
    32  	// Use user settings if provided.
    33  	if settings != nil {
    34  		if !settings.Formula.isZero() {
    35  			formula = settings.Formula
    36  			step = formula.Step
    37  			checkFormula(formula)
    38  			if formula.Derivative != 1 {
    39  				panic(badDerivOrder)
    40  			}
    41  		}
    42  		if settings.Step != 0 {
    43  			step = settings.Step
    44  		}
    45  		originKnown = settings.OriginKnown
    46  		originValue = settings.OriginValue
    47  		concurrent = settings.Concurrent
    48  	}
    49  
    50  	evals := len(formula.Stencil) * len(x)
    51  	nWorkers := computeWorkers(concurrent, evals)
    52  
    53  	hasOrigin := usesOrigin(formula.Stencil)
    54  	// Copy x in case it is modified during the call.
    55  	xcopy := make([]float64, len(x))
    56  	if hasOrigin && !originKnown {
    57  		copy(xcopy, x)
    58  		originValue = f(xcopy)
    59  	}
    60  
    61  	if nWorkers == 1 {
    62  		for i := range xcopy {
    63  			var deriv float64
    64  			for _, pt := range formula.Stencil {
    65  				if pt.Loc == 0 {
    66  					deriv += pt.Coeff * originValue
    67  					continue
    68  				}
    69  				// Copying the data anew has two benefits. First, it
    70  				// avoids floating point issues where adding and then
    71  				// subtracting the step don't return to the exact same
    72  				// location. Secondly, it protects against the function
    73  				// modifying the input data.
    74  				copy(xcopy, x)
    75  				xcopy[i] += pt.Loc * step
    76  				deriv += pt.Coeff * f(xcopy)
    77  			}
    78  			dst[i] = deriv / step
    79  		}
    80  		return dst
    81  	}
    82  
    83  	sendChan := make(chan fdrun, evals)
    84  	ansChan := make(chan fdrun, evals)
    85  	quit := make(chan struct{})
    86  	defer close(quit)
    87  
    88  	// Launch workers. Workers receive an index and a step, and compute the answer.
    89  	for i := 0; i < nWorkers; i++ {
    90  		go func(sendChan <-chan fdrun, ansChan chan<- fdrun, quit <-chan struct{}) {
    91  			xcopy := make([]float64, len(x))
    92  			for {
    93  				select {
    94  				case <-quit:
    95  					return
    96  				case run := <-sendChan:
    97  					// See above comment on the copy.
    98  					copy(xcopy, x)
    99  					xcopy[run.idx] += run.pt.Loc * step
   100  					run.result = f(xcopy)
   101  					ansChan <- run
   102  				}
   103  			}
   104  		}(sendChan, ansChan, quit)
   105  	}
   106  
   107  	// Launch the distributor. Distributor sends the cases to be computed.
   108  	go func(sendChan chan<- fdrun, ansChan chan<- fdrun) {
   109  		for i := range x {
   110  			for _, pt := range formula.Stencil {
   111  				if pt.Loc == 0 {
   112  					// Answer already known. Send the answer on the answer channel.
   113  					ansChan <- fdrun{
   114  						idx:    i,
   115  						pt:     pt,
   116  						result: originValue,
   117  					}
   118  					continue
   119  				}
   120  				// Answer not known, send the answer to be computed.
   121  				sendChan <- fdrun{
   122  					idx: i,
   123  					pt:  pt,
   124  				}
   125  			}
   126  		}
   127  	}(sendChan, ansChan)
   128  
   129  	for i := range dst {
   130  		dst[i] = 0
   131  	}
   132  	// Read in all of the results.
   133  	for i := 0; i < evals; i++ {
   134  		run := <-ansChan
   135  		dst[run.idx] += run.pt.Coeff * run.result
   136  	}
   137  	floats.Scale(1/step, dst)
   138  	return dst
   139  }
   140  
   141  type fdrun struct {
   142  	idx    int
   143  	pt     Point
   144  	result float64
   145  }