gorgonia.org/gorgonia@v0.9.17/op_math.go (about)

     1  package gorgonia
     2  
     3  /*
     4  This file holds all the Ops that are related to doing math-related work. Due to the numerousness of
     5  mathematical operations, they're classified into 3 main types:
     6  	elemBinOp - a representation of a binary mathematical operation that is performed elementwise (example: +, *, -, or >, <)
     7  	elemUnaryOp - a representation of a mathematical operation that is performed elmentwise
     8  	linAlgBinOp - a representation of a binary mathematical operation that is performed on matrices
     9  
    10  The individual operators are further exanded on operator*.go files. Their datatypes are often embedded in the datatypes here.
    11  
    12  For all data type, the methods are standardized by arrangement in the order the Op interface is defined.
    13  Any additional interfaces that the data type fulfils will be declared AFTER the Op interface methods.
    14  */
    15  
    16  import (
    17  	"bytes"
    18  	"encoding/binary"
    19  	"fmt"
    20  	"hash"
    21  
    22  	"github.com/chewxy/hm"
    23  	"github.com/pkg/errors"
    24  	"gorgonia.org/tensor"
    25  )
    26  
    27  /* ELEMENTWISE BINARY OPERATION */
    28  
    29  // elemBinOp is the representation of an operation that is to be performed elementwise
    30  type elemBinOp struct {
    31  	ʘBinaryOperator
    32  	arg0, arg1 hm.Type // pruned types only plz
    33  	retSame    bool    // for comparison ops, return same type?
    34  }
    35  
    36  func newEBOByType(ot ʘBinaryOperatorType, at, bt hm.Type) elemBinOp {
    37  	var binOp ʘBinaryOperator
    38  	switch att := at.(type) {
    39  	case tensor.Dtype:
    40  		switch bt.(type) {
    41  		case tensor.Dtype:
    42  			binOp = scalarBinOp{
    43  				ʘBinaryOperatorType: ot,
    44  				t:                   att,
    45  			}
    46  		case TensorType:
    47  			binOp = tBinOp{
    48  				ʘBinaryOperatorType: ot,
    49  				tensorLeft:          false,
    50  			}
    51  		default:
    52  			panic(fmt.Sprintf("Unsupported type of b %v!", bt))
    53  		}
    54  	case TensorType:
    55  		binOp = tBinOp{
    56  			ʘBinaryOperatorType: ot,
    57  			tensorLeft:          true,
    58  		}
    59  	default:
    60  		panic(fmt.Sprintf("Unsupported type of a %v!", at))
    61  	}
    62  	return elemBinOp{
    63  		ʘBinaryOperator: binOp,
    64  		arg0:            at,
    65  		arg1:            bt,
    66  	}
    67  }
    68  
    69  func newElemBinOp(ot ʘBinaryOperatorType, a, b *Node) elemBinOp {
    70  	// at := hm.Prune(a.t)
    71  	// bt := hm.Prune(b.t)
    72  
    73  	return newEBOByType(ot, a.t, b.t)
    74  }
    75  
    76  func (op elemBinOp) Arity() int { return 2 }
    77  
    78  // elemBinOp has either of these types:
    79  // 		elemBinOp :: (Floats a) ⇒ Tensor a → Tensor a → Tensor a
    80  // 		elemBinOp :: (Floats a) ⇒ Tensor a → a → Tensor a
    81  //		elemBinOp :: (Floats a) ⇒ a → Tensor a → a
    82  //		elemBinOp :: (Floats a) ⇒ a → a → a
    83  //		elemBinOp :: (Floats a) ⇒ a → a → Bool
    84  // 		elemBinOp :: (Floats a) ⇒ Tensor a → Tensor a → Tensor Bool
    85  // 		elemBinOp :: (Floats a) ⇒ Tensor a → a → Tensor Bool
    86  //		elemBinOp :: (Floats a) ⇒ a → Tensor a → Bool
    87  //
    88  // To make things clearer, it helps to consider elemBinOp to be the representation of
    89  // a dispatch table for different functions. In a sense it's "overloading" functions.
    90  //
    91  // At the moment, due to my refusal to create a sum type (which requires more finnicking with data constructors)
    92  // Type() happens pretty much at close to run time
    93  func (op elemBinOp) Type() hm.Type {
    94  	a := hm.TypeVariable('a')
    95  
    96  	var a0, a1, retType hm.Type
    97  	var arg0Dims int
    98  	switch arg0 := op.arg0.(type) {
    99  	case TensorType:
   100  		arg0Dims = arg0.Dims
   101  		a0 = makeFromTensorType(arg0, a)
   102  		retType = makeFromTensorType(arg0, a)
   103  	default:
   104  		a0 = a
   105  		retType = a
   106  	}
   107  
   108  	switch arg1 := op.arg1.(type) {
   109  	case TensorType:
   110  		if arg1.Dims >= arg0Dims {
   111  			retType = makeFromTensorType(arg1, a)
   112  		}
   113  		a1 = makeFromTensorType(arg1, a)
   114  	default:
   115  		a1 = a
   116  	}
   117  
   118  	if op.isArith() || (!op.isArith() && op.retSame) {
   119  		return hm.NewFnType(a0, a1, retType)
   120  	}
   121  
   122  	switch rt := retType.(type) {
   123  	case TensorType:
   124  		rt.Of = Bool
   125  		retType = rt
   126  	default:
   127  		retType = Bool
   128  	}
   129  
   130  	return hm.NewFnType(a0, a1, retType)
   131  }
   132  
   133  // elemBinOp has these allowed shapes:
   134  // 		op :: () → () → ()
   135  //		op :: () → (...) → (...)
   136  //		op :: (...) → () → (...)
   137  func (op elemBinOp) InferShape(inputs ...DimSizer) (retVal tensor.Shape, err error) {
   138  	shapeLogf("Inferring shape of %v", op)
   139  	enterLogScope()
   140  	defer leaveLogScope()
   141  
   142  	if inputs[0] == nil || inputs[1] == nil {
   143  		return nil, errors.Errorf(nyiFail, "elemBinOp.inferShape", "runtime impl")
   144  	}
   145  
   146  	switch x := inputs[0].(type) {
   147  	case tensor.Shape:
   148  		switch y := inputs[1].(type) {
   149  		case tensor.Shape:
   150  			switch {
   151  			case x.IsScalarEquiv() && y.IsScalarEquiv():
   152  				// preserve ambiguous scalar shape
   153  				switch {
   154  				case len(x) > 0 && x[0] == 1:
   155  					retVal = x
   156  				case len(y) > 0 && y[0] == 1:
   157  					retVal = y
   158  				case x.IsScalar() && y.IsScalar():
   159  					retVal = scalarShape
   160  				default:
   161  					retVal = scalarShape
   162  				}
   163  			case x.IsScalar() && !y.IsScalar():
   164  				retVal = y
   165  			case !x.IsScalar() && y.IsScalar():
   166  				retVal = x
   167  			case !x.IsScalar() && !y.IsScalar():
   168  				if !x.Eq(y) {
   169  					return nil, errors.Errorf("Shape mismatch: %v and %v", x, y)
   170  				}
   171  				if x.Dims() > y.Dims() {
   172  					retVal = x
   173  				} else {
   174  					retVal = y
   175  				}
   176  			}
   177  		default:
   178  			retVal = x
   179  		}
   180  	default:
   181  		switch y := inputs[1].(type) {
   182  		case tensor.Shape:
   183  			retVal = y
   184  		default:
   185  			retVal = scalarShape
   186  		}
   187  	}
   188  	return
   189  }
   190  
   191  // DiffWRT gives info on whether or not the operation is actually differentiable
   192  // For example, this is differentiable:
   193  //		c = a ** b
   194  // The result of the differentiation wrt to a and b would be:
   195  // 		dc/da = b * a ** (b-1)
   196  // 		dc/db = a ** b * ln(a)
   197  //
   198  // However, operators like < and > are NOT differentiable
   199  //
   200  // This method returns a slice of bools, indicating whether differentiation with regards to its operands
   201  // can be done. Since binOp has 2 operands, we'll return a slice
   202  func (op elemBinOp) DiffWRT(inputs int) []bool {
   203  	if inputs != 2 {
   204  		panic(fmt.Sprintf(binOpFail, inputs))
   205  	}
   206  
   207  	b := op.ʘBinaryOperator.binOpType()
   208  
   209  	if b >= maxʘBinaryOpType {
   210  		panic("Unsupported unary operator is not differentiable")
   211  	}
   212  
   213  	if b.isArith() {
   214  		return []bool{true, true}
   215  	}
   216  	return []bool{false, false}
   217  }
   218  
   219  func (op elemBinOp) SymDiff(inputs Nodes, output, gradNode *Node) (retVal Nodes, err error) {
   220  	if err = checkArity(op, len(inputs)); err != nil {
   221  		return
   222  	}
   223  
   224  	b := op.ʘBinaryOperator.binOpType()
   225  
   226  	if retVal, err = ʘBinOpDiffExprs[b](inputs[0], inputs[1], output, gradNode); err == nil {
   227  		for _, n := range retVal {
   228  			n.setGroup(gradClust)
   229  		}
   230  	}
   231  
   232  	// needed to handle scalar gradients such as b in the logit regression example
   233  	for i, grad := range retVal {
   234  		if inputs[i].IsScalar() && !grad.IsScalar() {
   235  			if retVal[i], err = Sum(grad); err != nil {
   236  				err = errors.Wrap(err, operationError)
   237  				return
   238  			}
   239  		}
   240  	}
   241  
   242  	return
   243  }
   244  
   245  func (op elemBinOp) Do(values ...Value) (Value, error) {
   246  	return op.ʘBinaryOperator.Do(op.retSame, values...)
   247  }
   248  
   249  func (op elemBinOp) DoDiff(ctx ExecutionContext, inputs Nodes, output *Node) (err error) {
   250  	if err = checkArity(op, len(inputs)); err != nil {
   251  		return
   252  	}
   253  
   254  	b := op.ʘBinaryOperator.binOpType()
   255  	if err = ʘBinOpDiffFns[b](ctx, inputs[0], inputs[1], output); err != nil {
   256  		if _, ok := err.(AutoDiffError); !ok {
   257  			return errors.Wrapf(err, autodiffFail, b)
   258  		}
   259  		err = nil
   260  	}
   261  
   262  	//handle scalar gradients
   263  	for _, in := range inputs {
   264  		indv := in.boundTo.(*dualValue)
   265  		if _, ok := indv.d.(Scalar); in.IsScalar() && !ok {
   266  			indvdT := indv.d.(tensor.Tensor)
   267  			defer returnTensor(indvdT)
   268  
   269  			var d Value
   270  			var t tensor.Tensor
   271  			if t, err = tensor.Sum(indvdT); err != nil {
   272  				return errors.Wrap(err, operationError)
   273  			}
   274  			defer returnTensor(t)
   275  
   276  			d, _ = anyToScalar(t.ScalarValue())
   277  			indv.SetDeriv(d)
   278  		}
   279  	}
   280  	return
   281  }
   282  
   283  func (op elemBinOp) ReturnsPtr() bool { return true }
   284  
   285  func (op elemBinOp) OverwritesInput() int {
   286  	if _, ok := op.arg0.(TensorType); ok {
   287  		return 0
   288  	}
   289  
   290  	if _, ok := op.arg1.(TensorType); ok {
   291  		return 1
   292  	}
   293  	return -1
   294  }
   295  
   296  func (op elemBinOp) WriteHash(h hash.Hash) {
   297  	if err := binary.Write(h, binary.LittleEndian, op.binOpType()); err != nil {
   298  		panic(err)
   299  	}
   300  
   301  	fmt.Fprintf(h, "%v,%v", op.arg0, op.arg1)
   302  }
   303  
   304  func (op elemBinOp) Hashcode() uint32 { return simpleHash(op) }
   305  
   306  // Fulfils UsePreallocDoer interface
   307  func (op elemBinOp) UsePreallocDo(prealloc Value, inputs ...Value) (retVal Value, err error) {
   308  	if !op.ReturnsPtr() {
   309  		return op.Do(inputs...)
   310  	}
   311  
   312  	if pd, ok := op.ʘBinaryOperator.(usePreallocDoerBinOp); ok {
   313  		return pd.UsePreallocDo(prealloc, op.retSame, inputs...)
   314  	}
   315  
   316  	if retVal, err = op.Do(inputs...); err != nil {
   317  		return
   318  	}
   319  	return Copy(prealloc, retVal)
   320  }
   321  
   322  // Fulfils UnsafeDoer interface
   323  func (op elemBinOp) UnsafeDo(inputs ...Value) (retVal Value, err error) {
   324  	if !op.ReturnsPtr() {
   325  		return op.Do(inputs...)
   326  	}
   327  
   328  	if ud, ok := op.ʘBinaryOperator.(unsafeDoerBinOp); ok {
   329  		return ud.UnsafeDo(op.retSame, inputs...)
   330  	}
   331  	return op.Do(inputs...)
   332  }
   333  
   334  // Fulfils the IncrDoer interface
   335  func (op elemBinOp) IncrDo(incr Value, inputs ...Value) (err error) {
   336  	if id, ok := op.ʘBinaryOperator.(incrDoerBinOp); ok {
   337  		return id.IncrDo(incr, op.retSame, inputs...)
   338  	}
   339  
   340  	// if !op.ReturnsPtr() {
   341  	var retVal Value
   342  	if retVal, err = op.Do(inputs...); err != nil {
   343  		return errors.Wrapf(err, doFail, op)
   344  	}
   345  
   346  	add := newEBOByType(addOpType, TypeOf(incr), TypeOf(retVal))
   347  	if retVal, err = add.UnsafeDo(incr, retVal); err != nil {
   348  		return errors.Wrapf(err, unsafeDoFail, add)
   349  	}
   350  	err = noIncrErr{retVal}
   351  	return
   352  	// }
   353  }
   354  
   355  func (op elemBinOp) String() string { return fmt.Sprintf("%v %t", op.ʘBinaryOperator, op.retSame) }
   356  
   357  // Fulfils the BinaryOp interface
   358  func (op elemBinOp) IsBinary() bool { return true }
   359  
   360  /* ELEMENTWISE UNARY OP */
   361  
   362  type elemUnaryOp struct {
   363  	ʘUnaryOperator
   364  
   365  	argTensor     bool
   366  	numericResult bool // indicate if boolean results should be converted to 1 and 0 in the respective Dtype
   367  }
   368  
   369  func newElemUnaryOp(op ʘUnaryOperatorType, a *Node) elemUnaryOp {
   370  	dt, err := dtypeOf(a.t)
   371  	if err != nil {
   372  		panic(err)
   373  	}
   374  
   375  	_, isTensor := a.t.(TensorType)
   376  
   377  	var operator ʘUnaryOperator
   378  	switch dt {
   379  	case Float32:
   380  		operator = sf32UnaryOperators[op]
   381  	case Float64:
   382  		operator = sf64UnaryOperators[op]
   383  	}
   384  
   385  	return elemUnaryOp{
   386  		ʘUnaryOperator: operator,
   387  		argTensor:      isTensor,
   388  	}
   389  }
   390  
   391  func (op elemUnaryOp) Arity() int { return 1 }
   392  
   393  // all pointwise unary operations have this type:
   394  //		op :: (Arithable a) ⇒ a → a
   395  func (op elemUnaryOp) Type() hm.Type {
   396  	a := hm.TypeVariable('a')
   397  	return hm.NewFnType(a, a)
   398  }
   399  
   400  func (op elemUnaryOp) InferShape(inputs ...DimSizer) (retVal tensor.Shape, err error) {
   401  	if inputs[0] == nil {
   402  		return nil, errors.Errorf(nyiFail, "inferShape", "nil shape")
   403  	}
   404  
   405  	return inputs[0].(tensor.Shape), nil
   406  }
   407  
   408  // diffWRT gives info on whether or not the operation is actually differentiable wrt to its inputs
   409  //
   410  // some operations, such as ceil(), sign(), floor cannot be differentiated wrt to its inputs (or I don't actually know how to do them)
   411  func (op elemUnaryOp) DiffWRT(inputs int) []bool {
   412  	if inputs != 1 {
   413  		panic(fmt.Sprintf("unary operator only supports one input, got %d instead", inputs))
   414  	}
   415  
   416  	u := op.ʘUnaryOperator.unaryOpType()
   417  
   418  	if u >= maxʘUnaryOperator {
   419  		panic("Unsupported unary operator is not differentiable")
   420  	}
   421  	return []bool{ʘUnaryOpDifferentiable[u]}
   422  }
   423  
   424  func (op elemUnaryOp) SymDiff(inputs Nodes, output, gradNode *Node) (retVal Nodes, err error) {
   425  	if err = checkArity(op, len(inputs)); err != nil {
   426  		return
   427  	}
   428  
   429  	u := op.ʘUnaryOperator.unaryOpType()
   430  
   431  	var n *Node
   432  	if n, err = ʘUnaryOpDiffExprs[u](inputs[0], output, gradNode); err == nil {
   433  		n.setGroup(gradClust)
   434  		retVal = Nodes{n}
   435  	}
   436  	return
   437  }
   438  
   439  func (op elemUnaryOp) DoDiff(ctx ExecutionContext, inputs Nodes, output *Node) (err error) {
   440  	if err = checkArity(op, len(inputs)); err != nil {
   441  		return
   442  	}
   443  
   444  	u := op.ʘUnaryOperator.unaryOpType()
   445  	return ʘUnaryOpDiffFns[u](inputs[0], output)
   446  }
   447  
   448  func (op elemUnaryOp) Do(inputs ...Value) (retVal Value, err error) {
   449  	if err = checkArity(op, len(inputs)); err != nil {
   450  		return
   451  	}
   452  	return op.do(inputs[0])
   453  }
   454  
   455  func (op elemUnaryOp) ReturnsPtr() bool { return true }
   456  
   457  func (op elemUnaryOp) OverwritesInput() int {
   458  	if op.argTensor {
   459  		return 0
   460  	}
   461  	return -1
   462  }
   463  
   464  func (op elemUnaryOp) WriteHash(h hash.Hash) {
   465  	if err := binary.Write(h, binary.LittleEndian, op.unaryOpType()); err != nil {
   466  		panic(err)
   467  	}
   468  
   469  	if op.argTensor {
   470  		h.Write([]byte{1})
   471  	} else {
   472  		h.Write([]byte{0})
   473  	}
   474  }
   475  
   476  func (op elemUnaryOp) Hashcode() uint32 { return simpleHash(op) }
   477  
   478  // fulfils UnsafeDoer interface
   479  func (op elemUnaryOp) UnsafeDo(inputs ...Value) (Value, error) {
   480  	if err := checkArity(op, len(inputs)); err != nil {
   481  		return nil, err
   482  	}
   483  	return op.do(inputs[0], tensor.UseUnsafe())
   484  }
   485  
   486  // fulfils UnaryOp interface
   487  
   488  func (op elemUnaryOp) isUnary() bool { return true }
   489  
   490  // misc private methods
   491  
   492  func (op elemUnaryOp) do(a Value, opts ...tensor.FuncOpt) (retVal Value, err error) {
   493  	switch v := a.(type) {
   494  	case tensor.Tensor:
   495  		return unaryCheckApply(op.ʘUnaryOperator, v, opts...)
   496  	case Scalar:
   497  		vt := v.Dtype()
   498  		switch vt {
   499  		case tensor.Float32:
   500  			vs := v.(*F32)
   501  			f := float32(*vs)
   502  			opFn := op.ʘUnaryOperator.(*sf32UnaryOperator)
   503  			retVal, _ = anyToScalar((*opFn)(f))
   504  		case tensor.Float64:
   505  			vs := v.(*F64)
   506  			f := float64(*vs)
   507  			opFn := op.ʘUnaryOperator.(*sf64UnaryOperator)
   508  			retVal, _ = anyToScalar((*opFn)(f))
   509  		default:
   510  			return nil, errors.Errorf(nyiFail, "elemUnaryOp.do", vt)
   511  		}
   512  	}
   513  	return
   514  }
   515  
   516  /* LINEAR ALGEBRA RELATED OPERATIONS */
   517  
   518  type linAlgBinOp struct {
   519  	āBinaryOperator
   520  	transA, transB bool
   521  }
   522  
   523  func (op linAlgBinOp) Arity() int { return 2 }
   524  
   525  func (op linAlgBinOp) InferShape(inputs ...DimSizer) (retVal tensor.Shape, err error) {
   526  	shapeLogf("Inferring shape of %v", op)
   527  	enterLogScope()
   528  	defer leaveLogScope()
   529  
   530  	if inputs[0] == nil || inputs[1] == nil {
   531  		return nil, nyi("InferShape for linalgBinOp", "runtime impl")
   532  	}
   533  
   534  	x, y := inputs[0].(tensor.Shape), inputs[1].(tensor.Shape)
   535  	if x == nil || y == nil {
   536  		return nil, errors.Errorf("Cannot infer shape from %v %v", x, y)
   537  	}
   538  
   539  	shapeLogf("x.shape: %v; y.shape: %v", x, y)
   540  	// TODO: add checks for tensors greater than 2 d
   541  
   542  	switch op.āBinaryOperator {
   543  	case matMulOperator:
   544  		if op.transA {
   545  			x = transpose2D(x)
   546  			defer tensor.ReturnInts(x)
   547  		}
   548  		if op.transB {
   549  			y = transpose2D(y)
   550  			defer tensor.ReturnInts(y)
   551  		}
   552  
   553  		if x[1] != y[0] {
   554  			return nil, errors.Errorf("Inner dimensions do not match up")
   555  		}
   556  
   557  		retVal = tensor.Shape{x[0], y[1]}
   558  	case matVecMulOperator:
   559  		if op.transA {
   560  			x = transpose2D(x)
   561  			defer tensor.ReturnInts(x)
   562  		}
   563  		if x[0] != y[0] && x[1] != y[0] {
   564  			return nil, errors.Errorf("Incompatible shapes: %v and %v", x, y)
   565  		}
   566  
   567  		switch {
   568  		case x[0] == y[0]:
   569  			retVal = tensor.Shape{x[1]}
   570  		case x[1] == y[0]:
   571  			retVal = tensor.Shape{x[0]}
   572  		}
   573  
   574  	case vecDotOperator:
   575  		retVal = scalarShape
   576  	case outerProdOperator:
   577  		// outerprods only handles vec x vec for now
   578  		retVal = tensor.Shape{x.TotalSize(), y.TotalSize()}
   579  	case batchedMatMulOperator:
   580  		x = x.Clone()
   581  		y = y.Clone()
   582  		innerX := x[len(x)-2:]
   583  		outerX := x[:len(x)-2]
   584  		innerY := y[len(y)-2:]
   585  		outerY := y[:len(y)-2]
   586  		if !outerX.Eq(outerY) {
   587  			return nil, errors.Errorf("Expected outer dimensions of %v and %v to match. Got %v and %v", x, y, outerX, outerY)
   588  		}
   589  
   590  		// batchSize := outerX.TotalSize()
   591  		if op.transA {
   592  			innerX = transpose2D(innerX)
   593  			defer tensor.ReturnInts(innerX)
   594  		}
   595  		if op.transB {
   596  			innerY = transpose2D(innerY)
   597  			defer tensor.ReturnInts(innerY)
   598  		}
   599  		retVal = append(outerX, innerX[0], innerY[1])
   600  	}
   601  	return
   602  }
   603  
   604  func (op linAlgBinOp) SymDiff(inputs Nodes, output, gradNode *Node) (retVal Nodes, err error) {
   605  	if err = checkArity(op, len(inputs)); err != nil {
   606  		return
   607  	}
   608  
   609  	o := op.āBinaryOperator
   610  
   611  	if retVal, err = āBinOpDiffExprs[o](op.transA, op.transB, inputs[0], inputs[1], output, gradNode); err != nil {
   612  		return nil, errors.Wrap(err, "Failed to differentiate expressions")
   613  	}
   614  
   615  	for _, n := range retVal {
   616  		n.setGroup(gradClust)
   617  	}
   618  	return
   619  }
   620  
   621  func (op linAlgBinOp) DoDiff(ctx ExecutionContext, inputs Nodes, output *Node) (err error) {
   622  	if err = checkArity(op, len(inputs)); err != nil {
   623  		return
   624  	}
   625  
   626  	o := op.āBinaryOperator
   627  	return āBinOpDiffs[o](ctx, op.transA, op.transB, inputs[0], inputs[1], output)
   628  }
   629  
   630  func (op linAlgBinOp) Do(inputs ...Value) (retVal Value, err error) { return op.do(inputs) }
   631  func (op linAlgBinOp) ReturnsPtr() bool                             { return true }
   632  func (op linAlgBinOp) OverwritesInput() int                         { return -1 }
   633  
   634  func (op linAlgBinOp) WriteHash(h hash.Hash) {
   635  	if err := binary.Write(h, binary.LittleEndian, op.āBinaryOperator); err != nil {
   636  		panic(err)
   637  	}
   638  
   639  	if op.transA {
   640  		h.Write([]byte{1})
   641  	} else {
   642  		h.Write([]byte{0})
   643  	}
   644  
   645  	if op.transB {
   646  		h.Write([]byte{1})
   647  	} else {
   648  		h.Write([]byte{0})
   649  	}
   650  }
   651  
   652  func (op linAlgBinOp) Hashcode() uint32 { return simpleHash(op) }
   653  
   654  func (op linAlgBinOp) String() string {
   655  	var buf bytes.Buffer
   656  
   657  	switch op.āBinaryOperator {
   658  	case matMulOperator, matVecMulOperator, batchedMatMulOperator:
   659  		buf.WriteString("A")
   660  	case vecDotOperator, outerProdOperator:
   661  		buf.WriteString("a")
   662  	}
   663  
   664  	if op.transA {
   665  		buf.WriteString("ᵀ")
   666  	}
   667  
   668  	switch op.āBinaryOperator {
   669  	case matMulOperator, batchedMatMulOperator:
   670  		fmt.Fprintf(&buf, " %v B", op.āBinaryOperator)
   671  	case matVecMulOperator, vecDotOperator, outerProdOperator:
   672  		fmt.Fprintf(&buf, " %v b", op.āBinaryOperator)
   673  	}
   674  
   675  	if op.transB {
   676  		buf.WriteString("ᵀ")
   677  	}
   678  
   679  	return buf.String()
   680  }
   681  
   682  // fulfils IncrDoer
   683  func (op linAlgBinOp) IncrDo(incr Value, inputs ...Value) (err error) {
   684  	t, ok := incr.(tensor.Tensor)
   685  
   686  	switch {
   687  	case ok && op.āBinaryOperator != batchedMatMulOperator:
   688  		_, err = op.do(inputs, tensor.WithIncr(t))
   689  		return
   690  	case ok && op.āBinaryOperator == batchedMatMulOperator:
   691  		_, err = op.preallocBatchMatMul(true, incr, inputs...)
   692  		return
   693  	}
   694  
   695  	var retVal Value
   696  	if retVal, err = op.do(inputs); err != nil {
   697  		return errors.Wrapf(err, doFail, op)
   698  	}
   699  
   700  	add := newEBOByType(addOpType, TypeOf(incr), TypeOf(retVal))
   701  	if retVal, err = add.UnsafeDo(incr, retVal); err != nil {
   702  		return errors.Wrapf(err, unsafeDoFail, add)
   703  	}
   704  
   705  	err = noIncrErr{retVal}
   706  	return
   707  }
   708  
   709  // fulfils UsePreallocDoer
   710  func (op linAlgBinOp) UsePreallocDo(prealloc Value, inputs ...Value) (retVal Value, err error) {
   711  	t, ok := prealloc.(tensor.Tensor)
   712  	if !ok {
   713  		return nil, errors.Errorf("Expected Tensor as preallocated value. Got %v of %T instead", prealloc, prealloc)
   714  	}
   715  	if op.āBinaryOperator == batchedMatMulOperator {
   716  		return op.preallocBatchMatMul(false, prealloc, inputs...)
   717  	}
   718  	return op.do(inputs, tensor.WithReuse(t))
   719  }
   720  
   721  // fulfils BinaryOp
   722  func (op linAlgBinOp) IsBinary() bool { return true }
   723  
   724  /* PRIVATE METHODS */
   725  
   726  func (op linAlgBinOp) do(inputs []Value, opts ...tensor.FuncOpt) (retVal Value, err error) {
   727  	if err = checkArity(op, len(inputs)); err != nil {
   728  		return
   729  	}
   730  
   731  	a, b := inputs[0].(tensor.Tensor), inputs[1].(tensor.Tensor)
   732  
   733  	if op.transA && op.āBinaryOperator != batchedMatMulOperator {
   734  		if err = a.T(); err != nil {
   735  			return nil, errors.Wrap(err, tFail)
   736  		}
   737  
   738  		// untranspose
   739  		defer a.T()
   740  	}
   741  
   742  	if op.transB && op.āBinaryOperator != batchedMatMulOperator {
   743  		if err = b.T(); err != nil {
   744  			return nil, errors.Wrap(err, tFail)
   745  		}
   746  
   747  		// untranspose
   748  		defer b.T()
   749  	}
   750  
   751  	switch op.āBinaryOperator {
   752  	case matMulOperator:
   753  		retVal, err = tensor.MatMul(a, b, opts...)
   754  	case matVecMulOperator:
   755  		retVal, err = tensor.MatVecMul(a, b, opts...)
   756  	case vecDotOperator:
   757  		var ret interface{}
   758  
   759  		if ret, err = tensor.Inner(a, b); err != nil {
   760  			return nil, errors.Wrapf(err, "Failed to carry out linalgBinOp operation %v", op)
   761  		}
   762  
   763  		retVal, _ = anyToScalar(ret)
   764  	case outerProdOperator:
   765  		retVal, err = tensor.Outer(a, b, opts...)
   766  	case batchedMatMulOperator:
   767  		// checks were done when the op was created
   768  		retVal, err = batchedMatMul(a, b, nil, op.transA, op.transB, false)
   769  	}
   770  
   771  	if err != nil {
   772  		return nil, fmt.Errorf("linAlgBinOp %v %s %v error: %w", a.Shape(), op.āBinaryOperator, b.Shape(), err)
   773  	}
   774  
   775  	return retVal, nil
   776  }
   777  
   778  func (op linAlgBinOp) preallocBatchMatMul(incr bool, prealloc Value, inputs ...Value) (retVal Value, err error) {
   779  	if err = checkArity(op, len(inputs)); err != nil {
   780  		return
   781  	}
   782  	a, b := inputs[0].(tensor.Tensor), inputs[1].(tensor.Tensor)
   783  	c := prealloc.(tensor.Tensor)
   784  	return batchedMatMul(a, b, c, op.transA, op.transB, incr)
   785  }
   786  
   787  type tensordotOp struct {
   788  	aAxes   []int
   789  	bAxes   []int
   790  	aDims   int
   791  	bDims   int
   792  	retDims int // Dimension of the tensor resulting from operation
   793  }
   794  
   795  func makeTensordotOp(a, b *Node, aAxes, bAxes []int) tensordotOp {
   796  	aDims := a.Shape().Dims()
   797  	bDims := b.Shape().Dims()
   798  	retDims := a.Shape().Dims() + b.Shape().Dims() - 2*len(aAxes)
   799  	if retDims < 0 {
   800  		retDims = 0
   801  	}
   802  	return tensordotOp{
   803  		aAxes:   aAxes,
   804  		bAxes:   bAxes,
   805  		aDims:   aDims,
   806  		bDims:   bDims,
   807  		retDims: retDims,
   808  	}
   809  }
   810  
   811  func (op tensordotOp) Arity() int { return 2 }
   812  
   813  func (op tensordotOp) Type() hm.Type {
   814  	var tRet hm.Type
   815  	if op.retDims == 0 {
   816  		tRet = hm.TypeVariable('a')
   817  	} else {
   818  		tRet = newTensorType(op.retDims, hm.TypeVariable('a'))
   819  	}
   820  	ta := newTensorType(op.aDims, hm.TypeVariable('a'))
   821  	tb := newTensorType(op.bDims, hm.TypeVariable('a'))
   822  
   823  	return hm.NewFnType(ta, tb, tRet)
   824  }
   825  
   826  func (op tensordotOp) InferShape(ds ...DimSizer) (tensor.Shape, error) {
   827  	if err := checkArity(op, len(ds)); err != nil {
   828  		return nil, errors.Wrap(err, "tensordot")
   829  	}
   830  
   831  	shapes, err := DimSizersToShapes(ds)
   832  	if err != nil {
   833  		return nil, err
   834  	}
   835  
   836  	aShape := shapes[0]
   837  	bShape := shapes[1]
   838  
   839  	aAxes := op.aAxes
   840  	bAxes := op.bAxes
   841  
   842  	shapeBackingLen := op.retDims
   843  
   844  	shapeBacking := make([]int, shapeBackingLen, shapeBackingLen)
   845  
   846  	shapeBackingPos := 0
   847  
   848  	for aShapeIndex, aShapeValue := range aShape {
   849  		if 0 > contains(aAxes, aShapeIndex) {
   850  			shapeBacking[shapeBackingPos] = aShapeValue
   851  			shapeBackingPos++
   852  		}
   853  	}
   854  
   855  	for bShapeIndex, bShapeValue := range bShape {
   856  		if 0 > contains(bAxes, bShapeIndex) {
   857  			shapeBacking[shapeBackingPos] = bShapeValue
   858  			shapeBackingPos++
   859  		}
   860  	}
   861  
   862  	return tensor.Shape(shapeBacking), nil
   863  }
   864  
   865  func (op tensordotOp) Do(vals ...Value) (Value, error) {
   866  	if err := checkArity(op, len(vals)); err != nil {
   867  		return nil, errors.Wrap(err, "tensordot")
   868  	}
   869  
   870  	ts, err := valuesToTensors(vals)
   871  	if err != nil {
   872  		return nil, errors.Wrap(err, "tensordot - valuesToTensors failed")
   873  	}
   874  
   875  	return tensor.Contract(ts[0], ts[1], op.aAxes, op.bAxes)
   876  }
   877  
   878  func (op tensordotOp) ReturnsPtr() bool { return true }
   879  
   880  func (op tensordotOp) CallsExtern() bool { return false }
   881  
   882  func (op tensordotOp) OverwritesInput() int { return -1 }
   883  
   884  func (op tensordotOp) WriteHash(h hash.Hash) {
   885  	h.Write([]byte("tensordotOp"))
   886  	fmt.Fprintf(h, "aAxes: %d, bAxes: %d, dims: %d", op.aAxes, op.bAxes, op.retDims)
   887  
   888  	return
   889  }
   890  
   891  func (op tensordotOp) Hashcode() uint32 { return simpleHash(op) }
   892  
   893  func (op tensordotOp) String() string {
   894  	return fmt.Sprintf("Tensordot(aAxes=%d, bAxes=%d)", op.aAxes, op.bAxes)
   895  }
   896  
   897  func (op tensordotOp) DoDiff(ctx ExecutionContext, inputs Nodes, output *Node) error {
   898  	odv := output.boundTo.(*dualValue)
   899  	odvd := odv.d.(tensor.Tensor)
   900  
   901  	for inNr, in := range inputs {
   902  		// abuse of language below: "i" up front will refer to current "in"
   903  		// "other" for the other input (there are only two)
   904  
   905  		// Who's derivative are we calculating?
   906  		var iAxes []int
   907  		var otherAxes []int
   908  		var otherdv *dualValue
   909  		var iWasFirstArgument bool
   910  
   911  		if 0 == inNr {
   912  			iAxes = op.aAxes
   913  			otherAxes = op.bAxes
   914  			otherdv = inputs[1].boundTo.(*dualValue)
   915  			iWasFirstArgument = true
   916  		} else {
   917  			iAxes = op.bAxes
   918  			otherAxes = op.aAxes
   919  			otherdv = inputs[0].boundTo.(*dualValue)
   920  			iWasFirstArgument = false
   921  		}
   922  
   923  		idv := in.boundTo.(*dualValue)
   924  		idvd := idv.d.(tensor.Tensor)
   925  
   926  		otherdvv := otherdv.Value.(tensor.Tensor)
   927  
   928  		// Below a tensordot will be performed: Its output axes will be in the wrong order w.r.t to the input.
   929  		// What is the correct permutation/pattern?
   930  		iAxesCoSorted := make([]int, len(iAxes))
   931  		for index, value := range iAxes {
   932  			iAxesCoSorted[index] = value
   933  		}
   934  
   935  		otherAxesSorted := make([]int, len(otherAxes))
   936  		for index, value := range otherAxes {
   937  			otherAxesSorted[index] = value
   938  		}
   939  
   940  		sortUniqueIntWithImitator(otherAxesSorted, iAxesCoSorted)
   941  		pattern := make([]int, len(in.Shape()))
   942  		counter := len(iAxes)
   943  
   944  		for patternIndex := 0; patternIndex < len(pattern); patternIndex++ {
   945  			iAxesCoSortedIndex := contains(iAxesCoSorted, patternIndex)
   946  			if 0 <= iAxesCoSortedIndex {
   947  				pattern[patternIndex] = iAxesCoSortedIndex
   948  			} else {
   949  				pattern[patternIndex] = counter
   950  				counter++
   951  			}
   952  		}
   953  		// if the shape is scalar equivalent, then we'll not have any transforms
   954  		if in.Shape().IsScalarEquiv() {
   955  			pattern = pattern[:0]
   956  		}
   957  
   958  		// Which axes of the other tensor and the output should be contracted?
   959  		// Other tensor: All axes that weren't contracted (with i ;-) ) in the original tensordot
   960  		// With the exception of scalars
   961  		dOtherAxes := make([]int, otherdvv.Dims())
   962  
   963  		if !otherdvv.Shape().IsScalarEquiv() {
   964  			var dOtherAxesIndex int
   965  
   966  			for axis := 0; axis < otherdvv.Dims(); axis++ {
   967  				if 0 > contains(otherAxes, axis) {
   968  					dOtherAxes[dOtherAxesIndex] = axis
   969  					dOtherAxesIndex++
   970  				}
   971  			}
   972  
   973  			dOtherAxes = dOtherAxes[0:dOtherAxesIndex]
   974  		}
   975  
   976  		// Output: All axes which belong to other in the output of original tensordot, so this depends on input ordering
   977  		dOutputAxes := make([]int, len(dOtherAxes))
   978  		if iWasFirstArgument {
   979  			outputOtherAxesStart := odvd.Dims() - len(dOtherAxes)
   980  
   981  			for axis := 0; axis < len(dOtherAxes); axis++ {
   982  				dOutputAxes[axis] = outputOtherAxesStart + axis
   983  			}
   984  		} else {
   985  			for axis := 0; axis < len(dOtherAxes); axis++ {
   986  				dOutputAxes[axis] = axis
   987  			}
   988  		}
   989  
   990  		// perform tensordot
   991  		switch st := odvd.(type) {
   992  		case *tensor.Dense:
   993  
   994  			otherdvvDense := otherdvv.(*tensor.Dense)
   995  			odvdDense := odvd.(*tensor.Dense)
   996  			var tensordot *tensor.Dense
   997  			var err error
   998  
   999  			switch {
  1000  			case odvdDense.Shape().IsScalarEquiv():
  1001  				tensordot, err = otherdvvDense.MulScalar(odvdDense, true)
  1002  			case otherdvvDense.IsVector() && odvdDense.IsVector() && 0 == len(dOtherAxes): // TensorMul does not support creating matrix from two vectors
  1003  				// Reformat vectors, so that MatMul will create a matrix from them
  1004  				var otherdvvDenseShapeOld tensor.Shape
  1005  				var odvdDenseShapeOld tensor.Shape
  1006  
  1007  				otherdvvDenseReshaped := false
  1008  				if !otherdvvDense.IsColVec() {
  1009  					otherdvvDenseShapeOld = otherdvvDense.Shape().Clone()
  1010  
  1011  					otherdvvVecDims, err := (otherdvvDense.AP.Shape()).DimSize(0)
  1012  					if err != nil {
  1013  						return err
  1014  					}
  1015  
  1016  					otherdvvDenseReshaped = true
  1017  					otherdvvDense.Reshape(otherdvvVecDims, 1)
  1018  				}
  1019  
  1020  				odvdDenseReshaped := false
  1021  				if !odvdDense.IsRowVec() {
  1022  					odvdDenseShapeOld = odvdDense.Shape().Clone()
  1023  					odvdDenseVecDims, err := (odvdDense.AP.Shape()).DimSize(0)
  1024  
  1025  					if err != nil {
  1026  						return err
  1027  					}
  1028  
  1029  					odvdDenseReshaped = true
  1030  					odvdDense.Reshape(1, odvdDenseVecDims)
  1031  				}
  1032  
  1033  				tensordot, err = otherdvvDense.MatMul(odvdDense)
  1034  
  1035  				// Undo Reshape
  1036  				if otherdvvDenseReshaped {
  1037  					otherdvvDense.Reshape(otherdvvDenseShapeOld...)
  1038  				}
  1039  
  1040  				if odvdDenseReshaped {
  1041  					odvdDense.Reshape(odvdDenseShapeOld...)
  1042  				}
  1043  
  1044  			default:
  1045  				tensordot, err = otherdvvDense.TensorMul(odvdDense, dOtherAxes, dOutputAxes)
  1046  
  1047  			}
  1048  
  1049  			if err != nil {
  1050  				return err
  1051  			}
  1052  			tensordotPerm, err := tensor.T(tensordot, pattern...)
  1053  			if err != nil {
  1054  				return err
  1055  			}
  1056  
  1057  			tensordotPermDense := tensordotPerm.(*tensor.Dense)
  1058  
  1059  			d := idvd.(*tensor.Dense)
  1060  			d.Add(tensordotPermDense, tensor.UseUnsafe()) // TODO: Should output directly into d and save the add
  1061  
  1062  		default:
  1063  			return errors.Errorf(nyiTypeFail, "Do Diff (hack)", st)
  1064  		}
  1065  	}
  1066  
  1067  	return nil
  1068  }
  1069  
  1070  func (op tensordotOp) DiffWRT(inputs int) []bool {
  1071  	retVal := make([]bool, inputs)
  1072  	for i := range retVal {
  1073  		retVal[i] = true
  1074  	}
  1075  	return retVal
  1076  }
  1077  
  1078  func (op tensordotOp) SymDiff(inputs Nodes, output *Node, grad *Node) (retVal Nodes, err error) {
  1079  	if err = checkArity(op, len(inputs)); err != nil {
  1080  		return
  1081  	}
  1082  
  1083  	retVal = make(Nodes, len(inputs))
  1084  
  1085  	for inNr, in := range inputs {
  1086  		// abuse of language below: "i" up front will refer to current "in"
  1087  		// "other" for the other input (there are only two)
  1088  
  1089  		// Who's derivative are we calculating?
  1090  		var iAxes []int
  1091  		var otherAxes []int
  1092  		var iWasFirstArgument bool
  1093  		var other *Node
  1094  
  1095  		if 0 == inNr {
  1096  			iAxes = op.aAxes
  1097  			otherAxes = op.bAxes
  1098  			other = inputs[1]
  1099  			iWasFirstArgument = true
  1100  		} else {
  1101  			iAxes = op.bAxes
  1102  			otherAxes = op.aAxes
  1103  			other = inputs[0]
  1104  			iWasFirstArgument = false
  1105  		}
  1106  
  1107  		// Below a tensordot will be performed: Its output axes will be in the wrong order w.r.t to the input.
  1108  		// What is the correct permutation/pattern?
  1109  		iAxesCoSorted := make([]int, len(iAxes))
  1110  		for index, value := range iAxes {
  1111  			iAxesCoSorted[index] = value
  1112  		}
  1113  
  1114  		otherAxesSorted := make([]int, len(otherAxes))
  1115  		for index, value := range otherAxes {
  1116  			otherAxesSorted[index] = value
  1117  		}
  1118  
  1119  		sortUniqueIntWithImitator(otherAxesSorted, iAxesCoSorted)
  1120  
  1121  		pattern := make([]int, len(in.shape))
  1122  		counter := len(iAxes)
  1123  
  1124  		for patternIndex := 0; patternIndex < len(pattern); patternIndex++ {
  1125  			iAxesCoSortedIndex := contains(iAxesCoSorted, patternIndex)
  1126  			if 0 <= iAxesCoSortedIndex {
  1127  				pattern[patternIndex] = iAxesCoSortedIndex
  1128  			} else {
  1129  				pattern[patternIndex] = counter
  1130  				counter++
  1131  			}
  1132  		}
  1133  
  1134  		// Which axes of the other tensor and the output should be contracted?
  1135  		// Other tensor: All axes that weren't contracted (with i ;-) ) in the original tensordot
  1136  		// With the exception of scalars
  1137  		dOtherAxes := make([]int, other.Dims())
  1138  		if !other.Shape().IsScalarEquiv() {
  1139  			var dOtherAxesIndex int
  1140  
  1141  			for axis := 0; axis < other.Dims(); axis++ {
  1142  				if 0 > contains(otherAxes, axis) {
  1143  					dOtherAxes[dOtherAxesIndex] = axis
  1144  					dOtherAxesIndex++
  1145  				}
  1146  			}
  1147  			dOtherAxes = dOtherAxes[0:dOtherAxesIndex]
  1148  		}
  1149  
  1150  		// Grad: All axes which belong to other in the output of original tensordot, so this depends on input ordering
  1151  		dGradAxes := make([]int, len(dOtherAxes))
  1152  		if iWasFirstArgument {
  1153  			gradAxesStart := grad.Dims() - len(dOtherAxes)
  1154  
  1155  			for axis := 0; axis < len(dOtherAxes); axis++ {
  1156  				dGradAxes[axis] = gradAxesStart + axis
  1157  			}
  1158  		} else {
  1159  			for axis := 0; axis < len(dOtherAxes); axis++ {
  1160  				dGradAxes[axis] = axis
  1161  			}
  1162  		}
  1163  
  1164  		// perform tensordot
  1165  		var tensordot *Node
  1166  		switch {
  1167  		case grad.Shape().IsScalarEquiv():
  1168  			if tensordot, err = HadamardProd(other, grad); err != nil {
  1169  				err = SymDiffError{
  1170  					nodes:  inputs,
  1171  					single: other,
  1172  					grad:   grad,
  1173  					err:    errors.Wrap(err, "While performing tensordot of (other × grad) in SymDiff of `tensordotOp`. Nodes() returns the inputs. Node() returns the `other`, Grad() returns grad`"),
  1174  				}
  1175  				return nil, err
  1176  			}
  1177  
  1178  		case other.Shape().IsVector() && grad.Shape().IsVector() && 0 == len(dOtherAxes): // TensorMul does not support creating matrix from two vectors
  1179  			// Reformat vectors, so that MatMul will create a matrix from them
  1180  			otherCorrectShape := other
  1181  			if !other.IsColVec() {
  1182  				otherVecDims, err := (other.Shape()).DimSize(0)
  1183  				if err != nil {
  1184  					err = SymDiffError{
  1185  						nodes:  inputs,
  1186  						single: other,
  1187  						err:    errors.Wrap(err, "While getting .DimSize(0) of other, while SymDiff-ing. Nodes() returns the inputs, Node() returns `other`. There is no Grad or Grad map."),
  1188  					}
  1189  					return nil, err
  1190  				}
  1191  
  1192  				if otherCorrectShape, err = Reshape(other, tensor.Shape{otherVecDims, 1}); err != nil {
  1193  					return nil, err
  1194  				}
  1195  			}
  1196  
  1197  			gradCorrectShape := grad
  1198  			if !grad.IsRowVec() {
  1199  				gradVecDims, err := (grad.Shape()).DimSize(0)
  1200  
  1201  				if err != nil {
  1202  					return nil, err
  1203  				}
  1204  
  1205  				if gradCorrectShape, err = Reshape(grad, tensor.Shape{1, gradVecDims}); err != nil {
  1206  					return nil, err
  1207  				}
  1208  			}
  1209  
  1210  			op := linAlgBinOp{āBinaryOperator: matMulOperator}
  1211  			if tensordot, err = binOpNode(op, otherCorrectShape, gradCorrectShape); err != nil {
  1212  				return nil, err
  1213  			}
  1214  
  1215  		default:
  1216  			tensordot, err = Tensordot(dOtherAxes, dGradAxes, other, grad)
  1217  		}
  1218  
  1219  		if err != nil {
  1220  			return nil, err
  1221  		}
  1222  
  1223  		ret, err := Transpose(tensordot, pattern...)
  1224  
  1225  		if err != nil {
  1226  			return nil, err
  1227  		}
  1228  
  1229  		retVal[inNr] = ret
  1230  	}
  1231  
  1232  	return retVal, nil
  1233  }