gorgonia.org/gorgonia@v0.9.17/op_math_test.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"log"
     5  	"runtime"
     6  	"testing"
     7  
     8  	"github.com/pkg/errors"
     9  	"github.com/stretchr/testify/assert"
    10  	"gorgonia.org/tensor"
    11  )
    12  
    13  type binOpTest struct {
    14  	binOp func(*Node, *Node) (*Node, error)
    15  	a, b  Value
    16  
    17  	correct       Value
    18  	correctDerivA Value
    19  	correctDerivB Value
    20  	correctShape  tensor.Shape
    21  }
    22  
    23  var binOpTests = []binOpTest{
    24  
    25  	{Add,
    26  		tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4})),
    27  		tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4})),
    28  
    29  		tensor.New(tensor.WithBacking([]float64{2, 4, 6, 8})),
    30  		tensor.New(tensor.WithBacking([]float64{1, 1, 1, 1})),
    31  		tensor.New(tensor.WithBacking([]float64{1, 1, 1, 1})),
    32  		tensor.Shape{4},
    33  	},
    34  
    35  	{Add,
    36  		tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4})),
    37  		NewF64(1.0),
    38  
    39  		tensor.New(tensor.WithBacking([]float64{2, 3, 4, 5})),
    40  		tensor.New(tensor.WithBacking([]float64{1, 1, 1, 1})),
    41  		NewF64(4.0),
    42  		tensor.Shape{4},
    43  	},
    44  
    45  	{Add,
    46  		NewF64(1.0),
    47  		tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4})),
    48  
    49  		tensor.New(tensor.WithBacking([]float64{2, 3, 4, 5})),
    50  		NewF64(4.0),
    51  		tensor.New(tensor.WithBacking([]float64{1, 1, 1, 1})),
    52  		tensor.Shape{4},
    53  	},
    54  
    55  	{Add,
    56  		NewF64(1.0),
    57  		NewF64(1.0),
    58  
    59  		NewF64(2.0),
    60  		NewF64(1.0),
    61  		NewF64(1.0),
    62  		scalarShape,
    63  	},
    64  
    65  	{Sub,
    66  		tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4})),
    67  		tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4})),
    68  
    69  		tensor.New(tensor.WithBacking([]float64{0, 0, 0, 0})),
    70  		tensor.New(tensor.WithBacking([]float64{1, 1, 1, 1})),
    71  		tensor.New(tensor.WithBacking([]float64{-1, -1, -1, -1})),
    72  		tensor.Shape{4},
    73  	},
    74  
    75  	{Sub,
    76  		tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4})),
    77  		NewF64(1.0),
    78  
    79  		tensor.New(tensor.WithBacking([]float64{0, 1, 2, 3})),
    80  		tensor.New(tensor.WithBacking([]float64{1, 1, 1, 1})),
    81  		NewF64(-4.0),
    82  		tensor.Shape{4},
    83  	},
    84  
    85  	{Sub,
    86  		NewF64(1.0),
    87  		tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4})),
    88  
    89  		tensor.New(tensor.WithBacking([]float64{0, -1, -2, -3})),
    90  		NewF64(4.0),
    91  		tensor.New(tensor.WithBacking([]float64{-1, -1, -1, -1})),
    92  		tensor.Shape{4},
    93  	},
    94  
    95  	{Sub,
    96  		NewF64(1.0),
    97  		NewF64(1.0),
    98  
    99  		NewF64(0.0),
   100  		NewF64(1.0),
   101  		NewF64(-1.0),
   102  		scalarShape,
   103  	},
   104  
   105  	{HadamardProd,
   106  		tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4})),
   107  		tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4})),
   108  
   109  		tensor.New(tensor.WithBacking([]float64{1, 4, 9, 16})),
   110  		tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4})),
   111  		tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4})),
   112  		tensor.Shape{4},
   113  	},
   114  
   115  	{Mul,
   116  		tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4})),
   117  		NewF64(1.0),
   118  
   119  		tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4})),
   120  		tensor.New(tensor.WithBacking([]float64{1, 1, 1, 1})),
   121  		NewF64(10),
   122  		tensor.Shape{4},
   123  	},
   124  
   125  	{Mul,
   126  		NewF64(1.0),
   127  		tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4})),
   128  
   129  		tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4})),
   130  		NewF64(10),
   131  		tensor.New(tensor.WithBacking([]float64{1, 1, 1, 1})),
   132  		tensor.Shape{4},
   133  	},
   134  
   135  	{Mul,
   136  		NewF64(1.0),
   137  		NewF64(1.0),
   138  
   139  		NewF64(1.0),
   140  		NewF64(1.0),
   141  		NewF64(1.0),
   142  		scalarShape,
   143  	},
   144  
   145  	{HadamardDiv,
   146  		tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4})),
   147  		tensor.New(tensor.WithBacking([]float64{1, 1, 1, 1})),
   148  
   149  		tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4})),
   150  		tensor.New(tensor.WithBacking([]float64{1, 1, 1, 1})),
   151  		tensor.New(tensor.WithBacking([]float64{-1, -2, -3, -4})),
   152  		tensor.Shape{4},
   153  	},
   154  
   155  	{Div,
   156  		tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4})),
   157  		NewF64(1.0),
   158  
   159  		tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4})),
   160  		tensor.New(tensor.WithBacking([]float64{1, 1, 1, 1})),
   161  		NewF64(-10),
   162  		tensor.Shape{4},
   163  	},
   164  
   165  	{Div,
   166  		NewF64(1),
   167  		tensor.New(tensor.WithBacking([]float64{1, 1, 1, 1})),
   168  
   169  		tensor.New(tensor.WithBacking([]float64{1, 1, 1, 1})),
   170  		NewF64(4),
   171  		tensor.New(tensor.WithBacking([]float64{-1, -1, -1, -1})),
   172  		tensor.Shape{4},
   173  	},
   174  
   175  	{Div,
   176  		NewF64(1.0),
   177  		NewF64(1.0),
   178  
   179  		NewF64(1.0),
   180  		NewF64(1.0),
   181  		NewF64(-1.0),
   182  		scalarShape,
   183  	},
   184  
   185  	// Float32
   186  
   187  	{Add,
   188  		tensor.New(tensor.WithBacking([]float32{1, 2, 3, 4})),
   189  		tensor.New(tensor.WithBacking([]float32{1, 2, 3, 4})),
   190  
   191  		tensor.New(tensor.WithBacking([]float32{2, 4, 6, 8})),
   192  		tensor.New(tensor.WithBacking([]float32{1, 1, 1, 1})),
   193  		tensor.New(tensor.WithBacking([]float32{1, 1, 1, 1})),
   194  		tensor.Shape{4},
   195  	},
   196  
   197  	{Add,
   198  		tensor.New(tensor.WithBacking([]float32{1, 2, 3, 4})),
   199  		NewF32(1.0),
   200  
   201  		tensor.New(tensor.WithBacking([]float32{2, 3, 4, 5})),
   202  		tensor.New(tensor.WithBacking([]float32{1, 1, 1, 1})),
   203  		NewF32(4.0),
   204  		tensor.Shape{4},
   205  	},
   206  
   207  	{Add,
   208  		NewF32(1.0),
   209  		tensor.New(tensor.WithBacking([]float32{1, 2, 3, 4})),
   210  
   211  		tensor.New(tensor.WithBacking([]float32{2, 3, 4, 5})),
   212  		NewF32(4.0),
   213  		tensor.New(tensor.WithBacking([]float32{1, 1, 1, 1})),
   214  		tensor.Shape{4},
   215  	},
   216  
   217  	{Add,
   218  		NewF32(1.0),
   219  		NewF32(1.0),
   220  
   221  		NewF32(2.0),
   222  		NewF32(1.0),
   223  		NewF32(1.0),
   224  		scalarShape,
   225  	},
   226  
   227  	{Sub,
   228  		tensor.New(tensor.WithBacking([]float32{1, 2, 3, 4})),
   229  		tensor.New(tensor.WithBacking([]float32{1, 2, 3, 4})),
   230  
   231  		tensor.New(tensor.WithBacking([]float32{0, 0, 0, 0})),
   232  		tensor.New(tensor.WithBacking([]float32{1, 1, 1, 1})),
   233  		tensor.New(tensor.WithBacking([]float32{-1, -1, -1, -1})),
   234  		tensor.Shape{4},
   235  	},
   236  
   237  	{Sub,
   238  		tensor.New(tensor.WithBacking([]float32{1, 2, 3, 4})),
   239  		NewF32(1.0),
   240  
   241  		tensor.New(tensor.WithBacking([]float32{0, 1, 2, 3})),
   242  		tensor.New(tensor.WithBacking([]float32{1, 1, 1, 1})),
   243  		NewF32(-4.0),
   244  		tensor.Shape{4},
   245  	},
   246  
   247  	{Sub,
   248  		NewF32(1.0),
   249  		tensor.New(tensor.WithBacking([]float32{1, 2, 3, 4})),
   250  
   251  		tensor.New(tensor.WithBacking([]float32{0, -1, -2, -3})),
   252  		NewF32(4.0),
   253  		tensor.New(tensor.WithBacking([]float32{-1, -1, -1, -1})),
   254  		tensor.Shape{4},
   255  	},
   256  
   257  	{Sub,
   258  		NewF32(1.0),
   259  		NewF32(1.0),
   260  
   261  		NewF32(0.0),
   262  		NewF32(1.0),
   263  		NewF32(-1.0),
   264  		scalarShape,
   265  	},
   266  
   267  	{HadamardProd,
   268  		tensor.New(tensor.WithBacking([]float32{1, 2, 3, 4})),
   269  		tensor.New(tensor.WithBacking([]float32{1, 2, 3, 4})),
   270  
   271  		tensor.New(tensor.WithBacking([]float32{1, 4, 9, 16})),
   272  		tensor.New(tensor.WithBacking([]float32{1, 2, 3, 4})),
   273  		tensor.New(tensor.WithBacking([]float32{1, 2, 3, 4})),
   274  		tensor.Shape{4},
   275  	},
   276  
   277  	{Mul,
   278  		tensor.New(tensor.WithBacking([]float32{1, 2, 3, 4})),
   279  		NewF32(1.0),
   280  
   281  		tensor.New(tensor.WithBacking([]float32{1, 2, 3, 4})),
   282  		tensor.New(tensor.WithBacking([]float32{1, 1, 1, 1})),
   283  		NewF32(10),
   284  		tensor.Shape{4},
   285  	},
   286  
   287  	{Mul,
   288  		NewF32(1.0),
   289  		tensor.New(tensor.WithBacking([]float32{1, 2, 3, 4})),
   290  
   291  		tensor.New(tensor.WithBacking([]float32{1, 2, 3, 4})),
   292  		NewF32(10),
   293  		tensor.New(tensor.WithBacking([]float32{1, 1, 1, 1})),
   294  		tensor.Shape{4},
   295  	},
   296  
   297  	{Mul,
   298  		NewF32(1.0),
   299  		NewF32(1.0),
   300  
   301  		NewF32(1.0),
   302  		NewF32(1.0),
   303  		NewF32(1.0),
   304  		scalarShape,
   305  	},
   306  
   307  	{HadamardDiv,
   308  		tensor.New(tensor.WithBacking([]float32{1, 2, 3, 4})),
   309  		tensor.New(tensor.WithBacking([]float32{1, 1, 1, 1})),
   310  
   311  		tensor.New(tensor.WithBacking([]float32{1, 2, 3, 4})),
   312  		tensor.New(tensor.WithBacking([]float32{1, 1, 1, 1})),
   313  		tensor.New(tensor.WithBacking([]float32{-1, -2, -3, -4})),
   314  		tensor.Shape{4},
   315  	},
   316  
   317  	{Div,
   318  		tensor.New(tensor.WithBacking([]float32{1, 2, 3, 4})),
   319  		NewF32(1.0),
   320  
   321  		tensor.New(tensor.WithBacking([]float32{1, 2, 3, 4})),
   322  		tensor.New(tensor.WithBacking([]float32{1, 1, 1, 1})),
   323  		NewF32(-10),
   324  		tensor.Shape{4},
   325  	},
   326  
   327  	{Div,
   328  		NewF32(1),
   329  		tensor.New(tensor.WithBacking([]float32{1, 1, 1, 1})),
   330  
   331  		tensor.New(tensor.WithBacking([]float32{1, 1, 1, 1})),
   332  		NewF32(4),
   333  		tensor.New(tensor.WithBacking([]float32{-1, -1, -1, -1})),
   334  		tensor.Shape{4},
   335  	},
   336  
   337  	{Div,
   338  		NewF32(1.0),
   339  		NewF32(1.0),
   340  
   341  		NewF32(1.0),
   342  		NewF32(1.0),
   343  		NewF32(-1.0),
   344  		scalarShape,
   345  	},
   346  
   347  	{
   348  		func(a *Node, b *Node) (*Node, error) {
   349  			return BatchedMatMul(a, b, false, false)
   350  		},
   351  		tensor.New(tensor.WithShape(2, 3, 4), tensor.WithBacking([]float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12})),
   352  		tensor.New(tensor.WithShape(2, 4, 1), tensor.WithBacking([]float64{1, 2, 3, 4, 1, 2, 3, 4})),
   353  
   354  		tensor.New(tensor.WithBacking([]float64{30, 70, 110, 30, 70, 110})),
   355  		tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4})),
   356  		tensor.New(tensor.WithBacking([]float64{15, 18, 21, 24, 15, 18, 21, 24})),
   357  		tensor.Shape{2, 3, 1},
   358  	},
   359  }
   360  
   361  func TestBasicArithmetic(t *testing.T) {
   362  	for i, bot := range binOpTests {
   363  		if err := testOneArithTape(t, bot, i); err != nil {
   364  			t.Fatalf("Test %d, Err: %+v", i, err)
   365  		}
   366  		runtime.GC()
   367  	}
   368  
   369  	for i, bot := range binOpTests {
   370  		// log.Printf("Test %d", i)
   371  		if err := testOneArithLisp(t, bot, i); err != nil {
   372  			t.Fatalf("Test %d, Err: %+v", i, err)
   373  		}
   374  		runtime.GC()
   375  	}
   376  }
   377  
   378  func testOneArithLisp(t *testing.T, bot binOpTest, i int) error {
   379  	g := NewGraph()
   380  	xV, _ := CloneValue(bot.a)
   381  	yV, _ := CloneValue(bot.b)
   382  	x := NodeFromAny(g, xV, WithName("x"))
   383  	y := NodeFromAny(g, yV, WithName("y"))
   384  
   385  	var ret *Node
   386  	var retVal Value
   387  	var err error
   388  	if ret, err = bot.binOp(x, y); err != nil {
   389  		return errors.Wrapf(err, "do binop failure")
   390  	}
   391  	Read(ret, &retVal)
   392  
   393  	if !(xV.Shape().IsScalar() && yV.Shape().IsScalar()) {
   394  		Must(Sum(ret))
   395  	}
   396  	m1 := NewLispMachine(g)
   397  	defer m1.Close()
   398  	if err = m1.RunAll(); err != nil {
   399  		return errors.Wrapf(err, "Error while running")
   400  	}
   401  
   402  	as := newAssertState(assert.New(t))
   403  	as.Equal(bot.correct.Data(), retVal.Data(), "Test %d result", i)
   404  	as.True(bot.correctShape.Eq(ret.Shape()))
   405  
   406  	var xG, yG Value
   407  	if xG, err = x.Grad(); err != nil {
   408  		return errors.Wrapf(err, "Failed to get the grad of x")
   409  	}
   410  
   411  	if yG, err = y.Grad(); err != nil {
   412  		return errors.Wrapf(err, "Failed to get the grad of y")
   413  	}
   414  
   415  	as.Equal(bot.correctDerivA.Data(), xG.Data(), "Test %v xgrad", i)
   416  	as.Equal(bot.correctDerivB.Data(), yG.Data(), "Test %v ygrad. Expected %v. Got %v", i, bot.correctDerivB, yG)
   417  	if !as.cont {
   418  		t.Errorf("an error occurred")
   419  	}
   420  
   421  	if assertGraphEngine(t, g, stdengType); t.Failed() {
   422  		return errors.New("Lisp Machine Graph Engine expected")
   423  	}
   424  	return nil
   425  }
   426  
   427  func testOneArithTape(t *testing.T, bot binOpTest, i int) error {
   428  	g := NewGraph()
   429  	xV, _ := CloneValue(bot.a)
   430  	yV, _ := CloneValue(bot.b)
   431  	x := NodeFromAny(g, xV, WithName("x"))
   432  	y := NodeFromAny(g, yV, WithName("y"))
   433  
   434  	var ret *Node
   435  	var retVal Value
   436  	var err error
   437  	if ret, err = bot.binOp(x, y); err != nil {
   438  		return errors.Wrapf(err, "binOp() failed")
   439  	}
   440  	Read(ret, &retVal)
   441  
   442  	cost := Must(Sum(ret))
   443  	var grads Nodes
   444  	if grads, err = Grad(cost, x, y); err != nil {
   445  		return errors.Wrapf(err, "Grad failed")
   446  	}
   447  
   448  	m1 := NewTapeMachine(g)
   449  	defer m1.Close()
   450  	if err = m1.RunAll(); err != nil {
   451  		t.Logf("%v", m1.Prog())
   452  		return errors.Wrapf(err, "Error while running")
   453  	}
   454  
   455  	as := newAssertState(assert.New(t))
   456  	as.True(bot.a.Shape().Eq(x.Shape()), "Test op doesn't change shape of input node")
   457  	as.True(bot.b.Shape().Eq(y.Shape()), "Test op doesn't change shape of input node")
   458  	as.Equal(bot.correct.Data(), retVal.Data(), "Test %d result", i)
   459  	as.True(bot.correctShape.Eq(ret.Shape()))
   460  	as.Equal(2, len(grads))
   461  	as.Equal(bot.correctDerivA.Data(), grads[0].Value().Data(), "Test %v xgrad", i)
   462  	as.Equal(bot.correctDerivB.Data(), grads[1].Value().Data(), "Test %v ygrad. Expected %v. Got %v", i, bot.correctDerivB, grads[1].Value())
   463  	if !as.cont {
   464  		prog := m1.Prog()
   465  		return errors.Errorf("Failed. Prog %v", prog)
   466  	}
   467  
   468  	if assertGraphEngine(t, g, stdengType); t.Failed() {
   469  		return errors.Errorf("BasicArithmetic. Engine of Graph is not stdengType.")
   470  	}
   471  	return nil
   472  }
   473  
   474  func TestTensordotOpDoDiff(t *testing.T) {
   475  	assert := assert.New(t)
   476  
   477  	// Vectors
   478  	g := NewGraph()
   479  	a := NewTensor(g, Float64, 1, WithName("a"), WithShape(1))
   480  	b := NewTensor(g, Float64, 1, WithName("b"), WithShape(1))
   481  
   482  	tensordot := tensordotOp{
   483  		aAxes:   []int{0},
   484  		bAxes:   []int{0},
   485  		aDims:   0,
   486  		bDims:   0,
   487  		retDims: 0,
   488  	}
   489  
   490  	c, err := ApplyOp(tensordot, a, b)
   491  
   492  	if err != nil {
   493  		log.Fatalf("scalars: Cannot ApplyOp: %+v", err)
   494  		return
   495  	}
   496  
   497  	aT := tensor.New(tensor.WithShape(), tensor.WithBacking([]float64{2}))
   498  	bT := tensor.New(tensor.WithShape(), tensor.WithBacking([]float64{21}))
   499  	cT := tensor.New(tensor.WithShape(), tensor.WithBacking([]float64{1})) // Backing doesn't matter as long as it is set
   500  
   501  	aVal, _, _, _ := anyToValue(aT)
   502  	bVal, _, _, _ := anyToValue(bT)
   503  	cVal, _, _, _ := anyToValue(cT)
   504  
   505  	a.bind(dvUnit(aVal))
   506  	b.bind(dvUnit(bVal))
   507  	c.bind(dvUnitVar(cVal)) // Will set Output derivative to all ones
   508  
   509  	if err := tensordot.DoDiff(ExecutionContext{}, Nodes{a, b}, c); err != nil {
   510  		t.Fatalf("scalars: Cannot DoDiff: %+v", err)
   511  	}
   512  
   513  	aG, _ := a.Grad()
   514  	aGfloat := aG.Data()
   515  
   516  	bG, _ := b.Grad()
   517  	bGfloat := bG.Data()
   518  
   519  	aGcorrect := 21.0
   520  	bGcorrect := 2.0
   521  
   522  	assert.Equal(aGcorrect, aGfloat)
   523  	assert.Equal(bGcorrect, bGfloat)
   524  
   525  	// Vectors
   526  
   527  	g = NewGraph()
   528  	a = NewTensor(g, Float64, 1, WithName("a"), WithShape(2))
   529  	b = NewTensor(g, Float64, 1, WithName("b"), WithShape(2))
   530  
   531  	tensordot = tensordotOp{
   532  		aAxes:   []int{0},
   533  		bAxes:   []int{0},
   534  		aDims:   1,
   535  		bDims:   1,
   536  		retDims: 1,
   537  	}
   538  
   539  	if c, err = ApplyOp(tensordot, a, b); err != nil {
   540  		log.Fatal("vectors: Cannot ApplyOp:", err)
   541  		return
   542  	}
   543  
   544  	aT = tensor.New(tensor.WithShape(2), tensor.WithBacking([]float64{1, 2}))
   545  	bT = tensor.New(tensor.WithShape(2), tensor.WithBacking([]float64{3, 4}))
   546  	cT = tensor.New(tensor.WithShape(1), tensor.WithBacking([]float64{1})) // Backing doesn't matter as long as it is set
   547  
   548  	aVal, _, _, _ = anyToValue(aT)
   549  	bVal, _, _, _ = anyToValue(bT)
   550  	cVal, _, _, _ = anyToValue(cT)
   551  
   552  	a.bind(dvUnit(aVal))
   553  	b.bind(dvUnit(bVal))
   554  	c.bind(dvUnitVar(cVal)) // Will set Output derivative to all ones
   555  
   556  	if err := tensordot.DoDiff(ExecutionContext{}, Nodes{a, b}, c); err != nil {
   557  		log.Fatal("vectors: Cannot DoDiff:", err)
   558  		return
   559  	}
   560  
   561  	aG, _ = a.Grad()
   562  	bG, _ = b.Grad()
   563  
   564  	aGfloats := extractF64s(aG)
   565  	bGfloats := extractF64s(bG)
   566  
   567  	aGcorrectFloats := []float64{3, 4}
   568  	bGcorrectFloats := []float64{1, 2}
   569  
   570  	assert.Equal(aGcorrectFloats, aGfloats)
   571  	assert.Equal(bGcorrectFloats, bGfloats)
   572  
   573  	// Matrix and Vector
   574  
   575  	g = NewGraph()
   576  	a = NewTensor(g, Float64, 2, WithName("a"), WithShape(2, 2))
   577  	b = NewTensor(g, Float64, 1, WithName("b"), WithShape(2))
   578  
   579  	tensordot = tensordotOp{
   580  		aAxes:   []int{1},
   581  		bAxes:   []int{0},
   582  		aDims:   2,
   583  		bDims:   1,
   584  		retDims: 1,
   585  	}
   586  
   587  	if c, err = ApplyOp(tensordot, a, b); err != nil {
   588  		log.Fatal("matrix vector: Cannot ApplyOp:", err)
   589  		return
   590  	}
   591  
   592  	aT = tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4}))
   593  	bT = tensor.New(tensor.WithShape(2), tensor.WithBacking([]float64{1, 2}))
   594  	cT = tensor.New(tensor.WithShape(2), tensor.WithBacking([]float64{1, 1})) // Backing doesn't matter as long as it is set
   595  
   596  	aVal, _, _, _ = anyToValue(aT)
   597  	bVal, _, _, _ = anyToValue(bT)
   598  	cVal, _, _, _ = anyToValue(cT)
   599  
   600  	a.bind(dvUnit(aVal))
   601  	b.bind(dvUnit(bVal))
   602  	c.bind(dvUnitVar(cVal)) // Will set Output derivative to all ones
   603  
   604  	if err := tensordot.DoDiff(ExecutionContext{}, Nodes{a, b}, c); err != nil {
   605  		log.Fatal("matrix vector: Cannot DoDiff:", err)
   606  		return
   607  	}
   608  
   609  	aG, _ = a.Grad()
   610  	bG, _ = b.Grad()
   611  
   612  	aGfloats = extractF64s(aG)
   613  	bGfloats = extractF64s(bG)
   614  
   615  	aGcorrectFloats = []float64{1, 2, 1, 2}
   616  	bGcorrectFloats = []float64{4, 6}
   617  
   618  	assert.Equal(aGcorrectFloats, aGfloats)
   619  	assert.Equal(bGcorrectFloats, bGfloats)
   620  
   621  	// Matrix multiplication
   622  
   623  	g = NewGraph()
   624  
   625  	a = NewTensor(g, Float64, 2, WithName("a"), WithShape(2, 2))
   626  	b = NewTensor(g, Float64, 2, WithName("b"), WithShape(2, 2))
   627  
   628  	tensordot = tensordotOp{
   629  		aAxes:   []int{1},
   630  		bAxes:   []int{0},
   631  		aDims:   2,
   632  		bDims:   2,
   633  		retDims: 2,
   634  	}
   635  
   636  	if c, err = ApplyOp(tensordot, a, b); err != nil {
   637  		log.Fatal("matrices: Cannot ApplyOp:", err)
   638  		return
   639  	}
   640  
   641  	aT = tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4}))
   642  	bT = tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4}))
   643  	cT = tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 1, 1, 1})) // Backing doesn't matter as long as it is set
   644  
   645  	aVal, _, _, _ = anyToValue(aT)
   646  	bVal, _, _, _ = anyToValue(bT)
   647  	cVal, _, _, _ = anyToValue(cT)
   648  
   649  	a.bind(dvUnit(aVal))
   650  	b.bind(dvUnit(bVal))
   651  	c.bind(dvUnitVar(cVal)) // Will set Output derivative to all ones
   652  
   653  	if err := tensordot.DoDiff(ExecutionContext{}, Nodes{a, b}, c); err != nil {
   654  		log.Fatal("matrices: Cannot DoDiff:", err)
   655  		return
   656  	}
   657  
   658  	aG, _ = a.Grad()
   659  	bG, _ = b.Grad()
   660  
   661  	aGfloats = extractF64s(aG)
   662  	bGfloats = extractF64s(bG)
   663  
   664  	aGcorrectFloats = []float64{3, 7, 3, 7}
   665  	bGcorrectFloats = []float64{4, 4, 6, 6}
   666  
   667  	assert.Equal(aGcorrectFloats, aGfloats)
   668  	assert.Equal(bGcorrectFloats, bGfloats)
   669  
   670  	// Total matrix contraction
   671  
   672  	g = NewGraph()
   673  
   674  	a = NewTensor(g, Float64, 2, WithName("a"), WithShape(2, 2))
   675  	b = NewTensor(g, Float64, 2, WithName("b"), WithShape(2, 2))
   676  
   677  	tensordot = tensordotOp{
   678  		aAxes:   []int{1, 0},
   679  		bAxes:   []int{0, 1},
   680  		aDims:   2,
   681  		bDims:   2,
   682  		retDims: 1,
   683  	}
   684  
   685  	if c, err = ApplyOp(tensordot, a, b); err != nil {
   686  		log.Fatal("matrices total contraction: Cannot ApplyOp:", err)
   687  		return
   688  	}
   689  
   690  	aT = tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4}))
   691  	bT = tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{5, 6, 7, 8}))
   692  	cT = tensor.New(tensor.WithShape(1), tensor.WithBacking([]float64{1})) // Backing doesn't matter as long as it is set
   693  
   694  	aVal, _, _, _ = anyToValue(aT)
   695  	bVal, _, _, _ = anyToValue(bT)
   696  	cVal, _, _, _ = anyToValue(cT)
   697  
   698  	a.bind(dvUnit(aVal))
   699  	b.bind(dvUnit(bVal))
   700  	c.bind(dvUnitVar(cVal)) // Will set Output derivative to all ones
   701  
   702  	if err := tensordot.DoDiff(ExecutionContext{}, Nodes{a, b}, c); err != nil {
   703  		log.Fatal("matrices total contraction: Cannot DoDiff:", err)
   704  		return
   705  	}
   706  
   707  	aG, _ = a.Grad()
   708  	bG, _ = b.Grad()
   709  
   710  	aGfloats = extractF64s(aG)
   711  	bGfloats = extractF64s(bG)
   712  
   713  	aGcorrectFloats = []float64{5, 7, 6, 8}
   714  	bGcorrectFloats = []float64{1, 3, 2, 4}
   715  
   716  	assert.Equal(aGcorrectFloats, aGfloats)
   717  	assert.Equal(bGcorrectFloats, bGfloats)
   718  
   719  }
   720  
   721  func TestLinearAlgebraOps(t *testing.T) {
   722  	g := NewGraph()
   723  	x := NewMatrix(g, Float64, WithShape(2, 3), WithName("x"))
   724  	y := NewMatrix(g, Float64, WithShape(3, 5), WithName("y"))
   725  	if _, err := Mul(x, y); err != nil {
   726  		t.Fatal(err)
   727  	}
   728  
   729  	if _, err := Mul(y, x); err == nil {
   730  		t.Error("Expect an error")
   731  	}
   732  }