github.com/wzzhu/tensor@v0.9.24/api_arith.go (about)

     1  package tensor
     2  
     3  import (
     4  	"github.com/pkg/errors"
     5  )
     6  
     7  // exported API for arithmetics and the stupidly crazy amount of overloaded semantics
     8  
     9  // Add performs a pointwise a+b. a and b can either be float64 or Tensor
    10  //
    11  // If both operands are Tensor, shape is checked first.
    12  // Even though the underlying data may have the same size (say (2,2) vs (4,1)), if they have different shapes, it will error out.
    13  //
    14  
    15  // Add performs elementwise addition on the Tensor(s). These operations are supported:
    16  //		Add(*Dense, scalar)
    17  //		Add(scalar, *Dense)
    18  //		Add(*Dense, *Dense)
    19  // If the Unsafe flag is passed in, the data of the first tensor will be overwritten
    20  func Add(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) {
    21  	var adder Adder
    22  	var oe standardEngine
    23  	var ok bool
    24  	switch at := a.(type) {
    25  	case Tensor:
    26  		oe = at.standardEngine()
    27  		switch bt := b.(type) {
    28  		case Tensor:
    29  			if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor addition
    30  				if oe != nil {
    31  					return oe.Add(at, bt, opts...)
    32  				}
    33  				if oe = bt.standardEngine(); oe != nil {
    34  					return oe.Add(at, bt, opts...)
    35  				}
    36  				if adder, ok = at.Engine().(Adder); ok {
    37  					return adder.Add(at, bt, opts...)
    38  				}
    39  				if adder, ok = bt.Engine().(Adder); ok {
    40  					return adder.Add(at, bt, opts...)
    41  				}
    42  				return nil, errors.New("Neither engines of either operand support Add")
    43  
    44  			} else { // at least one of the operands is a scalar
    45  				var leftTensor bool
    46  				if !bt.Shape().IsScalar() {
    47  					leftTensor = false // a Scalar-Tensor * b Tensor
    48  					tmp := at
    49  					at = bt
    50  					bt = tmp
    51  				} else {
    52  					leftTensor = true // a Tensor * b Scalar-Tensor
    53  				}
    54  
    55  				if oe != nil {
    56  					return oe.AddScalar(at, bt, leftTensor, opts...)
    57  				}
    58  				if oe = bt.standardEngine(); oe != nil {
    59  					return oe.AddScalar(at, bt, leftTensor, opts...)
    60  				}
    61  				if adder, ok = at.Engine().(Adder); ok {
    62  					return adder.AddScalar(at, bt, leftTensor, opts...)
    63  				}
    64  				if adder, ok = bt.Engine().(Adder); ok {
    65  					return adder.AddScalar(at, bt, leftTensor, opts...)
    66  				}
    67  				return nil, errors.New("Neither engines of either operand support Add")
    68  			}
    69  
    70  		default:
    71  			if oe != nil {
    72  				return oe.AddScalar(at, bt, true, opts...)
    73  			}
    74  			if adder, ok = at.Engine().(Adder); ok {
    75  				return adder.AddScalar(at, bt, true, opts...)
    76  			}
    77  			return nil, errors.New("Operand A's engine does not support Add")
    78  		}
    79  	default:
    80  		switch bt := b.(type) {
    81  		case Tensor:
    82  			if oe = bt.standardEngine(); oe != nil {
    83  				return oe.AddScalar(bt, at, false, opts...)
    84  			}
    85  			if adder, ok = bt.Engine().(Adder); ok {
    86  				return adder.AddScalar(bt, at, false, opts...)
    87  			}
    88  			return nil, errors.New("Operand B's engine does not support Add")
    89  		default:
    90  			return nil, errors.Errorf("Cannot perform Add of %T and %T", a, b)
    91  		}
    92  	}
    93  	panic("Unreachable")
    94  }
    95  
    96  // Sub performs elementwise subtraction on the Tensor(s). These operations are supported:
    97  //		Sub(*Dense, scalar)
    98  //		Sub(scalar, *Dense)
    99  //		Sub(*Dense, *Dense)
   100  // If the Unsafe flag is passed in, the data of the first tensor will be overwritten
   101  func Sub(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) {
   102  	var suber Suber
   103  	var oe standardEngine
   104  	var ok bool
   105  	switch at := a.(type) {
   106  	case Tensor:
   107  		oe = at.standardEngine()
   108  		switch bt := b.(type) {
   109  		case Tensor:
   110  			if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor substraction
   111  				if oe != nil {
   112  					return oe.Sub(at, bt, opts...)
   113  				}
   114  				if oe = bt.standardEngine(); oe != nil {
   115  					return oe.Sub(at, bt, opts...)
   116  				}
   117  				if suber, ok = at.Engine().(Suber); ok {
   118  					return suber.Sub(at, bt, opts...)
   119  				}
   120  				if suber, ok = bt.Engine().(Suber); ok {
   121  					return suber.Sub(at, bt, opts...)
   122  				}
   123  				return nil, errors.New("Neither engines of either operand support Sub")
   124  
   125  			} else { // at least one of the operands is a scalar
   126  				var leftTensor bool
   127  				if !bt.Shape().IsScalar() {
   128  					leftTensor = false // a Scalar-Tensor * b Tensor
   129  					tmp := at
   130  					at = bt
   131  					bt = tmp
   132  				} else {
   133  					leftTensor = true // a Tensor * b Scalar-Tensor
   134  				}
   135  
   136  				if oe != nil {
   137  					return oe.SubScalar(at, bt, leftTensor, opts...)
   138  				}
   139  				if oe = bt.standardEngine(); oe != nil {
   140  					return oe.SubScalar(at, bt, leftTensor, opts...)
   141  				}
   142  				if suber, ok = at.Engine().(Suber); ok {
   143  					return suber.SubScalar(at, bt, leftTensor, opts...)
   144  				}
   145  				if suber, ok = bt.Engine().(Suber); ok {
   146  					return suber.SubScalar(at, bt, leftTensor, opts...)
   147  				}
   148  				return nil, errors.New("Neither engines of either operand support Sub")
   149  			}
   150  
   151  		default:
   152  			if oe != nil {
   153  				return oe.SubScalar(at, bt, true, opts...)
   154  			}
   155  			if suber, ok = at.Engine().(Suber); ok {
   156  				return suber.SubScalar(at, bt, true, opts...)
   157  			}
   158  			return nil, errors.New("Operand A's engine does not support Sub")
   159  		}
   160  	default:
   161  		switch bt := b.(type) {
   162  		case Tensor:
   163  			if oe = bt.standardEngine(); oe != nil {
   164  				return oe.SubScalar(bt, at, false, opts...)
   165  			}
   166  			if suber, ok = bt.Engine().(Suber); ok {
   167  				return suber.SubScalar(bt, at, false, opts...)
   168  			}
   169  			return nil, errors.New("Operand B's engine does not support Sub")
   170  		default:
   171  			return nil, errors.Errorf("Cannot perform Sub of %T and %T", a, b)
   172  		}
   173  	}
   174  	panic("Unreachable")
   175  }
   176  
   177  // Mul performs elementwise multiplication on the Tensor(s). These operations are supported:
   178  //		Mul(*Dense, scalar)
   179  //		Mul(scalar, *Dense)
   180  //		Mul(*Dense, *Dense)
   181  // If the Unsafe flag is passed in, the data of the first tensor will be overwritten
   182  func Mul(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) {
   183  	var muler Muler
   184  	var oe standardEngine
   185  	var ok bool
   186  	switch at := a.(type) {
   187  	case Tensor:
   188  		oe = at.standardEngine()
   189  		switch bt := b.(type) {
   190  		case Tensor:
   191  			if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor multiplication
   192  				if oe != nil {
   193  					return oe.Mul(at, bt, opts...)
   194  				}
   195  				if oe = bt.standardEngine(); oe != nil {
   196  					return oe.Mul(at, bt, opts...)
   197  				}
   198  				if muler, ok = at.Engine().(Muler); ok {
   199  					return muler.Mul(at, bt, opts...)
   200  				}
   201  				if muler, ok = bt.Engine().(Muler); ok {
   202  					return muler.Mul(at, bt, opts...)
   203  				}
   204  				return nil, errors.New("Neither engines of either operand support Mul")
   205  
   206  			} else { // at least one of the operands is a scalar
   207  				var leftTensor bool
   208  				if !bt.Shape().IsScalar() {
   209  					leftTensor = false // a Scalar-Tensor * b Tensor
   210  					tmp := at
   211  					at = bt
   212  					bt = tmp
   213  				} else {
   214  					leftTensor = true // a Tensor * b Scalar-Tensor
   215  				}
   216  
   217  				if oe != nil {
   218  					return oe.MulScalar(at, bt, leftTensor, opts...)
   219  				}
   220  				if oe = bt.standardEngine(); oe != nil {
   221  					return oe.MulScalar(at, bt, leftTensor, opts...)
   222  				}
   223  				if muler, ok = at.Engine().(Muler); ok {
   224  					return muler.MulScalar(at, bt, leftTensor, opts...)
   225  				}
   226  				if muler, ok = bt.Engine().(Muler); ok {
   227  					return muler.MulScalar(at, bt, leftTensor, opts...)
   228  				}
   229  				return nil, errors.New("Neither engines of either operand support Mul")
   230  			}
   231  
   232  		default: // a Tensor * b interface
   233  			if oe != nil {
   234  				return oe.MulScalar(at, bt, true, opts...)
   235  			}
   236  			if muler, ok = at.Engine().(Muler); ok {
   237  				return muler.MulScalar(at, bt, true, opts...)
   238  			}
   239  			return nil, errors.New("Operand A's engine does not support Mul")
   240  		}
   241  
   242  	default:
   243  		switch bt := b.(type) {
   244  		case Tensor: // b Tensor * a interface
   245  			if oe = bt.standardEngine(); oe != nil {
   246  				return oe.MulScalar(bt, at, false, opts...)
   247  			}
   248  			if muler, ok = bt.Engine().(Muler); ok {
   249  				return muler.MulScalar(bt, at, false, opts...)
   250  			}
   251  			return nil, errors.New("Operand B's engine does not support Mul")
   252  
   253  		default: // b interface * a interface
   254  			return nil, errors.Errorf("Cannot perform Mul of %T and %T", a, b)
   255  		}
   256  	}
   257  	panic("Unreachable")
   258  }
   259  
   260  // Div performs elementwise division on the Tensor(s). These operations are supported:
   261  //		Div(*Dense, scalar)
   262  //		Div(scalar, *Dense)
   263  //		Div(*Dense, *Dense)
   264  // If the Unsafe flag is passed in, the data of the first tensor will be overwritten
   265  func Div(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) {
   266  	var diver Diver
   267  	var oe standardEngine
   268  	var ok bool
   269  	switch at := a.(type) {
   270  	case Tensor:
   271  		oe = at.standardEngine()
   272  		switch bt := b.(type) {
   273  		case Tensor:
   274  			if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor division
   275  				if oe != nil {
   276  					return oe.Div(at, bt, opts...)
   277  				}
   278  				if oe = bt.standardEngine(); oe != nil {
   279  					return oe.Div(at, bt, opts...)
   280  				}
   281  				if diver, ok = at.Engine().(Diver); ok {
   282  					return diver.Div(at, bt, opts...)
   283  				}
   284  				if diver, ok = bt.Engine().(Diver); ok {
   285  					return diver.Div(at, bt, opts...)
   286  				}
   287  				return nil, errors.New("Neither engines of either operand support Div")
   288  
   289  			} else { // at least one of the operands is a scalar
   290  				var leftTensor bool
   291  				if !bt.Shape().IsScalar() {
   292  					leftTensor = false // a Scalar-Tensor * b Tensor
   293  					tmp := at
   294  					at = bt
   295  					bt = tmp
   296  				} else {
   297  					leftTensor = true // a Tensor * b Scalar-Tensor
   298  				}
   299  
   300  				if oe != nil {
   301  					return oe.DivScalar(at, bt, leftTensor, opts...)
   302  				}
   303  				if oe = bt.standardEngine(); oe != nil {
   304  					return oe.DivScalar(at, bt, leftTensor, opts...)
   305  				}
   306  				if diver, ok = at.Engine().(Diver); ok {
   307  					return diver.DivScalar(at, bt, leftTensor, opts...)
   308  				}
   309  				if diver, ok = bt.Engine().(Diver); ok {
   310  					return diver.DivScalar(at, bt, leftTensor, opts...)
   311  				}
   312  				return nil, errors.New("Neither engines of either operand support Div")
   313  			}
   314  
   315  		default:
   316  			if oe != nil {
   317  				return oe.DivScalar(at, bt, true, opts...)
   318  			}
   319  			if diver, ok = at.Engine().(Diver); ok {
   320  				return diver.DivScalar(at, bt, true, opts...)
   321  			}
   322  			return nil, errors.New("Operand A's engine does not support Div")
   323  		}
   324  	default:
   325  		switch bt := b.(type) {
   326  		case Tensor:
   327  			if oe = bt.standardEngine(); oe != nil {
   328  				return oe.DivScalar(bt, at, false, opts...)
   329  			}
   330  			if diver, ok = bt.Engine().(Diver); ok {
   331  				return diver.DivScalar(bt, at, false, opts...)
   332  			}
   333  			return nil, errors.New("Operand B's engine does not support Div")
   334  		default:
   335  			return nil, errors.Errorf("Cannot perform Div of %T and %T", a, b)
   336  		}
   337  	}
   338  	panic("Unreachable")
   339  }
   340  
   341  // Pow performs elementwise exponentiation on the Tensor(s). These operations are supported:
   342  //		Pow(*Dense, scalar)
   343  //		Pow(scalar, *Dense)
   344  //		Pow(*Dense, *Dense)
   345  // If the Unsafe flag is passed in, the data of the first tensor will be overwritten
   346  func Pow(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) {
   347  	var power Power
   348  	var oe standardEngine
   349  	var ok bool
   350  	switch at := a.(type) {
   351  	case Tensor:
   352  		oe = at.standardEngine()
   353  		switch bt := b.(type) {
   354  		case Tensor:
   355  			if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor exponentiation
   356  				if oe != nil {
   357  					return oe.Pow(at, bt, opts...)
   358  				}
   359  				if oe = bt.standardEngine(); oe != nil {
   360  					return oe.Pow(at, bt, opts...)
   361  				}
   362  				if power, ok = at.Engine().(Power); ok {
   363  					return power.Pow(at, bt, opts...)
   364  				}
   365  				if power, ok = bt.Engine().(Power); ok {
   366  					return power.Pow(at, bt, opts...)
   367  				}
   368  				return nil, errors.New("Neither engines of either operand support Pow")
   369  
   370  			} else { // at least one of the operands is a scalar
   371  				var leftTensor bool
   372  				if !bt.Shape().IsScalar() {
   373  					leftTensor = false // a Scalar-Tensor * b Tensor
   374  					tmp := at
   375  					at = bt
   376  					bt = tmp
   377  				} else {
   378  					leftTensor = true // a Tensor * b Scalar-Tensor
   379  				}
   380  
   381  				if oe != nil {
   382  					return oe.PowScalar(at, bt, leftTensor, opts...)
   383  				}
   384  				if oe = bt.standardEngine(); oe != nil {
   385  					return oe.PowScalar(at, bt, leftTensor, opts...)
   386  				}
   387  				if power, ok = at.Engine().(Power); ok {
   388  					return power.PowScalar(at, bt, leftTensor, opts...)
   389  				}
   390  				if power, ok = bt.Engine().(Power); ok {
   391  					return power.PowScalar(at, bt, leftTensor, opts...)
   392  				}
   393  				return nil, errors.New("Neither engines of either operand support Pow")
   394  			}
   395  
   396  		default:
   397  			if oe != nil {
   398  				return oe.PowScalar(at, bt, true, opts...)
   399  			}
   400  			if power, ok = at.Engine().(Power); ok {
   401  				return power.PowScalar(at, bt, true, opts...)
   402  			}
   403  			return nil, errors.New("Operand A's engine does not support Pow")
   404  		}
   405  	default:
   406  		switch bt := b.(type) {
   407  		case Tensor:
   408  			if oe = bt.standardEngine(); oe != nil {
   409  				return oe.PowScalar(bt, at, false, opts...)
   410  			}
   411  			if power, ok = bt.Engine().(Power); ok {
   412  				return power.PowScalar(bt, at, false, opts...)
   413  			}
   414  			return nil, errors.New("Operand B's engine does not support Pow")
   415  		default:
   416  			return nil, errors.Errorf("Cannot perform Pow of %T and %T", a, b)
   417  		}
   418  	}
   419  	panic("Unreachable")
   420  }
   421  
   422  // Mod performs elementwise modulo on the Tensor(s). These operations are supported:
   423  //		Mod(*Dense, scalar)
   424  //		Mod(scalar, *Dense)
   425  //		Mod(*Dense, *Dense)
   426  // If the Unsafe flag is passed in, the data of the first tensor will be overwritten
   427  func Mod(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) {
   428  	var moder Moder
   429  	var oe standardEngine
   430  	var ok bool
   431  	switch at := a.(type) {
   432  	case Tensor:
   433  		oe = at.standardEngine()
   434  		switch bt := b.(type) {
   435  		case Tensor:
   436  			if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor modulo
   437  				if oe != nil {
   438  					return oe.Mod(at, bt, opts...)
   439  				}
   440  				if oe = bt.standardEngine(); oe != nil {
   441  					return oe.Mod(at, bt, opts...)
   442  				}
   443  				if moder, ok = at.Engine().(Moder); ok {
   444  					return moder.Mod(at, bt, opts...)
   445  				}
   446  				if moder, ok = bt.Engine().(Moder); ok {
   447  					return moder.Mod(at, bt, opts...)
   448  				}
   449  				return nil, errors.New("Neither engines of either operand support Mod")
   450  
   451  			} else { // at least one of the operands is a scalar
   452  				var leftTensor bool
   453  				if !bt.Shape().IsScalar() {
   454  					leftTensor = false // a Scalar-Tensor * b Tensor
   455  					tmp := at
   456  					at = bt
   457  					bt = tmp
   458  				} else {
   459  					leftTensor = true // a Tensor * b Scalar-Tensor
   460  				}
   461  
   462  				if oe != nil {
   463  					return oe.ModScalar(at, bt, leftTensor, opts...)
   464  				}
   465  				if oe = bt.standardEngine(); oe != nil {
   466  					return oe.ModScalar(at, bt, leftTensor, opts...)
   467  				}
   468  				if moder, ok = at.Engine().(Moder); ok {
   469  					return moder.ModScalar(at, bt, leftTensor, opts...)
   470  				}
   471  				if moder, ok = bt.Engine().(Moder); ok {
   472  					return moder.ModScalar(at, bt, leftTensor, opts...)
   473  				}
   474  				return nil, errors.New("Neither engines of either operand support Mod")
   475  			}
   476  
   477  		default:
   478  			if oe != nil {
   479  				return oe.ModScalar(at, bt, true, opts...)
   480  			}
   481  			if moder, ok = at.Engine().(Moder); ok {
   482  				return moder.ModScalar(at, bt, true, opts...)
   483  			}
   484  			return nil, errors.New("Operand A's engine does not support Mod")
   485  		}
   486  	default:
   487  		switch bt := b.(type) {
   488  		case Tensor:
   489  			if oe = bt.standardEngine(); oe != nil {
   490  				return oe.ModScalar(bt, at, false, opts...)
   491  			}
   492  			if moder, ok = bt.Engine().(Moder); ok {
   493  				return moder.ModScalar(bt, at, false, opts...)
   494  			}
   495  			return nil, errors.New("Operand B's engine does not support Mod")
   496  		default:
   497  			return nil, errors.Errorf("Cannot perform Mod of %T and %T", a, b)
   498  		}
   499  	}
   500  	panic("Unreachable")
   501  }
   502  
   503  // Dot is a highly opinionated API for performing dot product operations on two *Denses, a and b.
   504  // This function is opinionated with regard to the vector operations because of how it treats operations with vectors.
   505  // Vectors in this package comes in two flavours - column or row vectors. Column vectors have shape (x, 1), while row vectors have shape (1, x).
   506  //
   507  // As such, it is easy to assume that performing a linalg operation on vectors would follow the same rules (i.e shapes have to be aligned for things to work).
   508  // For the most part in this package, this is true. This function is one of the few notable exceptions.
   509  //
   510  // Here I give three specific examples of how the expectations of vector operations will differ.
   511  // 		Given two vectors, a, b with shapes (4, 1) and (4, 1), Dot() will perform an inner product as if the shapes were (1, 4) and (4, 1). This will result in a scalar value
   512  // 		Given matrix A and vector b with shapes (2, 4) and (1, 4), Dot() will perform a matrix-vector multiplication as if the shapes were (2,4) and (4,1). This will result in a column vector with shape (2,1)
   513  //		Given vector a and matrix B with shapes (3, 1) and (3, 2), Dot() will perform a matrix-vector multiplication as if it were Báµ€ * a
   514  //
   515  // The main reason why this opinionated route was taken was due to the author's familiarity with NumPy, and general laziness in translating existing machine learning algorithms
   516  // to fit the API of the package.
   517  func Dot(x, y Tensor, opts ...FuncOpt) (retVal Tensor, err error) {
   518  	if xdottir, ok := x.Engine().(Dotter); ok {
   519  		return xdottir.Dot(x, y, opts...)
   520  	}
   521  	if ydottir, ok := y.Engine().(Dotter); ok {
   522  		return ydottir.Dot(x, y, opts...)
   523  	}
   524  	return nil, errors.New("Neither x's nor y's engines support Dot")
   525  }
   526  
   527  // FMA performs Y = A * X + Y.
   528  func FMA(a Tensor, x interface{}, y Tensor) (retVal Tensor, err error) {
   529  	if xTensor, ok := x.(Tensor); ok {
   530  		if oe := a.standardEngine(); oe != nil {
   531  			return oe.FMA(a, xTensor, y)
   532  		}
   533  		if oe := xTensor.standardEngine(); oe != nil {
   534  			return oe.FMA(a, xTensor, y)
   535  		}
   536  		if oe := y.standardEngine(); oe != nil {
   537  			return oe.FMA(a, xTensor, y)
   538  		}
   539  
   540  		if e, ok := a.Engine().(FMAer); ok {
   541  			return e.FMA(a, xTensor, y)
   542  		}
   543  		if e, ok := xTensor.Engine().(FMAer); ok {
   544  			return e.FMA(a, xTensor, y)
   545  		}
   546  		if e, ok := y.Engine().(FMAer); ok {
   547  			return e.FMA(a, xTensor, y)
   548  		}
   549  	} else {
   550  		if oe := a.standardEngine(); oe != nil {
   551  			return oe.FMAScalar(a, x, y)
   552  		}
   553  		if oe := y.standardEngine(); oe != nil {
   554  			return oe.FMAScalar(a, x, y)
   555  		}
   556  
   557  		if e, ok := a.Engine().(FMAer); ok {
   558  			return e.FMAScalar(a, x, y)
   559  		}
   560  		if e, ok := y.Engine().(FMAer); ok {
   561  			return e.FMAScalar(a, x, y)
   562  		}
   563  	}
   564  	return Mul(a, x, WithIncr(y))
   565  }
   566  
   567  // MatMul performs matrix-matrix multiplication between two Tensors
   568  func MatMul(a, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) {
   569  	if a.Dtype() != b.Dtype() {
   570  		err = errors.Errorf(dtypeMismatch, a.Dtype(), b.Dtype())
   571  		return
   572  	}
   573  
   574  	switch at := a.(type) {
   575  	case *Dense:
   576  		bt := b.(*Dense)
   577  		return at.MatMul(bt, opts...)
   578  	}
   579  	panic("Unreachable")
   580  }
   581  
   582  // MatVecMul performs matrix-vector multiplication between two Tensors. `a` is expected to be a matrix, and `b` is expected to be a vector
   583  func MatVecMul(a, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) {
   584  	if a.Dtype() != b.Dtype() {
   585  		err = errors.Errorf(dtypeMismatch, a.Dtype(), b.Dtype())
   586  		return
   587  	}
   588  
   589  	switch at := a.(type) {
   590  	case *Dense:
   591  		bt := b.(*Dense)
   592  		return at.MatVecMul(bt, opts...)
   593  	}
   594  	panic("Unreachable")
   595  }
   596  
   597  // Inner finds the inner products of two vector Tensors. Both arguments to the functions are eexpected to be vectors.
   598  func Inner(a, b Tensor) (retVal interface{}, err error) {
   599  	if a.Dtype() != b.Dtype() {
   600  		err = errors.Errorf(dtypeMismatch, a.Dtype(), b.Dtype())
   601  		return
   602  	}
   603  
   604  	switch at := a.(type) {
   605  	case *Dense:
   606  		bt := b.(*Dense)
   607  		return at.Inner(bt)
   608  	}
   609  	panic("Unreachable")
   610  }
   611  
   612  // Outer performs the outer product of two vector Tensors. Both arguments to the functions are expected to be vectors.
   613  func Outer(a, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) {
   614  	if a.Dtype() != b.Dtype() {
   615  		err = errors.Errorf(dtypeMismatch, a.Dtype(), b.Dtype())
   616  		return
   617  	}
   618  
   619  	switch at := a.(type) {
   620  	case *Dense:
   621  		bt := b.(*Dense)
   622  		return at.Outer(bt, opts...)
   623  	}
   624  	panic("Unreachable")
   625  }
   626  
   627  // Contract performs a contraction of given tensors along given axes
   628  func Contract(a, b Tensor, aAxes, bAxes []int) (retVal Tensor, err error) {
   629  	if a.Dtype() != b.Dtype() {
   630  		err = errors.Errorf(dtypeMismatch, a.Dtype(), b.Dtype())
   631  		return
   632  	}
   633  
   634  	switch at := a.(type) {
   635  	case *Dense:
   636  		bt := b.(*Dense)
   637  		return at.TensorMul(bt, aAxes, bAxes)
   638  
   639  	default:
   640  		panic("Unreachable")
   641  	}
   642  }