gorgonia.org/gorgonia@v0.9.17/gorgonia_test.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"log"
     5  	"os"
     6  	"testing"
     7  
     8  	"github.com/chewxy/hm"
     9  	"github.com/stretchr/testify/assert"
    10  	"github.com/stretchr/testify/require"
    11  	"gorgonia.org/tensor"
    12  	nd "gorgonia.org/tensor"
    13  )
    14  
    15  func TestNewConstant(t *testing.T) {
    16  	assert := assert.New(t)
    17  
    18  	var expectedType hm.Type
    19  
    20  	t.Log("Testing New Constant Tensors")
    21  	backing := nd.Random(Float64, 9)
    22  	T := nd.New(nd.WithBacking(backing), nd.WithShape(3, 3))
    23  
    24  	ct := NewConstant(T)
    25  	expectedTT := makeTensorType(2, Float64)
    26  	expectedType = expectedTT
    27  
    28  	assert.Equal(nd.Shape{3, 3}, ct.shape)
    29  	assert.Equal(expectedType, ct.t)
    30  
    31  	ct = NewConstant(T, WithName("From TensorValue"))
    32  	assert.Equal(nd.Shape{3, 3}, ct.shape)
    33  	assert.Equal(expectedType, ct.t)
    34  	assert.Equal("From TensorValue", ct.name)
    35  
    36  	t.Log("Testing Constant Scalars")
    37  	cs := NewConstant(3.14)
    38  	expectedType = Float64
    39  	assert.Equal(scalarShape, cs.shape)
    40  	assert.Equal(expectedType, cs.t)
    41  }
    42  
    43  var anyNodeTest = []struct {
    44  	name string
    45  	any  interface{}
    46  
    47  	correctType  hm.Type
    48  	correctShape nd.Shape
    49  }{
    50  	{"float32", float32(3.14), Float32, scalarShape},
    51  	{"float64", float64(3.14), Float64, scalarShape},
    52  	{"int", int(3), Int, scalarShape},
    53  	{"bool", true, Bool, scalarShape},
    54  	{"nd.Tensor", nd.New(nd.Of(nd.Float64), nd.WithShape(2, 3, 4)), &TensorType{Dims: 3, Of: Float64}, nd.Shape{2, 3, 4}},
    55  	{"nd.Tensor", nd.New(nd.Of(nd.Float32), nd.WithShape(2, 3, 4)), &TensorType{Dims: 3, Of: Float32}, nd.Shape{2, 3, 4}},
    56  	{"ScalarValue", NewF64(3.14), Float64, scalarShape},
    57  	{"TensorValue", nd.New(nd.Of(nd.Float64), nd.WithShape(2, 3)), &TensorType{Dims: 2, Of: Float64}, nd.Shape{2, 3}},
    58  }
    59  
    60  func TestNodeFromAny(t *testing.T) {
    61  	assert := assert.New(t)
    62  	g := NewGraph()
    63  	for _, a := range anyNodeTest {
    64  		n := NodeFromAny(g, a.any, WithName(a.name))
    65  		assert.Equal(a.name, n.name)
    66  		assert.Equal(g, n.g)
    67  		assert.True(a.correctType.Eq(n.t), "%v type error: Want %v. Got %v", a.name, a.correctType, n.t)
    68  		assert.True(a.correctShape.Eq(n.shape), "%v shape error: Want %v. Got %v", a.name, a.correctShape, n.shape)
    69  	}
    70  }
    71  
    72  func TestOneHotVector(t *testing.T) {
    73  	assert := assert.New(t)
    74  	assert.EqualValues(
    75  		[]float32{0, 0, 0, 0, 0, 0, 1, 0, 0, 0},
    76  		OneHotVector(6, 10, nd.Float32).Value().Data())
    77  	assert.EqualValues(
    78  		[]float32{0, 1, 0, 0, 0},
    79  		OneHotVector(1, 5, nd.Float32).Value().Data())
    80  	assert.EqualValues(
    81  		[]float32{0, 1, 0, 0, 0, 0},
    82  		OneHotVector(1, 6, nd.Float32).Value().Data())
    83  	assert.EqualValues(
    84  		[]int{0, 0, 0, 1, 0},
    85  		OneHotVector(3, 5, nd.Int).Value().Data())
    86  	assert.EqualValues(
    87  		[]int32{0, 0, 0, 0, 0, 0, 1, 0, 0, 0},
    88  		OneHotVector(6, 10, nd.Int32).Value().Data())
    89  	assert.EqualValues(
    90  		[]float64{0, 1, 0, 0, 0},
    91  		OneHotVector(1, 5, nd.Float64).Value().Data())
    92  	assert.EqualValues(
    93  		[]int64{0, 1, 0, 0, 0, 0},
    94  		OneHotVector(1, 6, nd.Int64).Value().Data())
    95  }
    96  
    97  func TestRandomNodeBackprop(t *testing.T) {
    98  	g := NewGraph()
    99  	a := NewVector(g, Float64, WithShape(10), WithName("a"), WithInit(Zeroes()))
   100  	b := GaussianRandomNode(g, Float64, 0, 1, 10)
   101  	c := Must(Add(a, b))
   102  	d := Must(Sum(c))
   103  	vm := NewLispMachine(g, WithLogger(log.New(os.Stderr, "", 0)))
   104  	vm.RunAll()
   105  	t.Logf("d.Value %v", d.Value())
   106  }
   107  
   108  func TestLetErrors(t *testing.T) {
   109  	g := NewGraph()
   110  
   111  	testCases := []struct {
   112  		desc string
   113  		node *Node
   114  		val  interface{}
   115  		err  string
   116  	}{
   117  		{
   118  			desc: "DifferentShapes",
   119  			node: NewTensor(g, tensor.Float64, 2, WithShape(1, 1), WithInit(GlorotN(1.0)), WithName("x")),
   120  			val:  tensor.New(tensor.WithShape(1, 1, 1), tensor.WithBacking([]float64{0.5})),
   121  			err:  "Node's expected shape is (1, 1). Got (1, 1, 1) instead",
   122  		},
   123  		{
   124  			desc: "AssigningConst",
   125  			node: NewConstant(2, WithName("x")),
   126  			val:  tensor.New(tensor.WithShape(1, 1), tensor.WithBacking([]float64{0.5})),
   127  			err:  "Cannot bind a value to a non input node",
   128  		},
   129  	}
   130  
   131  	for _, tC := range testCases {
   132  		t.Run(tC.desc, func(t *testing.T) {
   133  			err := Let(tC.node, tC.val)
   134  			if tC.err != "" {
   135  				require.Error(t, err)
   136  				assert.Equal(t, tC.err, err.Error())
   137  			} else {
   138  				require.NoError(t, err)
   139  			}
   140  		})
   141  	}
   142  }
   143  
   144  func TestRead(t *testing.T) {
   145  	g := NewGraph()
   146  	xVal := tensor.New(tensor.WithShape(2, 4), tensor.WithBacking(tensor.Range(tensor.Float64, 0, 8)))
   147  	x := NodeFromAny(g, xVal, WithName("x"))
   148  
   149  	var v1, v2 Value
   150  	r1 := Read(x, &v1)
   151  	r2 := Read(x, &v2)
   152  	r3 := Read(x, &v1)
   153  
   154  	assert.Equal(t, r1, r3)
   155  	assert.NotEqual(t, r1, r2)
   156  }