gorgonia.org/gorgonia@v0.9.17/graph_test.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"testing"
     5  
     6  	"github.com/stretchr/testify/assert"
     7  	"gonum.org/v1/gonum/graph"
     8  	"gonum.org/v1/gonum/graph/iterator"
     9  	"gonum.org/v1/gonum/graph/topo"
    10  	"gorgonia.org/tensor"
    11  )
    12  
    13  func TestGraphBasics(t *testing.T) {
    14  	assert := assert.New(t)
    15  	g, x, y, xy := simpleEqn()
    16  
    17  	// basic stuff
    18  	assert.Equal(g, xy.g)
    19  	assert.Contains(g.AllNodes(), x)
    20  	assert.Contains(g.AllNodes(), y)
    21  	assert.Contains(g.AllNodes(), xy)
    22  
    23  	assert.Equal(Nodes{x, y}, g.leaves)
    24  
    25  	// Node/addressing stuff
    26  	xid := x.ID()
    27  	xFromID := g.Node(xid)
    28  	assert.Equal(x, xFromID)
    29  
    30  	var correctTo Nodes
    31  	correctTo = Nodes{xy}
    32  	assert.Equal(correctTo, g.to[x])
    33  	assert.Equal(correctTo, g.to[y])
    34  
    35  	// test Uniquifying ability of ExprGraph
    36  	newX := g.AddNode(x)
    37  	assert.Equal(x, newX)
    38  
    39  	newY := g.AddNode(y)
    40  	assert.Equal(y, newY)
    41  
    42  	newXY := Must(Add(x, y))
    43  	correctTo = append(correctTo, xy) // note this is correct. .Set() will be called when graph.To() is called
    44  	assert.Equal(xy, newXY)
    45  	assert.Equal(correctTo, g.to[y])
    46  	assert.Equal(correctTo, g.to[x])
    47  
    48  	correctTo = Nodes{xy}
    49  	assert.Equal(correctTo, sliceNodesToNodes(graph.NodesOf(g.To(y.ID()))))
    50  	assert.Equal(correctTo, sliceNodesToNodes(graph.NodesOf(g.To(x.ID()))))
    51  
    52  	assert.Equal(3, g.Nodes().Len())
    53  
    54  	// Now, time to deal with constants
    55  	xy1 := Must(Add(xy, onef64))
    56  	assert.Nil(onef64.g)
    57  	assert.Equal(g, xy1.g)
    58  
    59  	var containsOne bool
    60  
    61  	it := g.Nodes()
    62  	for it.Next() {
    63  		node := it.Node()
    64  		n := node.(*Node)
    65  		if n.Hashcode() == onef64.Hashcode() {
    66  			containsOne = true
    67  			break
    68  		}
    69  	}
    70  	if !containsOne {
    71  		t.Errorf("graph does not contain a clone of onef64: %v", g.Nodes())
    72  	}
    73  
    74  	// duplicate constants
    75  	one := NewConstant(1.0)
    76  	newOne := g.AddNode(one)
    77  	if one == newOne {
    78  		t.Error("one should not have been added to the graph")
    79  	}
    80  	assert.NotNil(newOne.g)
    81  	assert.NotEqual(one, newOne)
    82  }
    83  
    84  // This test is added to make sure I'm sane when dealing with sorted graphs
    85  // because sometimes Eobard Thawne is needed
    86  func TestGraphSort(t *testing.T) {
    87  	assert := assert.New(t)
    88  	g, _, _, z := simpleVecEqn()
    89  	WithName("z")(z)
    90  
    91  	var sortedNodes []graph.Node
    92  	var err error
    93  
    94  	// stability tests
    95  	for i := 0; i < 100; i++ {
    96  		if sortedNodes, err = topo.Sort(g); err != nil {
    97  			t.Error(err)
    98  		}
    99  		// expected := Nodes{z, y, x} // the old version of ExprGraph was stable with topo.Sort, but the new version ain't
   100  		// assert.Equal(expected, sortedNodes)
   101  		assert.Equal(z, sortedNodes[0])
   102  	}
   103  
   104  	// this is to remind myself how this thing sorts:
   105  	t.Logf("%v", graphNodeToNode(iterator.NewOrderedNodes(sortedNodes)))
   106  }
   107  
   108  // test that collisions are handled correctly
   109  func TestGraphCollisions(t *testing.T) {
   110  	assert := assert.New(t)
   111  	g, _, _, xy := simpleEqn()
   112  	delete(g.byHash, xy.hash)
   113  	g.byHash[0xdeadbeef] = xy
   114  	xy.hash = 0xdeadbeef
   115  	xy.name = "original"
   116  	t.Logf("original: %p, hash %x", xy, xy.Hashcode())
   117  
   118  	col := new(Node)
   119  	col.name = "COLIN THE COLLISION"
   120  	col.hash = 0xdeadbeef
   121  	col.hashed = true
   122  	col2 := g.AddNode(col)
   123  
   124  	assert.Equal(col, col2)
   125  	assert.Equal(4, len(g.AllNodes()), "%v", g.AllNodes())
   126  	assert.True(g.Has(col.ID()))
   127  
   128  	colleen := new(Node)
   129  	colleen.name = "COLLEEN THE COLLISION"
   130  	colleen.hash = 0xdeadbeef
   131  	colleen.hashed = true
   132  	colleen2 := g.AddNode(colleen)
   133  
   134  	assert.Equal(colleen, colleen2)
   135  	assert.Equal(5, len(g.AllNodes()), "%v", g.AllNodes())
   136  	assert.True(g.Has(colleen.ID()))
   137  
   138  }
   139  
   140  func TestGraphEquality(t *testing.T) {
   141  	_, x, y, z := simpleVecEqn()
   142  
   143  	xh1 := x.Hashcode()
   144  	yh1 := y.Hashcode()
   145  	if xh1 == yh1 {
   146  		t.Error("Different nodes, should have different hashes")
   147  	}
   148  
   149  	_, x2, y2, z2 := simpleVecEqn()
   150  
   151  	if x.Hashcode() != x2.Hashcode() {
   152  		t.Error("They should have the same hash")
   153  	}
   154  
   155  	if y.Hashcode() != y2.Hashcode() {
   156  		t.Error("They should have the same hash")
   157  	}
   158  
   159  	if z.Hashcode() != z2.Hashcode() {
   160  		t.Error("They should have the same hash")
   161  	}
   162  }
   163  
   164  func TestGraphSubgraph(t *testing.T) {
   165  	var err error
   166  	var sortedNodes Nodes
   167  	assert := assert.New(t)
   168  
   169  	g, x, y, z := simpleVecEqn()
   170  
   171  	sub := Nodes{x, y}
   172  	g2 := g.subgraph(sub, true)
   173  
   174  	t.Logf("%v", g2.AllNodes())
   175  
   176  	if sortedNodes, err = Sort(g2); err != nil {
   177  		t.Fatal(err)
   178  	}
   179  	assert.NotContains(sortedNodes, z)
   180  	assert.Contains(g2.roots, x)
   181  	assert.Contains(g2.roots, y)
   182  	assert.Equal(2, len(g2.roots))
   183  }
   184  
   185  func TestGraph_SubgraphRoots(t *testing.T) {
   186  	assert := assert.New(t)
   187  	g, x, y, z := simpleVecEqn()
   188  	sz := Must(Sum(z))
   189  	a := NewVector(g, Float64, WithName("a"), WithShape(2))
   190  	b := NewVector(g, Float64, WithName("b"), WithShape(2))
   191  	c := Must(Add(a, b))
   192  	sc := Must(Sum(c))
   193  
   194  	var szVal, scVal Value
   195  	readSZ := Read(sz, &szVal)
   196  	readSC := Read(sc, &scVal)
   197  
   198  	// check that stmt nodes aren't included in the roots
   199  	sg := g.SubgraphRoots(readSZ, readSC)
   200  	assert.Contains(sg.roots, sz)
   201  	assert.Contains(sg.roots, sc)
   202  	assert.Equal(2, len(sg.roots))
   203  
   204  	// check that subgrapphing actually works
   205  	sg = g.SubgraphRoots(c)
   206  	ns := sg.AllNodes()
   207  	assert.NotContains(ns, sc)
   208  	assert.NotContains(ns, readSC)
   209  	assert.NotContains(ns, x)
   210  	assert.NotContains(ns, y)
   211  	assert.NotContains(ns, z)
   212  	assert.NotContains(ns, sz)
   213  	assert.NotContains(ns, readSZ)
   214  }
   215  
   216  func TestGraph_ExactSubgraphRoots(t *testing.T) {
   217  	assert := assert.New(t)
   218  	g, x, y, z := simpleVecEqn()
   219  	sz := Must(Sum(z))
   220  	setXtoZ := Set(x, z) // setting x = z
   221  
   222  	sg0 := g.SubgraphRoots(sz)
   223  	sg1 := g.ExactSubgraphRoots(sz)
   224  	ns0 := sg0.AllNodes()
   225  	ns1 := sg1.AllNodes()
   226  	assert.Contains(ns0, setXtoZ)
   227  	assert.NotContains(ns1, setXtoZ)
   228  	assert.Contains(ns0, x)
   229  	assert.Contains(ns0, y)
   230  	assert.Contains(ns0, z)
   231  	assert.Contains(ns0, sz)
   232  
   233  }
   234  
   235  func TestGraph_Constant(t *testing.T) {
   236  	g := NewGraph()
   237  
   238  	v1 := NewF64(1.0)
   239  	c0 := g.Constant(v1)
   240  	c1 := g.Constant(v1)
   241  
   242  	if c0 != c1 {
   243  		t.Errorf("Expected c0 and c1 to be the same (pointer and all that)")
   244  	}
   245  }
   246  
   247  func TestGraph_Clone(t *testing.T) {
   248  	g, x, y, z := simpleVecEqn()
   249  	z2 := Must(Square(z))
   250  
   251  	// add a collided
   252  	z2t := z2.Type()
   253  	delete(g.byHash, z2.hash)
   254  	g.byHash[0xdeadbeef] = z2
   255  	col := new(Node)
   256  	col.g = g
   257  	col.name = "COLIN THE COLLISION"
   258  	col.hash = 0xdeadbeef
   259  	col.hashed = true
   260  	col.boundTo = NewF64(0)
   261  	col.t = z2t
   262  	g.AddNode(col)
   263  
   264  	colleen := new(Node)
   265  	colleen.g = g
   266  	colleen.name = "COLLEEN THE COLLISION"
   267  	colleen.hash = 0xdeadbeef
   268  	colleen.hashed = true
   269  	colleen.boundTo = NewF64(0)
   270  	colleen.t = z2t
   271  	g.AddNode(colleen)
   272  
   273  	one := onef64
   274  	z2p1 := Must(Add(z2, one))                                    // add a constant
   275  	rando := UniformRandomNode(g, Float64, 0, 1, z2p1.Shape()...) // add a weird node
   276  	blah := Must(HadamardProd(z2p1, rando))
   277  	cost := Must(Sum(blah))
   278  	_, err := Grad(cost, x, y)
   279  	if err != nil {
   280  		t.Fatal(err)
   281  	}
   282  
   283  	g.Roots() // call it to populate the roots field
   284  
   285  	// clone with nil values
   286  	g2 := g.Clone().(*ExprGraph)
   287  	for i, n := range g.all {
   288  		cloned := g2.all[i]
   289  		if !deepNodeEq(n, cloned) {
   290  			t.Errorf("Expected %d of all to be %v. Got %v instead", i, n, cloned)
   291  			break
   292  		}
   293  	}
   294  	if len(g.evac) != len(g2.evac) && len(g.evac) > 0 {
   295  		t.Errorf("Expected the evacs to have the same length")
   296  	}
   297  	for k, v := range g.evac {
   298  		var v2 Nodes
   299  		var ok bool
   300  		if v2, ok = g2.evac[k]; !ok {
   301  			t.Errorf("Key %v not found in cloned evac", k)
   302  			break
   303  		}
   304  		for i, n := range v {
   305  			if !deepNodeEq(n, v2[i]) {
   306  				t.Errorf("Expected v[%d] to have equal values", i)
   307  				break
   308  			}
   309  		}
   310  		if t.Failed() {
   311  			break
   312  		}
   313  	}
   314  	if len(g.roots) != len(g2.roots) {
   315  		t.Errorf("Expected roots to be %d. Got %d instead", len(g.roots), len(g2.roots))
   316  	}
   317  	for i, root := range g.roots {
   318  		if !deepNodeEq(root, g2.roots[i]) {
   319  			t.Errorf("Expected roots[%d] to have equal nodes", i)
   320  			break
   321  		}
   322  	}
   323  
   324  	if len(g.leaves) != len(g2.leaves) {
   325  		t.Errorf("Expected leaves to be %d. Got %d instead", len(g.leaves), len(g2.leaves))
   326  	}
   327  	for i, leaf := range g.leaves {
   328  		if !deepNodeEq(leaf, g2.leaves[i]) {
   329  			t.Errorf("Expected leaves[%d] to be equal", i)
   330  			break
   331  		}
   332  	}
   333  
   334  	Let(x, tensor.New(tensor.WithBacking([]float64{1, 2})))
   335  	Let(y, tensor.New(tensor.WithBacking([]float64{3, 4})))
   336  	m := NewLispMachine(g, ExecuteFwdOnly()) // the gradient has been precalculated
   337  	defer m.Close()
   338  	if err := m.RunAll(); err != nil {
   339  		t.Fatal(err)
   340  	}
   341  
   342  	g2 = g.Clone().(*ExprGraph)
   343  	for i, n := range g.all {
   344  		cloned := g2.all[i]
   345  		if !deepNodeEq(n, cloned) {
   346  			t.Errorf("Expected %d of all to be %v. Got %v instead", i, n, cloned)
   347  			break
   348  		}
   349  	}
   350  	if len(g.evac) != len(g2.evac) && len(g.evac) > 0 {
   351  		t.Errorf("Expected the evacs to have the same length")
   352  	}
   353  	for k, v := range g.evac {
   354  		var v2 Nodes
   355  		var ok bool
   356  		if v2, ok = g2.evac[k]; !ok {
   357  			t.Errorf("Key %v not found in cloned evac", k)
   358  			break
   359  		}
   360  		for i, n := range v {
   361  			if !deepNodeEq(n, v2[i]) {
   362  				t.Errorf("Expected v[%d] to have equal values", i)
   363  				break
   364  			}
   365  		}
   366  		if t.Failed() {
   367  			break
   368  		}
   369  	}
   370  	if len(g.roots) != len(g2.roots) {
   371  		t.Errorf("Expected roots to be %d. Got %d instead", len(g.roots), len(g2.roots))
   372  	}
   373  	for i, root := range g.roots {
   374  		if !deepNodeEq(root, g2.roots[i]) {
   375  			t.Errorf("Expected roots[%d] to have equal nodes", i)
   376  			break
   377  		}
   378  	}
   379  
   380  	if len(g.leaves) != len(g2.leaves) {
   381  		t.Errorf("Expected leaves to be %d. Got %d instead", len(g.leaves), len(g2.leaves))
   382  	}
   383  	for i, leaf := range g.leaves {
   384  		if !deepNodeEq(leaf, g2.leaves[i]) {
   385  			t.Errorf("Expected leaves[%d] to be equal", i)
   386  			break
   387  		}
   388  	}
   389  }
   390  
   391  func TestExprGraph_Edges(t *testing.T) {
   392  	g := NewGraph()
   393  
   394  	var x, y *Node
   395  
   396  	// define the expression
   397  	x = NewScalar(g, Float64, WithName("x"))
   398  	y = NewScalar(g, Float64, WithName("y"))
   399  	Add(x, y)
   400  	edgesIT := g.Edges()
   401  	if edgesIT.Len() != 2 {
   402  		t.Fail()
   403  	}
   404  }