github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/diff/fd/simplefunctions_test.go (about)

     1  // Copyright ©2017 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package fd
     6  
     7  import (
     8  	"github.com/jingcheng-WU/gonum/floats"
     9  	"github.com/jingcheng-WU/gonum/mat"
    10  )
    11  
    12  // ConstFunc is a constant function returning the value held by the type.
    13  type ConstFunc float64
    14  
    15  func (c ConstFunc) Func(x []float64) float64 {
    16  	return float64(c)
    17  }
    18  
    19  func (c ConstFunc) Grad(grad, x []float64) {
    20  	for i := range grad {
    21  		grad[i] = 0
    22  	}
    23  }
    24  
    25  func (c ConstFunc) Hess(dst mat.MutableSymmetric, x []float64) {
    26  	n := len(x)
    27  	for i := 0; i < n; i++ {
    28  		for j := i; j < n; j++ {
    29  			dst.SetSym(i, j, 0)
    30  		}
    31  	}
    32  }
    33  
    34  // LinearFunc is a linear function returning w*x+c.
    35  type LinearFunc struct {
    36  	w []float64
    37  	c float64
    38  }
    39  
    40  func (l LinearFunc) Func(x []float64) float64 {
    41  	return floats.Dot(l.w, x) + l.c
    42  }
    43  
    44  func (l LinearFunc) Grad(grad, x []float64) {
    45  	copy(grad, l.w)
    46  }
    47  
    48  func (l LinearFunc) Hess(dst mat.MutableSymmetric, x []float64) {
    49  	n := len(x)
    50  	for i := 0; i < n; i++ {
    51  		for j := i; j < n; j++ {
    52  			dst.SetSym(i, j, 0)
    53  		}
    54  	}
    55  }
    56  
    57  // QuadFunc is a quadratic function returning 0.5*x'*a*x + b*x + c.
    58  type QuadFunc struct {
    59  	a *mat.SymDense
    60  	b *mat.VecDense
    61  	c float64
    62  }
    63  
    64  func (q QuadFunc) Func(x []float64) float64 {
    65  	v := mat.NewVecDense(len(x), x)
    66  	var tmp mat.VecDense
    67  	tmp.MulVec(q.a, v)
    68  	return 0.5*mat.Dot(&tmp, v) + mat.Dot(q.b, v) + q.c
    69  }
    70  
    71  func (q QuadFunc) Grad(grad, x []float64) {
    72  	var tmp mat.VecDense
    73  	v := mat.NewVecDense(len(x), x)
    74  	tmp.MulVec(q.a, v)
    75  	for i := range grad {
    76  		grad[i] = tmp.At(i, 0) + q.b.At(i, 0)
    77  	}
    78  }
    79  
    80  func (q QuadFunc) Hess(dst mat.MutableSymmetric, x []float64) {
    81  	n := len(x)
    82  	for i := 0; i < n; i++ {
    83  		for j := i; j < n; j++ {
    84  			dst.SetSym(i, j, q.a.At(i, j))
    85  		}
    86  	}
    87  }