github.com/gopherd/gonum@v0.0.4/mat/inner.go (about)

     1  // Copyright ©2014 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package mat
     6  
     7  import (
     8  	"github.com/gopherd/gonum/blas"
     9  	"github.com/gopherd/gonum/blas/blas64"
    10  	"github.com/gopherd/gonum/internal/asm/f64"
    11  )
    12  
    13  // Inner computes the generalized inner product
    14  //  xᵀ A y
    15  // between the vectors x and y with matrix A, where x and y are treated as
    16  // column vectors.
    17  //
    18  // This is only a true inner product if A is symmetric positive definite, though
    19  // the operation works for any matrix A.
    20  //
    21  // Inner panics if x.Len != m or y.Len != n when A is an m x n matrix.
    22  func Inner(x Vector, a Matrix, y Vector) float64 {
    23  	m, n := a.Dims()
    24  	if x.Len() != m {
    25  		panic(ErrShape)
    26  	}
    27  	if y.Len() != n {
    28  		panic(ErrShape)
    29  	}
    30  	if m == 0 || n == 0 {
    31  		return 0
    32  	}
    33  
    34  	var sum float64
    35  
    36  	switch a := a.(type) {
    37  	case RawSymmetricer:
    38  		amat := a.RawSymmetric()
    39  		if amat.Uplo != blas.Upper {
    40  			// Panic as a string not a mat.Error.
    41  			panic(badSymTriangle)
    42  		}
    43  		var xmat, ymat blas64.Vector
    44  		if xrv, ok := x.(RawVectorer); ok {
    45  			xmat = xrv.RawVector()
    46  		} else {
    47  			break
    48  		}
    49  		if yrv, ok := y.(RawVectorer); ok {
    50  			ymat = yrv.RawVector()
    51  		} else {
    52  			break
    53  		}
    54  		for i := 0; i < x.Len(); i++ {
    55  			xi := x.AtVec(i)
    56  			if xi != 0 {
    57  				if ymat.Inc == 1 {
    58  					sum += xi * f64.DotUnitary(
    59  						amat.Data[i*amat.Stride+i:i*amat.Stride+n],
    60  						ymat.Data[i:],
    61  					)
    62  				} else {
    63  					sum += xi * f64.DotInc(
    64  						amat.Data[i*amat.Stride+i:i*amat.Stride+n],
    65  						ymat.Data[i*ymat.Inc:], uintptr(n-i),
    66  						1, uintptr(ymat.Inc),
    67  						0, 0,
    68  					)
    69  				}
    70  			}
    71  			yi := y.AtVec(i)
    72  			if i != n-1 && yi != 0 {
    73  				if xmat.Inc == 1 {
    74  					sum += yi * f64.DotUnitary(
    75  						amat.Data[i*amat.Stride+i+1:i*amat.Stride+n],
    76  						xmat.Data[i+1:],
    77  					)
    78  				} else {
    79  					sum += yi * f64.DotInc(
    80  						amat.Data[i*amat.Stride+i+1:i*amat.Stride+n],
    81  						xmat.Data[(i+1)*xmat.Inc:], uintptr(n-i-1),
    82  						1, uintptr(xmat.Inc),
    83  						0, 0,
    84  					)
    85  				}
    86  			}
    87  		}
    88  		return sum
    89  	case RawMatrixer:
    90  		amat := a.RawMatrix()
    91  		var ymat blas64.Vector
    92  		if yrv, ok := y.(RawVectorer); ok {
    93  			ymat = yrv.RawVector()
    94  		} else {
    95  			break
    96  		}
    97  		for i := 0; i < x.Len(); i++ {
    98  			xi := x.AtVec(i)
    99  			if xi != 0 {
   100  				if ymat.Inc == 1 {
   101  					sum += xi * f64.DotUnitary(
   102  						amat.Data[i*amat.Stride:i*amat.Stride+n],
   103  						ymat.Data,
   104  					)
   105  				} else {
   106  					sum += xi * f64.DotInc(
   107  						amat.Data[i*amat.Stride:i*amat.Stride+n],
   108  						ymat.Data, uintptr(n),
   109  						1, uintptr(ymat.Inc),
   110  						0, 0,
   111  					)
   112  				}
   113  			}
   114  		}
   115  		return sum
   116  	}
   117  	for i := 0; i < x.Len(); i++ {
   118  		xi := x.AtVec(i)
   119  		for j := 0; j < y.Len(); j++ {
   120  			sum += xi * a.At(i, j) * y.AtVec(j)
   121  		}
   122  	}
   123  	return sum
   124  }