gorgonia.org/gorgonia@v0.9.17/stabilization.go (about) 1 package gorgonia 2 3 import "github.com/pkg/errors" 4 5 var unaryOpStabilizationFns = make(map[ʘUnaryOperatorType][]func(*Node) (*Node, error)) 6 var binOpStabilizationFns = make(map[ʘBinaryOperatorType][]func(*Node, *Node) (*Node, error)) 7 8 func init() { 9 unaryOpStabilizationFns[lnOpType] = []func(*Node) (*Node, error){ 10 logSigmoidStabilization, 11 logStabilization, 12 logSoftmaxStabilization, 13 } 14 binOpStabilizationFns[subOpType] = []func(*Node, *Node) (*Node, error){ 15 exp1mStabilization, 16 oneMinusSigmoidStabilization, 17 } 18 unaryOpStabilizationFns[log1pOpType] = []func(*Node) (*Node, error){ 19 log1pExpStabilization, 20 log1pNegSigmoidStabilization, 21 } 22 unaryOpStabilizationFns[negOpType] = []func(*Node) (*Node, error){negNegOptimization} 23 } 24 25 // logStabilization converts 26 // log(1+a) or log(a+1) to log1p(a) 27 // log(1-a) to log1p(-a) 28 // place before log; a should be positive. 29 func logStabilization(a *Node) (retVal *Node, err error) { 30 stabLogf("Stabilizing log(1+a) of %v", a) 31 enterLogScope() 32 defer leaveLogScope() 33 34 var x *Node 35 var aop elemBinOp 36 var ok bool 37 38 if aop, ok = a.op.(elemBinOp); !ok { 39 return a, noStabilizationErr{} 40 } 41 input0 := a.children[0] 42 input1 := a.children[1] 43 44 stabLogf("input0: %v", input0.Name()) 45 stabLogf("input1: %v", input1.Name()) 46 bot := aop.ʘBinaryOperator.binOpType() 47 switch bot { 48 case addOpType: 49 if cnst, ok := input0.op.(constant); ok { 50 if constEq(cnst, onef32ConstOp) || constEq(cnst, onef64ConstOp) { 51 x = input1 52 break 53 } 54 } 55 56 if cnst, ok := input1.op.(constant); ok { 57 if constEq(cnst, onef32ConstOp) || constEq(cnst, onef64ConstOp) { 58 x = input0 59 break 60 } 61 } 62 63 return a, noStabilizationErr{} 64 case subOpType: 65 if cnst, ok := input0.op.(constant); !ok || (ok && !constEq(cnst, onef32ConstOp) && !constEq(cnst, onef64ConstOp)) { 66 return a, noStabilizationErr{} 67 } 68 x = input1 69 default: 70 return a, noStabilizationErr{} 71 } 72 73 g := a.g 74 g.removeAllEdgesFrom(a) // remove all references 75 g.RemoveNode(a) 76 defer returnNode(a) // send it back to the pool, since it is literally useless now 77 78 if bot == subOpType { 79 if retVal, err = Neg(x); err == nil { 80 return Log1p(retVal) 81 } 82 return nil, errors.Wrap(err, negFail) 83 } 84 return Log1p(x) 85 } 86 87 // expStabilization converts exp(x)-1 to expm1(x) 88 // place before sub; i0 should be exp(x); i1 should be 1 89 func exp1mStabilization(a, b *Node) (retVal *Node, err error) { 90 stabLogf("Stabilizing exp(x)-1 to expm1(x) of %v and %v", a, b) 91 enterLogScope() 92 defer leaveLogScope() 93 94 if cnst, ok := b.op.(constant); !ok || (ok && !constEq(cnst, onef32ConstOp) && !constEq(cnst, onef64ConstOp)) { 95 return nil, noStabilizationErr{} 96 } 97 98 if euo, ok := a.op.(elemUnaryOp); !ok || euo.unaryOpType() != expOpType { 99 return nil, noStabilizationErr{} 100 } 101 102 op := newElemUnaryOp(expm1OpType, a.children[0]) 103 return ApplyOp(op, a.children[0]) 104 } 105 106 // oneMinusSigmoidStabilization stabilizes 1-sigmoid(x) by replacing it with sigmoid(-x) 107 // place before sub 108 func oneMinusSigmoidStabilization(a, b *Node) (retVal *Node, err error) { 109 stabLogf("Stabilizing 1-sigmoid(x) to sigmoid(-x) of %v and %v", a, b) 110 enterLogScope() 111 defer leaveLogScope() 112 113 if cnst, ok := a.op.(constant); !ok || (ok && !constEq(cnst, onef32ConstOp) && !constEq(cnst, onef64ConstOp)) { 114 return nil, noStabilizationErr{} 115 } 116 117 if euo, ok := b.op.(elemUnaryOp); !ok || euo.unaryOpType() != sigmoidOpType { 118 return nil, noStabilizationErr{} 119 } 120 121 x := b.children[0] 122 if retVal, err = Neg(x); err == nil { 123 return Sigmoid(retVal) 124 } 125 return nil, errors.Wrap(err, negFail) 126 } 127 128 // logSigmoidStabilization stabilizes log(sigmoid(x)) by replacing it with -softplus(-x) 129 // place before log; a should be sigmoid(x) 130 func logSigmoidStabilization(a *Node) (retVal *Node, err error) { 131 stabLogf("Stabilizing log sigmoid of %v", a) 132 enterLogScope() 133 defer leaveLogScope() 134 135 if euo, ok := a.op.(elemUnaryOp); !ok || (ok && euo.unaryOpType() != sigmoidOpType) { 136 return a, noStabilizationErr{} 137 } 138 139 x := a.children[0] 140 stabLogf("x : %v", x.Name()) 141 142 if retVal, err = Neg(x); err == nil { 143 if retVal, err = Softplus(retVal); err == nil { 144 retVal, err = Neg(retVal) 145 if err != nil { 146 return nil, errors.Wrap(err, negFail) 147 } 148 return retVal, nil 149 } 150 return nil, errors.Wrap(err, softplusFail) 151 } 152 return nil, errors.Wrap(err, negFail) 153 } 154 155 // log1pExpStabilization stabilizes log1p(exp(x)) by substituting it with softplus(x) 156 // place before log1p; a should be exp(x) 157 func log1pExpStabilization(a *Node) (retVal *Node, err error) { 158 stabLogf("Stabilizing log1p(exp(x)) of %v", a) 159 enterLogScope() 160 defer leaveLogScope() 161 162 if euo, ok := a.op.(elemUnaryOp); !ok || (ok && euo.unaryOpType() != expOpType) { 163 stabLogf("op: %v; %v", a.op, a.children) 164 return a, noStabilizationErr{} 165 } 166 167 x := a.children[0] 168 stabLogf("OKKKKK") 169 return Softplus(x) 170 } 171 172 // log1pNegSigmoidStabilization stabilizes log1p(-sigmoid(x)) by substituting it with -softplus(x) 173 // place before log1p; a should be -sigmoid(x) 174 func log1pNegSigmoidStabilization(a *Node) (retVal *Node, err error) { 175 stabLogf("Stabilizing log1p(-sigmoid(x)) : %v", a) 176 enterLogScope() 177 defer leaveLogScope() 178 179 if euo, ok := a.op.(elemUnaryOp); !ok || (ok && euo.unaryOpType() != negOpType) { 180 return a, noStabilizationErr{} 181 } 182 183 if euo, ok := a.children[0].op.(elemUnaryOp); !ok || (ok && euo.unaryOpType() != sigmoidOpType) { 184 return a, noStabilizationErr{} 185 } 186 187 x := a.children[0].children[0] 188 189 stabLogf("x : %v", x.Name()) 190 191 if retVal, err = Softplus(x); err == nil { 192 retVal, err = Neg(retVal) 193 if err != nil { 194 return nil, errors.Wrap(err, negFail) 195 } 196 return retVal, nil 197 } 198 return nil, errors.Wrap(err, softplusFail) 199 } 200 201 // logSoftmaxStabilization converts 202 // log(softmax(a)) to softmax{isLog: true}(a) 203 // log(a * softmax(b)) to log(a) + softmax{isLog: true}(b) 204 func logSoftmaxStabilization(a *Node) (retVal *Node, err error) { 205 stabLogf("Stabilizing log(softmax) of %v", a) 206 enterLogScope() 207 defer leaveLogScope() 208 209 switch op := a.op.(type) { 210 case *softmaxOp: 211 op.isLog = true 212 return a, nil 213 case elemBinOp: 214 if op.ʘBinaryOperator.binOpType() == mulOpType { 215 fst := a.children[0] 216 snd := a.children[1] 217 218 var hasSm bool 219 var smop1, smop2 *softmaxOp 220 if smop, ok := fst.op.(*softmaxOp); ok { 221 hasSm = true 222 smop1 = smop 223 } 224 if smop, ok := snd.op.(*softmaxOp); ok { 225 hasSm = true 226 smop2 = smop 227 } 228 229 if hasSm { 230 var newFst, newSnd *Node 231 switch { 232 case smop1 != nil && smop2 == nil: 233 smop1.isLog = true 234 newFst = fst 235 if newSnd, err = Log(snd); err != nil { 236 return nil, err 237 } 238 case smop1 == nil && smop2 != nil: 239 smop2.isLog = true 240 newSnd = snd 241 if newFst, err = Log(fst); err != nil { 242 return nil, err 243 } 244 case smop1 != nil && smop2 != nil: 245 smop1.isLog = true 246 smop2.isLog = true 247 newFst = fst 248 newSnd = snd 249 default: 250 return a, noStabilizationErr{} 251 } 252 253 // g := a.g 254 // g.removeAllEdgesFrom(a) // remove all references 255 // g.RemoveNode(a) 256 // returnNode(a) // send it back to the pool, since it is literally useless now 257 return Add(newFst, newSnd) 258 } 259 260 } 261 } 262 return a, noStabilizationErr{} 263 264 } 265 266 /* Graph Optimizations */ 267 268 // negNegOptimization optimizes away -(-x) to just return x 269 // place before neg 270 func negNegOptimization(a *Node) (retVal *Node, err error) { 271 stabLogf("Optimizing -(-x)") 272 enterLogScope() 273 defer leaveLogScope() 274 275 if euo, ok := a.op.(elemUnaryOp); !ok || (ok && euo.unaryOpType() != negOpType) { 276 return a, noStabilizationErr{} 277 } 278 279 x := a.children[0] 280 return x, nil 281 }