github.com/gopherd/gonum@v0.0.4/mat/inner.go (about) 1 // Copyright ©2014 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package mat 6 7 import ( 8 "github.com/gopherd/gonum/blas" 9 "github.com/gopherd/gonum/blas/blas64" 10 "github.com/gopherd/gonum/internal/asm/f64" 11 ) 12 13 // Inner computes the generalized inner product 14 // xᵀ A y 15 // between the vectors x and y with matrix A, where x and y are treated as 16 // column vectors. 17 // 18 // This is only a true inner product if A is symmetric positive definite, though 19 // the operation works for any matrix A. 20 // 21 // Inner panics if x.Len != m or y.Len != n when A is an m x n matrix. 22 func Inner(x Vector, a Matrix, y Vector) float64 { 23 m, n := a.Dims() 24 if x.Len() != m { 25 panic(ErrShape) 26 } 27 if y.Len() != n { 28 panic(ErrShape) 29 } 30 if m == 0 || n == 0 { 31 return 0 32 } 33 34 var sum float64 35 36 switch a := a.(type) { 37 case RawSymmetricer: 38 amat := a.RawSymmetric() 39 if amat.Uplo != blas.Upper { 40 // Panic as a string not a mat.Error. 41 panic(badSymTriangle) 42 } 43 var xmat, ymat blas64.Vector 44 if xrv, ok := x.(RawVectorer); ok { 45 xmat = xrv.RawVector() 46 } else { 47 break 48 } 49 if yrv, ok := y.(RawVectorer); ok { 50 ymat = yrv.RawVector() 51 } else { 52 break 53 } 54 for i := 0; i < x.Len(); i++ { 55 xi := x.AtVec(i) 56 if xi != 0 { 57 if ymat.Inc == 1 { 58 sum += xi * f64.DotUnitary( 59 amat.Data[i*amat.Stride+i:i*amat.Stride+n], 60 ymat.Data[i:], 61 ) 62 } else { 63 sum += xi * f64.DotInc( 64 amat.Data[i*amat.Stride+i:i*amat.Stride+n], 65 ymat.Data[i*ymat.Inc:], uintptr(n-i), 66 1, uintptr(ymat.Inc), 67 0, 0, 68 ) 69 } 70 } 71 yi := y.AtVec(i) 72 if i != n-1 && yi != 0 { 73 if xmat.Inc == 1 { 74 sum += yi * f64.DotUnitary( 75 amat.Data[i*amat.Stride+i+1:i*amat.Stride+n], 76 xmat.Data[i+1:], 77 ) 78 } else { 79 sum += yi * f64.DotInc( 80 amat.Data[i*amat.Stride+i+1:i*amat.Stride+n], 81 xmat.Data[(i+1)*xmat.Inc:], uintptr(n-i-1), 82 1, uintptr(xmat.Inc), 83 0, 0, 84 ) 85 } 86 } 87 } 88 return sum 89 case RawMatrixer: 90 amat := a.RawMatrix() 91 var ymat blas64.Vector 92 if yrv, ok := y.(RawVectorer); ok { 93 ymat = yrv.RawVector() 94 } else { 95 break 96 } 97 for i := 0; i < x.Len(); i++ { 98 xi := x.AtVec(i) 99 if xi != 0 { 100 if ymat.Inc == 1 { 101 sum += xi * f64.DotUnitary( 102 amat.Data[i*amat.Stride:i*amat.Stride+n], 103 ymat.Data, 104 ) 105 } else { 106 sum += xi * f64.DotInc( 107 amat.Data[i*amat.Stride:i*amat.Stride+n], 108 ymat.Data, uintptr(n), 109 1, uintptr(ymat.Inc), 110 0, 0, 111 ) 112 } 113 } 114 } 115 return sum 116 } 117 for i := 0; i < x.Len(); i++ { 118 xi := x.AtVec(i) 119 for j := 0; j < y.Len(); j++ { 120 sum += xi * a.At(i, j) * y.AtVec(j) 121 } 122 } 123 return sum 124 }