github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/diff/fd/laplacian.go (about)

     1  // Copyright ©2017 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package fd
     6  
     7  import "sync"
     8  
     9  // Laplacian computes the Laplacian of the multivariate function f at the location
    10  // x. That is, Laplacian returns
    11  //  ∆ f(x) = ∇ · ∇ f(x) = \sum_i ∂^2 f(x)/∂x_i^2
    12  // The finite difference formula and other options are specified by settings.
    13  // The order of the difference formula must be 2 or Laplacian will panic.
    14  func Laplacian(f func(x []float64) float64, x []float64, settings *Settings) float64 {
    15  	n := len(x)
    16  	if n == 0 {
    17  		panic("laplacian: x has zero length")
    18  	}
    19  
    20  	// Default settings.
    21  	formula := Central2nd
    22  	step := formula.Step
    23  	var originValue float64
    24  	var originKnown, concurrent bool
    25  
    26  	// Use user settings if provided.
    27  	if settings != nil {
    28  		if !settings.Formula.isZero() {
    29  			formula = settings.Formula
    30  			step = formula.Step
    31  			checkFormula(formula)
    32  			if formula.Derivative != 2 {
    33  				panic(badDerivOrder)
    34  			}
    35  		}
    36  		if settings.Step != 0 {
    37  			if settings.Step < 0 {
    38  				panic(negativeStep)
    39  			}
    40  			step = settings.Step
    41  		}
    42  		originKnown = settings.OriginKnown
    43  		originValue = settings.OriginValue
    44  		concurrent = settings.Concurrent
    45  	}
    46  
    47  	evals := n * len(formula.Stencil)
    48  	if usesOrigin(formula.Stencil) {
    49  		evals -= n
    50  	}
    51  
    52  	nWorkers := computeWorkers(concurrent, evals)
    53  	if nWorkers == 1 {
    54  		return laplacianSerial(f, x, formula.Stencil, step, originKnown, originValue)
    55  	}
    56  	return laplacianConcurrent(nWorkers, evals, f, x, formula.Stencil, step, originKnown, originValue)
    57  }
    58  
    59  func laplacianSerial(f func(x []float64) float64, x []float64, stencil []Point, step float64, originKnown bool, originValue float64) float64 {
    60  	n := len(x)
    61  	xCopy := make([]float64, n)
    62  	fo := func() float64 {
    63  		// Copy x in case it is modified during the call.
    64  		copy(xCopy, x)
    65  		return f(x)
    66  	}
    67  	is2 := 1 / (step * step)
    68  	origin := getOrigin(originKnown, originValue, fo, stencil)
    69  	var laplacian float64
    70  	for i := 0; i < n; i++ {
    71  		for _, pt := range stencil {
    72  			var v float64
    73  			if pt.Loc == 0 {
    74  				v = origin
    75  			} else {
    76  				// Copying the data anew has two benefits. First, it
    77  				// avoids floating point issues where adding and then
    78  				// subtracting the step don't return to the exact same
    79  				// location. Secondly, it protects against the function
    80  				// modifying the input data.
    81  				copy(xCopy, x)
    82  				xCopy[i] += pt.Loc * step
    83  				v = f(xCopy)
    84  			}
    85  			laplacian += v * pt.Coeff * is2
    86  		}
    87  	}
    88  	return laplacian
    89  }
    90  
    91  func laplacianConcurrent(nWorkers, evals int, f func(x []float64) float64, x []float64, stencil []Point, step float64, originKnown bool, originValue float64) float64 {
    92  	type run struct {
    93  		i      int
    94  		idx    int
    95  		result float64
    96  	}
    97  	n := len(x)
    98  	send := make(chan run, evals)
    99  	ans := make(chan run, evals)
   100  
   101  	var originWG sync.WaitGroup
   102  	hasOrigin := usesOrigin(stencil)
   103  	if hasOrigin {
   104  		originWG.Add(1)
   105  		// Launch worker to compute the origin.
   106  		go func() {
   107  			defer originWG.Done()
   108  			xCopy := make([]float64, len(x))
   109  			copy(xCopy, x)
   110  			originValue = f(xCopy)
   111  		}()
   112  	}
   113  
   114  	var workerWG sync.WaitGroup
   115  	// Launch workers.
   116  	for i := 0; i < nWorkers; i++ {
   117  		workerWG.Add(1)
   118  		go func(send <-chan run, ans chan<- run) {
   119  			defer workerWG.Done()
   120  			xCopy := make([]float64, len(x))
   121  			for r := range send {
   122  				if stencil[r.idx].Loc == 0 {
   123  					originWG.Wait()
   124  					r.result = originValue
   125  				} else {
   126  					// See laplacianSerial for comment on the copy.
   127  					copy(xCopy, x)
   128  					xCopy[r.i] += stencil[r.idx].Loc * step
   129  					r.result = f(xCopy)
   130  				}
   131  				ans <- r
   132  			}
   133  		}(send, ans)
   134  	}
   135  
   136  	// Launch the distributor, which sends all of runs.
   137  	go func(send chan<- run) {
   138  		for i := 0; i < n; i++ {
   139  			for idx := range stencil {
   140  				send <- run{
   141  					i: i, idx: idx,
   142  				}
   143  			}
   144  		}
   145  		close(send)
   146  		// Wait for all the workers to quit, then close the ans channel.
   147  		workerWG.Wait()
   148  		close(ans)
   149  	}(send)
   150  
   151  	// Read in the results.
   152  	is2 := 1 / (step * step)
   153  	var laplacian float64
   154  	for r := range ans {
   155  		laplacian += r.result * stencil[r.idx].Coeff * is2
   156  	}
   157  	return laplacian
   158  }