github.com/gopherd/gonum@v0.0.4/diff/fd/crosslaplacian.go (about)

     1  // Copyright ©2017 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package fd
     6  
     7  import (
     8  	"math"
     9  	"sync"
    10  )
    11  
    12  // CrossLaplacian computes a Laplacian-like quantity for a function of two vectors
    13  // at the locations x and y.
    14  // It computes
    15  //  ∇_y · ∇_x f(x,y) = \sum_i ∂^2 f(x,y)/∂x_i ∂y_i
    16  // The two input vector lengths must be the same.
    17  //
    18  // Finite difference formula and other options are specified by settings. If
    19  // settings is nil, CrossLaplacian will be estimated using the Forward formula and
    20  // a default step size.
    21  //
    22  // CrossLaplacian panics if the two input vectors are not the same length, or if
    23  // the derivative order of the formula is not 1.
    24  func CrossLaplacian(f func(x, y []float64) float64, x, y []float64, settings *Settings) float64 {
    25  	n := len(x)
    26  	if n == 0 {
    27  		panic("crosslaplacian: x has zero length")
    28  	}
    29  	if len(x) != len(y) {
    30  		panic("crosslaplacian: input vector length mismatch")
    31  	}
    32  
    33  	// Default settings.
    34  	formula := Forward
    35  	step := math.Sqrt(formula.Step) // Use the sqrt because taking derivatives of derivatives.
    36  	var originValue float64
    37  	var originKnown, concurrent bool
    38  
    39  	// Use user settings if provided.
    40  	if settings != nil {
    41  		if !settings.Formula.isZero() {
    42  			formula = settings.Formula
    43  			step = math.Sqrt(formula.Step)
    44  			checkFormula(formula)
    45  			if formula.Derivative != 1 {
    46  				panic(badDerivOrder)
    47  			}
    48  		}
    49  		if settings.Step != 0 {
    50  			if settings.Step < 0 {
    51  				panic(negativeStep)
    52  			}
    53  			step = settings.Step
    54  		}
    55  		originKnown = settings.OriginKnown
    56  		originValue = settings.OriginValue
    57  		concurrent = settings.Concurrent
    58  	}
    59  
    60  	evals := n * len(formula.Stencil) * len(formula.Stencil)
    61  	if usesOrigin(formula.Stencil) {
    62  		evals -= n
    63  	}
    64  
    65  	nWorkers := computeWorkers(concurrent, evals)
    66  	if nWorkers == 1 {
    67  		return crossLaplacianSerial(f, x, y, formula.Stencil, step, originKnown, originValue)
    68  	}
    69  	return crossLaplacianConcurrent(nWorkers, evals, f, x, y, formula.Stencil, step, originKnown, originValue)
    70  }
    71  
    72  func crossLaplacianSerial(f func(x, y []float64) float64, x, y []float64, stencil []Point, step float64, originKnown bool, originValue float64) float64 {
    73  	n := len(x)
    74  	xCopy := make([]float64, len(x))
    75  	yCopy := make([]float64, len(y))
    76  	fo := func() float64 {
    77  		// Copy x and y in case they are modified during the call.
    78  		copy(xCopy, x)
    79  		copy(yCopy, y)
    80  		return f(x, y)
    81  	}
    82  	origin := getOrigin(originKnown, originValue, fo, stencil)
    83  
    84  	is2 := 1 / (step * step)
    85  	var laplacian float64
    86  	for i := 0; i < n; i++ {
    87  		for _, pty := range stencil {
    88  			for _, ptx := range stencil {
    89  				var v float64
    90  				if ptx.Loc == 0 && pty.Loc == 0 {
    91  					v = origin
    92  				} else {
    93  					// Copying the data anew has two benefits. First, it
    94  					// avoids floating point issues where adding and then
    95  					// subtracting the step don't return to the exact same
    96  					// location. Secondly, it protects against the function
    97  					// modifying the input data.
    98  					copy(yCopy, y)
    99  					copy(xCopy, x)
   100  					yCopy[i] += pty.Loc * step
   101  					xCopy[i] += ptx.Loc * step
   102  					v = f(xCopy, yCopy)
   103  				}
   104  				laplacian += v * ptx.Coeff * pty.Coeff * is2
   105  			}
   106  		}
   107  	}
   108  	return laplacian
   109  }
   110  
   111  func crossLaplacianConcurrent(nWorkers, evals int, f func(x, y []float64) float64, x, y []float64, stencil []Point, step float64, originKnown bool, originValue float64) float64 {
   112  	n := len(x)
   113  	type run struct {
   114  		i          int
   115  		xIdx, yIdx int
   116  		result     float64
   117  	}
   118  
   119  	send := make(chan run, evals)
   120  	ans := make(chan run, evals)
   121  
   122  	var originWG sync.WaitGroup
   123  	hasOrigin := usesOrigin(stencil)
   124  	if hasOrigin {
   125  		originWG.Add(1)
   126  		// Launch worker to compute the origin.
   127  		go func() {
   128  			defer originWG.Done()
   129  			xCopy := make([]float64, len(x))
   130  			yCopy := make([]float64, len(y))
   131  			copy(xCopy, x)
   132  			copy(yCopy, y)
   133  			originValue = f(xCopy, yCopy)
   134  		}()
   135  	}
   136  
   137  	var workerWG sync.WaitGroup
   138  	// Launch workers.
   139  	for i := 0; i < nWorkers; i++ {
   140  		workerWG.Add(1)
   141  		go func(send <-chan run, ans chan<- run) {
   142  			defer workerWG.Done()
   143  			xCopy := make([]float64, len(x))
   144  			yCopy := make([]float64, len(y))
   145  			for r := range send {
   146  				if stencil[r.xIdx].Loc == 0 && stencil[r.yIdx].Loc == 0 {
   147  					originWG.Wait()
   148  					r.result = originValue
   149  				} else {
   150  					// See crossLaplacianSerial for comment on the copy.
   151  					copy(xCopy, x)
   152  					copy(yCopy, y)
   153  					xCopy[r.i] += stencil[r.xIdx].Loc * step
   154  					yCopy[r.i] += stencil[r.yIdx].Loc * step
   155  					r.result = f(xCopy, yCopy)
   156  				}
   157  				ans <- r
   158  			}
   159  		}(send, ans)
   160  	}
   161  
   162  	// Launch the distributor, which sends all of runs.
   163  	go func(send chan<- run) {
   164  		for i := 0; i < n; i++ {
   165  			for xIdx := range stencil {
   166  				for yIdx := range stencil {
   167  					send <- run{
   168  						i: i, xIdx: xIdx, yIdx: yIdx,
   169  					}
   170  				}
   171  			}
   172  		}
   173  		close(send)
   174  		// Wait for all the workers to quit, then close the ans channel.
   175  		workerWG.Wait()
   176  		close(ans)
   177  	}(send)
   178  
   179  	// Read in the results.
   180  	is2 := 1 / (step * step)
   181  	var laplacian float64
   182  	for r := range ans {
   183  		laplacian += r.result * stencil[r.xIdx].Coeff * stencil[r.yIdx].Coeff * is2
   184  	}
   185  	return laplacian
   186  }