gorgonia.org/gorgonia@v0.9.17/node_test.go (about) 1 package gorgonia 2 3 import ( 4 "testing" 5 6 "github.com/stretchr/testify/assert" 7 ) 8 9 func TestNodeBasics(t *testing.T) { 10 var n *Node 11 var c0, c1 *Node 12 g := NewGraph() 13 14 // withGraph 15 n = newNode(In(g)) 16 if n.g == nil { 17 t.Error("Expected *Node to be constructed with a graph") 18 } 19 returnNode(n) 20 21 // withType 22 n = newNode(In(g), WithType(Float64)) 23 if !n.t.Eq(Float64) { 24 t.Error("Expected *Node to be constructed with Float64") 25 } 26 returnNode(n) 27 28 // withOp 29 n = newNode(In(g), WithOp(newEBOByType(addOpType, Float64, Float64))) 30 if op, ok := n.op.(elemBinOp); ok { 31 if op.binOpType() != addOpType { 32 t.Error("expected addOpType") 33 } 34 } else { 35 t.Error("Expected *Node to be constructed with an addOp") 36 } 37 returnNode(n) 38 39 // withOp - statement op 40 n = newNode(In(g), WithOp(letOp{})) 41 if _, ok := n.op.(letOp); ok { 42 if !n.isStmt { 43 t.Errorf("Expected *Node.isStmt to be true when a statement op is passed in") 44 } 45 } else { 46 t.Error("Expected *Node to be constructed with a letOp") 47 } 48 returnNode(n) 49 50 // WithName 51 n = newNode(In(g), WithName("n")) 52 if n.name != "n" { 53 t.Error("Expected *Node to be constructed with a name \"n\"") 54 } 55 returnNode(n) 56 57 // withChildren 58 c0 = newNode(In(g), WithName("C0")) 59 c1 = newNode(In(g), WithName("C1")) 60 n = newNode(In(g), WithChildren(Nodes{c0, c1})) 61 if len(n.children) == 2 { 62 if !n.children.Contains(c0) || !n.children.Contains(c1) { 63 t.Error("Expected *Node to contain those two children") 64 } 65 } else { 66 t.Error("Expected *Node to be constructed with 2 children") 67 } 68 if !n.isRoot() { 69 t.Error("n is supposed to be root") 70 } 71 72 returnNode(n) 73 returnNode(c0) 74 returnNode(c1) 75 76 // withChildren but they're constants 77 c0 = NewConstant(3.14) 78 n = newNode(In(g), WithChildren(Nodes{c0})) 79 if len(n.children) != 1 { 80 t.Error("Expected *Node to have 1 child") 81 } 82 returnNode(n) 83 returnNode(c0) 84 85 n = newNode(In(g), WithValue(F64(3.14)), WithGrad(F64(1))) 86 if _, ok := n.boundTo.(*dualValue); !ok { 87 t.Error("Expected a dual Value") 88 } 89 returnNode(n) 90 91 // WithValue but no type 92 n = newNode(In(g), WithValue(F64(3.14))) 93 if !n.t.Eq(Float64) { 94 t.Error("Expected a *Node to be constructed WithValue to get its type from the value if none exists") 95 } 96 if !ValueEq(n.boundTo, NewF64(3.14)) { 97 t.Error("Expected *Node to be bound to the correct value. Something has gone really wrong here") 98 } 99 returnNode(n) 100 101 // WithValue but with existing type that is the same 102 n = newNode(In(g), WithType(Float64), WithValue(F64(3.14))) 103 if !ValueEq(n.boundTo, NewF64(3.14)) { 104 t.Error("Expected *Node to be bound to the correct value. Something has gone really wrong here") 105 } 106 returnNode(n) 107 108 // This is acceptable and should not panic 109 n = newNode(In(g), WithType(makeTensorType(1, Float64)), WithShape(2, 1)) 110 returnNode(n) 111 112 // Returns itsef 113 n = newNode(In(g), WithType(makeTensorType(2, Float32)), WithShape(2, 3)) 114 m := n.Node() 115 if n != m { 116 t.Error("Expected n.Node() to return itself, pointers and all") 117 } 118 ns := n.Nodes() 119 if len(ns) != 1 { 120 t.Errorf("Expected Nodes() to return a slice of length 1. Got %v", ns) 121 } 122 if ns[0] != n { 123 t.Error("Expected first slice to be itself.") 124 } 125 m = nil 126 returnNode(n) 127 128 // bad stuff 129 var f func() 130 131 // no graph 132 f = func() { 133 n = newNode(WithType(Float64)) 134 } 135 assert.Panics(t, f) 136 137 // conflicting types, types were set first 138 f = func() { 139 n = newNode(In(g), WithType(Float32), WithValue(F64(1))) 140 } 141 assert.Panics(t, f) 142 143 // type mismatch - values were set first 144 f = func() { 145 n = newNode(In(g), WithValue(F64(1)), WithType(Float32)) 146 } 147 assert.Panics(t, f) 148 149 // shape type mismatch 150 f = func() { 151 n = newNode(In(g), WithType(makeTensorType(1, Float64)), WithShape(2, 2)) 152 } 153 assert.Panics(t, f) 154 155 // bad grads 156 f = func() { 157 n = newNode(WithGrad(F64(3.14))) 158 } 159 assert.Panics(t, f) 160 } 161 162 func TestNewUniqueNodes(t *testing.T) { 163 var n *Node 164 var c0, c1 *Node 165 g := NewGraph() 166 167 // withChildren but they're constants 168 c0 = NewConstant(3.14) 169 c1 = newNode(In(g), WithValue(5.0)) 170 n = NewUniqueNode(In(g), WithChildren(Nodes{c0, c1})) 171 if n.children[0].g == nil { 172 t.Error("Expected a cloned constant child to have graph g") 173 } 174 175 returnNode(n) 176 } 177 178 func TestCloneTo(t *testing.T) { 179 g := NewGraph() 180 g2 := NewGraph() 181 182 n := NewUniqueNode(WithName("n"), WithType(Float64), In(g)) 183 n.CloneTo(g2) 184 185 assert.True(t, nodeEq(g2.AllNodes()[0], n)) 186 }