gonum.org/v1/gonum@v0.15.1-0.20240517103525-f853624cb1bb/diff/fd/laplacian.go (about)

     1  // Copyright ©2017 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package fd
     6  
     7  import "sync"
     8  
     9  // Laplacian computes the Laplacian of the multivariate function f at the location
    10  // x. That is, Laplacian returns
    11  //
    12  //	∆ f(x) = ∇ · ∇ f(x) = \sum_i ∂^2 f(x)/∂x_i^2
    13  //
    14  // The finite difference formula and other options are specified by settings.
    15  // The order of the difference formula must be 2 or Laplacian will panic.
    16  func Laplacian(f func(x []float64) float64, x []float64, settings *Settings) float64 {
    17  	n := len(x)
    18  	if n == 0 {
    19  		panic("laplacian: x has zero length")
    20  	}
    21  
    22  	// Default settings.
    23  	formula := Central2nd
    24  	step := formula.Step
    25  	var originValue float64
    26  	var originKnown, concurrent bool
    27  
    28  	// Use user settings if provided.
    29  	if settings != nil {
    30  		if !settings.Formula.isZero() {
    31  			formula = settings.Formula
    32  			step = formula.Step
    33  			checkFormula(formula)
    34  			if formula.Derivative != 2 {
    35  				panic(badDerivOrder)
    36  			}
    37  		}
    38  		if settings.Step != 0 {
    39  			if settings.Step < 0 {
    40  				panic(negativeStep)
    41  			}
    42  			step = settings.Step
    43  		}
    44  		originKnown = settings.OriginKnown
    45  		originValue = settings.OriginValue
    46  		concurrent = settings.Concurrent
    47  	}
    48  
    49  	evals := n * len(formula.Stencil)
    50  	if usesOrigin(formula.Stencil) {
    51  		evals -= n
    52  	}
    53  
    54  	nWorkers := computeWorkers(concurrent, evals)
    55  	if nWorkers == 1 {
    56  		return laplacianSerial(f, x, formula.Stencil, step, originKnown, originValue)
    57  	}
    58  	return laplacianConcurrent(nWorkers, evals, f, x, formula.Stencil, step, originKnown, originValue)
    59  }
    60  
    61  func laplacianSerial(f func(x []float64) float64, x []float64, stencil []Point, step float64, originKnown bool, originValue float64) float64 {
    62  	n := len(x)
    63  	xCopy := make([]float64, n)
    64  	fo := func() float64 {
    65  		// Copy x in case it is modified during the call.
    66  		copy(xCopy, x)
    67  		return f(x)
    68  	}
    69  	is2 := 1 / (step * step)
    70  	origin := getOrigin(originKnown, originValue, fo, stencil)
    71  	var laplacian float64
    72  	for i := 0; i < n; i++ {
    73  		for _, pt := range stencil {
    74  			var v float64
    75  			if pt.Loc == 0 {
    76  				v = origin
    77  			} else {
    78  				// Copying the data anew has two benefits. First, it
    79  				// avoids floating point issues where adding and then
    80  				// subtracting the step don't return to the exact same
    81  				// location. Secondly, it protects against the function
    82  				// modifying the input data.
    83  				copy(xCopy, x)
    84  				xCopy[i] += pt.Loc * step
    85  				v = f(xCopy)
    86  			}
    87  			laplacian += v * pt.Coeff * is2
    88  		}
    89  	}
    90  	return laplacian
    91  }
    92  
    93  func laplacianConcurrent(nWorkers, evals int, f func(x []float64) float64, x []float64, stencil []Point, step float64, originKnown bool, originValue float64) float64 {
    94  	type run struct {
    95  		i      int
    96  		idx    int
    97  		result float64
    98  	}
    99  	n := len(x)
   100  	send := make(chan run, evals)
   101  	ans := make(chan run, evals)
   102  
   103  	var originWG sync.WaitGroup
   104  	hasOrigin := usesOrigin(stencil)
   105  	if hasOrigin {
   106  		originWG.Add(1)
   107  		// Launch worker to compute the origin.
   108  		go func() {
   109  			defer originWG.Done()
   110  			xCopy := make([]float64, len(x))
   111  			copy(xCopy, x)
   112  			originValue = f(xCopy)
   113  		}()
   114  	}
   115  
   116  	var workerWG sync.WaitGroup
   117  	// Launch workers.
   118  	for i := 0; i < nWorkers; i++ {
   119  		workerWG.Add(1)
   120  		go func(send <-chan run, ans chan<- run) {
   121  			defer workerWG.Done()
   122  			xCopy := make([]float64, len(x))
   123  			for r := range send {
   124  				if stencil[r.idx].Loc == 0 {
   125  					originWG.Wait()
   126  					r.result = originValue
   127  				} else {
   128  					// See laplacianSerial for comment on the copy.
   129  					copy(xCopy, x)
   130  					xCopy[r.i] += stencil[r.idx].Loc * step
   131  					r.result = f(xCopy)
   132  				}
   133  				ans <- r
   134  			}
   135  		}(send, ans)
   136  	}
   137  
   138  	// Launch the distributor, which sends all of runs.
   139  	go func(send chan<- run) {
   140  		for i := 0; i < n; i++ {
   141  			for idx := range stencil {
   142  				send <- run{
   143  					i: i, idx: idx,
   144  				}
   145  			}
   146  		}
   147  		close(send)
   148  		// Wait for all the workers to quit, then close the ans channel.
   149  		workerWG.Wait()
   150  		close(ans)
   151  	}(send)
   152  
   153  	// Read in the results.
   154  	is2 := 1 / (step * step)
   155  	var laplacian float64
   156  	for r := range ans {
   157  		laplacian += r.result * stencil[r.idx].Coeff * is2
   158  	}
   159  	return laplacian
   160  }