gorgonia.org/gorgonia@v0.9.17/stabilization.go (about)

     1  package gorgonia
     2  
     3  import "github.com/pkg/errors"
     4  
     5  var unaryOpStabilizationFns = make(map[ʘUnaryOperatorType][]func(*Node) (*Node, error))
     6  var binOpStabilizationFns = make(map[ʘBinaryOperatorType][]func(*Node, *Node) (*Node, error))
     7  
     8  func init() {
     9  	unaryOpStabilizationFns[lnOpType] = []func(*Node) (*Node, error){
    10  		logSigmoidStabilization,
    11  		logStabilization,
    12  		logSoftmaxStabilization,
    13  	}
    14  	binOpStabilizationFns[subOpType] = []func(*Node, *Node) (*Node, error){
    15  		exp1mStabilization,
    16  		oneMinusSigmoidStabilization,
    17  	}
    18  	unaryOpStabilizationFns[log1pOpType] = []func(*Node) (*Node, error){
    19  		log1pExpStabilization,
    20  		log1pNegSigmoidStabilization,
    21  	}
    22  	unaryOpStabilizationFns[negOpType] = []func(*Node) (*Node, error){negNegOptimization}
    23  }
    24  
    25  // logStabilization converts
    26  // 	log(1+a) or log(a+1) to log1p(a)
    27  //	log(1-a) to log1p(-a)
    28  // place before log; a should be positive.
    29  func logStabilization(a *Node) (retVal *Node, err error) {
    30  	stabLogf("Stabilizing log(1+a) of %v", a)
    31  	enterLogScope()
    32  	defer leaveLogScope()
    33  
    34  	var x *Node
    35  	var aop elemBinOp
    36  	var ok bool
    37  
    38  	if aop, ok = a.op.(elemBinOp); !ok {
    39  		return a, noStabilizationErr{}
    40  	}
    41  	input0 := a.children[0]
    42  	input1 := a.children[1]
    43  
    44  	stabLogf("input0: %v", input0.Name())
    45  	stabLogf("input1: %v", input1.Name())
    46  	bot := aop.ʘBinaryOperator.binOpType()
    47  	switch bot {
    48  	case addOpType:
    49  		if cnst, ok := input0.op.(constant); ok {
    50  			if constEq(cnst, onef32ConstOp) || constEq(cnst, onef64ConstOp) {
    51  				x = input1
    52  				break
    53  			}
    54  		}
    55  
    56  		if cnst, ok := input1.op.(constant); ok {
    57  			if constEq(cnst, onef32ConstOp) || constEq(cnst, onef64ConstOp) {
    58  				x = input0
    59  				break
    60  			}
    61  		}
    62  
    63  		return a, noStabilizationErr{}
    64  	case subOpType:
    65  		if cnst, ok := input0.op.(constant); !ok || (ok && !constEq(cnst, onef32ConstOp) && !constEq(cnst, onef64ConstOp)) {
    66  			return a, noStabilizationErr{}
    67  		}
    68  		x = input1
    69  	default:
    70  		return a, noStabilizationErr{}
    71  	}
    72  
    73  	g := a.g
    74  	g.removeAllEdgesFrom(a) // remove all references
    75  	g.RemoveNode(a)
    76  	defer returnNode(a) // send it back to the pool, since it is literally useless now
    77  
    78  	if bot == subOpType {
    79  		if retVal, err = Neg(x); err == nil {
    80  			return Log1p(retVal)
    81  		}
    82  		return nil, errors.Wrap(err, negFail)
    83  	}
    84  	return Log1p(x)
    85  }
    86  
    87  // expStabilization converts exp(x)-1 to expm1(x)
    88  // place before sub; i0 should be exp(x); i1 should be 1
    89  func exp1mStabilization(a, b *Node) (retVal *Node, err error) {
    90  	stabLogf("Stabilizing exp(x)-1 to expm1(x) of %v and %v", a, b)
    91  	enterLogScope()
    92  	defer leaveLogScope()
    93  
    94  	if cnst, ok := b.op.(constant); !ok || (ok && !constEq(cnst, onef32ConstOp) && !constEq(cnst, onef64ConstOp)) {
    95  		return nil, noStabilizationErr{}
    96  	}
    97  
    98  	if euo, ok := a.op.(elemUnaryOp); !ok || euo.unaryOpType() != expOpType {
    99  		return nil, noStabilizationErr{}
   100  	}
   101  
   102  	op := newElemUnaryOp(expm1OpType, a.children[0])
   103  	return ApplyOp(op, a.children[0])
   104  }
   105  
   106  // oneMinusSigmoidStabilization stabilizes 1-sigmoid(x) by replacing it with sigmoid(-x)
   107  // place before sub
   108  func oneMinusSigmoidStabilization(a, b *Node) (retVal *Node, err error) {
   109  	stabLogf("Stabilizing 1-sigmoid(x) to sigmoid(-x) of %v and %v", a, b)
   110  	enterLogScope()
   111  	defer leaveLogScope()
   112  
   113  	if cnst, ok := a.op.(constant); !ok || (ok && !constEq(cnst, onef32ConstOp) && !constEq(cnst, onef64ConstOp)) {
   114  		return nil, noStabilizationErr{}
   115  	}
   116  
   117  	if euo, ok := b.op.(elemUnaryOp); !ok || euo.unaryOpType() != sigmoidOpType {
   118  		return nil, noStabilizationErr{}
   119  	}
   120  
   121  	x := b.children[0]
   122  	if retVal, err = Neg(x); err == nil {
   123  		return Sigmoid(retVal)
   124  	}
   125  	return nil, errors.Wrap(err, negFail)
   126  }
   127  
   128  // logSigmoidStabilization stabilizes log(sigmoid(x)) by replacing it with -softplus(-x)
   129  // place before log; a should be sigmoid(x)
   130  func logSigmoidStabilization(a *Node) (retVal *Node, err error) {
   131  	stabLogf("Stabilizing log sigmoid of %v", a)
   132  	enterLogScope()
   133  	defer leaveLogScope()
   134  
   135  	if euo, ok := a.op.(elemUnaryOp); !ok || (ok && euo.unaryOpType() != sigmoidOpType) {
   136  		return a, noStabilizationErr{}
   137  	}
   138  
   139  	x := a.children[0]
   140  	stabLogf("x : %v", x.Name())
   141  
   142  	if retVal, err = Neg(x); err == nil {
   143  		if retVal, err = Softplus(retVal); err == nil {
   144  			retVal, err = Neg(retVal)
   145  			if err != nil {
   146  				return nil, errors.Wrap(err, negFail)
   147  			}
   148  			return retVal, nil
   149  		}
   150  		return nil, errors.Wrap(err, softplusFail)
   151  	}
   152  	return nil, errors.Wrap(err, negFail)
   153  }
   154  
   155  // log1pExpStabilization stabilizes log1p(exp(x)) by substituting it with softplus(x)
   156  // place before log1p; a should be exp(x)
   157  func log1pExpStabilization(a *Node) (retVal *Node, err error) {
   158  	stabLogf("Stabilizing log1p(exp(x)) of %v", a)
   159  	enterLogScope()
   160  	defer leaveLogScope()
   161  
   162  	if euo, ok := a.op.(elemUnaryOp); !ok || (ok && euo.unaryOpType() != expOpType) {
   163  		stabLogf("op: %v; %v", a.op, a.children)
   164  		return a, noStabilizationErr{}
   165  	}
   166  
   167  	x := a.children[0]
   168  	stabLogf("OKKKKK")
   169  	return Softplus(x)
   170  }
   171  
   172  // log1pNegSigmoidStabilization stabilizes log1p(-sigmoid(x)) by substituting it with -softplus(x)
   173  // place before log1p;  a should be -sigmoid(x)
   174  func log1pNegSigmoidStabilization(a *Node) (retVal *Node, err error) {
   175  	stabLogf("Stabilizing log1p(-sigmoid(x)) : %v", a)
   176  	enterLogScope()
   177  	defer leaveLogScope()
   178  
   179  	if euo, ok := a.op.(elemUnaryOp); !ok || (ok && euo.unaryOpType() != negOpType) {
   180  		return a, noStabilizationErr{}
   181  	}
   182  
   183  	if euo, ok := a.children[0].op.(elemUnaryOp); !ok || (ok && euo.unaryOpType() != sigmoidOpType) {
   184  		return a, noStabilizationErr{}
   185  	}
   186  
   187  	x := a.children[0].children[0]
   188  
   189  	stabLogf("x : %v", x.Name())
   190  
   191  	if retVal, err = Softplus(x); err == nil {
   192  		retVal, err = Neg(retVal)
   193  		if err != nil {
   194  			return nil, errors.Wrap(err, negFail)
   195  		}
   196  		return retVal, nil
   197  	}
   198  	return nil, errors.Wrap(err, softplusFail)
   199  }
   200  
   201  // logSoftmaxStabilization converts
   202  // 	log(softmax(a)) to softmax{isLog: true}(a)
   203  //	log(a * softmax(b)) to log(a) + softmax{isLog: true}(b)
   204  func logSoftmaxStabilization(a *Node) (retVal *Node, err error) {
   205  	stabLogf("Stabilizing log(softmax) of %v", a)
   206  	enterLogScope()
   207  	defer leaveLogScope()
   208  
   209  	switch op := a.op.(type) {
   210  	case *softmaxOp:
   211  		op.isLog = true
   212  		return a, nil
   213  	case elemBinOp:
   214  		if op.ʘBinaryOperator.binOpType() == mulOpType {
   215  			fst := a.children[0]
   216  			snd := a.children[1]
   217  
   218  			var hasSm bool
   219  			var smop1, smop2 *softmaxOp
   220  			if smop, ok := fst.op.(*softmaxOp); ok {
   221  				hasSm = true
   222  				smop1 = smop
   223  			}
   224  			if smop, ok := snd.op.(*softmaxOp); ok {
   225  				hasSm = true
   226  				smop2 = smop
   227  			}
   228  
   229  			if hasSm {
   230  				var newFst, newSnd *Node
   231  				switch {
   232  				case smop1 != nil && smop2 == nil:
   233  					smop1.isLog = true
   234  					newFst = fst
   235  					if newSnd, err = Log(snd); err != nil {
   236  						return nil, err
   237  					}
   238  				case smop1 == nil && smop2 != nil:
   239  					smop2.isLog = true
   240  					newSnd = snd
   241  					if newFst, err = Log(fst); err != nil {
   242  						return nil, err
   243  					}
   244  				case smop1 != nil && smop2 != nil:
   245  					smop1.isLog = true
   246  					smop2.isLog = true
   247  					newFst = fst
   248  					newSnd = snd
   249  				default:
   250  					return a, noStabilizationErr{}
   251  				}
   252  
   253  				// g := a.g
   254  				// g.removeAllEdgesFrom(a) // remove all references
   255  				// g.RemoveNode(a)
   256  				// returnNode(a) // send it back to the pool, since it is literally useless now
   257  				return Add(newFst, newSnd)
   258  			}
   259  
   260  		}
   261  	}
   262  	return a, noStabilizationErr{}
   263  
   264  }
   265  
   266  /* Graph Optimizations */
   267  
   268  // negNegOptimization optimizes away -(-x) to just return x
   269  // place before neg
   270  func negNegOptimization(a *Node) (retVal *Node, err error) {
   271  	stabLogf("Optimizing -(-x)")
   272  	enterLogScope()
   273  	defer leaveLogScope()
   274  
   275  	if euo, ok := a.op.(elemUnaryOp); !ok || (ok && euo.unaryOpType() != negOpType) {
   276  		return a, noStabilizationErr{}
   277  	}
   278  
   279  	x := a.children[0]
   280  	return x, nil
   281  }