gorgonia.org/gorgonia@v0.9.17/operations_test.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"io/ioutil"
     5  	"log"
     6  	"runtime"
     7  	"testing"
     8  
     9  	"github.com/stretchr/testify/assert"
    10  	"github.com/stretchr/testify/require"
    11  	"gorgonia.org/tensor"
    12  )
    13  
    14  func TestApplyOp(t *testing.T) {
    15  	assert := assert.New(t)
    16  	g := NewGraph()
    17  
    18  	var cpi *Node
    19  	var ct *Node
    20  	var op Op
    21  
    22  	t.Log("Simple Constant Scalar test")
    23  	cpi = NewConstant(3.1415, WithName("constantPi"))
    24  	cpi = g.AddNode(cpi)
    25  
    26  	t.Logf("g: %v", cpi.g)
    27  
    28  	op = newElemBinOp(addOpType, cpi, cpi)
    29  	added, err := ApplyOpWithName(op, "+ pi pi", cpi, cpi)
    30  	if err != nil {
    31  		t.Fatal(err)
    32  	}
    33  	assert.Equal(g, added.g)
    34  	assert.Equal(Float64, added.t)
    35  
    36  	ct = NewConstant(tensor.Ones(tensor.Float64, 3, 3)) // no graph set for ct
    37  	op = newElemBinOp(addOpType, cpi, ct)
    38  	if added, err = ApplyOpWithName(op, "+ pi constTensor(3,3)_ones", cpi, ct); err != nil {
    39  		t.Error(err)
    40  	}
    41  }
    42  
    43  var mulTests = []struct {
    44  	name   string
    45  	xshape tensor.Shape
    46  	wshape tensor.Shape
    47  
    48  	gradX []float64
    49  	gradW []float64
    50  }{
    51  	{"x vector", tensor.Shape{2}, tensor.Shape{2, 3}, []float64{3, 12}, []float64{0, 0, 0, 1, 1, 1}},
    52  	{"x mat", tensor.Shape{3, 2}, tensor.Shape{2, 3}, []float64{3, 12, 3, 12, 3, 12}, []float64{6, 6, 6, 9, 9, 9}},
    53  	{"x_vec_w_vec", tensor.Shape{6}, tensor.Shape{6}, []float64{0, 1, 2, 3, 4, 5}, []float64{0, 1, 2, 3, 4, 5}},
    54  }
    55  
    56  func TestMul(t *testing.T) {
    57  	defer runtime.GC()
    58  	assert := assert.New(t)
    59  	for _, mts := range mulTests {
    60  		g := NewGraph()
    61  		x := NewTensor(g, Float64, mts.xshape.Dims(), WithName(mts.name), WithShape(mts.xshape...), WithInit(RangedFrom(0)))
    62  		w := NewTensor(g, Float64, mts.wshape.Dims(), WithName("w"), WithShape(mts.wshape...), WithInit(RangedFrom(0)))
    63  
    64  		xw, err := Mul(x, w)
    65  		if err != nil {
    66  			t.Errorf("Error when testing %q. Err: %v", mts.name, err)
    67  			continue
    68  		}
    69  
    70  		if mts.xshape.IsVector() && mts.wshape.IsVector() {
    71  			if _, err = Grad(xw, x, w); err != nil {
    72  				t.Errorf("Error while differentiating %q, Err: %v", mts.name, err)
    73  				continue
    74  			}
    75  		} else {
    76  			cost, err := Sum(xw)
    77  			if err != nil {
    78  				t.Errorf("Error when summing %q. Err: %v", mts.name, err)
    79  				continue
    80  			}
    81  
    82  			if _, err = Grad(cost, x, w); err != nil {
    83  				t.Errorf("Error while differentiating %q, Err: %v", mts.name, err)
    84  				continue
    85  			}
    86  		}
    87  
    88  		m := NewTapeMachine(g)
    89  		if err = m.RunAll(); err != nil {
    90  			t.Errorf("Error while executing %q. Err: %v", mts.name, err)
    91  			continue
    92  		}
    93  
    94  		gradX, err := x.Grad()
    95  		if err != nil {
    96  			t.Errorf("Error while getting gradient of x %q. Err: %v", mts.name, err)
    97  		}
    98  
    99  		gradW, err := w.Grad()
   100  		if err != nil {
   101  			t.Errorf("Error while getting gradient of w %q. Err: %v", mts.name, err)
   102  		}
   103  
   104  		assert.Equal(mts.gradX, gradX.Data().([]float64))
   105  		assert.Equal(mts.gradW, gradW.Data().([]float64))
   106  		assert.True(mts.xshape.Eq(gradX.Shape()))
   107  		assert.True(mts.wshape.Eq(gradW.Shape()))
   108  		m.Close()
   109  	}
   110  
   111  	t.Logf("Testing Mul with LispMachine")
   112  	for _, mts := range mulTests {
   113  		g := NewGraph()
   114  		x := NewTensor(g, Float64, mts.xshape.Dims(), WithName(mts.name), WithShape(mts.xshape...), WithInit(RangedFrom(0)))
   115  		w := NewTensor(g, Float64, mts.wshape.Dims(), WithName("w"), WithShape(mts.wshape...), WithInit(RangedFrom(0)))
   116  
   117  		xw, err := Mul(x, w)
   118  		if err != nil {
   119  			t.Errorf("Error when testing %q. Err: %v", mts.name, err)
   120  			continue
   121  		}
   122  
   123  		if mts.xshape.IsVector() && mts.wshape.IsVector() {
   124  
   125  		} else {
   126  			if _, err = Sum(xw); err != nil {
   127  				t.Errorf("Error when summing %q. Err: %v", mts.name, err)
   128  				continue
   129  			}
   130  		}
   131  
   132  		m := NewLispMachine(g)
   133  
   134  		if err = m.RunAll(); err != nil {
   135  			// ioutil.WriteFile(fmt.Sprintf("fullGraph_%v.dot", mts.name), []byte(g.ToDot()), 0644)
   136  			t.Errorf("Error while executing %q. Err: %v", mts.name, err)
   137  			continue
   138  		}
   139  
   140  		gradX, err := x.Grad()
   141  		if err != nil {
   142  			t.Errorf("Error while getting gradient of x %q. Err: %v", mts.name, err)
   143  		}
   144  
   145  		gradW, err := w.Grad()
   146  		if err != nil {
   147  			t.Errorf("Error while getting gradient of w %q. Err: %v", mts.name, err)
   148  		}
   149  
   150  		assert.Equal(mts.gradX, gradX.Data().([]float64))
   151  		assert.Equal(mts.gradW, gradW.Data().([]float64))
   152  		assert.True(mts.xshape.Eq(gradX.Shape()))
   153  		assert.True(mts.wshape.Eq(gradW.Shape()))
   154  		m.Close()
   155  	}
   156  }
   157  
   158  var gtTests = []struct {
   159  	a, b    Value
   160  	retSame bool
   161  
   162  	expected Value
   163  	err      bool
   164  }{
   165  	// s-s
   166  	{NewF64(float64(1)), NewF64(float64(0)), true, NewF64(1.0), false},
   167  	{NewF64(float64(0)), NewF64(float64(1)), true, NewF64(0.0), false},
   168  	{NewF64(float64(1)), NewF64(float64(0)), false, NewB(true), false},
   169  	{NewF32(float32(0)), NewF32(float32(1)), false, NewB(false), false},
   170  
   171  	// s-t
   172  	{
   173  		NewF64(float64(1)), tensor.New(tensor.WithShape(2), tensor.WithBacking([]float64{0, 2})),
   174  		true,
   175  		tensor.New(tensor.WithShape(2), tensor.WithBacking([]float64{1, 0})),
   176  		false,
   177  	},
   178  
   179  	{
   180  		NewF32(float32(1)), tensor.New(tensor.WithShape(2), tensor.WithBacking([]float32{0, 2})),
   181  		false,
   182  		tensor.New(tensor.WithShape(2), tensor.WithBacking([]bool{true, false})),
   183  		false,
   184  	},
   185  
   186  	// t-s
   187  	{
   188  		tensor.New(tensor.WithShape(2), tensor.WithBacking([]float64{0, 2})), NewF64(float64(1)),
   189  		true,
   190  		tensor.New(tensor.WithShape(2), tensor.WithBacking([]float64{0, 1})),
   191  		false,
   192  	},
   193  
   194  	{
   195  		tensor.New(tensor.WithShape(2), tensor.WithBacking([]float32{0, 2})), NewF32(float32(1)),
   196  		false,
   197  		tensor.New(tensor.WithShape(2), tensor.WithBacking([]bool{false, true})),
   198  		false,
   199  	},
   200  
   201  	// t-t
   202  	{
   203  		tensor.New(tensor.WithShape(2, 3), tensor.WithBacking([]float64{0, 1, 2, 3, 4, 5})),
   204  		tensor.New(tensor.WithShape(2, 3), tensor.WithBacking([]float64{5, 4, 3, 2, 1, 0})),
   205  		true,
   206  
   207  		tensor.New(tensor.WithShape(2, 3), tensor.WithBacking([]float64{0, 0, 0, 1, 1, 1})),
   208  		false,
   209  	},
   210  
   211  	{
   212  		tensor.New(tensor.WithShape(2, 3), tensor.WithBacking([]float64{0, 1, 2, 3, 4, 5})),
   213  		tensor.New(tensor.WithShape(2, 3), tensor.WithBacking([]float64{5, 4, 3, 2, 1, 0})),
   214  		false,
   215  
   216  		tensor.New(tensor.WithShape(2, 3), tensor.WithBacking([]bool{false, false, false, true, true, true})),
   217  		false,
   218  	},
   219  
   220  	// stupids
   221  
   222  	// different shapes
   223  	{
   224  		tensor.New(tensor.Of(tensor.Float32), tensor.WithShape(2)), tensor.New(tensor.Of(tensor.Float32), tensor.WithShape(4)),
   225  		true, nil, true,
   226  	},
   227  
   228  	// different dtypes
   229  	{
   230  		tensor.New(tensor.Of(tensor.Float64), tensor.WithShape(2)), tensor.New(tensor.Of(tensor.Float32), tensor.WithShape(2)),
   231  		true, nil, true,
   232  	},
   233  }
   234  
   235  func TestGt(t *testing.T) {
   236  	defer runtime.GC()
   237  	for i, gtts := range gtTests {
   238  		// if i != 11 {
   239  		// 	continue
   240  		// }
   241  		g := NewGraph()
   242  		a := NodeFromAny(g, gtts.a, WithName("a"))
   243  		b := NodeFromAny(g, gtts.b, WithName("b"))
   244  
   245  		var ret *Node
   246  		var err error
   247  		ret, err = Gt(a, b, gtts.retSame)
   248  
   249  		switch {
   250  		case gtts.err:
   251  			if err == nil {
   252  				t.Errorf("Expected an error in Test %d", i)
   253  			}
   254  			continue
   255  		case !gtts.err && err != nil:
   256  			t.Errorf("Test %d: %+v", i, err)
   257  			continue
   258  		}
   259  
   260  		if gtts.retSame {
   261  			cost := Must(Sum(ret))
   262  			Grad(cost, a, b)
   263  		}
   264  
   265  		m1 := NewTapeMachine(g)
   266  		if err = m1.RunAll(); err != nil {
   267  			ioutil.WriteFile("fail.dot", []byte(g.ToDot()), 0644)
   268  			t.Errorf("%v", m1.Prog())
   269  			t.Errorf("Test %d: %+v", i, err)
   270  			continue
   271  		}
   272  
   273  		if !ValueEq(gtts.expected, ret.Value()) {
   274  			t.Errorf("Test %d Expected %v. Got %v", i, gtts.expected, ret.Value())
   275  		}
   276  
   277  		// Test LispMachine implementation
   278  		h := NewGraph()
   279  		x := NodeFromAny(h, gtts.a, WithName("x"))
   280  		y := NodeFromAny(h, gtts.b, WithName("y"))
   281  		ret2, _ := Gt(x, y, gtts.retSame)
   282  
   283  		var m2 VM
   284  		if gtts.retSame {
   285  			Must(Sum(ret2))
   286  			m2 = NewLispMachine(h)
   287  		} else {
   288  			m2 = NewLispMachine(h, ExecuteFwdOnly())
   289  		}
   290  		if err = m2.RunAll(); err != nil {
   291  			t.Errorf("Test %d LispMachine: %+v", i, err)
   292  			continue
   293  		}
   294  
   295  		if !ValueEq(ret.Value(), ret2.Value()) {
   296  			t.Errorf("Test %d. Expected %v. Got  %v", i, ret.Value(), ret2.Value())
   297  		}
   298  		m1.Close()
   299  		m2.Close()
   300  		runtime.GC()
   301  	}
   302  
   303  	// other special cases
   304  	g := NewGraph()
   305  	c := NewConstant(F64(1))
   306  	// T := NewTensor(g, Float64, 1, WithShape(2), WithInit(RangedFrom(0)))
   307  	T := UniformRandomNode(g, Float64, 0, 1, 2)
   308  
   309  	var gt *Node
   310  	var err error
   311  	if gt, err = Gt(c, T, true); err != nil {
   312  		t.Error(err)
   313  	}
   314  	cost := Must(Sum(gt))
   315  	Grad(cost, T)
   316  
   317  	m1 := NewTapeMachine(g)
   318  	defer m1.Close()
   319  	if err = m1.RunAll(); err != nil {
   320  		t.Error(err)
   321  	}
   322  
   323  	if (TensorType{Dims: 1, Of: Float64}) != TypeOf(gt.Value()) {
   324  		t.Error("Expected a tensor type of float64")
   325  	}
   326  
   327  	// Same test as above, but using *lispMachine
   328  
   329  	h := NewGraph()
   330  	d := NewConstant(F64(1))
   331  	U := UniformRandomNode(h, Float64, 0, 1, 2)
   332  	var gt2 *Node
   333  	if gt2, err = Gt(d, U, true); err != nil {
   334  		t.Error(err)
   335  	}
   336  	Must(Sum(gt2))
   337  
   338  	m2 := NewLispMachine(h)
   339  	defer m2.Close()
   340  	if err = m2.RunAll(); err != nil {
   341  		t.Error(err)
   342  	}
   343  
   344  	if (TensorType{Dims: 1, Of: Float64}) != TypeOf(gt2.Value()) {
   345  		t.Error("Expected a tensor type of float64")
   346  	}
   347  
   348  	t.Logf("%v", gt2.Value())
   349  	runtime.GC()
   350  
   351  }
   352  
   353  func TestMisha(t *testing.T) {
   354  	defer runtime.GC()
   355  	assert := assert.New(t)
   356  	g := NewGraph()
   357  	var err error
   358  	var x0, x1, x2, f0, f1, f2 *Node
   359  	var grad0, grad1, grad2 Nodes
   360  
   361  	x0 = NewScalar(g, Float64, WithName("x0"))
   362  	x1 = NewScalar(g, Float64, WithName("x1"))
   363  	x2 = NewScalar(g, Float64, WithName("x2"))
   364  
   365  	Let(x0, -2.5)
   366  	Let(x1, -2.2)
   367  	Let(x2, 1.0)
   368  
   369  	f0 = Must(Mish(x0))
   370  	f1 = Must(Mish(x1))
   371  	f2 = Must(Mish(x2))
   372  
   373  	if grad0, err = Grad(f0, x0); err != nil {
   374  		t.Error(err)
   375  	}
   376  	if grad1, err = Grad(f1, x1); err != nil {
   377  		t.Error(err)
   378  	}
   379  	if grad2, err = Grad(f2, x2); err != nil {
   380  		t.Error(err)
   381  	}
   382  
   383  	machine := NewTapeMachine(g)
   384  	defer machine.Close()
   385  	if err = machine.RunAll(); err != nil {
   386  		t.Error(err)
   387  	}
   388  
   389  	// assert non-monotonicity of Mish
   390  	// x0 < x1 < x2 && f0 > f1 < f2
   391  	assert.Less(extractF64(x0.Value()), extractF64(x1.Value()))
   392  	assert.Less(extractF64(x1.Value()), extractF64(x2.Value()))
   393  	assert.Greater(extractF64(f0.Value()), extractF64(f1.Value()))
   394  	assert.Less(extractF64(f1.Value()), extractF64(f2.Value()))
   395  
   396  	// assert non-monotonocity of Mish'
   397  	assert.Greater(extractF64(grad0[0].Value()), extractF64(grad1[0].Value()))
   398  	assert.Less(extractF64(grad1[0].Value()), extractF64(grad2[0].Value()))
   399  }
   400  
   401  func TestSoftMax(t *testing.T) {
   402  	defer runtime.GC()
   403  	g := NewGraph()
   404  	xT := tensor.New(tensor.WithBacking([]float64{0.1, 0.2, -0.3, 0.4, 0.5}))
   405  	x := NewVector(g, Float64, WithShape(5), WithValue(xT))
   406  	sm := Must(SoftMax(x))
   407  	logsm := Must(Neg(Must(Log(sm))))
   408  	cost := Must(Slice(logsm, S(2)))
   409  
   410  	if _, err := Grad(cost, x); err != nil {
   411  		t.Error(err)
   412  	}
   413  
   414  	m := NewTapeMachine(g, TraceExec())
   415  	defer m.Close()
   416  	if err := m.RunAll(); err != nil {
   417  		t.Error(err)
   418  	}
   419  	ioutil.WriteFile("fullGraph.dot", []byte(g.ToDot()), 0644)
   420  	var xG Value
   421  	var err error
   422  	if xG, err = x.Grad(); err != nil {
   423  		t.Error(err)
   424  	}
   425  
   426  	// machine 2, graph 2
   427  	h := NewGraph()
   428  	xT2 := tensor.New(tensor.WithBacking([]float64{0.1, 0.2, -0.3, 0.4, 0.5}))
   429  	x2 := NewVector(h, Float64, WithShape(5), WithValue(xT2))
   430  	sm2 := Must(SoftMax(x2))
   431  	logsm2 := Must(Neg(Must(Log(sm2))))
   432  	Must(Slice(logsm2, S(2)))
   433  
   434  	m2 := NewLispMachine(h)
   435  	defer m2.Close()
   436  	if err = m2.RunAll(); err != nil {
   437  		log.Printf("ERR %v", err)
   438  		t.Error(err)
   439  	}
   440  
   441  	var x2G Value
   442  	if x2G, err = x2.Grad(); err != nil {
   443  		t.Error(err)
   444  	}
   445  
   446  	if !floatsEqual64(xG.Data().([]float64), x2G.Data().([]float64)) {
   447  		t.Errorf("Expected both gradients of X to be the same.")
   448  	}
   449  	t.Logf("\n%v\n%v\n%v", sm.Value(), logsm.Value(), cost.Value())
   450  	correctXGrad := []float64{
   451  		0.178025447751409, 0.1967485475322529, -0.8806659736677602, 0.24030921861990098, 0.2655827597641975,
   452  	}
   453  
   454  	if !floatsEqual64(correctXGrad, x2G.Data().([]float64)) {
   455  		t.Errorf("Expected results to be %v. Got %v.", correctXGrad, x2G.Data())
   456  	}
   457  	if !floatsEqual64(correctXGrad, xG.Data().([]float64)) {
   458  		t.Errorf("Expected results to be %v. Got %v.", correctXGrad, xG.Data())
   459  	}
   460  }
   461  
   462  var sliceTests = []struct {
   463  	name   string
   464  	shape  tensor.Shape
   465  	slices []tensor.Slice
   466  
   467  	expected tensor.Shape
   468  	data     interface{}
   469  	err      bool
   470  }{
   471  	{"vec[0]", tensor.Shape{2}, []tensor.Slice{S(0)}, scalarShape, float64(0), false},
   472  	{"vec[0:2]", tensor.Shape{2}, []tensor.Slice{S(0, 2)}, tensor.Shape{2}, []float64{0, 1}, false},
   473  	{"Mat[0]", tensor.Shape{2, 3}, []tensor.Slice{S(0)}, tensor.Shape{3}, []float64{0, 1, 2}, false},
   474  	{"Mat[:, 0]", tensor.Shape{2, 3}, []tensor.Slice{nil, S(0)}, tensor.Shape{2}, []float64{0, 3}, false},
   475  	{"3Tensor[0]", tensor.Shape{2, 3, 4}, []tensor.Slice{S(0)}, tensor.Shape{3, 4}, tensor.Range(tensor.Float64, 0, 12), false},
   476  	{"3Tensor[0:2]", tensor.Shape{2, 3, 4}, []tensor.Slice{S(0, 2)}, tensor.Shape{2, 3, 4}, tensor.Range(tensor.Float64, 0, 24), false},
   477  	{"3Tensor[:, 0]", tensor.Shape{2, 3, 4}, []tensor.Slice{nil, S(0)}, tensor.Shape{2, 4}, []float64{0, 1, 2, 3, 12, 13, 14, 15}, false},
   478  	{"3Tensor[0, :, 0]", tensor.Shape{2, 3, 4}, []tensor.Slice{S(0), nil, S(0)}, tensor.Shape{3}, []float64{0, 4, 8}, false},
   479  
   480  	{"vec[:, 0]", tensor.Shape{2}, []tensor.Slice{nil, S(0)}, nil, nil, true},
   481  }
   482  
   483  func TestSlice(t *testing.T) {
   484  	defer runtime.GC()
   485  	for _, sts := range sliceTests {
   486  		g := NewGraph()
   487  		x := NewTensor(g, Float64, len(sts.shape), WithShape(sts.shape...), WithInit(RangedFrom(0)))
   488  		sliced, err := Slice(x, sts.slices...)
   489  		switch {
   490  		case sts.err:
   491  			if err == nil {
   492  				t.Errorf("Expected an error while running test %q", sts.name)
   493  			}
   494  			continue
   495  		case !sts.err && err != nil:
   496  			t.Errorf("Error in %q: %+v", sts.name, err)
   497  			continue
   498  		}
   499  
   500  		// test expected shapes:
   501  		if !sts.expected.Eq(sliced.shape) {
   502  			t.Errorf("Test %q - Expected %v. Got %v instead", sts.name, sts.expected, sliced.shape)
   503  			continue
   504  		}
   505  
   506  		// test forwards and backwards prop
   507  		cost := Must(Sum(sliced))
   508  		if _, err := Grad(cost, x); err != nil {
   509  			t.Errorf("Test %q failed to backprop: %+v", sts.name, err)
   510  			continue
   511  		}
   512  
   513  		m1 := NewTapeMachine(g)
   514  		if err = m1.RunAll(); err != nil {
   515  			t.Errorf("Test %q Runtime error %+v ", sts.name, err)
   516  			continue
   517  		}
   518  
   519  		sV := sliced.Value()
   520  		if !sts.expected.Eq(sV.Shape()) {
   521  			t.Errorf("Test %q For TapeMachine. Expected sliced value to have the shape %v. Got %v instead", sts.name, sts.expected, sV.Shape())
   522  		}
   523  
   524  		assert.Equal(t, sts.data, sV.Data(), "Test %q For TapeMachine data expected %v, Got %v instead. Formatted:\n %+v", sts.name, sts.data, sV.Data(), sV)
   525  
   526  		// Test Lisp Machine for equivalence of gradients
   527  
   528  		h := NewGraph()
   529  		a := NewTensor(g, Float64, len(sts.shape), WithShape(sts.shape...), WithInit(RangedFrom(0)))
   530  		sliced2 := Must(Slice(a, sts.slices...))
   531  		Must(Sum(sliced2))
   532  
   533  		m2 := NewLispMachine(h)
   534  		if err = m2.RunAll(); err != nil {
   535  			t.Errorf("Test %q Lispmachine Runtime error: %+v", sts.name, err)
   536  			continue
   537  		}
   538  
   539  		s2V := sliced2.Value()
   540  		if !sts.expected.Eq(s2V.Shape()) {
   541  			t.Errorf("Test %q For LispMachine. Expected sliced value to have the shape %v. Got %v instead", sts.name, sts.expected, s2V.Shape())
   542  		}
   543  
   544  		assert.Equal(t, sts.data, s2V.Data(), "Test %q For TapeMachine data expected %v, Got %v instead. Formatted:\n %+v", sts.name, sts.data, s2V.Data(), s2V)
   545  
   546  		sG, err := sliced.Grad()
   547  		if err != nil {
   548  			t.Errorf("Test %q sliced has no grad: %+v", sts.name, err)
   549  			continue
   550  		}
   551  
   552  		s2G, err := sliced2.Grad()
   553  		if err != nil {
   554  			t.Errorf("Test %q sliced2 has no grad: %+v", sts.name, err)
   555  			continue
   556  		}
   557  
   558  		if !ValueEq(sG, s2G) {
   559  			t.Errorf("Test %q - Expected sG and s2G to have the same value", sts.name)
   560  		}
   561  
   562  		m1.Close()
   563  		m2.Close()
   564  
   565  		// For visual checks
   566  		// xG, err := x.Grad()
   567  		// t.Logf("Test  %q x: \n%+v,\n%+v", sts.name, x.Value(), xG)
   568  	}
   569  
   570  	// special cases with UnsafeLet
   571  	g := NewGraph()
   572  	x := NewTensor(g, Float64, 2, WithShape(2, 3), WithInit(RangedFrom(0)))
   573  	sliced, _ := Slice(x, S(0))
   574  	cost := Must(Slice(sliced, S(0)))
   575  	Grad(cost, x)
   576  
   577  	m := NewTapeMachine(g)
   578  	defer m.Close()
   579  	// mutate the graph before running
   580  	UnsafeLet(sliced, S(1))
   581  	UnsafeLet(cost, S(2))
   582  	if err := m.RunAll(); err != nil {
   583  		t.Fatal(err)
   584  	}
   585  
   586  	xG, err := x.Grad()
   587  	if err != nil {
   588  		t.Fatal(err)
   589  	}
   590  
   591  	// ioutil.WriteFile("blah.dot", []byte(g.ToDot()), 0644)
   592  	assert.Equal(t, []float64{0, 0, 0, 0, 0, 1}, xG.Data())
   593  	// visual inspection
   594  	// t.Logf("x: \n%+v,\n%+v", x.Value(), xG)
   595  
   596  }
   597  
   598  var sumTests = []struct {
   599  	name  string
   600  	shape tensor.Shape
   601  	along []int
   602  
   603  	expectedShape tensor.Shape
   604  	expectedVal   Value
   605  	expectedGrad  Value
   606  	err           bool
   607  }{
   608  	{"Sum(vec)", tensor.Shape{2}, nil, scalarShape, NewF64(1.0), NewF64(1.0), false},
   609  	{"Sum(vec, 0)", tensor.Shape{2}, []int{0}, scalarShape, NewF64(1), NewF64(1.0), false},
   610  	{"Sum(Mat)", tensor.Shape{2, 3}, nil, scalarShape, NewF64(15.0), tensor.New(tensor.WithShape(2, 3), tensor.WithBacking([]float64{1, 1, 1, 1, 1, 1})), false},
   611  	{"Sum(Mat, 0)", tensor.Shape{2, 3}, []int{0}, tensor.Shape{3},
   612  		tensor.New(tensor.WithShape(3), tensor.WithBacking([]float64{3, 5, 7})),
   613  		tensor.New(tensor.WithShape(2, 3), tensor.WithBacking([]float64{1, 1, 1, 1, 1, 1})), false,
   614  	},
   615  	{"Sum(Mat, 1)", tensor.Shape{2, 3}, []int{1}, tensor.Shape{2},
   616  		tensor.New(tensor.WithShape(2), tensor.WithBacking([]float64{3, 12})),
   617  		tensor.New(tensor.WithShape(2, 3), tensor.WithBacking([]float64{1, 1, 1, 1, 1, 1})), false,
   618  	},
   619  
   620  	// TODO: tests for 3-Tensors
   621  	// TODO: negative and stupids cases.
   622  }
   623  
   624  func TestSum(t *testing.T) {
   625  	defer runtime.GC()
   626  	for _, sts := range sumTests {
   627  		g := NewGraph()
   628  		x := NewTensor(g, Float64, len(sts.shape), WithShape(sts.shape...), WithInit(RangedFrom(0)))
   629  		var s *Node
   630  		var err error
   631  
   632  		if len(sts.along) == 0 {
   633  			s, err = Sum(x)
   634  		} else {
   635  			s, err = Sum(x, sts.along...)
   636  		}
   637  
   638  		switch {
   639  		case sts.err:
   640  			if err == nil {
   641  				t.Errorf("Expected an error in %q", sts.name)
   642  			}
   643  			continue
   644  		case !sts.err && err != nil:
   645  			t.Errorf("Test %q errored while Sum() %+v", sts.name, err)
   646  			continue
   647  		}
   648  
   649  		if !sts.expectedShape.Eq(s.shape) {
   650  			t.Errorf("Test %q has wrong shape. Want %v, got %v instead", sts.name, sts.expectedShape, s.shape)
   651  			continue
   652  		}
   653  
   654  		cost := s
   655  		if len(sts.along) < len(sts.shape) && len(sts.along) > 0 {
   656  			cost = Must(Sum(s))
   657  		}
   658  
   659  		if _, err = Grad(cost, x); err != nil {
   660  			t.Errorf("Test %q - Unable to back prop. Err : %+v", sts.name, err)
   661  			continue
   662  		}
   663  
   664  		m := NewTapeMachine(g)
   665  		if err = m.RunAll(); err != nil {
   666  			t.Errorf("Test %q - Runtime error: %v", sts.name, err)
   667  			continue
   668  		}
   669  
   670  		if !ValueEq(sts.expectedVal, s.Value()) {
   671  			t.Errorf("Test %q Expected %v. Got %v", sts.name, sts.expectedVal, s.Value())
   672  		}
   673  
   674  		sG, err := s.Grad()
   675  		if err != nil {
   676  			t.Errorf("Test %q Grad() error: %+v", sts.name, err)
   677  			continue
   678  		}
   679  
   680  		// LISP MACHINE TO TEST GRAD EQUIVALENCE
   681  		h := NewGraph()
   682  		a := NewTensor(h, Float64, len(sts.shape), WithShape(sts.shape...), WithInit(RangedFrom(0)))
   683  		var b *Node
   684  		if len(sts.along) == 0 {
   685  			b = Must(Sum(a))
   686  		} else {
   687  			b = Must(Sum(a, sts.along...))
   688  		}
   689  
   690  		if len(sts.along) < len(sts.shape) && len(sts.along) > 0 {
   691  			Must(Sum(b))
   692  		}
   693  
   694  		m2 := NewLispMachine(h)
   695  		if err = m2.RunAll(); err != nil {
   696  			t.Errorf("Test %q Lisp machine runtime error %+v", sts.name, err)
   697  			continue
   698  		}
   699  
   700  		if !ValueEq(sts.expectedVal, b.Value()) {
   701  			t.Errorf("Test %q LispMachine Run. Expected %v. Got %v instead", sts.name, sts.expectedVal, b.Value())
   702  		}
   703  
   704  		bG, err := b.Grad()
   705  		if err != nil {
   706  			t.Errorf("Test %q Grad() err in lispmachine run %+v", sts.name, err)
   707  			continue
   708  		}
   709  
   710  		if !ValueEq(sG, bG) {
   711  			t.Errorf("Expected the values of the partial derivatives of both machines to be the same")
   712  		}
   713  
   714  		m.Close()
   715  		m2.Close()
   716  	}
   717  }
   718  
   719  func TestNorm(t *testing.T) {
   720  	assert := assert.New(t)
   721  	g := NewGraph()
   722  	x := NewMatrix(g, Float64, WithShape(3, 3))
   723  	norm, err := Norm(x, 0, 2)
   724  	if err != nil {
   725  		t.Error(err)
   726  		return
   727  	}
   728  	m := NewLispMachine(g, ExecuteFwdOnly())
   729  	defer m.Close()
   730  
   731  	xT := tensor.New(tensor.WithShape(3, 3), tensor.WithBacking(tensor.Range(tensor.Float64, 0, 9)))
   732  	Let(x, xT)
   733  	m.RunAll()
   734  
   735  	correct := []float64{6.708203932499369, 8.12403840463596, 9.643650760992955}
   736  	assert.Equal(correct, extractF64s(norm.Value()))
   737  
   738  }
   739  
   740  func TestMean(t *testing.T) {
   741  	g := NewGraph()
   742  	x := NewMatrix(g, Float64, WithShape(3, 3))
   743  	m, err := Mean(x)
   744  	if err != nil {
   745  		t.Fatal(err)
   746  	}
   747  
   748  	if !m.IsScalar() {
   749  		t.Error("Expected result to be scalar")
   750  	}
   751  }
   752  
   753  func TestTensordot(t *testing.T) {
   754  	assert := assert.New(t)
   755  
   756  	// Scalars
   757  	g := NewGraph()
   758  
   759  	a := NewTensor(g, Float64, 0, WithName("a"), WithShape(1), WithInit(RangedFrom(2)))
   760  	b := NewTensor(g, Float64, 0, WithName("b"), WithShape(1), WithInit(RangedFrom(21)))
   761  	c := NewTensor(g, Float64, 0, WithName("c"), WithShape(1), WithInit(ValuesOf(1.0)))
   762  
   763  	tensordot, err := Tensordot([]int{0}, []int{0}, a, b)
   764  	if err == nil {
   765  		t.Fatal("Expected scalars to fail")
   766  	}
   767  
   768  	// Scalar-like
   769  	g = NewGraph()
   770  	a = NewTensor(g, Float64, 1, WithName("a"), WithShape(1), WithInit(RangedFrom(2)))
   771  	b = NewTensor(g, Float64, 1, WithName("b"), WithShape(1), WithInit(RangedFrom(21)))
   772  	c = NewTensor(g, Float64, 1, WithName("c"), WithShape(1), WithInit(ValuesOf(1.0)))
   773  
   774  	tensordot, err = Tensordot([]int{0}, []int{0}, a, b)
   775  	if err != nil {
   776  		t.Fatal(err)
   777  	}
   778  	log.Printf("SHAPE a %v b %v c %v tensordot %v", a.Shape(), b.Shape(), c.Shape(), tensordot.Shape())
   779  
   780  	dtensordot, err := Backpropagate(Nodes{tensordot}, Nodes{c}, Nodes{a, b})
   781  
   782  	if err != nil {
   783  		t.Fatalf("%+v", err)
   784  	}
   785  
   786  	m := NewTapeMachine(g)
   787  	defer m.Close()
   788  	if err = m.RunAll(); err != nil {
   789  		t.Fatal(err)
   790  	}
   791  
   792  	correctScalarlike := []float64{42.0}
   793  	value := tensordot.Value().Data()
   794  	assert.Equal(correctScalarlike, value)
   795  
   796  	dtensordotCorrectScalarlike0 := []float64{21}
   797  	dtensordotCorrectScalarlike1 := []float64{2}
   798  
   799  	assert.Equal(dtensordotCorrectScalarlike0, dtensordot[0].Value().Data())
   800  	assert.Equal(dtensordotCorrectScalarlike1, dtensordot[1].Value().Data())
   801  
   802  	// Vectors
   803  
   804  	g = NewGraph()
   805  	a = NewTensor(g, Float64, 1, WithName("a"), WithShape(2), WithInit(RangedFrom(1)))
   806  	b = NewTensor(g, Float64, 1, WithName("b"), WithShape(2), WithInit(RangedFrom(3)))
   807  	c = NewTensor(g, Float64, 0, WithName("c"), WithShape(), WithInit(ValuesOf(1.0)))
   808  
   809  	if tensordot, err = Tensordot([]int{0}, []int{0}, a, b); err != nil {
   810  		t.Fatal(err)
   811  	}
   812  
   813  	if dtensordot, err = Backpropagate(Nodes{tensordot}, Nodes{c}, Nodes{a, b}); err != nil {
   814  		t.Fatalf("%+v", err)
   815  	}
   816  
   817  	// Need to multiply dtensordot with identiy matrix, otherwise the transpose action in symdiff is not performed
   818  	id := NewConstant(tensor.I(Float64, 2, 2, 0))
   819  
   820  	dtensordot0 := Must(Mul(id, dtensordot[0]))
   821  	dtensordot1 := Must(Mul(id, dtensordot[1]))
   822  
   823  	m = NewTapeMachine(g)
   824  	defer m.Close()
   825  	if err = m.RunAll(); err != nil {
   826  		t.Fatal(err)
   827  	}
   828  
   829  	log.Printf("TensorDot %v | %v", tensordot.Value().Shape(), tensordot.Type())
   830  	correctScalarlike = []float64{11}
   831  	assert.Equal(correctScalarlike, tensordot.Value().Data())
   832  
   833  	dcorrect0 := []float64{3, 4}
   834  	dcorrect1 := []float64{1, 2}
   835  
   836  	assert.Equal(dcorrect0, extractF64s(dtensordot[0].Value()))
   837  	assert.Equal(dcorrect1, extractF64s(dtensordot[1].Value()))
   838  
   839  	// Vector and Matrix
   840  	g = NewGraph()
   841  	a = NewTensor(g, Float64, 2, WithName("a"), WithShape(2, 2), WithInit(RangedFrom(0)))
   842  	b = NewTensor(g, Float64, 1, WithName("b"), WithShape(2), WithInit(RangedFrom(0)))
   843  
   844  	c = NewTensor(g, Float64, 1, WithName("c"), WithShape(2), WithInit(ValuesOf(1.0)))
   845  
   846  	if tensordot, err = Tensordot([]int{1}, []int{0}, a, b); err != nil {
   847  		t.Fatal(err)
   848  	}
   849  
   850  	if dtensordot, err = Backpropagate(Nodes{tensordot}, Nodes{c}, Nodes{a, b}); err != nil {
   851  		t.Fatal(err)
   852  	}
   853  
   854  	// Need to multiply dtensordot with identiy matrix, otherwise the transpose action in symdiff is not performed
   855  	id = NewConstant(tensor.I(Float64, 2, 2, 0))
   856  
   857  	if dtensordot0, err = Mul(id, dtensordot[0]); err != nil {
   858  		t.Fatal(err)
   859  	}
   860  	if dtensordot1, err = Mul(id, dtensordot[1]); err != nil {
   861  		t.Fatal(err)
   862  	}
   863  
   864  	m = NewTapeMachine(g)
   865  	defer m.Close()
   866  	if err = m.RunAll(); err != nil {
   867  		t.Fatal(err)
   868  	}
   869  
   870  	correct := []float64{1, 3}
   871  	assert.Equal(correct, extractF64s(tensordot.Value()))
   872  
   873  	dcorrect0 = []float64{0, 1, 0, 1}
   874  	dcorrect1 = []float64{2, 4}
   875  
   876  	assert.Equal(dcorrect0, extractF64s(dtensordot0.Value()))
   877  	assert.Equal(dcorrect1, extractF64s(dtensordot1.Value()))
   878  
   879  	// Matrices
   880  	g = NewGraph()
   881  
   882  	a = NewTensor(g, Float64, 2, WithName("a"), WithShape(2, 2), WithInit(RangedFrom(0)))
   883  	b = NewTensor(g, Float64, 2, WithName("b"), WithShape(2, 2), WithInit(RangedFrom(0)))
   884  
   885  	c = NewTensor(g, Float64, 2, WithName("c"), WithShape(2, 2), WithInit(ValuesOf(1.0)))
   886  
   887  	if tensordot, err = Tensordot([]int{1}, []int{1}, a, b); err != nil {
   888  		t.Fatal(err)
   889  	}
   890  
   891  	if dtensordot, err = Backpropagate(Nodes{tensordot}, Nodes{c}, Nodes{a, b}); err != nil {
   892  		t.Fatal(err)
   893  	}
   894  
   895  	// Need to multiply dtensordot with identiy matrix, otherwise the transpose action in symdiff is not performed
   896  	id = NewConstant(tensor.I(Float64, 2, 2, 0))
   897  
   898  	if dtensordot0, err = Mul(id, dtensordot[0]); err != nil {
   899  		t.Fatal(err)
   900  	}
   901  	if dtensordot1, err = Mul(id, dtensordot[1]); err != nil {
   902  		t.Fatal(err)
   903  	}
   904  
   905  	m = NewTapeMachine(g)
   906  	if err = m.RunAll(); err != nil {
   907  		t.Fatal(err)
   908  	}
   909  
   910  	correct = []float64{1, 3, 3, 13}
   911  	assert.Equal(correct, extractF64s(tensordot.Value()))
   912  
   913  	dcorrect := []float64{2, 4, 2, 4}
   914  	assert.Equal(dcorrect, extractF64s(dtensordot0.Value()))
   915  	assert.Equal(dcorrect, extractF64s(dtensordot1.Value()))
   916  
   917  	// Total matrix contraction
   918  	g = NewGraph()
   919  
   920  	a = NewTensor(g, Float64, 2, WithName("a"), WithShape(2, 2), WithInit(RangedFrom(0)))
   921  	b = NewTensor(g, Float64, 2, WithName("b"), WithShape(2, 2), WithInit(RangedFrom(0)))
   922  
   923  	c = NewTensor(g, Float64, 0, WithName("c"), WithShape(), WithInit(ValuesOf(1.0)))
   924  
   925  	if tensordot, err = Tensordot([]int{0, 1}, []int{0, 1}, a, b); err != nil {
   926  		t.Fatal(err)
   927  	}
   928  
   929  	if dtensordot, err = Backpropagate(Nodes{tensordot}, Nodes{c}, Nodes{a, b}); err != nil {
   930  		t.Fatal(err)
   931  	}
   932  
   933  	// Need to multiply dtensordot with identiy matrix, otherwise the transpose action in symdiff is not performed
   934  	id = NewConstant(tensor.I(Float64, 2, 2, 0))
   935  
   936  	if dtensordot0, err = Mul(id, dtensordot[0]); err != nil {
   937  		t.Fatal(err)
   938  	}
   939  	if dtensordot1, err = Mul(id, dtensordot[1]); err != nil {
   940  		t.Fatal(err)
   941  	}
   942  
   943  	m = NewTapeMachine(g)
   944  	defer m.Close()
   945  	if err = m.RunAll(); err != nil {
   946  		t.Fatal(err)
   947  	}
   948  
   949  	correctScalarlike = []float64{14}
   950  	assert.Equal(correctScalarlike, tensordot.Value().Data())
   951  
   952  	dcorrect = []float64{0, 1, 2, 3}
   953  	assert.Equal(dcorrect, extractF64s(dtensordot0.Value()))
   954  	assert.Equal(dcorrect, extractF64s(dtensordot1.Value()))
   955  
   956  }
   957  
   958  var reshapeTests = []struct {
   959  	testName string
   960  	input    tensor.Shape
   961  	to       tensor.Shape
   962  	output   tensor.Shape
   963  	err      bool
   964  }{
   965  	{"simple", tensor.Shape{2, 2}, tensor.Shape{4}, tensor.Shape{4}, false},
   966  	{"simple big tensor", tensor.Shape{200, 200}, tensor.Shape{200 * 200}, tensor.Shape{200 * 200}, false},
   967  	{"negative dim1 1", tensor.Shape{3, 2}, tensor.Shape{6, -1}, tensor.Shape{6, 1}, false},
   968  	{"negative dim1 2", tensor.Shape{3, 2}, tensor.Shape{2, -1}, tensor.Shape{2, 3}, false},
   969  	{"negative dim0 1", tensor.Shape{3, 2}, tensor.Shape{-1, 3}, tensor.Shape{2, 3}, false},
   970  	{"negative dims0.1 with error", tensor.Shape{3, 2}, tensor.Shape{-1, -1}, nil, true},
   971  	{"devative dim0 with error", tensor.Shape{3, 2}, tensor.Shape{4, -1}, nil, true},
   972  }
   973  
   974  func TestReshape(t *testing.T) {
   975  	for _, rst := range reshapeTests {
   976  		g := NewGraph()
   977  		T := NewTensor(g, Float64, len(rst.input), WithShape(rst.input.Clone()...))
   978  		T2, err := Reshape(T, rst.to.Clone())
   979  		t.Log(T2)
   980  		switch {
   981  		case rst.err && err == nil:
   982  			t.Fatalf("Expected Error when testing %v", rst)
   983  		case rst.err:
   984  			continue
   985  		case err != nil:
   986  			t.Fatal(err)
   987  		default:
   988  			assert.True(t, rst.output.Eq(T2.Shape()), "expected both to be the same")
   989  		}
   990  
   991  	}
   992  }
   993  func TestReshape_Dense(t *testing.T) {
   994  	for _, rst := range reshapeTests {
   995  		g := NewGraph()
   996  		tT := tensor.New(tensor.Of(tensor.Float64), tensor.WithShape(rst.input.Clone()...))
   997  		T := NodeFromAny(g, tT)
   998  		T2, err := Reshape(T, rst.to.Clone())
   999  		switch {
  1000  		case rst.err && err == nil:
  1001  			t.Fatalf("Expected Error when testing %v", rst)
  1002  		case rst.err:
  1003  			continue
  1004  		case err != nil:
  1005  			t.Fatal(err)
  1006  		default:
  1007  			assert.True(t, rst.output.Eq(T2.Shape()), "expected both to be the same")
  1008  		}
  1009  		m := NewTapeMachine(g)
  1010  		if err := m.RunAll(); err != nil {
  1011  			t.Errorf("Error while executing %q. Err: %v", rst.testName, err)
  1012  			continue
  1013  		}
  1014  
  1015  	}
  1016  }
  1017  
  1018  func TestReshapeRuntime(t *testing.T) {
  1019  	g := NewGraph()
  1020  	x := NewMatrix(g, tensor.Float64, WithName("x"), WithShape(28, 28), WithInit(GlorotU(1)))
  1021  	w := NewMatrix(g, tensor.Float64, WithName("W"), WithShape(50, 784), WithInit(GlorotU(1)))
  1022  	x2 := Must(Reshape(x, tensor.Shape{784}))
  1023  	wx := Must(Mul(w, x2))
  1024  	wx2 := Must(Reshape(wx, tensor.Shape{5, 10}))
  1025  
  1026  	cost := Must(Sum(wx2))
  1027  	if _, err := Grad(cost, w); err != nil {
  1028  		t.Fatal(err)
  1029  	}
  1030  	m := NewTapeMachine(g)
  1031  	if err := m.RunAll(); err != nil {
  1032  		t.Fatal(err)
  1033  	}
  1034  
  1035  	if !x.Value().Shape().Eq(tensor.Shape{28, 28}) {
  1036  		t.Errorf("A mutation of shape has occurred")
  1037  	}
  1038  }
  1039  
  1040  var ravelTests = []struct {
  1041  	input  tensor.Shape
  1042  	output tensor.Shape
  1043  }{
  1044  	{
  1045  		tensor.Shape{3, 3},
  1046  		tensor.Shape{9},
  1047  	},
  1048  	{
  1049  		tensor.Shape{2, 3},
  1050  		tensor.Shape{6},
  1051  	},
  1052  	{
  1053  		tensor.Shape{2, 1, 3},
  1054  		tensor.Shape{6},
  1055  	},
  1056  	{
  1057  		tensor.Shape{1, 1, 1},
  1058  		tensor.Shape{1},
  1059  	},
  1060  }
  1061  
  1062  func TestRavel(t *testing.T) {
  1063  	c := require.New(t)
  1064  
  1065  	for i, rst := range ravelTests {
  1066  		g := NewGraph()
  1067  		t := NewTensor(g, Float64, len(rst.input), WithShape(rst.input...))
  1068  		t2, err := Ravel(t)
  1069  
  1070  		c.NoError(err)
  1071  		c.Equal(rst.output, t2.Shape(), "expected to be flatten in test case: %d", i)
  1072  	}
  1073  }
  1074  
  1075  func TestAuto(t *testing.T) {
  1076  	testCases := []struct {
  1077  		desc          string
  1078  		shapeA        tensor.Shape
  1079  		shapeB        tensor.Shape
  1080  		expectedShape tensor.Shape
  1081  		expectedErr   string
  1082  	}{
  1083  		{
  1084  			desc:        "Example 0",
  1085  			shapeA:      tensor.Shape{12},
  1086  			shapeB:      tensor.Shape{1, 11},
  1087  			expectedErr: "shapes (12) and (1, 11) should have the same dimensions",
  1088  		},
  1089  		{
  1090  			desc:          "Example 1",
  1091  			shapeA:        tensor.Shape{12, 1},
  1092  			shapeB:        tensor.Shape{12, 11},
  1093  			expectedShape: tensor.Shape{12, 11},
  1094  			expectedErr:   "",
  1095  		},
  1096  		{
  1097  			desc:          "Example 2",
  1098  			shapeA:        tensor.Shape{1, 12},
  1099  			shapeB:        tensor.Shape{11, 12},
  1100  			expectedShape: tensor.Shape{11, 12},
  1101  			expectedErr:   "",
  1102  		},
  1103  		{
  1104  			desc:          "Example 3",
  1105  			shapeA:        tensor.Shape{2, 3, 5},
  1106  			shapeB:        tensor.Shape{2, 3, 1},
  1107  			expectedShape: tensor.Shape{2, 3, 5},
  1108  			expectedErr:   "",
  1109  		},
  1110  		{
  1111  			desc:          "Example 4",
  1112  			shapeA:        tensor.Shape{2, 1, 5},
  1113  			shapeB:        tensor.Shape{2, 3, 5},
  1114  			expectedShape: tensor.Shape{2, 3, 5},
  1115  			expectedErr:   "",
  1116  		},
  1117  		{
  1118  			desc:          "Example 5",
  1119  			shapeA:        tensor.Shape{2, 1, 1},
  1120  			shapeB:        tensor.Shape{2, 5, 3},
  1121  			expectedShape: tensor.Shape{2, 5, 3},
  1122  			expectedErr:   "",
  1123  		},
  1124  	}
  1125  	for _, tC := range testCases {
  1126  		t.Run(tC.desc, func(t *testing.T) {
  1127  			c := require.New(t)
  1128  
  1129  			g := NewGraph()
  1130  			a := NewTensor(g, Float64, tC.shapeA.Dims(), WithShape(tC.shapeA...), WithInit(RangedFrom(0)))
  1131  			b := NewTensor(g, Float64, tC.shapeB.Dims(), WithShape(tC.shapeB...), WithInit(RangedFrom(0)))
  1132  
  1133  			out, err := Auto(BroadcastHadamardProd, a, b)
  1134  
  1135  			if tC.expectedErr != "" {
  1136  				c.Error(err)
  1137  				c.Equal(tC.expectedErr, err.Error())
  1138  				return
  1139  			} else {
  1140  				c.NoError(err)
  1141  			}
  1142  
  1143  			c.Equal(tC.expectedShape, out.Shape())
  1144  
  1145  			out, err = Auto(BroadcastHadamardProd, b, a)
  1146  			c.NoError(err)
  1147  			c.Equal(tC.expectedShape, out.Shape())
  1148  		})
  1149  	}
  1150  }