gonum.org/v1/gonum@v0.15.1-0.20240517103525-f853624cb1bb/diff/fd/crosslaplacian.go (about)

     1  // Copyright ©2017 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package fd
     6  
     7  import (
     8  	"math"
     9  	"sync"
    10  )
    11  
    12  // CrossLaplacian computes a Laplacian-like quantity for a function of two vectors
    13  // at the locations x and y.
    14  // It computes
    15  //
    16  //	∇_y · ∇_x f(x,y) = \sum_i ∂^2 f(x,y)/∂x_i ∂y_i
    17  //
    18  // The two input vector lengths must be the same.
    19  //
    20  // Finite difference formula and other options are specified by settings. If
    21  // settings is nil, CrossLaplacian will be estimated using the Forward formula and
    22  // a default step size.
    23  //
    24  // CrossLaplacian panics if the two input vectors are not the same length, or if
    25  // the derivative order of the formula is not 1.
    26  func CrossLaplacian(f func(x, y []float64) float64, x, y []float64, settings *Settings) float64 {
    27  	n := len(x)
    28  	if n == 0 {
    29  		panic("crosslaplacian: x has zero length")
    30  	}
    31  	if len(x) != len(y) {
    32  		panic("crosslaplacian: input vector length mismatch")
    33  	}
    34  
    35  	// Default settings.
    36  	formula := Forward
    37  	step := math.Sqrt(formula.Step) // Use the sqrt because taking derivatives of derivatives.
    38  	var originValue float64
    39  	var originKnown, concurrent bool
    40  
    41  	// Use user settings if provided.
    42  	if settings != nil {
    43  		if !settings.Formula.isZero() {
    44  			formula = settings.Formula
    45  			step = math.Sqrt(formula.Step)
    46  			checkFormula(formula)
    47  			if formula.Derivative != 1 {
    48  				panic(badDerivOrder)
    49  			}
    50  		}
    51  		if settings.Step != 0 {
    52  			if settings.Step < 0 {
    53  				panic(negativeStep)
    54  			}
    55  			step = settings.Step
    56  		}
    57  		originKnown = settings.OriginKnown
    58  		originValue = settings.OriginValue
    59  		concurrent = settings.Concurrent
    60  	}
    61  
    62  	evals := n * len(formula.Stencil) * len(formula.Stencil)
    63  	if usesOrigin(formula.Stencil) {
    64  		evals -= n
    65  	}
    66  
    67  	nWorkers := computeWorkers(concurrent, evals)
    68  	if nWorkers == 1 {
    69  		return crossLaplacianSerial(f, x, y, formula.Stencil, step, originKnown, originValue)
    70  	}
    71  	return crossLaplacianConcurrent(nWorkers, evals, f, x, y, formula.Stencil, step, originKnown, originValue)
    72  }
    73  
    74  func crossLaplacianSerial(f func(x, y []float64) float64, x, y []float64, stencil []Point, step float64, originKnown bool, originValue float64) float64 {
    75  	n := len(x)
    76  	xCopy := make([]float64, len(x))
    77  	yCopy := make([]float64, len(y))
    78  	fo := func() float64 {
    79  		// Copy x and y in case they are modified during the call.
    80  		copy(xCopy, x)
    81  		copy(yCopy, y)
    82  		return f(x, y)
    83  	}
    84  	origin := getOrigin(originKnown, originValue, fo, stencil)
    85  
    86  	is2 := 1 / (step * step)
    87  	var laplacian float64
    88  	for i := 0; i < n; i++ {
    89  		for _, pty := range stencil {
    90  			for _, ptx := range stencil {
    91  				var v float64
    92  				if ptx.Loc == 0 && pty.Loc == 0 {
    93  					v = origin
    94  				} else {
    95  					// Copying the data anew has two benefits. First, it
    96  					// avoids floating point issues where adding and then
    97  					// subtracting the step don't return to the exact same
    98  					// location. Secondly, it protects against the function
    99  					// modifying the input data.
   100  					copy(yCopy, y)
   101  					copy(xCopy, x)
   102  					yCopy[i] += pty.Loc * step
   103  					xCopy[i] += ptx.Loc * step
   104  					v = f(xCopy, yCopy)
   105  				}
   106  				laplacian += v * ptx.Coeff * pty.Coeff * is2
   107  			}
   108  		}
   109  	}
   110  	return laplacian
   111  }
   112  
   113  func crossLaplacianConcurrent(nWorkers, evals int, f func(x, y []float64) float64, x, y []float64, stencil []Point, step float64, originKnown bool, originValue float64) float64 {
   114  	n := len(x)
   115  	type run struct {
   116  		i          int
   117  		xIdx, yIdx int
   118  		result     float64
   119  	}
   120  
   121  	send := make(chan run, evals)
   122  	ans := make(chan run, evals)
   123  
   124  	var originWG sync.WaitGroup
   125  	hasOrigin := usesOrigin(stencil)
   126  	if hasOrigin {
   127  		originWG.Add(1)
   128  		// Launch worker to compute the origin.
   129  		go func() {
   130  			defer originWG.Done()
   131  			xCopy := make([]float64, len(x))
   132  			yCopy := make([]float64, len(y))
   133  			copy(xCopy, x)
   134  			copy(yCopy, y)
   135  			originValue = f(xCopy, yCopy)
   136  		}()
   137  	}
   138  
   139  	var workerWG sync.WaitGroup
   140  	// Launch workers.
   141  	for i := 0; i < nWorkers; i++ {
   142  		workerWG.Add(1)
   143  		go func(send <-chan run, ans chan<- run) {
   144  			defer workerWG.Done()
   145  			xCopy := make([]float64, len(x))
   146  			yCopy := make([]float64, len(y))
   147  			for r := range send {
   148  				if stencil[r.xIdx].Loc == 0 && stencil[r.yIdx].Loc == 0 {
   149  					originWG.Wait()
   150  					r.result = originValue
   151  				} else {
   152  					// See crossLaplacianSerial for comment on the copy.
   153  					copy(xCopy, x)
   154  					copy(yCopy, y)
   155  					xCopy[r.i] += stencil[r.xIdx].Loc * step
   156  					yCopy[r.i] += stencil[r.yIdx].Loc * step
   157  					r.result = f(xCopy, yCopy)
   158  				}
   159  				ans <- r
   160  			}
   161  		}(send, ans)
   162  	}
   163  
   164  	// Launch the distributor, which sends all of runs.
   165  	go func(send chan<- run) {
   166  		for i := 0; i < n; i++ {
   167  			for xIdx := range stencil {
   168  				for yIdx := range stencil {
   169  					send <- run{
   170  						i: i, xIdx: xIdx, yIdx: yIdx,
   171  					}
   172  				}
   173  			}
   174  		}
   175  		close(send)
   176  		// Wait for all the workers to quit, then close the ans channel.
   177  		workerWG.Wait()
   178  		close(ans)
   179  	}(send)
   180  
   181  	// Read in the results.
   182  	is2 := 1 / (step * step)
   183  	var laplacian float64
   184  	for r := range ans {
   185  		laplacian += r.result * stencil[r.xIdx].Coeff * stencil[r.yIdx].Coeff * is2
   186  	}
   187  	return laplacian
   188  }