gorgonia.org/gorgonia@v0.9.17/gorgonia_test.go (about) 1 package gorgonia 2 3 import ( 4 "log" 5 "os" 6 "testing" 7 8 "github.com/chewxy/hm" 9 "github.com/stretchr/testify/assert" 10 "github.com/stretchr/testify/require" 11 "gorgonia.org/tensor" 12 nd "gorgonia.org/tensor" 13 ) 14 15 func TestNewConstant(t *testing.T) { 16 assert := assert.New(t) 17 18 var expectedType hm.Type 19 20 t.Log("Testing New Constant Tensors") 21 backing := nd.Random(Float64, 9) 22 T := nd.New(nd.WithBacking(backing), nd.WithShape(3, 3)) 23 24 ct := NewConstant(T) 25 expectedTT := makeTensorType(2, Float64) 26 expectedType = expectedTT 27 28 assert.Equal(nd.Shape{3, 3}, ct.shape) 29 assert.Equal(expectedType, ct.t) 30 31 ct = NewConstant(T, WithName("From TensorValue")) 32 assert.Equal(nd.Shape{3, 3}, ct.shape) 33 assert.Equal(expectedType, ct.t) 34 assert.Equal("From TensorValue", ct.name) 35 36 t.Log("Testing Constant Scalars") 37 cs := NewConstant(3.14) 38 expectedType = Float64 39 assert.Equal(scalarShape, cs.shape) 40 assert.Equal(expectedType, cs.t) 41 } 42 43 var anyNodeTest = []struct { 44 name string 45 any interface{} 46 47 correctType hm.Type 48 correctShape nd.Shape 49 }{ 50 {"float32", float32(3.14), Float32, scalarShape}, 51 {"float64", float64(3.14), Float64, scalarShape}, 52 {"int", int(3), Int, scalarShape}, 53 {"bool", true, Bool, scalarShape}, 54 {"nd.Tensor", nd.New(nd.Of(nd.Float64), nd.WithShape(2, 3, 4)), &TensorType{Dims: 3, Of: Float64}, nd.Shape{2, 3, 4}}, 55 {"nd.Tensor", nd.New(nd.Of(nd.Float32), nd.WithShape(2, 3, 4)), &TensorType{Dims: 3, Of: Float32}, nd.Shape{2, 3, 4}}, 56 {"ScalarValue", NewF64(3.14), Float64, scalarShape}, 57 {"TensorValue", nd.New(nd.Of(nd.Float64), nd.WithShape(2, 3)), &TensorType{Dims: 2, Of: Float64}, nd.Shape{2, 3}}, 58 } 59 60 func TestNodeFromAny(t *testing.T) { 61 assert := assert.New(t) 62 g := NewGraph() 63 for _, a := range anyNodeTest { 64 n := NodeFromAny(g, a.any, WithName(a.name)) 65 assert.Equal(a.name, n.name) 66 assert.Equal(g, n.g) 67 assert.True(a.correctType.Eq(n.t), "%v type error: Want %v. Got %v", a.name, a.correctType, n.t) 68 assert.True(a.correctShape.Eq(n.shape), "%v shape error: Want %v. Got %v", a.name, a.correctShape, n.shape) 69 } 70 } 71 72 func TestOneHotVector(t *testing.T) { 73 assert := assert.New(t) 74 assert.EqualValues( 75 []float32{0, 0, 0, 0, 0, 0, 1, 0, 0, 0}, 76 OneHotVector(6, 10, nd.Float32).Value().Data()) 77 assert.EqualValues( 78 []float32{0, 1, 0, 0, 0}, 79 OneHotVector(1, 5, nd.Float32).Value().Data()) 80 assert.EqualValues( 81 []float32{0, 1, 0, 0, 0, 0}, 82 OneHotVector(1, 6, nd.Float32).Value().Data()) 83 assert.EqualValues( 84 []int{0, 0, 0, 1, 0}, 85 OneHotVector(3, 5, nd.Int).Value().Data()) 86 assert.EqualValues( 87 []int32{0, 0, 0, 0, 0, 0, 1, 0, 0, 0}, 88 OneHotVector(6, 10, nd.Int32).Value().Data()) 89 assert.EqualValues( 90 []float64{0, 1, 0, 0, 0}, 91 OneHotVector(1, 5, nd.Float64).Value().Data()) 92 assert.EqualValues( 93 []int64{0, 1, 0, 0, 0, 0}, 94 OneHotVector(1, 6, nd.Int64).Value().Data()) 95 } 96 97 func TestRandomNodeBackprop(t *testing.T) { 98 g := NewGraph() 99 a := NewVector(g, Float64, WithShape(10), WithName("a"), WithInit(Zeroes())) 100 b := GaussianRandomNode(g, Float64, 0, 1, 10) 101 c := Must(Add(a, b)) 102 d := Must(Sum(c)) 103 vm := NewLispMachine(g, WithLogger(log.New(os.Stderr, "", 0))) 104 vm.RunAll() 105 t.Logf("d.Value %v", d.Value()) 106 } 107 108 func TestLetErrors(t *testing.T) { 109 g := NewGraph() 110 111 testCases := []struct { 112 desc string 113 node *Node 114 val interface{} 115 err string 116 }{ 117 { 118 desc: "DifferentShapes", 119 node: NewTensor(g, tensor.Float64, 2, WithShape(1, 1), WithInit(GlorotN(1.0)), WithName("x")), 120 val: tensor.New(tensor.WithShape(1, 1, 1), tensor.WithBacking([]float64{0.5})), 121 err: "Node's expected shape is (1, 1). Got (1, 1, 1) instead", 122 }, 123 { 124 desc: "AssigningConst", 125 node: NewConstant(2, WithName("x")), 126 val: tensor.New(tensor.WithShape(1, 1), tensor.WithBacking([]float64{0.5})), 127 err: "Cannot bind a value to a non input node", 128 }, 129 } 130 131 for _, tC := range testCases { 132 t.Run(tC.desc, func(t *testing.T) { 133 err := Let(tC.node, tC.val) 134 if tC.err != "" { 135 require.Error(t, err) 136 assert.Equal(t, tC.err, err.Error()) 137 } else { 138 require.NoError(t, err) 139 } 140 }) 141 } 142 } 143 144 func TestRead(t *testing.T) { 145 g := NewGraph() 146 xVal := tensor.New(tensor.WithShape(2, 4), tensor.WithBacking(tensor.Range(tensor.Float64, 0, 8))) 147 x := NodeFromAny(g, xVal, WithName("x")) 148 149 var v1, v2 Value 150 r1 := Read(x, &v1) 151 r2 := Read(x, &v2) 152 r3 := Read(x, &v1) 153 154 assert.Equal(t, r1, r3) 155 assert.NotEqual(t, r1, r2) 156 }