gorgonia.org/gorgonia@v0.9.17/op_reduction_test.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"fmt"
     5  	"runtime"
     6  	"testing"
     7  
     8  	"github.com/stretchr/testify/assert"
     9  	"gorgonia.org/tensor"
    10  )
    11  
    12  func TestSumOpGrad(t *testing.T) {
    13  	t.SkipNow()
    14  	assert := assert.New(t)
    15  	// var g *ExprGraph
    16  	var z, sz *Node
    17  	var grads Nodes
    18  	var err error
    19  	var op sumOp
    20  
    21  	_, _, _, z = simpleVecEqn()
    22  	sz = Must(Sum(z))
    23  	// t.Logf(" %v  %v %v %v", g, x, y, z)
    24  
    25  	diffWRT := sz.diffWRT()
    26  	assert.Equal([]bool{true}, diffWRT)
    27  
    28  	op = sz.op.(sumOp)
    29  	grads, err = op.SymDiff(Nodes{z}, sz, onef64)
    30  	assert.Nilf(err, "Got %+v", err)
    31  	assert.Equal(1, len(grads))
    32  	t.Logf("%v", grads[0])
    33  }
    34  
    35  func TestSumOpFakeVec(t *testing.T) {
    36  	g := NewGraph()
    37  
    38  	xv := tensor.New(tensor.WithBacking([]float64{1, 2}), tensor.WithShape(2, 1))
    39  	yv := tensor.New(tensor.WithBacking([]float64{10, 20}), tensor.WithShape(1, 2))
    40  	x := NewMatrix(g, Float64, WithName("x"), WithShape(2, 1), WithValue(xv))
    41  	y := NewMatrix(g, Float64, WithName("y"), WithShape(1, 2), WithValue(yv))
    42  	sx, _ := Sum(x)
    43  	sy, _ := Sum(y)
    44  
    45  	assert.True(t, sx.Shape().Eq(tensor.ScalarShape()))
    46  	assert.True(t, sy.Shape().Eq(tensor.ScalarShape()))
    47  
    48  	sx2, _ := Sum(x, 1)
    49  	assert.True(t, sx2.Shape().Eq(tensor.Shape{2}))
    50  
    51  	vm := NewTapeMachine(g)
    52  	vm.RunAll()
    53  
    54  	assert.Equal(t, 3.0, sx.Value().Data(), "Expected sx to be 3.0")
    55  	assert.Equal(t, 30.0, sy.Value().Data(), "Expected sy to be 30.0")
    56  	assert.Equal(t, []float64{1, 2}, sx2.Value().Data(), "sx2 should be a flat array")
    57  }
    58  
    59  func TestSumOpDiff(t *testing.T) {
    60  	defer runtime.GC()
    61  	assert := assert.New(t)
    62  	var g, g2 *ExprGraph
    63  	var x, y, z, a, b, c *Node
    64  	// var x, y, a, b *Node
    65  	var xG, yG, aG, bG Value
    66  	// var xG, aG Value
    67  	// var prog *program
    68  	// var locMap map[*Node]register
    69  	var m *tapeMachine
    70  	var m2 *lispMachine
    71  	var err error
    72  
    73  	// Basic Test case: a vector is summed
    74  
    75  	g = NewGraph()
    76  	x = NewVector(g, Float64, WithName("x"), WithShape(5), WithInit(RangedFrom(0)))
    77  	y = Must(Sum(x))
    78  	WithName("y")(y)
    79  
    80  	Grad(y, x)
    81  
    82  	// ioutil.WriteFile("SumOp.dot", []byte(g.ToDot()), 0644)
    83  
    84  	m = NewTapeMachine(g)
    85  	defer m.Close()
    86  	if err = m.RunAll(); err != nil {
    87  		t.Error(err)
    88  	}
    89  
    90  	g2 = NewGraph()
    91  	a = NewVector(g2, Float64, WithShape(5), WithInit(RangedFrom(0)))
    92  	b = Must(Sum(a))
    93  
    94  	m2 = NewLispMachine(g2, WithWatchlist())
    95  	defer m2.Close()
    96  	if err = m2.RunAll(); err != nil {
    97  		t.Error(err)
    98  	}
    99  
   100  	if aG, err = a.Grad(); err != nil {
   101  		t.Error(err)
   102  	}
   103  
   104  	if xG, err = x.Grad(); err != nil {
   105  		t.Error(err)
   106  	}
   107  
   108  	if bG, err = b.Grad(); err != nil {
   109  		t.Error(err)
   110  	}
   111  
   112  	if yG, err = y.Grad(); err != nil {
   113  		t.Error(err)
   114  	}
   115  
   116  	assert.True(ValueEq(x.Value(), a.Value()))
   117  	assert.True(ValueEq(xG, aG))
   118  	assert.True(ValueEq(y.Value(), b.Value()))
   119  	assert.True(ValueEq(yG, bG))
   120  
   121  	// long standing bug: sometimes the derivation will get executed in the machine first
   122  	// for example, the deriv of y is 1, and occasionally, the machine will choose to
   123  	// execute const 1 into register 0
   124  	// It would then fail to bind to y's boundTo, because at that point in time, y is still unknown.
   125  
   126  	// assert.Equal(y.Grad(), b.Grad())
   127  
   128  	// Slightly more advanced test case: A matrix is summed
   129  	g = NewGraph()
   130  	x = NewMatrix(g, Float64, WithName("x"), WithShape(11, 7), WithInit(RangedFrom(0)))
   131  	y = Must(Sum(x))
   132  	WithName("y")(y)
   133  
   134  	Grad(y, x)
   135  
   136  	m = NewTapeMachine(g)
   137  	defer m.Close()
   138  	if err = m.RunAll(); err != nil {
   139  		t.Error(err)
   140  	}
   141  
   142  	g2 = NewGraph()
   143  	a = NewMatrix(g2, Float64, WithName("x"), WithShape(11, 7), WithInit(RangedFrom(0)))
   144  	b = Must(Sum(a))
   145  
   146  	m2 = NewLispMachine(g2)
   147  	defer m2.Close()
   148  	if err = m2.RunAll(); err != nil {
   149  		t.Error(err)
   150  	}
   151  
   152  	if aG, err = a.Grad(); err != nil {
   153  		t.Error(err)
   154  	}
   155  
   156  	if xG, err = x.Grad(); err != nil {
   157  		t.Error(err)
   158  	}
   159  	if bG, err = b.Grad(); err != nil {
   160  		t.Error(err)
   161  	}
   162  
   163  	if yG, err = y.Grad(); err != nil {
   164  		t.Error(err)
   165  	}
   166  	assert.True(ValueEq(x.Value(), a.Value()))
   167  	assert.True(ValueEq(xG, aG))
   168  	assert.True(ValueEq(y.Value(), b.Value()))
   169  	assert.True(ValueEq(yG, bG))
   170  
   171  	/* Sum is not the root node */
   172  
   173  	g = NewGraph()
   174  	x = NewMatrix(g, Float64, WithName("x"), WithShape(11, 7), WithInit(RangedFrom(0)))
   175  	y = Must(Sum(x))
   176  	z = Must(Add(y, twof64))
   177  
   178  	if _, err = Grad(z, x); err != nil {
   179  		t.Fatal(err)
   180  	}
   181  
   182  	m = NewTapeMachine(g)
   183  	defer m.Close()
   184  	if err = m.RunAll(); err != nil {
   185  		t.Errorf("%v", m.Prog())
   186  		t.Error(err)
   187  	}
   188  
   189  	g2 = NewGraph()
   190  	a = NewMatrix(g2, Float64, WithName("x"), WithShape(11, 7), WithInit(RangedFrom(0)))
   191  	b = Must(Sum(a))
   192  	c = Must(Add(b, twof64))
   193  
   194  	m2 = NewLispMachine(g2)
   195  	defer m2.Close()
   196  	if err = m2.RunAll(); err != nil {
   197  		t.Fatalf("%+v", err)
   198  	}
   199  
   200  	if aG, err = a.Grad(); err != nil {
   201  		t.Error(err)
   202  	}
   203  
   204  	if xG, err = x.Grad(); err != nil {
   205  		t.Error(err)
   206  	}
   207  
   208  	if bG, err = b.Grad(); err != nil {
   209  		t.Error(err)
   210  	}
   211  
   212  	if yG, err = b.Grad(); err != nil {
   213  		t.Error(err)
   214  	}
   215  
   216  	assert.True(ValueEq(x.Value(), a.Value()))
   217  	assert.True(ValueEq(xG, aG))
   218  	assert.True(ValueEq(y.Value(), b.Value()))
   219  	assert.True(ValueEq(yG, bG))
   220  	assert.True(ValueEq(z.Value(), c.Value()))
   221  
   222  	runtime.GC()
   223  }
   224  
   225  func TestMaxOp(t *testing.T) {
   226  	subTests := []reductionTest{
   227  		{dt: Float32, inShape: []int{3, 2}, inData: []float32{1, 2, 3, 4, 5, 6}, op: Max, along: []int{0}, wantShape: []int{2}, wantData: []float32{5, 6}},
   228  		{dt: Float32, inShape: []int{3, 2}, inData: []float32{1, 2, 3, 4, 5, 6}, op: Max, along: []int{1}, wantShape: []int{3}, wantData: []float32{2, 4, 6}},
   229  		{dt: Float32, inShape: []int{3, 2}, inData: []float32{1, 2, 3, 4, 5, 6}, op: Max, along: []int{}, wantShape: []int{}, wantData: float32(6)},
   230  		{dt: Float32, inShape: []int{3, 2}, inData: []float32{1, 2, 3, 4, 5, 6}, op: Max, along: []int{0, 1}, wantShape: []int{}, wantData: float32(6)},
   231  		{dt: Float32, inShape: []int{3, 2}, inData: []float32{1, 2, 3, 4, 5, 6}, op: Max, along: []int{1, 0}, wantShape: []int{}, wantData: float32(6)},
   232  		//{dt: Float32, inShape: []int{1, 6}, inData: []float32{1, 2, 3, 4, 5, 6}, op: Max, along: []int{1}, wantShape: []int{}, wantData: float32(6)},
   233  		{
   234  			dt:        Float32,
   235  			inShape:   []int{2, 2, 2, 2},
   236  			inData:    []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
   237  			op:        Max,
   238  			along:     []int{0, 1, 2, 3},
   239  			wantShape: []int{},
   240  			wantData:  float32(16),
   241  		},
   242  		{
   243  			dt:        Float32,
   244  			inShape:   []int{2, 2, 2, 2},
   245  			inData:    []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
   246  			op:        Max,
   247  			along:     []int{},
   248  			wantShape: []int{},
   249  			wantData:  float32(16),
   250  		},
   251  		{
   252  			dt:        Float32,
   253  			inShape:   []int{2, 2, 2, 2},
   254  			inData:    []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
   255  			op:        Max,
   256  			along:     []int{0},
   257  			wantShape: []int{2, 2, 2},
   258  			wantData:  []float32{9, 10, 11, 12, 13, 14, 15, 16},
   259  		},
   260  		{
   261  			dt:        Float32,
   262  			inShape:   []int{2, 2, 2, 2},
   263  			inData:    []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
   264  			op:        Max,
   265  			along:     []int{1},
   266  			wantShape: []int{2, 2, 2},
   267  			wantData:  []float32{5, 6, 7, 8, 13, 14, 15, 16},
   268  		},
   269  		{
   270  			dt:        Float32,
   271  			inShape:   []int{2, 2, 2, 2},
   272  			inData:    []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
   273  			op:        Max,
   274  			along:     []int{2},
   275  			wantShape: []int{2, 2, 2},
   276  			wantData:  []float32{3, 4, 7, 8, 11, 12, 15, 16},
   277  		},
   278  		{
   279  			dt:        Float32,
   280  			inShape:   []int{2, 2, 2, 2},
   281  			inData:    []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
   282  			op:        Max,
   283  			along:     []int{3},
   284  			wantShape: []int{2, 2, 2},
   285  			wantData:  []float32{2, 4, 6, 8, 10, 12, 14, 16},
   286  		},
   287  		{
   288  			dt:        Float32,
   289  			inShape:   []int{2, 2, 2, 2},
   290  			inData:    []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
   291  			op:        Max,
   292  			along:     []int{1, 3},
   293  			wantShape: []int{2, 2},
   294  			wantData:  []float32{6, 8, 14, 16},
   295  		},
   296  		{
   297  			dt:        Float32,
   298  			inShape:   []int{2, 2, 2, 2},
   299  			inData:    []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
   300  			op:        Max,
   301  			along:     []int{0, 2, 3},
   302  			wantShape: []int{2},
   303  			wantData:  []float32{12, 16},
   304  		},
   305  	}
   306  
   307  	for _, subTest := range subTests {
   308  		t.Run(fmt.Sprintf("along %v", subTest.along), func(t *testing.T) {
   309  			testReductionOp(t, subTest)
   310  		})
   311  	}
   312  }
   313  
   314  func TestSumOp(t *testing.T) {
   315  	subTests := []reductionTest{
   316  		{dt: Float32, inShape: []int{3, 2}, inData: []float32{1, 2, 3, 4, 5, 6}, op: Sum, along: []int{0}, wantShape: []int{2}, wantData: []float32{9, 12}},
   317  		{dt: Float32, inShape: []int{3, 2}, inData: []float32{1, 2, 3, 4, 5, 6}, op: Sum, along: []int{1}, wantShape: []int{3}, wantData: []float32{3, 7, 11}},
   318  		{dt: Float32, inShape: []int{3, 2}, inData: []float32{1, 2, 3, 4, 5, 6}, op: Sum, along: []int{}, wantShape: []int{}, wantData: float32(21)},
   319  		{dt: Float32, inShape: []int{3, 2}, inData: []float32{1, 2, 3, 4, 5, 6}, op: Sum, along: []int{0, 1}, wantShape: []int{}, wantData: float32(21)},
   320  		{dt: Float32, inShape: []int{3, 2}, inData: []float32{1, 2, 3, 4, 5, 6}, op: Sum, along: []int{1, 0}, wantShape: []int{}, wantData: float32(21)},
   321  		{
   322  			dt:        Float32,
   323  			inShape:   []int{2, 2, 2, 2},
   324  			inData:    []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
   325  			op:        Sum,
   326  			along:     []int{0, 1, 2, 3},
   327  			wantShape: []int{},
   328  			wantData:  float32(136),
   329  		},
   330  		{
   331  			dt:        Float32,
   332  			inShape:   []int{2, 2, 2, 2},
   333  			inData:    []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
   334  			op:        Sum,
   335  			along:     []int{},
   336  			wantShape: []int{},
   337  			wantData:  float32(136),
   338  		},
   339  		{
   340  			dt:        Float32,
   341  			inShape:   []int{2, 2, 2, 2},
   342  			inData:    []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
   343  			op:        Sum,
   344  			along:     []int{0},
   345  			wantShape: []int{2, 2, 2},
   346  			wantData:  []float32{10, 12, 14, 16, 18, 20, 22, 24},
   347  		},
   348  		{
   349  			dt:        Float32,
   350  			inShape:   []int{2, 2, 2, 2},
   351  			inData:    []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
   352  			op:        Sum,
   353  			along:     []int{1},
   354  			wantShape: []int{2, 2, 2},
   355  			wantData:  []float32{6, 8, 10, 12, 22, 24, 26, 28},
   356  		},
   357  		{
   358  			dt:        Float32,
   359  			inShape:   []int{2, 2, 2, 2},
   360  			inData:    []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
   361  			op:        Sum,
   362  			along:     []int{2},
   363  			wantShape: []int{2, 2, 2},
   364  			wantData:  []float32{4, 6, 12, 14, 20, 22, 28, 30},
   365  		},
   366  		{
   367  			dt:        Float32,
   368  			inShape:   []int{2, 2, 2, 2},
   369  			inData:    []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
   370  			op:        Sum,
   371  			along:     []int{3},
   372  			wantShape: []int{2, 2, 2},
   373  			wantData:  []float32{3, 7, 11, 15, 19, 23, 27, 31},
   374  		},
   375  		{
   376  			dt:        Float32,
   377  			inShape:   []int{2, 2, 2, 2},
   378  			inData:    []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
   379  			op:        Sum,
   380  			along:     []int{1, 3},
   381  			wantShape: []int{2, 2},
   382  			wantData:  []float32{14, 22, 46, 54},
   383  		},
   384  		{
   385  			dt:        Float32,
   386  			inShape:   []int{2, 2, 2, 2},
   387  			inData:    []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
   388  			op:        Sum,
   389  			along:     []int{0, 2, 3},
   390  			wantShape: []int{2},
   391  			wantData:  []float32{52, 84},
   392  		},
   393  	}
   394  
   395  	for _, subTest := range subTests {
   396  		t.Run(fmt.Sprintf("along %v", subTest.along), func(t *testing.T) {
   397  			testReductionOp(t, subTest)
   398  		})
   399  	}
   400  }
   401  
   402  type reductionTest struct {
   403  	dt        tensor.Dtype
   404  	inShape   tensor.Shape
   405  	inData    interface{}
   406  	op        func(*Node, ...int) (*Node, error)
   407  	along     []int
   408  	wantShape tensor.Shape
   409  	wantData  interface{}
   410  }
   411  
   412  func testReductionOp(t *testing.T, test reductionTest) {
   413  	g := NewGraph()
   414  	Xn := NewTensor(g, test.dt, len(test.inShape), WithShape(test.inShape...))
   415  	got := Must(test.op(Xn, test.along...))
   416  
   417  	xT := tensor.New(tensor.WithShape(test.inShape...), tensor.WithBacking(test.inData))
   418  	vm := NewTapeMachine(g)
   419  	defer vm.Close()
   420  	vm.Let(Xn, xT)
   421  	err := vm.RunAll()
   422  	if err != nil {
   423  		t.Fatal(err)
   424  	}
   425  	assert := assert.New(t)
   426  	assert.Equal(test.wantShape, got.Value().Shape(), "shape mismatch")
   427  	assert.Equal(test.wantData, got.Value().Data(), "data mismatch")
   428  }
   429  
   430  func TestMaxOpGrad(t *testing.T) {
   431  	subTests := []reductionGradTest{
   432  		{
   433  			dt:           Float64,
   434  			inShape:      tensor.Shape{6},
   435  			inData:       []float64{1, 2, 3, 4, 5, 6},
   436  			op:           Max,
   437  			along:        []int{},
   438  			outGradShape: tensor.Shape{1},
   439  			outGrad:      []float64{1},
   440  			wantInGrad:   []float64{0, 0, 0, 0, 0, 1},
   441  		},
   442  		{
   443  			dt:           Float32,
   444  			inShape:      tensor.Shape{6},
   445  			inData:       []float32{1, 2, 3, 4, 5, 6},
   446  			op:           Max,
   447  			along:        []int{0},
   448  			outGradShape: tensor.Shape{1},
   449  			outGrad:      []float32{1},
   450  			wantInGrad:   []float32{0, 0, 0, 0, 0, 1},
   451  		},
   452  		{
   453  			dt:           Float32,
   454  			inShape:      tensor.Shape{6},
   455  			inData:       []float32{1, 2, 3, 4, 5, 6},
   456  			op:           Max,
   457  			along:        []int{},
   458  			outGradShape: tensor.Shape{1},
   459  			outGrad:      []float32{1},
   460  			wantInGrad:   []float32{0, 0, 0, 0, 0, 1},
   461  		},
   462  		{
   463  			dt:           Float32,
   464  			inShape:      tensor.Shape{3, 2},
   465  			inData:       []float32{1, 2, 3, 4, 5, 6},
   466  			op:           Max,
   467  			along:        []int{0},
   468  			outGradShape: tensor.Shape{2},
   469  			outGrad:      []float32{0.2, 0.8},
   470  			wantInGrad:   []float32{0, 0, 0, 0, 0.2, 0.8},
   471  		},
   472  		{
   473  			dt:           Float32,
   474  			inShape:      tensor.Shape{3, 2},
   475  			inData:       []float32{1, 2, 3, 4, 5, 6},
   476  			op:           Max,
   477  			along:        []int{1},
   478  			outGradShape: tensor.Shape{3},
   479  			outGrad:      []float32{0.1, 0.3, 0.6},
   480  			wantInGrad:   []float32{0, 0.1, 0, 0.3, 0, 0.6},
   481  		},
   482  		{
   483  			dt:           Float32,
   484  			inShape:      tensor.Shape{3, 2},
   485  			inData:       []float32{1, 2, 3, 4, 5, 6},
   486  			op:           Max,
   487  			along:        []int{0, 1},
   488  			outGradShape: tensor.Shape{1},
   489  			outGrad:      []float32{1},
   490  			wantInGrad:   []float32{0, 0, 0, 0, 0, 1},
   491  		},
   492  		//{
   493  		//	dt:           Float32,
   494  		//	inShape:      tensor.Shape{1, 6},
   495  		//	inData:       []float32{1, 2, 3, 4, 5, 6},
   496  		//	op:           Max,
   497  		//	along:        []int{1},
   498  		//	outGradShape: tensor.Shape{6},
   499  		//	outGrad:      []float32{1},
   500  		//	wantInGrad:   []float32{0, 0, 0, 0, 0, 1},
   501  		//},
   502  	}
   503  
   504  	for _, subTest := range subTests {
   505  		t.Run(fmt.Sprintf("%v along %v %v", subTest.inShape, subTest.along, subTest.dt), func(t *testing.T) {
   506  			testReductionOpGrad(t, subTest)
   507  		})
   508  	}
   509  }
   510  
   511  type reductionGradTest struct {
   512  	dt           tensor.Dtype
   513  	inShape      tensor.Shape
   514  	inData       interface{}
   515  	op           func(*Node, ...int) (*Node, error)
   516  	along        []int
   517  	outGradShape tensor.Shape
   518  	outGrad      interface{}
   519  	wantInGrad   interface{}
   520  }
   521  
   522  func testReductionOpGrad(t *testing.T, test reductionGradTest) {
   523  	assert := assert.New(t)
   524  
   525  	var xG Value
   526  	var err error
   527  
   528  	// Run op
   529  	g := NewGraph()
   530  	xN := NewTensor(g, test.dt, len(test.inShape), WithShape(test.inShape...))
   531  	y := Must(test.op(xN, test.along...))
   532  
   533  	outGrad := NewTensor(g, test.dt, len(test.outGradShape), WithValue(tensor.New(tensor.WithShape(test.outGradShape...), tensor.WithBacking(test.outGrad))))
   534  	if _, err = Backpropagate(Nodes{y}, Nodes{outGrad}, Nodes{xN}); err != nil {
   535  		t.Fatal(err)
   536  	}
   537  
   538  	xT := tensor.New(tensor.WithShape(test.inShape...), tensor.WithBacking(test.inData))
   539  	vm := NewTapeMachine(g)
   540  	defer vm.Close()
   541  	vm.Let(xN, xT)
   542  	if err = vm.RunAll(); err != nil {
   543  		t.Fatal(err)
   544  	}
   545  
   546  	// Test grad functions
   547  	diffWRT := y.diffWRT()
   548  	assert.Equal([]bool{true}, diffWRT)
   549  
   550  	if xG, err = xN.Grad(); err != nil {
   551  		t.Fatal(err)
   552  	}
   553  	assert.Equal(test.inShape, xG.Shape(), "grad shape mismatch")
   554  	assert.Equal(test.wantInGrad, xG.Data(), "grad data mismatch")
   555  }
   556  
   557  // TestFollowupOp confirms that an element-wise binary op will work as expected after a sum/max.
   558  // The underlying reduction on the tensor changes the number of dimensions, but the gorgonia node does not.
   559  // We therefore confirm that the resulting nodes actually work.
   560  func TestFollowupOp(t *testing.T) {
   561  	g := NewGraph()
   562  	Xn := NewTensor(g, tensor.Float64, 4, WithShape(2, 2, 2, 2), WithInit(RangedFrom(1)))
   563  	mx := Must(Max(Xn, 1, 2))
   564  	sx := Must(Sum(Xn, 1, 2))
   565  	y := NewTensor(g, tensor.Float64, 2, WithShape(2, 2), WithInit(RangedFrom(1)))
   566  
   567  	amx := Must(Add(mx, y))
   568  	asx := Must(Add(sx, y))
   569  	assert.Equal(t, amx.Shape(), tensor.Shape{2, 2})
   570  	assert.Equal(t, asx.Shape(), tensor.Shape{2, 2})
   571  	vm := NewTapeMachine(g)
   572  	defer vm.Close()
   573  	err := vm.RunAll()
   574  	if err != nil {
   575  		t.Error(err)
   576  	}
   577  	assert.Equal(t, []float64{8, 10, 18, 20}, amx.Value().Data(), "data mismatch")
   578  	assert.Equal(t, []float64{17, 22, 51, 56}, asx.Value().Data(), "data mismatch")
   579  }