gonum.org/v1/gonum@v0.14.0/mat/inner.go (about)

     1  // Copyright ©2014 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package mat
     6  
     7  import (
     8  	"gonum.org/v1/gonum/blas"
     9  	"gonum.org/v1/gonum/blas/blas64"
    10  	"gonum.org/v1/gonum/internal/asm/f64"
    11  )
    12  
    13  // Inner computes the generalized inner product
    14  //
    15  //	xᵀ A y
    16  //
    17  // between the vectors x and y with matrix A, where x and y are treated as
    18  // column vectors.
    19  //
    20  // This is only a true inner product if A is symmetric positive definite, though
    21  // the operation works for any matrix A.
    22  //
    23  // Inner panics if x.Len != m or y.Len != n when A is an m x n matrix.
    24  func Inner(x Vector, a Matrix, y Vector) float64 {
    25  	m, n := a.Dims()
    26  	if x.Len() != m {
    27  		panic(ErrShape)
    28  	}
    29  	if y.Len() != n {
    30  		panic(ErrShape)
    31  	}
    32  	if m == 0 || n == 0 {
    33  		return 0
    34  	}
    35  
    36  	var sum float64
    37  
    38  	switch a := a.(type) {
    39  	case RawSymmetricer:
    40  		amat := a.RawSymmetric()
    41  		if amat.Uplo != blas.Upper {
    42  			// Panic as a string not a mat.Error.
    43  			panic(badSymTriangle)
    44  		}
    45  		var xmat, ymat blas64.Vector
    46  		if xrv, ok := x.(RawVectorer); ok {
    47  			xmat = xrv.RawVector()
    48  		} else {
    49  			break
    50  		}
    51  		if yrv, ok := y.(RawVectorer); ok {
    52  			ymat = yrv.RawVector()
    53  		} else {
    54  			break
    55  		}
    56  		for i := 0; i < x.Len(); i++ {
    57  			xi := x.AtVec(i)
    58  			if xi != 0 {
    59  				if ymat.Inc == 1 {
    60  					sum += xi * f64.DotUnitary(
    61  						amat.Data[i*amat.Stride+i:i*amat.Stride+n],
    62  						ymat.Data[i:],
    63  					)
    64  				} else {
    65  					sum += xi * f64.DotInc(
    66  						amat.Data[i*amat.Stride+i:i*amat.Stride+n],
    67  						ymat.Data[i*ymat.Inc:], uintptr(n-i),
    68  						1, uintptr(ymat.Inc),
    69  						0, 0,
    70  					)
    71  				}
    72  			}
    73  			yi := y.AtVec(i)
    74  			if i != n-1 && yi != 0 {
    75  				if xmat.Inc == 1 {
    76  					sum += yi * f64.DotUnitary(
    77  						amat.Data[i*amat.Stride+i+1:i*amat.Stride+n],
    78  						xmat.Data[i+1:],
    79  					)
    80  				} else {
    81  					sum += yi * f64.DotInc(
    82  						amat.Data[i*amat.Stride+i+1:i*amat.Stride+n],
    83  						xmat.Data[(i+1)*xmat.Inc:], uintptr(n-i-1),
    84  						1, uintptr(xmat.Inc),
    85  						0, 0,
    86  					)
    87  				}
    88  			}
    89  		}
    90  		return sum
    91  	case RawMatrixer:
    92  		amat := a.RawMatrix()
    93  		var ymat blas64.Vector
    94  		if yrv, ok := y.(RawVectorer); ok {
    95  			ymat = yrv.RawVector()
    96  		} else {
    97  			break
    98  		}
    99  		for i := 0; i < x.Len(); i++ {
   100  			xi := x.AtVec(i)
   101  			if xi != 0 {
   102  				if ymat.Inc == 1 {
   103  					sum += xi * f64.DotUnitary(
   104  						amat.Data[i*amat.Stride:i*amat.Stride+n],
   105  						ymat.Data,
   106  					)
   107  				} else {
   108  					sum += xi * f64.DotInc(
   109  						amat.Data[i*amat.Stride:i*amat.Stride+n],
   110  						ymat.Data, uintptr(n),
   111  						1, uintptr(ymat.Inc),
   112  						0, 0,
   113  					)
   114  				}
   115  			}
   116  		}
   117  		return sum
   118  	}
   119  	for i := 0; i < x.Len(); i++ {
   120  		xi := x.AtVec(i)
   121  		for j := 0; j < y.Len(); j++ {
   122  			sum += xi * a.At(i, j) * y.AtVec(j)
   123  		}
   124  	}
   125  	return sum
   126  }