gorgonia.org/gorgonia@v0.9.17/operatorLinAlg_const.go (about)

     1  package gorgonia
     2  
     3  import "github.com/chewxy/hm"
     4  
     5  // āBinOpStrs is the string representation for binLAOperator
     6  // It should be held constant
     7  var āBinOpStrs = [maxĀBinaryOperator]string{
     8  	"×",   // matMulOperator
     9  	"×",   // matVecMulOperator
    10  	"⋅",   // vecDotOperator
    11  	"⊗",   // outerProdOperator
    12  	"×××", // batchedMatMulOperator
    13  }
    14  
    15  var āBinOpDiffExprs = [maxĀBinaryOperator]func(tA, tB bool, x, y, z, grad *Node) (Nodes, error){
    16  	matMulDiffExpr,
    17  	matVecMulDiffExpr,
    18  	vecDotDiffExpr,
    19  	outerProdDiffExpr,
    20  	batchedMatMulDiffExpr,
    21  }
    22  
    23  var āBinOpDiffs = [maxĀBinaryOperator]func(ctx ExecutionContext, tA, tB bool, x, y, z *Node) error{
    24  	matMulDiff,
    25  	matVecMulDiff,
    26  	vecDotDiff,
    27  	outerProdDiff,
    28  	batchedMatMulDiff,
    29  }
    30  
    31  var āBinOpTypes = [maxĀBinaryOperator]func() hm.Type{
    32  	matMulType,
    33  	matVecMulType,
    34  	vecDotType,
    35  	outerProdType,
    36  	batchedMatMulType,
    37  }
    38  
    39  /* TYPES FOR LINALG BINARY OP*/
    40  
    41  // matVecMulOp is a function with this type:
    42  //		matVecMulOp :: (Float a) ⇒ Vector a → Matrix a → Vector a
    43  //
    44  // For the moment only floats are allowed
    45  func matVecMulType() hm.Type {
    46  	a := hm.TypeVariable('a')
    47  	v := makeTensorType(1, a)
    48  	m := makeTensorType(2, a)
    49  
    50  	return hm.NewFnType(m, v, v)
    51  }
    52  
    53  // matMulOp is a function with this type:
    54  //		matMulOp :: (Float a) ⇒ Matrix a → Matrix a → Matrix a
    55  //
    56  // For the moment only floats are allowed
    57  func matMulType() hm.Type {
    58  	a := hm.TypeVariable('a')
    59  	m := makeTensorType(2, a)
    60  
    61  	return hm.NewFnType(m, m, m)
    62  }
    63  
    64  // vecDotOp is a function with this type:
    65  //		vecDotOp :: (Float a) ⇒ Vector a → Vector a → a
    66  //
    67  // For the moment only floats are allowed
    68  func vecDotType() hm.Type {
    69  	a := hm.TypeVariable('a')
    70  	v := makeTensorType(1, a)
    71  
    72  	return hm.NewFnType(v, v, a)
    73  }
    74  
    75  // outerProdOp is a function with this type:
    76  //		outerProdOp :: (Float a) ⇒ Vector a → Vector a → Matrix a
    77  //
    78  // For the moment only floats are allowed
    79  func outerProdType() hm.Type {
    80  	a := hm.TypeVariable('a')
    81  	v := makeTensorType(1, a)
    82  	m := makeTensorType(2, a)
    83  
    84  	return hm.NewFnType(v, v, m)
    85  }
    86  
    87  func batchedMatMulType() hm.Type {
    88  	a := hm.TypeVariable('a')
    89  	return hm.NewFnType(a, a, a)
    90  }