github.com/gopherd/gonum@v0.0.4/diff/fd/simplefunctions_test.go (about) 1 // Copyright ©2017 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package fd 6 7 import ( 8 "github.com/gopherd/gonum/floats" 9 "github.com/gopherd/gonum/mat" 10 ) 11 12 // ConstFunc is a constant function returning the value held by the type. 13 type ConstFunc float64 14 15 func (c ConstFunc) Func(x []float64) float64 { 16 return float64(c) 17 } 18 19 func (c ConstFunc) Grad(grad, x []float64) { 20 for i := range grad { 21 grad[i] = 0 22 } 23 } 24 25 func (c ConstFunc) Hess(dst mat.MutableSymmetric, x []float64) { 26 n := len(x) 27 for i := 0; i < n; i++ { 28 for j := i; j < n; j++ { 29 dst.SetSym(i, j, 0) 30 } 31 } 32 } 33 34 // LinearFunc is a linear function returning w*x+c. 35 type LinearFunc struct { 36 w []float64 37 c float64 38 } 39 40 func (l LinearFunc) Func(x []float64) float64 { 41 return floats.Dot(l.w, x) + l.c 42 } 43 44 func (l LinearFunc) Grad(grad, x []float64) { 45 copy(grad, l.w) 46 } 47 48 func (l LinearFunc) Hess(dst mat.MutableSymmetric, x []float64) { 49 n := len(x) 50 for i := 0; i < n; i++ { 51 for j := i; j < n; j++ { 52 dst.SetSym(i, j, 0) 53 } 54 } 55 } 56 57 // QuadFunc is a quadratic function returning 0.5*x'*a*x + b*x + c. 58 type QuadFunc struct { 59 a *mat.SymDense 60 b *mat.VecDense 61 c float64 62 } 63 64 func (q QuadFunc) Func(x []float64) float64 { 65 v := mat.NewVecDense(len(x), x) 66 var tmp mat.VecDense 67 tmp.MulVec(q.a, v) 68 return 0.5*mat.Dot(&tmp, v) + mat.Dot(q.b, v) + q.c 69 } 70 71 func (q QuadFunc) Grad(grad, x []float64) { 72 var tmp mat.VecDense 73 v := mat.NewVecDense(len(x), x) 74 tmp.MulVec(q.a, v) 75 for i := range grad { 76 grad[i] = tmp.At(i, 0) + q.b.At(i, 0) 77 } 78 } 79 80 func (q QuadFunc) Hess(dst mat.MutableSymmetric, x []float64) { 81 n := len(x) 82 for i := 0; i < n; i++ { 83 for j := i; j < n; j++ { 84 dst.SetSym(i, j, q.a.At(i, j)) 85 } 86 } 87 }