gorgonia.org/gorgonia@v0.9.17/node_test.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"testing"
     5  
     6  	"github.com/stretchr/testify/assert"
     7  )
     8  
     9  func TestNodeBasics(t *testing.T) {
    10  	var n *Node
    11  	var c0, c1 *Node
    12  	g := NewGraph()
    13  
    14  	// withGraph
    15  	n = newNode(In(g))
    16  	if n.g == nil {
    17  		t.Error("Expected *Node to be constructed with a graph")
    18  	}
    19  	returnNode(n)
    20  
    21  	// withType
    22  	n = newNode(In(g), WithType(Float64))
    23  	if !n.t.Eq(Float64) {
    24  		t.Error("Expected *Node to be constructed with Float64")
    25  	}
    26  	returnNode(n)
    27  
    28  	// withOp
    29  	n = newNode(In(g), WithOp(newEBOByType(addOpType, Float64, Float64)))
    30  	if op, ok := n.op.(elemBinOp); ok {
    31  		if op.binOpType() != addOpType {
    32  			t.Error("expected addOpType")
    33  		}
    34  	} else {
    35  		t.Error("Expected *Node to be constructed with an addOp")
    36  	}
    37  	returnNode(n)
    38  
    39  	// withOp - statement op
    40  	n = newNode(In(g), WithOp(letOp{}))
    41  	if _, ok := n.op.(letOp); ok {
    42  		if !n.isStmt {
    43  			t.Errorf("Expected *Node.isStmt to be true when a statement op is passed in")
    44  		}
    45  	} else {
    46  		t.Error("Expected  *Node to be constructed with a letOp")
    47  	}
    48  	returnNode(n)
    49  
    50  	// WithName
    51  	n = newNode(In(g), WithName("n"))
    52  	if n.name != "n" {
    53  		t.Error("Expected *Node to be constructed with a name \"n\"")
    54  	}
    55  	returnNode(n)
    56  
    57  	// withChildren
    58  	c0 = newNode(In(g), WithName("C0"))
    59  	c1 = newNode(In(g), WithName("C1"))
    60  	n = newNode(In(g), WithChildren(Nodes{c0, c1}))
    61  	if len(n.children) == 2 {
    62  		if !n.children.Contains(c0) || !n.children.Contains(c1) {
    63  			t.Error("Expected *Node to contain those two children")
    64  		}
    65  	} else {
    66  		t.Error("Expected *Node to be constructed with 2 children")
    67  	}
    68  	if !n.isRoot() {
    69  		t.Error("n is supposed to be root")
    70  	}
    71  
    72  	returnNode(n)
    73  	returnNode(c0)
    74  	returnNode(c1)
    75  
    76  	// withChildren but they're constants
    77  	c0 = NewConstant(3.14)
    78  	n = newNode(In(g), WithChildren(Nodes{c0}))
    79  	if len(n.children) != 1 {
    80  		t.Error("Expected *Node to have 1 child")
    81  	}
    82  	returnNode(n)
    83  	returnNode(c0)
    84  
    85  	n = newNode(In(g), WithValue(F64(3.14)), WithGrad(F64(1)))
    86  	if _, ok := n.boundTo.(*dualValue); !ok {
    87  		t.Error("Expected a dual Value")
    88  	}
    89  	returnNode(n)
    90  
    91  	// WithValue but no type
    92  	n = newNode(In(g), WithValue(F64(3.14)))
    93  	if !n.t.Eq(Float64) {
    94  		t.Error("Expected a *Node to be constructed WithValue to get its type from the value if none exists")
    95  	}
    96  	if !ValueEq(n.boundTo, NewF64(3.14)) {
    97  		t.Error("Expected *Node to be bound to the correct value. Something has gone really wrong here")
    98  	}
    99  	returnNode(n)
   100  
   101  	// WithValue but with existing type that is the same
   102  	n = newNode(In(g), WithType(Float64), WithValue(F64(3.14)))
   103  	if !ValueEq(n.boundTo, NewF64(3.14)) {
   104  		t.Error("Expected *Node to be bound to the correct value. Something has gone really wrong here")
   105  	}
   106  	returnNode(n)
   107  
   108  	// This is acceptable and should not panic
   109  	n = newNode(In(g), WithType(makeTensorType(1, Float64)), WithShape(2, 1))
   110  	returnNode(n)
   111  
   112  	// Returns itsef
   113  	n = newNode(In(g), WithType(makeTensorType(2, Float32)), WithShape(2, 3))
   114  	m := n.Node()
   115  	if n != m {
   116  		t.Error("Expected n.Node() to return itself, pointers and all")
   117  	}
   118  	ns := n.Nodes()
   119  	if len(ns) != 1 {
   120  		t.Errorf("Expected Nodes() to return a slice of length 1. Got %v", ns)
   121  	}
   122  	if ns[0] != n {
   123  		t.Error("Expected first slice to be itself.")
   124  	}
   125  	m = nil
   126  	returnNode(n)
   127  
   128  	// bad stuff
   129  	var f func()
   130  
   131  	// no graph
   132  	f = func() {
   133  		n = newNode(WithType(Float64))
   134  	}
   135  	assert.Panics(t, f)
   136  
   137  	// conflicting types, types were set first
   138  	f = func() {
   139  		n = newNode(In(g), WithType(Float32), WithValue(F64(1)))
   140  	}
   141  	assert.Panics(t, f)
   142  
   143  	// type mismatch - values were set first
   144  	f = func() {
   145  		n = newNode(In(g), WithValue(F64(1)), WithType(Float32))
   146  	}
   147  	assert.Panics(t, f)
   148  
   149  	// shape type mismatch
   150  	f = func() {
   151  		n = newNode(In(g), WithType(makeTensorType(1, Float64)), WithShape(2, 2))
   152  	}
   153  	assert.Panics(t, f)
   154  
   155  	// bad grads
   156  	f = func() {
   157  		n = newNode(WithGrad(F64(3.14)))
   158  	}
   159  	assert.Panics(t, f)
   160  }
   161  
   162  func TestNewUniqueNodes(t *testing.T) {
   163  	var n *Node
   164  	var c0, c1 *Node
   165  	g := NewGraph()
   166  
   167  	// withChildren but they're constants
   168  	c0 = NewConstant(3.14)
   169  	c1 = newNode(In(g), WithValue(5.0))
   170  	n = NewUniqueNode(In(g), WithChildren(Nodes{c0, c1}))
   171  	if n.children[0].g == nil {
   172  		t.Error("Expected a cloned constant child to have graph g")
   173  	}
   174  
   175  	returnNode(n)
   176  }
   177  
   178  func TestCloneTo(t *testing.T) {
   179  	g := NewGraph()
   180  	g2 := NewGraph()
   181  
   182  	n := NewUniqueNode(WithName("n"), WithType(Float64), In(g))
   183  	n.CloneTo(g2)
   184  
   185  	assert.True(t, nodeEq(g2.AllNodes()[0], n))
   186  }