gorgonia.org/gorgonia@v0.9.17/operatorPointwise_unary_const.go (about) 1 package gorgonia 2 3 import ( 4 "fmt" 5 "math" 6 7 "github.com/chewxy/math32" 8 ) 9 10 var ( 11 /* float64 */ 12 13 // non differentiable 14 absf64 = sf64UnaryOperator(math.Abs) 15 signf64 = sf64UnaryOperator(_signf64) 16 ceilf64 = sf64UnaryOperator(math.Ceil) 17 floorf64 = sf64UnaryOperator(math.Floor) 18 19 // differentiable 20 sinf64 = sf64UnaryOperator(math.Sin) 21 cosf64 = sf64UnaryOperator(math.Cos) 22 expf64 = sf64UnaryOperator(math.Exp) 23 lnf64 = sf64UnaryOperator(math.Log) 24 log2f64 = sf64UnaryOperator(math.Log2) 25 negf64 = sf64UnaryOperator(_negf64) 26 squaref64 = sf64UnaryOperator(_squaref64) 27 sqrtf64 = sf64UnaryOperator(math.Sqrt) 28 inversef64 = sf64UnaryOperator(_inversef64) 29 inverseSqrtf64 = sf64UnaryOperator(_inverseSqrtf64) 30 31 // activation functions 32 cubef64 = sf64UnaryOperator(_cubef64) 33 tanhf64 = sf64UnaryOperator(_tanhf64) 34 sigmoidf64 = sf64UnaryOperator(_sigmoidf64) 35 36 // numerical stabilization optimization 37 log1pf64 = sf64UnaryOperator(math.Log1p) 38 expm1f64 = sf64UnaryOperator(math.Expm1) 39 softplusf64 = sf64UnaryOperator(_softplusf64) 40 // softplus isn't necessarily only a numerical stabilization op 41 // (you can use it elsewhere), but I included it under numerical optimization 42 43 /* Float32 */ 44 45 // non differentiable 46 absf32 = sf32UnaryOperator(math32.Abs) 47 signf32 = sf32UnaryOperator(_signf32) 48 ceilf32 = sf32UnaryOperator(math32.Ceil) 49 floorf32 = sf32UnaryOperator(math32.Floor) 50 51 // start differentiable 52 sinf32 = sf32UnaryOperator(math32.Sin) 53 cosf32 = sf32UnaryOperator(math32.Cos) 54 expf32 = sf32UnaryOperator(math32.Exp) 55 lnf32 = sf32UnaryOperator(math32.Log) 56 log2f32 = sf32UnaryOperator(math32.Log2) 57 negf32 = sf32UnaryOperator(_negf32) 58 squaref32 = sf32UnaryOperator(_squaref32) 59 sqrtf32 = sf32UnaryOperator(math32.Sqrt) 60 inversef32 = sf32UnaryOperator(_inversef32) 61 inverseSqrtf32 = sf32UnaryOperator(_inverseSqrtf32) 62 63 // typically used in activation functions 64 cubef32 = sf32UnaryOperator(_cubef32) 65 tanhf32 = sf32UnaryOperator(_tanhf32) 66 sigmoidf32 = sf32UnaryOperator(_sigmoidf32) 67 68 // numerical stabilization optimization 69 log1pf32 = sf32UnaryOperator(math32.Log1p) 70 expm1f32 = sf32UnaryOperator(math32.Expm1) 71 softplusf32 = sf32UnaryOperator(_softplusf32) 72 ) 73 74 type ʘUnaryOperatorType byte 75 76 const ( 77 absOpType ʘUnaryOperatorType = iota 78 signOpType 79 ceilOpType 80 floorOpType 81 82 // start differentiable 83 sinOpType 84 cosOpType 85 expOpType 86 lnOpType 87 log2OpType 88 negOpType 89 squareOpType 90 sqrtOpType 91 inverseOpType // multiplicative inverse 92 inverseSqrtOpType // 1/sqrt(x) 93 94 // typically used in activation functions 95 cubeOpType 96 tanhOpType 97 sigmoidOpType 98 99 // optimization related 100 log1pOpType 101 expm1OpType 102 softplusOpType 103 104 maxʘUnaryOperator // delimits end of all possible unary ops 105 ) 106 107 func (u ʘUnaryOperatorType) String() string { 108 if u >= maxʘUnaryOperator { 109 return fmt.Sprintf("UNSUPPORTED UNARY OPERATOR (%d); max: %d", u, maxʘUnaryOperator) 110 } 111 112 return ʘUnaryOpStrs[u] 113 } 114 115 // ʘUnaryOpStrs is the string representation for a unaryOpType 116 // It should be held constant. 117 var ʘUnaryOpStrs = [maxʘUnaryOperator]string{ 118 "abs", "sign", "ceil", "floor", 119 "sin", "cos", "exp", 120 "ln", "log2", "neg", "square", "sqrt", 121 "inv", "invSqrt", 122 "cube", "tanh", "sigmoid", 123 124 "log1p", "expm1", "softplus", 125 } 126 127 // ʘUnaryOpDifferentiable is the array of whether a unary operator is differentiable 128 // It should be held constant 129 var ʘUnaryOpDifferentiable = [maxʘUnaryOperator]bool{ 130 true, false, false, false, 131 true, true, true, 132 true, true, true, true, true, 133 true, true, 134 true, true, true, 135 136 true, true, true, 137 } 138 139 var ʘUnaryOpDiffExprs = [maxʘUnaryOperator]func(x, y, gradY *Node) (*Node, error){ 140 absDiffExpr, nondiffUnaryOpExpr, nondiffUnaryOpExpr, nondiffUnaryOpExpr, 141 sinDiffExpr, cosDiffExpr, expDiffExpr, 142 lnDiffExpr, log2DiffExpr, negDiffExpr, squareDiffExpr, sqrtDiffExpr, 143 inverseDiffExpr, inverseSqrtDiffExpr, cubeDiffExpr, tanhDiffExpr, sigmoidDiffExpr, 144 145 log1pDiffExpr, expm1DiffExpr, softplusDiffExpr, 146 } 147 148 var ʘUnaryOpDiffFns = [maxʘUnaryOperator]func(x, y *Node) error{ 149 absDiff, nondiffUnaryOp, nondiffUnaryOp, nondiffUnaryOp, 150 sinDiff, cosDiff, expDiff, 151 lnDiff, log2Diff, negDiff, squareDiff, sqrtDiff, 152 inverseDiff, inverseSqrtDiff, cubeDiff, tanhDiff, sigmoidDiff, 153 154 log1pDiff, expm1Diff, softplusDiff, 155 } 156 157 var sf64UnaryOperators = [maxʘUnaryOperator]*sf64UnaryOperator{ 158 &absf64, 159 &signf64, 160 &ceilf64, 161 &floorf64, 162 &sinf64, 163 &cosf64, 164 &expf64, 165 &lnf64, 166 &log2f64, 167 &negf64, 168 &squaref64, 169 &sqrtf64, 170 &inversef64, 171 &inverseSqrtf64, 172 &cubef64, 173 &tanhf64, 174 &sigmoidf64, 175 176 &log1pf64, 177 &expm1f64, 178 &softplusf64, 179 } 180 181 var sf32UnaryOperators = [maxʘUnaryOperator]*sf32UnaryOperator{ 182 &absf32, 183 &signf32, 184 &ceilf32, 185 &floorf32, 186 &sinf32, 187 &cosf32, 188 &expf32, 189 &lnf32, 190 &log2f32, 191 &negf32, 192 &squaref32, 193 &sqrtf32, 194 &inversef32, 195 &inverseSqrtf32, 196 &cubef32, 197 &tanhf32, 198 &sigmoidf32, 199 200 &log1pf32, 201 &expm1f32, 202 &softplusf32, 203 }