gorgonia.org/gorgonia@v0.9.17/differentiation_test.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"testing"
     5  
     6  	"github.com/stretchr/testify/assert"
     7  	"gonum.org/v1/gonum/graph/iterator"
     8  	"gonum.org/v1/gonum/graph/topo"
     9  )
    10  
    11  func TestForwardDiffAnalysis(t *testing.T) {
    12  	g := NewGraph()
    13  	x := NewScalar(g, Float64, WithName("x"))
    14  	y := NewScalar(g, Float64, WithName("y"))
    15  	z := NewScalar(g, Float64, WithName("z"))
    16  
    17  	res1 := Must(Log(Must(Mul(x, y))))
    18  
    19  	sorted, err := topo.Sort(g)
    20  	if err != nil {
    21  		t.Error(err)
    22  	}
    23  
    24  	sortedNodes := graphNodeToNode(iterator.NewOrderedNodes(sorted))
    25  	affectsOutput, err := forwardDiffAnalysis(Nodes{res1}, sortedNodes)
    26  	if err != nil {
    27  		t.Error(err)
    28  	}
    29  
    30  	t.Logf("%v", affectsOutput)
    31  	if affectsOutput.Contains(z) {
    32  		t.Error("It shouldn't contain res2 or z")
    33  	}
    34  }
    35  
    36  func TestBackwardDiffAnalysis(t *testing.T) {
    37  	g := NewGraph()
    38  	x := NewScalar(g, Float64, WithName("x"))
    39  	y := NewScalar(g, Float64, WithName("y"))
    40  	z := NewScalar(g, Float64, WithName("z"))
    41  
    42  	res1 := Must(Log(Must(Mul(x, y))))
    43  	res2 := Must(Log(Must(Mul(x, y)))) // yes it's a duplicate
    44  
    45  	sorted, err := topo.Sort(g)
    46  	if err != nil {
    47  		t.Error(err)
    48  	}
    49  
    50  	sortedNodes := graphNodeToNode(iterator.NewOrderedNodes(sorted))
    51  	affectedByOutput, err := backwardDiffAnalysis(Nodes{x, y}, sortedNodes)
    52  	if err != nil {
    53  		t.Error(err)
    54  	}
    55  
    56  	t.Logf("%v", affectedByOutput)
    57  
    58  	if !affectedByOutput.Contains(res1) || !affectedByOutput.Contains(res2) {
    59  		t.Error("Expected res1 and res2 to be affected by wrts")
    60  	}
    61  
    62  	if affectedByOutput.Contains(z) {
    63  		t.Error("z shouldn't be in the list at all")
    64  	}
    65  }
    66  
    67  func TestBackprop(t *testing.T) {
    68  	assert := assert.New(t)
    69  	gradOut := NewConstant(ones(Float64), WithName("GradOut"))
    70  
    71  	t.Log("Simple backprop")
    72  	g := NewGraph()
    73  	x := NewVector(g, Float64, WithName("x"), WithShape(10)) // horizontal vector
    74  	y := NewVector(g, Float64, WithName("y"), WithShape(10)) // horizontal vector
    75  
    76  	res := Must(Mul(x, y))
    77  
    78  	grad := g.AddNode(gradOut)
    79  	inputs := Nodes{x, y}
    80  	ret, err := Backpropagate(Nodes{res}, Nodes{grad}, inputs)
    81  	if err != nil {
    82  		t.Error(err)
    83  	}
    84  
    85  	assert.Equal(Nodes{inputs[1], grad}, ret[0].children)
    86  	assert.Equal(Nodes{inputs[0], grad}, ret[1].children)
    87  	assert.Equal(mulOpType, ret[0].op.(elemBinOp).ʘBinaryOperator.binOpType())
    88  	assert.Equal(mulOpType, ret[1].op.(elemBinOp).ʘBinaryOperator.binOpType())
    89  
    90  	// reset
    91  	t.Log("Progressively more complex")
    92  	g = NewGraph()
    93  	x = NewMatrix(g, Float64, WithName("x"), WithShape(1, 10))  // row vector
    94  	w := NewMatrix(g, Float64, WithName("w"), WithShape(10, 1)) // col vector
    95  
    96  	mul := Must(Mul(x, w))
    97  	res = Must(Exp(mul))
    98  
    99  	grad = g.AddNode(gradOut)
   100  	inputs = Nodes{x, w}
   101  	if ret, err = Backpropagate(Nodes{res}, Nodes{grad}, inputs); err != nil {
   102  		t.Error(err)
   103  	}
   104  
   105  	// Notes:
   106  	//
   107  	// extra was created in the Backprop process
   108  
   109  	extra := Must(Mul(res, onef64))
   110  	dzdxExpectedPath := Nodes{ret[0], w, extra, res, mul, x, w, grad}
   111  	dzdwExpectedPath := Nodes{ret[1], x, extra, res, mul, x, w, grad}
   112  
   113  	assert.True(dzdxExpectedPath.Equals(ret[0].seqWalk()))
   114  	assert.True(dzdwExpectedPath.Equals(ret[1].seqWalk()))
   115  
   116  	/*
   117  		ioutil.WriteFile("Test_Res.dot", []byte(res.ToDot()), 0644)
   118  		for i, n := range ret {
   119  			WithName(fmt.Sprintf("dz/d%s", inputs[i].Name()))(n)
   120  			ioutil.WriteFile(fmt.Sprintf("Test_Grad_%d.dot", i), []byte(n.ToDot()), 0644)
   121  		}
   122  		ioutil.WriteFile("WholeGraph.dot", []byte(g.ToDot()), 0644)
   123  	*/
   124  }
   125  
   126  // Compound ops (like expm1, log1p and sigmoid) have fairly complex diff results. Got bitten by log1p's diffExpr, so here's the test for them all
   127  func TestCompoundOpDiff(t *testing.T) {
   128  	g := NewGraph()
   129  
   130  	saved := stabilization
   131  	stabilization = true
   132  	defer func() {
   133  		stabilization = saved
   134  	}()
   135  
   136  	// log1p
   137  	x := NewVector(g, Float64, WithName("x"), WithShape(2))
   138  	p := Must(Add(x, onef64))
   139  	lp := Must(Log(p))
   140  	op := lp.op.(elemUnaryOp)
   141  	diffs, err := op.SymDiff(Nodes{x}, lp, onef64)
   142  	if err != nil {
   143  		t.Error(err)
   144  	}
   145  
   146  	if len(diffs) != 1 {
   147  		t.Fatal("Expected only one result")
   148  	}
   149  
   150  	diff := diffs[0]
   151  	ebo, ok := diff.op.(elemBinOp)
   152  	if !ok || ok && ebo.binOpType() != divOpType {
   153  		t.Error("Expected an elemBinOp")
   154  		t.Error("Expected divOp to be the result of differentiating log1p")
   155  	}
   156  	if diff.children[0].Hashcode() != onef64.Hashcode() {
   157  		t.Errorf("Expected 1 as the numerator. Got %v instead", diff.children[0])
   158  	}
   159  	ebo, ok = diff.children[1].op.(elemBinOp)
   160  	if !ok || ok && ebo.binOpType() != addOpType {
   161  		t.Error("Expected child1 to be (+)")
   162  	}
   163  
   164  }