gonum.org/v1/gonum@v0.15.1-0.20240517103525-f853624cb1bb/diff/fd/hessian.go (about)

     1  // Copyright ©2017 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package fd
     6  
     7  import (
     8  	"math"
     9  	"sync"
    10  
    11  	"gonum.org/v1/gonum/mat"
    12  )
    13  
    14  // Hessian approximates the Hessian matrix of the multivariate function f at
    15  // the location x. That is
    16  //
    17  //	H_{i,j} = ∂^2 f(x)/∂x_i ∂x_j
    18  //
    19  // The resulting H will be stored in dst. Finite difference formula and other
    20  // options are specified by settings. If settings is nil, the Hessian will be
    21  // estimated using the Forward formula and a default step size.
    22  //
    23  // If the dst matrix is empty it will be resized to the correct dimensions,
    24  // otherwise the dimensions of dst must match the length of x or Hessian will panic.
    25  // Hessian will panic if the derivative order of the formula is not 1.
    26  func Hessian(dst *mat.SymDense, f func(x []float64) float64, x []float64, settings *Settings) {
    27  	n := len(x)
    28  	if dst.IsEmpty() {
    29  		*dst = *(dst.GrowSym(n).(*mat.SymDense))
    30  	} else if dst.SymmetricDim() != n {
    31  		panic("hessian: dst size mismatch")
    32  	}
    33  	dst.Zero()
    34  
    35  	// Default settings.
    36  	formula := Forward
    37  	step := math.Sqrt(formula.Step) // Use the sqrt because taking derivatives of derivatives.
    38  	var originValue float64
    39  	var originKnown, concurrent bool
    40  
    41  	// Use user settings if provided.
    42  	if settings != nil {
    43  		if !settings.Formula.isZero() {
    44  			formula = settings.Formula
    45  			step = math.Sqrt(formula.Step)
    46  			checkFormula(formula)
    47  			if formula.Derivative != 1 {
    48  				panic(badDerivOrder)
    49  			}
    50  		}
    51  		if settings.Step != 0 {
    52  			if settings.Step < 0 {
    53  				panic(negativeStep)
    54  			}
    55  			step = settings.Step
    56  		}
    57  		originKnown = settings.OriginKnown
    58  		originValue = settings.OriginValue
    59  		concurrent = settings.Concurrent
    60  	}
    61  
    62  	evals := n * (n + 1) / 2 * len(formula.Stencil) * len(formula.Stencil)
    63  	for _, pt := range formula.Stencil {
    64  		if pt.Loc == 0 {
    65  			evals -= n * (n + 1) / 2
    66  			break
    67  		}
    68  	}
    69  
    70  	nWorkers := computeWorkers(concurrent, evals)
    71  	if nWorkers == 1 {
    72  		hessianSerial(dst, f, x, formula.Stencil, step, originKnown, originValue)
    73  		return
    74  	}
    75  	hessianConcurrent(dst, nWorkers, evals, f, x, formula.Stencil, step, originKnown, originValue)
    76  }
    77  
    78  func hessianSerial(dst *mat.SymDense, f func(x []float64) float64, x []float64, stencil []Point, step float64, originKnown bool, originValue float64) {
    79  	n := len(x)
    80  	xCopy := make([]float64, n)
    81  	fo := func() float64 {
    82  		// Copy x in case it is modified during the call.
    83  		copy(xCopy, x)
    84  		return f(x)
    85  	}
    86  	is2 := 1 / (step * step)
    87  	origin := getOrigin(originKnown, originValue, fo, stencil)
    88  	for i := 0; i < n; i++ {
    89  		for j := i; j < n; j++ {
    90  			var hess float64
    91  			for _, pti := range stencil {
    92  				for _, ptj := range stencil {
    93  					var v float64
    94  					if pti.Loc == 0 && ptj.Loc == 0 {
    95  						v = origin
    96  					} else {
    97  						// Copying the data anew has two benefits. First, it
    98  						// avoids floating point issues where adding and then
    99  						// subtracting the step don't return to the exact same
   100  						// location. Secondly, it protects against the function
   101  						// modifying the input data.
   102  						copy(xCopy, x)
   103  						xCopy[i] += pti.Loc * step
   104  						xCopy[j] += ptj.Loc * step
   105  						v = f(xCopy)
   106  					}
   107  					hess += v * pti.Coeff * ptj.Coeff * is2
   108  				}
   109  			}
   110  			dst.SetSym(i, j, hess)
   111  		}
   112  	}
   113  }
   114  
   115  func hessianConcurrent(dst *mat.SymDense, nWorkers, evals int, f func(x []float64) float64, x []float64, stencil []Point, step float64, originKnown bool, originValue float64) {
   116  	n := dst.SymmetricDim()
   117  	type run struct {
   118  		i, j       int
   119  		iIdx, jIdx int
   120  		result     float64
   121  	}
   122  
   123  	send := make(chan run, evals)
   124  	ans := make(chan run, evals)
   125  
   126  	var originWG sync.WaitGroup
   127  	hasOrigin := usesOrigin(stencil)
   128  	if hasOrigin {
   129  		originWG.Add(1)
   130  		// Launch worker to compute the origin.
   131  		go func() {
   132  			defer originWG.Done()
   133  			xCopy := make([]float64, len(x))
   134  			copy(xCopy, x)
   135  			originValue = f(xCopy)
   136  		}()
   137  	}
   138  
   139  	var workerWG sync.WaitGroup
   140  	// Launch workers.
   141  	for i := 0; i < nWorkers; i++ {
   142  		workerWG.Add(1)
   143  		go func(send <-chan run, ans chan<- run) {
   144  			defer workerWG.Done()
   145  			xCopy := make([]float64, len(x))
   146  			for r := range send {
   147  				if stencil[r.iIdx].Loc == 0 && stencil[r.jIdx].Loc == 0 {
   148  					originWG.Wait()
   149  					r.result = originValue
   150  				} else {
   151  					// See hessianSerial for comment on the copy.
   152  					copy(xCopy, x)
   153  					xCopy[r.i] += stencil[r.iIdx].Loc * step
   154  					xCopy[r.j] += stencil[r.jIdx].Loc * step
   155  					r.result = f(xCopy)
   156  				}
   157  				ans <- r
   158  			}
   159  		}(send, ans)
   160  	}
   161  
   162  	// Launch the distributor, which sends all of runs.
   163  	go func(send chan<- run) {
   164  		for i := 0; i < n; i++ {
   165  			for j := i; j < n; j++ {
   166  				for iIdx := range stencil {
   167  					for jIdx := range stencil {
   168  						send <- run{
   169  							i: i, j: j, iIdx: iIdx, jIdx: jIdx,
   170  						}
   171  					}
   172  				}
   173  			}
   174  		}
   175  		close(send)
   176  		// Wait for all the workers to quit, then close the ans channel.
   177  		workerWG.Wait()
   178  		close(ans)
   179  	}(send)
   180  
   181  	is2 := 1 / (step * step)
   182  	// Read in the results.
   183  	for r := range ans {
   184  		v := r.result * stencil[r.iIdx].Coeff * stencil[r.jIdx].Coeff * is2
   185  		v += dst.At(r.i, r.j)
   186  		dst.SetSym(r.i, r.j, v)
   187  	}
   188  }