gorgonia.org/gorgonia@v0.9.17/operatorPointwise_binary_const.go (about) 1 package gorgonia 2 3 import "gorgonia.org/tensor" 4 5 var ( 6 /* scalar-tensor float64 and vice versa */ 7 8 // arith 9 tadd = denseBinOp(tensor.Add) 10 tsub = denseBinOp(tensor.Sub) 11 tmul = denseBinOp(tensor.Mul) 12 tdiv = denseBinOp(tensor.Div) 13 tpow = denseBinOp(tensor.Pow) 14 15 // cmp 16 tlt = denseCmpOp(tensor.Lt) 17 tgt = denseCmpOp(tensor.Gt) 18 tlte = denseCmpOp(tensor.Lte) 19 tgte = denseCmpOp(tensor.Gte) 20 teq = denseCmpOp(tensor.ElEq) 21 tne = denseCmpOp(tensor.ElNe) 22 ) 23 24 type denseBinOp func(a, b interface{}, opts ...tensor.FuncOpt) (tensor.Tensor, error) 25 type denseCmpOp func(a, b interface{}, opts ...tensor.FuncOpt) (tensor.Tensor, error) 26 27 type ʘBinaryOperatorType byte 28 29 const ( 30 // arith 31 addOpType ʘBinaryOperatorType = iota 32 subOpType 33 mulOpType 34 divOpType 35 powOpType 36 37 // cmp 38 ltOpType 39 gtOpType 40 lteOpType 41 gteOpType 42 eqOpType 43 neOpType 44 45 maxʘBinaryOpType // delimits the end of all possible binOpType 46 ) 47 48 func (op ʘBinaryOperatorType) String() string { 49 return ʘBinOpStrs[op] 50 } 51 52 // ʘBinOpStrs is the string representation for a binOpType 53 // It should be held constant. 54 var ʘBinOpStrs = [maxʘBinaryOpType]string{ 55 // arith ops 56 "+", 57 "-", 58 "⊙", 59 "÷", 60 "^", 61 62 // cmp ops 63 "<", 64 ">", 65 "<=", 66 ">=", 67 "==", 68 "!=", 69 } 70 71 // ʘBinOpNames is the string representation for a binOpType 72 // It should be held constant. 73 var ʘBinOpNames = [maxʘBinaryOpType]string{ 74 // arith ops 75 "add", 76 "sub", 77 "mul", 78 "div", 79 "pow", 80 81 // cmp ops 82 "lt", 83 "gt", 84 "lte", 85 "gte", 86 "eq", 87 "ne", 88 } 89 90 // ʘBinOpCommutative is the array that stores whether a binary operator is commutative 91 // It should be held constant. 92 var ʘBinOpCommutative = [maxʘBinaryOpType]bool{ 93 true, false, true, false, false, 94 false, false, false, false, true, true, 95 } 96 97 var ʘBinOpDiffExprs = [maxʘBinaryOpType]func(x, y, z, gradZ *Node) (Nodes, error){ 98 addDiffExpr, subDiffExpr, hadamardProdDiffExpr, hadamardDivDiffExpr, hadamardPowDiffExpr, 99 nondiffBinOpExpr, nondiffBinOpExpr, nondiffBinOpExpr, nondiffBinOpExpr, nondiffBinOpExpr, nondiffBinOpExpr, 100 } 101 102 var ʘBinOpDiffFns = [maxʘBinaryOpType]func(ctx ExecutionContext, x, y, z *Node) error{ 103 addDiff, subDiff, hadamardProdDiff, hadamardDivDiff, hadamardPowDiff, 104 nondiffBinOp, nondiffBinOp, nondiffBinOp, nondiffBinOp, nondiffBinOp, nondiffBinOp, 105 } 106 107 // isCommutative gives info about whether the operator is commutative 108 // For example: 109 // a + b == b + a 110 // will ALWAYS evaluate to true. The same cannot be said about subtraction: 111 // a - b != b - a 112 // While a-b *may* be equal to b-a, it is not guaranteed. Therefore subtraction 113 // is not commutative 114 func (op ʘBinaryOperatorType) isCommutative() bool { 115 if op >= maxʘBinaryOpType { 116 panic("isCommutative() for unsupported BinOp undefined") 117 } 118 return ʘBinOpCommutative[op] 119 } 120 121 func (op ʘBinaryOperatorType) diffWRT(inputs int) []bool { 122 if inputs != 2 { 123 panic("binary operator only supports 2 inputs") 124 } 125 126 if op.isArith() { 127 return []bool{true, true} 128 } 129 return []bool{false, false} 130 } 131 132 // isArith indicates if the binary operator is an arithmetic type 133 func (op ʘBinaryOperatorType) isArith() bool { 134 switch op { 135 case addOpType, subOpType, mulOpType, divOpType, powOpType: 136 return true 137 default: 138 return false 139 } 140 } 141 142 var binOps = [maxʘBinaryOpType]*denseBinOp{ 143 &tadd, 144 &tsub, 145 &tmul, 146 &tdiv, 147 &tpow, 148 nil, // lt 149 nil, // gt 150 nil, // lte 151 nil, // gte 152 nil, // eq 153 nil, // ne 154 } 155 156 var cmpOps = [maxʘBinaryOpType]*denseCmpOp{ 157 nil, // add 158 nil, // sub 159 nil, // mul 160 nil, // div 161 nil, // pow 162 &tlt, 163 &tgt, 164 &tlte, 165 &tgte, 166 &teq, 167 &tne, 168 }