github.com/gopherd/gonum@v0.0.4/diff/fd/hessian.go (about)

     1  // Copyright ©2017 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package fd
     6  
     7  import (
     8  	"math"
     9  	"sync"
    10  
    11  	"github.com/gopherd/gonum/mat"
    12  )
    13  
    14  // Hessian approximates the Hessian matrix of the multivariate function f at
    15  // the location x. That is
    16  //  H_{i,j} = ∂^2 f(x)/∂x_i ∂x_j
    17  // The resulting H will be stored in dst. Finite difference formula and other
    18  // options are specified by settings. If settings is nil, the Hessian will be
    19  // estimated using the Forward formula and a default step size.
    20  //
    21  // If the dst matrix is empty it will be resized to the correct dimensions,
    22  // otherwise the dimensions of dst must match the length of x or Hessian will panic.
    23  // Hessian will panic if the derivative order of the formula is not 1.
    24  func Hessian(dst *mat.SymDense, f func(x []float64) float64, x []float64, settings *Settings) {
    25  	n := len(x)
    26  	if dst.IsEmpty() {
    27  		*dst = *(dst.GrowSym(n).(*mat.SymDense))
    28  	} else if dst.SymmetricDim() != n {
    29  		panic("hessian: dst size mismatch")
    30  	}
    31  	dst.Zero()
    32  
    33  	// Default settings.
    34  	formula := Forward
    35  	step := math.Sqrt(formula.Step) // Use the sqrt because taking derivatives of derivatives.
    36  	var originValue float64
    37  	var originKnown, concurrent bool
    38  
    39  	// Use user settings if provided.
    40  	if settings != nil {
    41  		if !settings.Formula.isZero() {
    42  			formula = settings.Formula
    43  			step = math.Sqrt(formula.Step)
    44  			checkFormula(formula)
    45  			if formula.Derivative != 1 {
    46  				panic(badDerivOrder)
    47  			}
    48  		}
    49  		if settings.Step != 0 {
    50  			if settings.Step < 0 {
    51  				panic(negativeStep)
    52  			}
    53  			step = settings.Step
    54  		}
    55  		originKnown = settings.OriginKnown
    56  		originValue = settings.OriginValue
    57  		concurrent = settings.Concurrent
    58  	}
    59  
    60  	evals := n * (n + 1) / 2 * len(formula.Stencil) * len(formula.Stencil)
    61  	for _, pt := range formula.Stencil {
    62  		if pt.Loc == 0 {
    63  			evals -= n * (n + 1) / 2
    64  			break
    65  		}
    66  	}
    67  
    68  	nWorkers := computeWorkers(concurrent, evals)
    69  	if nWorkers == 1 {
    70  		hessianSerial(dst, f, x, formula.Stencil, step, originKnown, originValue)
    71  		return
    72  	}
    73  	hessianConcurrent(dst, nWorkers, evals, f, x, formula.Stencil, step, originKnown, originValue)
    74  }
    75  
    76  func hessianSerial(dst *mat.SymDense, f func(x []float64) float64, x []float64, stencil []Point, step float64, originKnown bool, originValue float64) {
    77  	n := len(x)
    78  	xCopy := make([]float64, n)
    79  	fo := func() float64 {
    80  		// Copy x in case it is modified during the call.
    81  		copy(xCopy, x)
    82  		return f(x)
    83  	}
    84  	is2 := 1 / (step * step)
    85  	origin := getOrigin(originKnown, originValue, fo, stencil)
    86  	for i := 0; i < n; i++ {
    87  		for j := i; j < n; j++ {
    88  			var hess float64
    89  			for _, pti := range stencil {
    90  				for _, ptj := range stencil {
    91  					var v float64
    92  					if pti.Loc == 0 && ptj.Loc == 0 {
    93  						v = origin
    94  					} else {
    95  						// Copying the data anew has two benefits. First, it
    96  						// avoids floating point issues where adding and then
    97  						// subtracting the step don't return to the exact same
    98  						// location. Secondly, it protects against the function
    99  						// modifying the input data.
   100  						copy(xCopy, x)
   101  						xCopy[i] += pti.Loc * step
   102  						xCopy[j] += ptj.Loc * step
   103  						v = f(xCopy)
   104  					}
   105  					hess += v * pti.Coeff * ptj.Coeff * is2
   106  				}
   107  			}
   108  			dst.SetSym(i, j, hess)
   109  		}
   110  	}
   111  }
   112  
   113  func hessianConcurrent(dst *mat.SymDense, nWorkers, evals int, f func(x []float64) float64, x []float64, stencil []Point, step float64, originKnown bool, originValue float64) {
   114  	n := dst.SymmetricDim()
   115  	type run struct {
   116  		i, j       int
   117  		iIdx, jIdx int
   118  		result     float64
   119  	}
   120  
   121  	send := make(chan run, evals)
   122  	ans := make(chan run, evals)
   123  
   124  	var originWG sync.WaitGroup
   125  	hasOrigin := usesOrigin(stencil)
   126  	if hasOrigin {
   127  		originWG.Add(1)
   128  		// Launch worker to compute the origin.
   129  		go func() {
   130  			defer originWG.Done()
   131  			xCopy := make([]float64, len(x))
   132  			copy(xCopy, x)
   133  			originValue = f(xCopy)
   134  		}()
   135  	}
   136  
   137  	var workerWG sync.WaitGroup
   138  	// Launch workers.
   139  	for i := 0; i < nWorkers; i++ {
   140  		workerWG.Add(1)
   141  		go func(send <-chan run, ans chan<- run) {
   142  			defer workerWG.Done()
   143  			xCopy := make([]float64, len(x))
   144  			for r := range send {
   145  				if stencil[r.iIdx].Loc == 0 && stencil[r.jIdx].Loc == 0 {
   146  					originWG.Wait()
   147  					r.result = originValue
   148  				} else {
   149  					// See hessianSerial for comment on the copy.
   150  					copy(xCopy, x)
   151  					xCopy[r.i] += stencil[r.iIdx].Loc * step
   152  					xCopy[r.j] += stencil[r.jIdx].Loc * step
   153  					r.result = f(xCopy)
   154  				}
   155  				ans <- r
   156  			}
   157  		}(send, ans)
   158  	}
   159  
   160  	// Launch the distributor, which sends all of runs.
   161  	go func(send chan<- run) {
   162  		for i := 0; i < n; i++ {
   163  			for j := i; j < n; j++ {
   164  				for iIdx := range stencil {
   165  					for jIdx := range stencil {
   166  						send <- run{
   167  							i: i, j: j, iIdx: iIdx, jIdx: jIdx,
   168  						}
   169  					}
   170  				}
   171  			}
   172  		}
   173  		close(send)
   174  		// Wait for all the workers to quit, then close the ans channel.
   175  		workerWG.Wait()
   176  		close(ans)
   177  	}(send)
   178  
   179  	is2 := 1 / (step * step)
   180  	// Read in the results.
   181  	for r := range ans {
   182  		v := r.result * stencil[r.iIdx].Coeff * stencil[r.jIdx].Coeff * is2
   183  		v += dst.At(r.i, r.j)
   184  		dst.SetSym(r.i, r.j, v)
   185  	}
   186  }