gorgonia.org/gorgonia@v0.9.17/stabilization_test.go (about) 1 package gorgonia 2 3 import ( 4 "io/ioutil" 5 "testing" 6 ) 7 8 func TestLogStabilization(t *testing.T) { 9 g := NewGraph() 10 11 // log(a+1) 12 x := NewVector(g, Float64, WithName("x"), WithShape(2)) 13 p := Must(Add(x, onef64)) 14 lp := Must(Log(p)) 15 if lp.children[0] != x { 16 t.Error("Oops.") 17 ioutil.WriteFile("log(a+1).dot", []byte(lp.ToDot()), 0644) 18 } 19 20 // log(1+a) 21 p = Must(Add(onef64, x)) 22 lp = Must(Log(p)) 23 if lp.children[0] != x { 24 t.Error("Oops.") 25 ioutil.WriteFile("log(1+a).dot", []byte(lp.ToDot()), 0644) 26 } 27 28 //log(1-a) 29 m := Must(Sub(onef64, x)) 30 lp = Must(Log(m)) 31 if euo, ok := lp.children[0].op.(elemUnaryOp); !ok { 32 t.Error("Oops.") 33 } else { 34 if euo.unaryOpType() != negOpType { 35 t.Error("Expected Neg Op") 36 } 37 38 if lp.children[0].children[0] != x { 39 t.Error("Oops.") 40 } 41 } 42 43 if t.Failed() { 44 ioutil.WriteFile("log(1-a).dot", []byte(lp.ToDot()), 0644) 45 } 46 47 //log(a-1) 48 m = Must(Sub(x, onef64)) 49 lp = Must(Log(m)) 50 //TODO: surely there is a better way to test? 51 if lp.children[0] == x { 52 t.Error("Oops.") 53 } 54 55 // log(a+2) 56 // We expect to keep the same operation tree, without stabilization 57 p = Must(Add(x, twof64)) 58 lp = Must(Log(p)) 59 if lp.children[0] != p { 60 t.Error("Oops.") 61 ioutil.WriteFile("log(a+2).dot", []byte(lp.ToDot()), 0644) 62 } 63 } 64 65 func TestExpStabilization(t *testing.T) { 66 g := NewGraph() 67 68 x := NewVector(g, Float64, WithName("x"), WithShape(2)) 69 e := Must(Exp(x)) 70 s := Must(Sub(e, onef64)) 71 72 if s.children[0] != x { 73 t.Error("oops") 74 } 75 76 if euo, ok := s.op.(elemUnaryOp); !ok || (ok && euo.unaryOpType() != expm1OpType) { 77 t.Error("oops") 78 } 79 80 if t.Failed() { 81 ioutil.WriteFile("exp(a)-1.dot", []byte(s.ToDot()), 0644) 82 } 83 } 84 85 func TestLogSigmoidStabilization(t *testing.T) { 86 g := NewGraph() 87 88 stabilization = true 89 x := NewVector(g, Float64, WithName("x"), WithShape(2)) 90 y := Must(Sigmoid(x)) 91 WithName("y")(y) 92 logY := Must(Log(y)) 93 WithName("log(sigmoid(x))")(logY) 94 95 if euo, ok := logY.op.(elemUnaryOp); !ok || (ok && euo.unaryOpType() != negOpType) { 96 t.Error("Oops") 97 } 98 99 if euo, ok := logY.children[0].op.(elemUnaryOp); !ok || (ok && euo.unaryOpType() != softplusOpType) { 100 t.Error("Oops2") 101 } 102 103 if euo, ok := logY.children[0].children[0].op.(elemUnaryOp); !ok || (ok && euo.unaryOpType() != negOpType) { 104 t.Error("Oops3") 105 } 106 107 if logY.children[0].children[0].children[0] != x { 108 t.Errorf("Oops4: %v", logY.children[0].children[0].children[0].Name()) 109 } 110 111 if t.Failed() { 112 ioutil.WriteFile("fullGraph.dot", []byte(g.ToDot()), 0644) 113 ioutil.WriteFile("logY.dot", []byte(logY.ToDot()), 0644) 114 } 115 }