gorgonia.org/gorgonia@v0.9.17/operatorLinAlg_const.go (about) 1 package gorgonia 2 3 import "github.com/chewxy/hm" 4 5 // āBinOpStrs is the string representation for binLAOperator 6 // It should be held constant 7 var āBinOpStrs = [maxĀBinaryOperator]string{ 8 "×", // matMulOperator 9 "×", // matVecMulOperator 10 "⋅", // vecDotOperator 11 "⊗", // outerProdOperator 12 "×××", // batchedMatMulOperator 13 } 14 15 var āBinOpDiffExprs = [maxĀBinaryOperator]func(tA, tB bool, x, y, z, grad *Node) (Nodes, error){ 16 matMulDiffExpr, 17 matVecMulDiffExpr, 18 vecDotDiffExpr, 19 outerProdDiffExpr, 20 batchedMatMulDiffExpr, 21 } 22 23 var āBinOpDiffs = [maxĀBinaryOperator]func(ctx ExecutionContext, tA, tB bool, x, y, z *Node) error{ 24 matMulDiff, 25 matVecMulDiff, 26 vecDotDiff, 27 outerProdDiff, 28 batchedMatMulDiff, 29 } 30 31 var āBinOpTypes = [maxĀBinaryOperator]func() hm.Type{ 32 matMulType, 33 matVecMulType, 34 vecDotType, 35 outerProdType, 36 batchedMatMulType, 37 } 38 39 /* TYPES FOR LINALG BINARY OP*/ 40 41 // matVecMulOp is a function with this type: 42 // matVecMulOp :: (Float a) ⇒ Vector a → Matrix a → Vector a 43 // 44 // For the moment only floats are allowed 45 func matVecMulType() hm.Type { 46 a := hm.TypeVariable('a') 47 v := makeTensorType(1, a) 48 m := makeTensorType(2, a) 49 50 return hm.NewFnType(m, v, v) 51 } 52 53 // matMulOp is a function with this type: 54 // matMulOp :: (Float a) ⇒ Matrix a → Matrix a → Matrix a 55 // 56 // For the moment only floats are allowed 57 func matMulType() hm.Type { 58 a := hm.TypeVariable('a') 59 m := makeTensorType(2, a) 60 61 return hm.NewFnType(m, m, m) 62 } 63 64 // vecDotOp is a function with this type: 65 // vecDotOp :: (Float a) ⇒ Vector a → Vector a → a 66 // 67 // For the moment only floats are allowed 68 func vecDotType() hm.Type { 69 a := hm.TypeVariable('a') 70 v := makeTensorType(1, a) 71 72 return hm.NewFnType(v, v, a) 73 } 74 75 // outerProdOp is a function with this type: 76 // outerProdOp :: (Float a) ⇒ Vector a → Vector a → Matrix a 77 // 78 // For the moment only floats are allowed 79 func outerProdType() hm.Type { 80 a := hm.TypeVariable('a') 81 v := makeTensorType(1, a) 82 m := makeTensorType(2, a) 83 84 return hm.NewFnType(v, v, m) 85 } 86 87 func batchedMatMulType() hm.Type { 88 a := hm.TypeVariable('a') 89 return hm.NewFnType(a, a, a) 90 }