gorgonia.org/gorgonia@v0.9.17/differentiation_test.go (about) 1 package gorgonia 2 3 import ( 4 "testing" 5 6 "github.com/stretchr/testify/assert" 7 "gonum.org/v1/gonum/graph/iterator" 8 "gonum.org/v1/gonum/graph/topo" 9 ) 10 11 func TestForwardDiffAnalysis(t *testing.T) { 12 g := NewGraph() 13 x := NewScalar(g, Float64, WithName("x")) 14 y := NewScalar(g, Float64, WithName("y")) 15 z := NewScalar(g, Float64, WithName("z")) 16 17 res1 := Must(Log(Must(Mul(x, y)))) 18 19 sorted, err := topo.Sort(g) 20 if err != nil { 21 t.Error(err) 22 } 23 24 sortedNodes := graphNodeToNode(iterator.NewOrderedNodes(sorted)) 25 affectsOutput, err := forwardDiffAnalysis(Nodes{res1}, sortedNodes) 26 if err != nil { 27 t.Error(err) 28 } 29 30 t.Logf("%v", affectsOutput) 31 if affectsOutput.Contains(z) { 32 t.Error("It shouldn't contain res2 or z") 33 } 34 } 35 36 func TestBackwardDiffAnalysis(t *testing.T) { 37 g := NewGraph() 38 x := NewScalar(g, Float64, WithName("x")) 39 y := NewScalar(g, Float64, WithName("y")) 40 z := NewScalar(g, Float64, WithName("z")) 41 42 res1 := Must(Log(Must(Mul(x, y)))) 43 res2 := Must(Log(Must(Mul(x, y)))) // yes it's a duplicate 44 45 sorted, err := topo.Sort(g) 46 if err != nil { 47 t.Error(err) 48 } 49 50 sortedNodes := graphNodeToNode(iterator.NewOrderedNodes(sorted)) 51 affectedByOutput, err := backwardDiffAnalysis(Nodes{x, y}, sortedNodes) 52 if err != nil { 53 t.Error(err) 54 } 55 56 t.Logf("%v", affectedByOutput) 57 58 if !affectedByOutput.Contains(res1) || !affectedByOutput.Contains(res2) { 59 t.Error("Expected res1 and res2 to be affected by wrts") 60 } 61 62 if affectedByOutput.Contains(z) { 63 t.Error("z shouldn't be in the list at all") 64 } 65 } 66 67 func TestBackprop(t *testing.T) { 68 assert := assert.New(t) 69 gradOut := NewConstant(ones(Float64), WithName("GradOut")) 70 71 t.Log("Simple backprop") 72 g := NewGraph() 73 x := NewVector(g, Float64, WithName("x"), WithShape(10)) // horizontal vector 74 y := NewVector(g, Float64, WithName("y"), WithShape(10)) // horizontal vector 75 76 res := Must(Mul(x, y)) 77 78 grad := g.AddNode(gradOut) 79 inputs := Nodes{x, y} 80 ret, err := Backpropagate(Nodes{res}, Nodes{grad}, inputs) 81 if err != nil { 82 t.Error(err) 83 } 84 85 assert.Equal(Nodes{inputs[1], grad}, ret[0].children) 86 assert.Equal(Nodes{inputs[0], grad}, ret[1].children) 87 assert.Equal(mulOpType, ret[0].op.(elemBinOp).ʘBinaryOperator.binOpType()) 88 assert.Equal(mulOpType, ret[1].op.(elemBinOp).ʘBinaryOperator.binOpType()) 89 90 // reset 91 t.Log("Progressively more complex") 92 g = NewGraph() 93 x = NewMatrix(g, Float64, WithName("x"), WithShape(1, 10)) // row vector 94 w := NewMatrix(g, Float64, WithName("w"), WithShape(10, 1)) // col vector 95 96 mul := Must(Mul(x, w)) 97 res = Must(Exp(mul)) 98 99 grad = g.AddNode(gradOut) 100 inputs = Nodes{x, w} 101 if ret, err = Backpropagate(Nodes{res}, Nodes{grad}, inputs); err != nil { 102 t.Error(err) 103 } 104 105 // Notes: 106 // 107 // extra was created in the Backprop process 108 109 extra := Must(Mul(res, onef64)) 110 dzdxExpectedPath := Nodes{ret[0], w, extra, res, mul, x, w, grad} 111 dzdwExpectedPath := Nodes{ret[1], x, extra, res, mul, x, w, grad} 112 113 assert.True(dzdxExpectedPath.Equals(ret[0].seqWalk())) 114 assert.True(dzdwExpectedPath.Equals(ret[1].seqWalk())) 115 116 /* 117 ioutil.WriteFile("Test_Res.dot", []byte(res.ToDot()), 0644) 118 for i, n := range ret { 119 WithName(fmt.Sprintf("dz/d%s", inputs[i].Name()))(n) 120 ioutil.WriteFile(fmt.Sprintf("Test_Grad_%d.dot", i), []byte(n.ToDot()), 0644) 121 } 122 ioutil.WriteFile("WholeGraph.dot", []byte(g.ToDot()), 0644) 123 */ 124 } 125 126 // Compound ops (like expm1, log1p and sigmoid) have fairly complex diff results. Got bitten by log1p's diffExpr, so here's the test for them all 127 func TestCompoundOpDiff(t *testing.T) { 128 g := NewGraph() 129 130 saved := stabilization 131 stabilization = true 132 defer func() { 133 stabilization = saved 134 }() 135 136 // log1p 137 x := NewVector(g, Float64, WithName("x"), WithShape(2)) 138 p := Must(Add(x, onef64)) 139 lp := Must(Log(p)) 140 op := lp.op.(elemUnaryOp) 141 diffs, err := op.SymDiff(Nodes{x}, lp, onef64) 142 if err != nil { 143 t.Error(err) 144 } 145 146 if len(diffs) != 1 { 147 t.Fatal("Expected only one result") 148 } 149 150 diff := diffs[0] 151 ebo, ok := diff.op.(elemBinOp) 152 if !ok || ok && ebo.binOpType() != divOpType { 153 t.Error("Expected an elemBinOp") 154 t.Error("Expected divOp to be the result of differentiating log1p") 155 } 156 if diff.children[0].Hashcode() != onef64.Hashcode() { 157 t.Errorf("Expected 1 as the numerator. Got %v instead", diff.children[0]) 158 } 159 ebo, ok = diff.children[1].op.(elemBinOp) 160 if !ok || ok && ebo.binOpType() != addOpType { 161 t.Error("Expected child1 to be (+)") 162 } 163 164 }