gorgonia.org/gorgonia@v0.9.17/op_math_test.go (about) 1 package gorgonia 2 3 import ( 4 "log" 5 "runtime" 6 "testing" 7 8 "github.com/pkg/errors" 9 "github.com/stretchr/testify/assert" 10 "gorgonia.org/tensor" 11 ) 12 13 type binOpTest struct { 14 binOp func(*Node, *Node) (*Node, error) 15 a, b Value 16 17 correct Value 18 correctDerivA Value 19 correctDerivB Value 20 correctShape tensor.Shape 21 } 22 23 var binOpTests = []binOpTest{ 24 25 {Add, 26 tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4})), 27 tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4})), 28 29 tensor.New(tensor.WithBacking([]float64{2, 4, 6, 8})), 30 tensor.New(tensor.WithBacking([]float64{1, 1, 1, 1})), 31 tensor.New(tensor.WithBacking([]float64{1, 1, 1, 1})), 32 tensor.Shape{4}, 33 }, 34 35 {Add, 36 tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4})), 37 NewF64(1.0), 38 39 tensor.New(tensor.WithBacking([]float64{2, 3, 4, 5})), 40 tensor.New(tensor.WithBacking([]float64{1, 1, 1, 1})), 41 NewF64(4.0), 42 tensor.Shape{4}, 43 }, 44 45 {Add, 46 NewF64(1.0), 47 tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4})), 48 49 tensor.New(tensor.WithBacking([]float64{2, 3, 4, 5})), 50 NewF64(4.0), 51 tensor.New(tensor.WithBacking([]float64{1, 1, 1, 1})), 52 tensor.Shape{4}, 53 }, 54 55 {Add, 56 NewF64(1.0), 57 NewF64(1.0), 58 59 NewF64(2.0), 60 NewF64(1.0), 61 NewF64(1.0), 62 scalarShape, 63 }, 64 65 {Sub, 66 tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4})), 67 tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4})), 68 69 tensor.New(tensor.WithBacking([]float64{0, 0, 0, 0})), 70 tensor.New(tensor.WithBacking([]float64{1, 1, 1, 1})), 71 tensor.New(tensor.WithBacking([]float64{-1, -1, -1, -1})), 72 tensor.Shape{4}, 73 }, 74 75 {Sub, 76 tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4})), 77 NewF64(1.0), 78 79 tensor.New(tensor.WithBacking([]float64{0, 1, 2, 3})), 80 tensor.New(tensor.WithBacking([]float64{1, 1, 1, 1})), 81 NewF64(-4.0), 82 tensor.Shape{4}, 83 }, 84 85 {Sub, 86 NewF64(1.0), 87 tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4})), 88 89 tensor.New(tensor.WithBacking([]float64{0, -1, -2, -3})), 90 NewF64(4.0), 91 tensor.New(tensor.WithBacking([]float64{-1, -1, -1, -1})), 92 tensor.Shape{4}, 93 }, 94 95 {Sub, 96 NewF64(1.0), 97 NewF64(1.0), 98 99 NewF64(0.0), 100 NewF64(1.0), 101 NewF64(-1.0), 102 scalarShape, 103 }, 104 105 {HadamardProd, 106 tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4})), 107 tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4})), 108 109 tensor.New(tensor.WithBacking([]float64{1, 4, 9, 16})), 110 tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4})), 111 tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4})), 112 tensor.Shape{4}, 113 }, 114 115 {Mul, 116 tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4})), 117 NewF64(1.0), 118 119 tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4})), 120 tensor.New(tensor.WithBacking([]float64{1, 1, 1, 1})), 121 NewF64(10), 122 tensor.Shape{4}, 123 }, 124 125 {Mul, 126 NewF64(1.0), 127 tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4})), 128 129 tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4})), 130 NewF64(10), 131 tensor.New(tensor.WithBacking([]float64{1, 1, 1, 1})), 132 tensor.Shape{4}, 133 }, 134 135 {Mul, 136 NewF64(1.0), 137 NewF64(1.0), 138 139 NewF64(1.0), 140 NewF64(1.0), 141 NewF64(1.0), 142 scalarShape, 143 }, 144 145 {HadamardDiv, 146 tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4})), 147 tensor.New(tensor.WithBacking([]float64{1, 1, 1, 1})), 148 149 tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4})), 150 tensor.New(tensor.WithBacking([]float64{1, 1, 1, 1})), 151 tensor.New(tensor.WithBacking([]float64{-1, -2, -3, -4})), 152 tensor.Shape{4}, 153 }, 154 155 {Div, 156 tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4})), 157 NewF64(1.0), 158 159 tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4})), 160 tensor.New(tensor.WithBacking([]float64{1, 1, 1, 1})), 161 NewF64(-10), 162 tensor.Shape{4}, 163 }, 164 165 {Div, 166 NewF64(1), 167 tensor.New(tensor.WithBacking([]float64{1, 1, 1, 1})), 168 169 tensor.New(tensor.WithBacking([]float64{1, 1, 1, 1})), 170 NewF64(4), 171 tensor.New(tensor.WithBacking([]float64{-1, -1, -1, -1})), 172 tensor.Shape{4}, 173 }, 174 175 {Div, 176 NewF64(1.0), 177 NewF64(1.0), 178 179 NewF64(1.0), 180 NewF64(1.0), 181 NewF64(-1.0), 182 scalarShape, 183 }, 184 185 // Float32 186 187 {Add, 188 tensor.New(tensor.WithBacking([]float32{1, 2, 3, 4})), 189 tensor.New(tensor.WithBacking([]float32{1, 2, 3, 4})), 190 191 tensor.New(tensor.WithBacking([]float32{2, 4, 6, 8})), 192 tensor.New(tensor.WithBacking([]float32{1, 1, 1, 1})), 193 tensor.New(tensor.WithBacking([]float32{1, 1, 1, 1})), 194 tensor.Shape{4}, 195 }, 196 197 {Add, 198 tensor.New(tensor.WithBacking([]float32{1, 2, 3, 4})), 199 NewF32(1.0), 200 201 tensor.New(tensor.WithBacking([]float32{2, 3, 4, 5})), 202 tensor.New(tensor.WithBacking([]float32{1, 1, 1, 1})), 203 NewF32(4.0), 204 tensor.Shape{4}, 205 }, 206 207 {Add, 208 NewF32(1.0), 209 tensor.New(tensor.WithBacking([]float32{1, 2, 3, 4})), 210 211 tensor.New(tensor.WithBacking([]float32{2, 3, 4, 5})), 212 NewF32(4.0), 213 tensor.New(tensor.WithBacking([]float32{1, 1, 1, 1})), 214 tensor.Shape{4}, 215 }, 216 217 {Add, 218 NewF32(1.0), 219 NewF32(1.0), 220 221 NewF32(2.0), 222 NewF32(1.0), 223 NewF32(1.0), 224 scalarShape, 225 }, 226 227 {Sub, 228 tensor.New(tensor.WithBacking([]float32{1, 2, 3, 4})), 229 tensor.New(tensor.WithBacking([]float32{1, 2, 3, 4})), 230 231 tensor.New(tensor.WithBacking([]float32{0, 0, 0, 0})), 232 tensor.New(tensor.WithBacking([]float32{1, 1, 1, 1})), 233 tensor.New(tensor.WithBacking([]float32{-1, -1, -1, -1})), 234 tensor.Shape{4}, 235 }, 236 237 {Sub, 238 tensor.New(tensor.WithBacking([]float32{1, 2, 3, 4})), 239 NewF32(1.0), 240 241 tensor.New(tensor.WithBacking([]float32{0, 1, 2, 3})), 242 tensor.New(tensor.WithBacking([]float32{1, 1, 1, 1})), 243 NewF32(-4.0), 244 tensor.Shape{4}, 245 }, 246 247 {Sub, 248 NewF32(1.0), 249 tensor.New(tensor.WithBacking([]float32{1, 2, 3, 4})), 250 251 tensor.New(tensor.WithBacking([]float32{0, -1, -2, -3})), 252 NewF32(4.0), 253 tensor.New(tensor.WithBacking([]float32{-1, -1, -1, -1})), 254 tensor.Shape{4}, 255 }, 256 257 {Sub, 258 NewF32(1.0), 259 NewF32(1.0), 260 261 NewF32(0.0), 262 NewF32(1.0), 263 NewF32(-1.0), 264 scalarShape, 265 }, 266 267 {HadamardProd, 268 tensor.New(tensor.WithBacking([]float32{1, 2, 3, 4})), 269 tensor.New(tensor.WithBacking([]float32{1, 2, 3, 4})), 270 271 tensor.New(tensor.WithBacking([]float32{1, 4, 9, 16})), 272 tensor.New(tensor.WithBacking([]float32{1, 2, 3, 4})), 273 tensor.New(tensor.WithBacking([]float32{1, 2, 3, 4})), 274 tensor.Shape{4}, 275 }, 276 277 {Mul, 278 tensor.New(tensor.WithBacking([]float32{1, 2, 3, 4})), 279 NewF32(1.0), 280 281 tensor.New(tensor.WithBacking([]float32{1, 2, 3, 4})), 282 tensor.New(tensor.WithBacking([]float32{1, 1, 1, 1})), 283 NewF32(10), 284 tensor.Shape{4}, 285 }, 286 287 {Mul, 288 NewF32(1.0), 289 tensor.New(tensor.WithBacking([]float32{1, 2, 3, 4})), 290 291 tensor.New(tensor.WithBacking([]float32{1, 2, 3, 4})), 292 NewF32(10), 293 tensor.New(tensor.WithBacking([]float32{1, 1, 1, 1})), 294 tensor.Shape{4}, 295 }, 296 297 {Mul, 298 NewF32(1.0), 299 NewF32(1.0), 300 301 NewF32(1.0), 302 NewF32(1.0), 303 NewF32(1.0), 304 scalarShape, 305 }, 306 307 {HadamardDiv, 308 tensor.New(tensor.WithBacking([]float32{1, 2, 3, 4})), 309 tensor.New(tensor.WithBacking([]float32{1, 1, 1, 1})), 310 311 tensor.New(tensor.WithBacking([]float32{1, 2, 3, 4})), 312 tensor.New(tensor.WithBacking([]float32{1, 1, 1, 1})), 313 tensor.New(tensor.WithBacking([]float32{-1, -2, -3, -4})), 314 tensor.Shape{4}, 315 }, 316 317 {Div, 318 tensor.New(tensor.WithBacking([]float32{1, 2, 3, 4})), 319 NewF32(1.0), 320 321 tensor.New(tensor.WithBacking([]float32{1, 2, 3, 4})), 322 tensor.New(tensor.WithBacking([]float32{1, 1, 1, 1})), 323 NewF32(-10), 324 tensor.Shape{4}, 325 }, 326 327 {Div, 328 NewF32(1), 329 tensor.New(tensor.WithBacking([]float32{1, 1, 1, 1})), 330 331 tensor.New(tensor.WithBacking([]float32{1, 1, 1, 1})), 332 NewF32(4), 333 tensor.New(tensor.WithBacking([]float32{-1, -1, -1, -1})), 334 tensor.Shape{4}, 335 }, 336 337 {Div, 338 NewF32(1.0), 339 NewF32(1.0), 340 341 NewF32(1.0), 342 NewF32(1.0), 343 NewF32(-1.0), 344 scalarShape, 345 }, 346 347 { 348 func(a *Node, b *Node) (*Node, error) { 349 return BatchedMatMul(a, b, false, false) 350 }, 351 tensor.New(tensor.WithShape(2, 3, 4), tensor.WithBacking([]float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12})), 352 tensor.New(tensor.WithShape(2, 4, 1), tensor.WithBacking([]float64{1, 2, 3, 4, 1, 2, 3, 4})), 353 354 tensor.New(tensor.WithBacking([]float64{30, 70, 110, 30, 70, 110})), 355 tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4})), 356 tensor.New(tensor.WithBacking([]float64{15, 18, 21, 24, 15, 18, 21, 24})), 357 tensor.Shape{2, 3, 1}, 358 }, 359 } 360 361 func TestBasicArithmetic(t *testing.T) { 362 for i, bot := range binOpTests { 363 if err := testOneArithTape(t, bot, i); err != nil { 364 t.Fatalf("Test %d, Err: %+v", i, err) 365 } 366 runtime.GC() 367 } 368 369 for i, bot := range binOpTests { 370 // log.Printf("Test %d", i) 371 if err := testOneArithLisp(t, bot, i); err != nil { 372 t.Fatalf("Test %d, Err: %+v", i, err) 373 } 374 runtime.GC() 375 } 376 } 377 378 func testOneArithLisp(t *testing.T, bot binOpTest, i int) error { 379 g := NewGraph() 380 xV, _ := CloneValue(bot.a) 381 yV, _ := CloneValue(bot.b) 382 x := NodeFromAny(g, xV, WithName("x")) 383 y := NodeFromAny(g, yV, WithName("y")) 384 385 var ret *Node 386 var retVal Value 387 var err error 388 if ret, err = bot.binOp(x, y); err != nil { 389 return errors.Wrapf(err, "do binop failure") 390 } 391 Read(ret, &retVal) 392 393 if !(xV.Shape().IsScalar() && yV.Shape().IsScalar()) { 394 Must(Sum(ret)) 395 } 396 m1 := NewLispMachine(g) 397 defer m1.Close() 398 if err = m1.RunAll(); err != nil { 399 return errors.Wrapf(err, "Error while running") 400 } 401 402 as := newAssertState(assert.New(t)) 403 as.Equal(bot.correct.Data(), retVal.Data(), "Test %d result", i) 404 as.True(bot.correctShape.Eq(ret.Shape())) 405 406 var xG, yG Value 407 if xG, err = x.Grad(); err != nil { 408 return errors.Wrapf(err, "Failed to get the grad of x") 409 } 410 411 if yG, err = y.Grad(); err != nil { 412 return errors.Wrapf(err, "Failed to get the grad of y") 413 } 414 415 as.Equal(bot.correctDerivA.Data(), xG.Data(), "Test %v xgrad", i) 416 as.Equal(bot.correctDerivB.Data(), yG.Data(), "Test %v ygrad. Expected %v. Got %v", i, bot.correctDerivB, yG) 417 if !as.cont { 418 t.Errorf("an error occurred") 419 } 420 421 if assertGraphEngine(t, g, stdengType); t.Failed() { 422 return errors.New("Lisp Machine Graph Engine expected") 423 } 424 return nil 425 } 426 427 func testOneArithTape(t *testing.T, bot binOpTest, i int) error { 428 g := NewGraph() 429 xV, _ := CloneValue(bot.a) 430 yV, _ := CloneValue(bot.b) 431 x := NodeFromAny(g, xV, WithName("x")) 432 y := NodeFromAny(g, yV, WithName("y")) 433 434 var ret *Node 435 var retVal Value 436 var err error 437 if ret, err = bot.binOp(x, y); err != nil { 438 return errors.Wrapf(err, "binOp() failed") 439 } 440 Read(ret, &retVal) 441 442 cost := Must(Sum(ret)) 443 var grads Nodes 444 if grads, err = Grad(cost, x, y); err != nil { 445 return errors.Wrapf(err, "Grad failed") 446 } 447 448 m1 := NewTapeMachine(g) 449 defer m1.Close() 450 if err = m1.RunAll(); err != nil { 451 t.Logf("%v", m1.Prog()) 452 return errors.Wrapf(err, "Error while running") 453 } 454 455 as := newAssertState(assert.New(t)) 456 as.True(bot.a.Shape().Eq(x.Shape()), "Test op doesn't change shape of input node") 457 as.True(bot.b.Shape().Eq(y.Shape()), "Test op doesn't change shape of input node") 458 as.Equal(bot.correct.Data(), retVal.Data(), "Test %d result", i) 459 as.True(bot.correctShape.Eq(ret.Shape())) 460 as.Equal(2, len(grads)) 461 as.Equal(bot.correctDerivA.Data(), grads[0].Value().Data(), "Test %v xgrad", i) 462 as.Equal(bot.correctDerivB.Data(), grads[1].Value().Data(), "Test %v ygrad. Expected %v. Got %v", i, bot.correctDerivB, grads[1].Value()) 463 if !as.cont { 464 prog := m1.Prog() 465 return errors.Errorf("Failed. Prog %v", prog) 466 } 467 468 if assertGraphEngine(t, g, stdengType); t.Failed() { 469 return errors.Errorf("BasicArithmetic. Engine of Graph is not stdengType.") 470 } 471 return nil 472 } 473 474 func TestTensordotOpDoDiff(t *testing.T) { 475 assert := assert.New(t) 476 477 // Vectors 478 g := NewGraph() 479 a := NewTensor(g, Float64, 1, WithName("a"), WithShape(1)) 480 b := NewTensor(g, Float64, 1, WithName("b"), WithShape(1)) 481 482 tensordot := tensordotOp{ 483 aAxes: []int{0}, 484 bAxes: []int{0}, 485 aDims: 0, 486 bDims: 0, 487 retDims: 0, 488 } 489 490 c, err := ApplyOp(tensordot, a, b) 491 492 if err != nil { 493 log.Fatalf("scalars: Cannot ApplyOp: %+v", err) 494 return 495 } 496 497 aT := tensor.New(tensor.WithShape(), tensor.WithBacking([]float64{2})) 498 bT := tensor.New(tensor.WithShape(), tensor.WithBacking([]float64{21})) 499 cT := tensor.New(tensor.WithShape(), tensor.WithBacking([]float64{1})) // Backing doesn't matter as long as it is set 500 501 aVal, _, _, _ := anyToValue(aT) 502 bVal, _, _, _ := anyToValue(bT) 503 cVal, _, _, _ := anyToValue(cT) 504 505 a.bind(dvUnit(aVal)) 506 b.bind(dvUnit(bVal)) 507 c.bind(dvUnitVar(cVal)) // Will set Output derivative to all ones 508 509 if err := tensordot.DoDiff(ExecutionContext{}, Nodes{a, b}, c); err != nil { 510 t.Fatalf("scalars: Cannot DoDiff: %+v", err) 511 } 512 513 aG, _ := a.Grad() 514 aGfloat := aG.Data() 515 516 bG, _ := b.Grad() 517 bGfloat := bG.Data() 518 519 aGcorrect := 21.0 520 bGcorrect := 2.0 521 522 assert.Equal(aGcorrect, aGfloat) 523 assert.Equal(bGcorrect, bGfloat) 524 525 // Vectors 526 527 g = NewGraph() 528 a = NewTensor(g, Float64, 1, WithName("a"), WithShape(2)) 529 b = NewTensor(g, Float64, 1, WithName("b"), WithShape(2)) 530 531 tensordot = tensordotOp{ 532 aAxes: []int{0}, 533 bAxes: []int{0}, 534 aDims: 1, 535 bDims: 1, 536 retDims: 1, 537 } 538 539 if c, err = ApplyOp(tensordot, a, b); err != nil { 540 log.Fatal("vectors: Cannot ApplyOp:", err) 541 return 542 } 543 544 aT = tensor.New(tensor.WithShape(2), tensor.WithBacking([]float64{1, 2})) 545 bT = tensor.New(tensor.WithShape(2), tensor.WithBacking([]float64{3, 4})) 546 cT = tensor.New(tensor.WithShape(1), tensor.WithBacking([]float64{1})) // Backing doesn't matter as long as it is set 547 548 aVal, _, _, _ = anyToValue(aT) 549 bVal, _, _, _ = anyToValue(bT) 550 cVal, _, _, _ = anyToValue(cT) 551 552 a.bind(dvUnit(aVal)) 553 b.bind(dvUnit(bVal)) 554 c.bind(dvUnitVar(cVal)) // Will set Output derivative to all ones 555 556 if err := tensordot.DoDiff(ExecutionContext{}, Nodes{a, b}, c); err != nil { 557 log.Fatal("vectors: Cannot DoDiff:", err) 558 return 559 } 560 561 aG, _ = a.Grad() 562 bG, _ = b.Grad() 563 564 aGfloats := extractF64s(aG) 565 bGfloats := extractF64s(bG) 566 567 aGcorrectFloats := []float64{3, 4} 568 bGcorrectFloats := []float64{1, 2} 569 570 assert.Equal(aGcorrectFloats, aGfloats) 571 assert.Equal(bGcorrectFloats, bGfloats) 572 573 // Matrix and Vector 574 575 g = NewGraph() 576 a = NewTensor(g, Float64, 2, WithName("a"), WithShape(2, 2)) 577 b = NewTensor(g, Float64, 1, WithName("b"), WithShape(2)) 578 579 tensordot = tensordotOp{ 580 aAxes: []int{1}, 581 bAxes: []int{0}, 582 aDims: 2, 583 bDims: 1, 584 retDims: 1, 585 } 586 587 if c, err = ApplyOp(tensordot, a, b); err != nil { 588 log.Fatal("matrix vector: Cannot ApplyOp:", err) 589 return 590 } 591 592 aT = tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4})) 593 bT = tensor.New(tensor.WithShape(2), tensor.WithBacking([]float64{1, 2})) 594 cT = tensor.New(tensor.WithShape(2), tensor.WithBacking([]float64{1, 1})) // Backing doesn't matter as long as it is set 595 596 aVal, _, _, _ = anyToValue(aT) 597 bVal, _, _, _ = anyToValue(bT) 598 cVal, _, _, _ = anyToValue(cT) 599 600 a.bind(dvUnit(aVal)) 601 b.bind(dvUnit(bVal)) 602 c.bind(dvUnitVar(cVal)) // Will set Output derivative to all ones 603 604 if err := tensordot.DoDiff(ExecutionContext{}, Nodes{a, b}, c); err != nil { 605 log.Fatal("matrix vector: Cannot DoDiff:", err) 606 return 607 } 608 609 aG, _ = a.Grad() 610 bG, _ = b.Grad() 611 612 aGfloats = extractF64s(aG) 613 bGfloats = extractF64s(bG) 614 615 aGcorrectFloats = []float64{1, 2, 1, 2} 616 bGcorrectFloats = []float64{4, 6} 617 618 assert.Equal(aGcorrectFloats, aGfloats) 619 assert.Equal(bGcorrectFloats, bGfloats) 620 621 // Matrix multiplication 622 623 g = NewGraph() 624 625 a = NewTensor(g, Float64, 2, WithName("a"), WithShape(2, 2)) 626 b = NewTensor(g, Float64, 2, WithName("b"), WithShape(2, 2)) 627 628 tensordot = tensordotOp{ 629 aAxes: []int{1}, 630 bAxes: []int{0}, 631 aDims: 2, 632 bDims: 2, 633 retDims: 2, 634 } 635 636 if c, err = ApplyOp(tensordot, a, b); err != nil { 637 log.Fatal("matrices: Cannot ApplyOp:", err) 638 return 639 } 640 641 aT = tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4})) 642 bT = tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4})) 643 cT = tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 1, 1, 1})) // Backing doesn't matter as long as it is set 644 645 aVal, _, _, _ = anyToValue(aT) 646 bVal, _, _, _ = anyToValue(bT) 647 cVal, _, _, _ = anyToValue(cT) 648 649 a.bind(dvUnit(aVal)) 650 b.bind(dvUnit(bVal)) 651 c.bind(dvUnitVar(cVal)) // Will set Output derivative to all ones 652 653 if err := tensordot.DoDiff(ExecutionContext{}, Nodes{a, b}, c); err != nil { 654 log.Fatal("matrices: Cannot DoDiff:", err) 655 return 656 } 657 658 aG, _ = a.Grad() 659 bG, _ = b.Grad() 660 661 aGfloats = extractF64s(aG) 662 bGfloats = extractF64s(bG) 663 664 aGcorrectFloats = []float64{3, 7, 3, 7} 665 bGcorrectFloats = []float64{4, 4, 6, 6} 666 667 assert.Equal(aGcorrectFloats, aGfloats) 668 assert.Equal(bGcorrectFloats, bGfloats) 669 670 // Total matrix contraction 671 672 g = NewGraph() 673 674 a = NewTensor(g, Float64, 2, WithName("a"), WithShape(2, 2)) 675 b = NewTensor(g, Float64, 2, WithName("b"), WithShape(2, 2)) 676 677 tensordot = tensordotOp{ 678 aAxes: []int{1, 0}, 679 bAxes: []int{0, 1}, 680 aDims: 2, 681 bDims: 2, 682 retDims: 1, 683 } 684 685 if c, err = ApplyOp(tensordot, a, b); err != nil { 686 log.Fatal("matrices total contraction: Cannot ApplyOp:", err) 687 return 688 } 689 690 aT = tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4})) 691 bT = tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{5, 6, 7, 8})) 692 cT = tensor.New(tensor.WithShape(1), tensor.WithBacking([]float64{1})) // Backing doesn't matter as long as it is set 693 694 aVal, _, _, _ = anyToValue(aT) 695 bVal, _, _, _ = anyToValue(bT) 696 cVal, _, _, _ = anyToValue(cT) 697 698 a.bind(dvUnit(aVal)) 699 b.bind(dvUnit(bVal)) 700 c.bind(dvUnitVar(cVal)) // Will set Output derivative to all ones 701 702 if err := tensordot.DoDiff(ExecutionContext{}, Nodes{a, b}, c); err != nil { 703 log.Fatal("matrices total contraction: Cannot DoDiff:", err) 704 return 705 } 706 707 aG, _ = a.Grad() 708 bG, _ = b.Grad() 709 710 aGfloats = extractF64s(aG) 711 bGfloats = extractF64s(bG) 712 713 aGcorrectFloats = []float64{5, 7, 6, 8} 714 bGcorrectFloats = []float64{1, 3, 2, 4} 715 716 assert.Equal(aGcorrectFloats, aGfloats) 717 assert.Equal(bGcorrectFloats, bGfloats) 718 719 } 720 721 func TestLinearAlgebraOps(t *testing.T) { 722 g := NewGraph() 723 x := NewMatrix(g, Float64, WithShape(2, 3), WithName("x")) 724 y := NewMatrix(g, Float64, WithShape(3, 5), WithName("y")) 725 if _, err := Mul(x, y); err != nil { 726 t.Fatal(err) 727 } 728 729 if _, err := Mul(y, x); err == nil { 730 t.Error("Expect an error") 731 } 732 }