gorgonia.org/gorgonia@v0.9.17/graph_test.go (about) 1 package gorgonia 2 3 import ( 4 "testing" 5 6 "github.com/stretchr/testify/assert" 7 "gonum.org/v1/gonum/graph" 8 "gonum.org/v1/gonum/graph/iterator" 9 "gonum.org/v1/gonum/graph/topo" 10 "gorgonia.org/tensor" 11 ) 12 13 func TestGraphBasics(t *testing.T) { 14 assert := assert.New(t) 15 g, x, y, xy := simpleEqn() 16 17 // basic stuff 18 assert.Equal(g, xy.g) 19 assert.Contains(g.AllNodes(), x) 20 assert.Contains(g.AllNodes(), y) 21 assert.Contains(g.AllNodes(), xy) 22 23 assert.Equal(Nodes{x, y}, g.leaves) 24 25 // Node/addressing stuff 26 xid := x.ID() 27 xFromID := g.Node(xid) 28 assert.Equal(x, xFromID) 29 30 var correctTo Nodes 31 correctTo = Nodes{xy} 32 assert.Equal(correctTo, g.to[x]) 33 assert.Equal(correctTo, g.to[y]) 34 35 // test Uniquifying ability of ExprGraph 36 newX := g.AddNode(x) 37 assert.Equal(x, newX) 38 39 newY := g.AddNode(y) 40 assert.Equal(y, newY) 41 42 newXY := Must(Add(x, y)) 43 correctTo = append(correctTo, xy) // note this is correct. .Set() will be called when graph.To() is called 44 assert.Equal(xy, newXY) 45 assert.Equal(correctTo, g.to[y]) 46 assert.Equal(correctTo, g.to[x]) 47 48 correctTo = Nodes{xy} 49 assert.Equal(correctTo, sliceNodesToNodes(graph.NodesOf(g.To(y.ID())))) 50 assert.Equal(correctTo, sliceNodesToNodes(graph.NodesOf(g.To(x.ID())))) 51 52 assert.Equal(3, g.Nodes().Len()) 53 54 // Now, time to deal with constants 55 xy1 := Must(Add(xy, onef64)) 56 assert.Nil(onef64.g) 57 assert.Equal(g, xy1.g) 58 59 var containsOne bool 60 61 it := g.Nodes() 62 for it.Next() { 63 node := it.Node() 64 n := node.(*Node) 65 if n.Hashcode() == onef64.Hashcode() { 66 containsOne = true 67 break 68 } 69 } 70 if !containsOne { 71 t.Errorf("graph does not contain a clone of onef64: %v", g.Nodes()) 72 } 73 74 // duplicate constants 75 one := NewConstant(1.0) 76 newOne := g.AddNode(one) 77 if one == newOne { 78 t.Error("one should not have been added to the graph") 79 } 80 assert.NotNil(newOne.g) 81 assert.NotEqual(one, newOne) 82 } 83 84 // This test is added to make sure I'm sane when dealing with sorted graphs 85 // because sometimes Eobard Thawne is needed 86 func TestGraphSort(t *testing.T) { 87 assert := assert.New(t) 88 g, _, _, z := simpleVecEqn() 89 WithName("z")(z) 90 91 var sortedNodes []graph.Node 92 var err error 93 94 // stability tests 95 for i := 0; i < 100; i++ { 96 if sortedNodes, err = topo.Sort(g); err != nil { 97 t.Error(err) 98 } 99 // expected := Nodes{z, y, x} // the old version of ExprGraph was stable with topo.Sort, but the new version ain't 100 // assert.Equal(expected, sortedNodes) 101 assert.Equal(z, sortedNodes[0]) 102 } 103 104 // this is to remind myself how this thing sorts: 105 t.Logf("%v", graphNodeToNode(iterator.NewOrderedNodes(sortedNodes))) 106 } 107 108 // test that collisions are handled correctly 109 func TestGraphCollisions(t *testing.T) { 110 assert := assert.New(t) 111 g, _, _, xy := simpleEqn() 112 delete(g.byHash, xy.hash) 113 g.byHash[0xdeadbeef] = xy 114 xy.hash = 0xdeadbeef 115 xy.name = "original" 116 t.Logf("original: %p, hash %x", xy, xy.Hashcode()) 117 118 col := new(Node) 119 col.name = "COLIN THE COLLISION" 120 col.hash = 0xdeadbeef 121 col.hashed = true 122 col2 := g.AddNode(col) 123 124 assert.Equal(col, col2) 125 assert.Equal(4, len(g.AllNodes()), "%v", g.AllNodes()) 126 assert.True(g.Has(col.ID())) 127 128 colleen := new(Node) 129 colleen.name = "COLLEEN THE COLLISION" 130 colleen.hash = 0xdeadbeef 131 colleen.hashed = true 132 colleen2 := g.AddNode(colleen) 133 134 assert.Equal(colleen, colleen2) 135 assert.Equal(5, len(g.AllNodes()), "%v", g.AllNodes()) 136 assert.True(g.Has(colleen.ID())) 137 138 } 139 140 func TestGraphEquality(t *testing.T) { 141 _, x, y, z := simpleVecEqn() 142 143 xh1 := x.Hashcode() 144 yh1 := y.Hashcode() 145 if xh1 == yh1 { 146 t.Error("Different nodes, should have different hashes") 147 } 148 149 _, x2, y2, z2 := simpleVecEqn() 150 151 if x.Hashcode() != x2.Hashcode() { 152 t.Error("They should have the same hash") 153 } 154 155 if y.Hashcode() != y2.Hashcode() { 156 t.Error("They should have the same hash") 157 } 158 159 if z.Hashcode() != z2.Hashcode() { 160 t.Error("They should have the same hash") 161 } 162 } 163 164 func TestGraphSubgraph(t *testing.T) { 165 var err error 166 var sortedNodes Nodes 167 assert := assert.New(t) 168 169 g, x, y, z := simpleVecEqn() 170 171 sub := Nodes{x, y} 172 g2 := g.subgraph(sub, true) 173 174 t.Logf("%v", g2.AllNodes()) 175 176 if sortedNodes, err = Sort(g2); err != nil { 177 t.Fatal(err) 178 } 179 assert.NotContains(sortedNodes, z) 180 assert.Contains(g2.roots, x) 181 assert.Contains(g2.roots, y) 182 assert.Equal(2, len(g2.roots)) 183 } 184 185 func TestGraph_SubgraphRoots(t *testing.T) { 186 assert := assert.New(t) 187 g, x, y, z := simpleVecEqn() 188 sz := Must(Sum(z)) 189 a := NewVector(g, Float64, WithName("a"), WithShape(2)) 190 b := NewVector(g, Float64, WithName("b"), WithShape(2)) 191 c := Must(Add(a, b)) 192 sc := Must(Sum(c)) 193 194 var szVal, scVal Value 195 readSZ := Read(sz, &szVal) 196 readSC := Read(sc, &scVal) 197 198 // check that stmt nodes aren't included in the roots 199 sg := g.SubgraphRoots(readSZ, readSC) 200 assert.Contains(sg.roots, sz) 201 assert.Contains(sg.roots, sc) 202 assert.Equal(2, len(sg.roots)) 203 204 // check that subgrapphing actually works 205 sg = g.SubgraphRoots(c) 206 ns := sg.AllNodes() 207 assert.NotContains(ns, sc) 208 assert.NotContains(ns, readSC) 209 assert.NotContains(ns, x) 210 assert.NotContains(ns, y) 211 assert.NotContains(ns, z) 212 assert.NotContains(ns, sz) 213 assert.NotContains(ns, readSZ) 214 } 215 216 func TestGraph_ExactSubgraphRoots(t *testing.T) { 217 assert := assert.New(t) 218 g, x, y, z := simpleVecEqn() 219 sz := Must(Sum(z)) 220 setXtoZ := Set(x, z) // setting x = z 221 222 sg0 := g.SubgraphRoots(sz) 223 sg1 := g.ExactSubgraphRoots(sz) 224 ns0 := sg0.AllNodes() 225 ns1 := sg1.AllNodes() 226 assert.Contains(ns0, setXtoZ) 227 assert.NotContains(ns1, setXtoZ) 228 assert.Contains(ns0, x) 229 assert.Contains(ns0, y) 230 assert.Contains(ns0, z) 231 assert.Contains(ns0, sz) 232 233 } 234 235 func TestGraph_Constant(t *testing.T) { 236 g := NewGraph() 237 238 v1 := NewF64(1.0) 239 c0 := g.Constant(v1) 240 c1 := g.Constant(v1) 241 242 if c0 != c1 { 243 t.Errorf("Expected c0 and c1 to be the same (pointer and all that)") 244 } 245 } 246 247 func TestGraph_Clone(t *testing.T) { 248 g, x, y, z := simpleVecEqn() 249 z2 := Must(Square(z)) 250 251 // add a collided 252 z2t := z2.Type() 253 delete(g.byHash, z2.hash) 254 g.byHash[0xdeadbeef] = z2 255 col := new(Node) 256 col.g = g 257 col.name = "COLIN THE COLLISION" 258 col.hash = 0xdeadbeef 259 col.hashed = true 260 col.boundTo = NewF64(0) 261 col.t = z2t 262 g.AddNode(col) 263 264 colleen := new(Node) 265 colleen.g = g 266 colleen.name = "COLLEEN THE COLLISION" 267 colleen.hash = 0xdeadbeef 268 colleen.hashed = true 269 colleen.boundTo = NewF64(0) 270 colleen.t = z2t 271 g.AddNode(colleen) 272 273 one := onef64 274 z2p1 := Must(Add(z2, one)) // add a constant 275 rando := UniformRandomNode(g, Float64, 0, 1, z2p1.Shape()...) // add a weird node 276 blah := Must(HadamardProd(z2p1, rando)) 277 cost := Must(Sum(blah)) 278 _, err := Grad(cost, x, y) 279 if err != nil { 280 t.Fatal(err) 281 } 282 283 g.Roots() // call it to populate the roots field 284 285 // clone with nil values 286 g2 := g.Clone().(*ExprGraph) 287 for i, n := range g.all { 288 cloned := g2.all[i] 289 if !deepNodeEq(n, cloned) { 290 t.Errorf("Expected %d of all to be %v. Got %v instead", i, n, cloned) 291 break 292 } 293 } 294 if len(g.evac) != len(g2.evac) && len(g.evac) > 0 { 295 t.Errorf("Expected the evacs to have the same length") 296 } 297 for k, v := range g.evac { 298 var v2 Nodes 299 var ok bool 300 if v2, ok = g2.evac[k]; !ok { 301 t.Errorf("Key %v not found in cloned evac", k) 302 break 303 } 304 for i, n := range v { 305 if !deepNodeEq(n, v2[i]) { 306 t.Errorf("Expected v[%d] to have equal values", i) 307 break 308 } 309 } 310 if t.Failed() { 311 break 312 } 313 } 314 if len(g.roots) != len(g2.roots) { 315 t.Errorf("Expected roots to be %d. Got %d instead", len(g.roots), len(g2.roots)) 316 } 317 for i, root := range g.roots { 318 if !deepNodeEq(root, g2.roots[i]) { 319 t.Errorf("Expected roots[%d] to have equal nodes", i) 320 break 321 } 322 } 323 324 if len(g.leaves) != len(g2.leaves) { 325 t.Errorf("Expected leaves to be %d. Got %d instead", len(g.leaves), len(g2.leaves)) 326 } 327 for i, leaf := range g.leaves { 328 if !deepNodeEq(leaf, g2.leaves[i]) { 329 t.Errorf("Expected leaves[%d] to be equal", i) 330 break 331 } 332 } 333 334 Let(x, tensor.New(tensor.WithBacking([]float64{1, 2}))) 335 Let(y, tensor.New(tensor.WithBacking([]float64{3, 4}))) 336 m := NewLispMachine(g, ExecuteFwdOnly()) // the gradient has been precalculated 337 defer m.Close() 338 if err := m.RunAll(); err != nil { 339 t.Fatal(err) 340 } 341 342 g2 = g.Clone().(*ExprGraph) 343 for i, n := range g.all { 344 cloned := g2.all[i] 345 if !deepNodeEq(n, cloned) { 346 t.Errorf("Expected %d of all to be %v. Got %v instead", i, n, cloned) 347 break 348 } 349 } 350 if len(g.evac) != len(g2.evac) && len(g.evac) > 0 { 351 t.Errorf("Expected the evacs to have the same length") 352 } 353 for k, v := range g.evac { 354 var v2 Nodes 355 var ok bool 356 if v2, ok = g2.evac[k]; !ok { 357 t.Errorf("Key %v not found in cloned evac", k) 358 break 359 } 360 for i, n := range v { 361 if !deepNodeEq(n, v2[i]) { 362 t.Errorf("Expected v[%d] to have equal values", i) 363 break 364 } 365 } 366 if t.Failed() { 367 break 368 } 369 } 370 if len(g.roots) != len(g2.roots) { 371 t.Errorf("Expected roots to be %d. Got %d instead", len(g.roots), len(g2.roots)) 372 } 373 for i, root := range g.roots { 374 if !deepNodeEq(root, g2.roots[i]) { 375 t.Errorf("Expected roots[%d] to have equal nodes", i) 376 break 377 } 378 } 379 380 if len(g.leaves) != len(g2.leaves) { 381 t.Errorf("Expected leaves to be %d. Got %d instead", len(g.leaves), len(g2.leaves)) 382 } 383 for i, leaf := range g.leaves { 384 if !deepNodeEq(leaf, g2.leaves[i]) { 385 t.Errorf("Expected leaves[%d] to be equal", i) 386 break 387 } 388 } 389 } 390 391 func TestExprGraph_Edges(t *testing.T) { 392 g := NewGraph() 393 394 var x, y *Node 395 396 // define the expression 397 x = NewScalar(g, Float64, WithName("x")) 398 y = NewScalar(g, Float64, WithName("y")) 399 Add(x, y) 400 edgesIT := g.Edges() 401 if edgesIT.Len() != 2 { 402 t.Fail() 403 } 404 }