gorgonia.org/gorgonia@v0.9.17/operations_test.go (about) 1 package gorgonia 2 3 import ( 4 "io/ioutil" 5 "log" 6 "runtime" 7 "testing" 8 9 "github.com/stretchr/testify/assert" 10 "github.com/stretchr/testify/require" 11 "gorgonia.org/tensor" 12 ) 13 14 func TestApplyOp(t *testing.T) { 15 assert := assert.New(t) 16 g := NewGraph() 17 18 var cpi *Node 19 var ct *Node 20 var op Op 21 22 t.Log("Simple Constant Scalar test") 23 cpi = NewConstant(3.1415, WithName("constantPi")) 24 cpi = g.AddNode(cpi) 25 26 t.Logf("g: %v", cpi.g) 27 28 op = newElemBinOp(addOpType, cpi, cpi) 29 added, err := ApplyOpWithName(op, "+ pi pi", cpi, cpi) 30 if err != nil { 31 t.Fatal(err) 32 } 33 assert.Equal(g, added.g) 34 assert.Equal(Float64, added.t) 35 36 ct = NewConstant(tensor.Ones(tensor.Float64, 3, 3)) // no graph set for ct 37 op = newElemBinOp(addOpType, cpi, ct) 38 if added, err = ApplyOpWithName(op, "+ pi constTensor(3,3)_ones", cpi, ct); err != nil { 39 t.Error(err) 40 } 41 } 42 43 var mulTests = []struct { 44 name string 45 xshape tensor.Shape 46 wshape tensor.Shape 47 48 gradX []float64 49 gradW []float64 50 }{ 51 {"x vector", tensor.Shape{2}, tensor.Shape{2, 3}, []float64{3, 12}, []float64{0, 0, 0, 1, 1, 1}}, 52 {"x mat", tensor.Shape{3, 2}, tensor.Shape{2, 3}, []float64{3, 12, 3, 12, 3, 12}, []float64{6, 6, 6, 9, 9, 9}}, 53 {"x_vec_w_vec", tensor.Shape{6}, tensor.Shape{6}, []float64{0, 1, 2, 3, 4, 5}, []float64{0, 1, 2, 3, 4, 5}}, 54 } 55 56 func TestMul(t *testing.T) { 57 defer runtime.GC() 58 assert := assert.New(t) 59 for _, mts := range mulTests { 60 g := NewGraph() 61 x := NewTensor(g, Float64, mts.xshape.Dims(), WithName(mts.name), WithShape(mts.xshape...), WithInit(RangedFrom(0))) 62 w := NewTensor(g, Float64, mts.wshape.Dims(), WithName("w"), WithShape(mts.wshape...), WithInit(RangedFrom(0))) 63 64 xw, err := Mul(x, w) 65 if err != nil { 66 t.Errorf("Error when testing %q. Err: %v", mts.name, err) 67 continue 68 } 69 70 if mts.xshape.IsVector() && mts.wshape.IsVector() { 71 if _, err = Grad(xw, x, w); err != nil { 72 t.Errorf("Error while differentiating %q, Err: %v", mts.name, err) 73 continue 74 } 75 } else { 76 cost, err := Sum(xw) 77 if err != nil { 78 t.Errorf("Error when summing %q. Err: %v", mts.name, err) 79 continue 80 } 81 82 if _, err = Grad(cost, x, w); err != nil { 83 t.Errorf("Error while differentiating %q, Err: %v", mts.name, err) 84 continue 85 } 86 } 87 88 m := NewTapeMachine(g) 89 if err = m.RunAll(); err != nil { 90 t.Errorf("Error while executing %q. Err: %v", mts.name, err) 91 continue 92 } 93 94 gradX, err := x.Grad() 95 if err != nil { 96 t.Errorf("Error while getting gradient of x %q. Err: %v", mts.name, err) 97 } 98 99 gradW, err := w.Grad() 100 if err != nil { 101 t.Errorf("Error while getting gradient of w %q. Err: %v", mts.name, err) 102 } 103 104 assert.Equal(mts.gradX, gradX.Data().([]float64)) 105 assert.Equal(mts.gradW, gradW.Data().([]float64)) 106 assert.True(mts.xshape.Eq(gradX.Shape())) 107 assert.True(mts.wshape.Eq(gradW.Shape())) 108 m.Close() 109 } 110 111 t.Logf("Testing Mul with LispMachine") 112 for _, mts := range mulTests { 113 g := NewGraph() 114 x := NewTensor(g, Float64, mts.xshape.Dims(), WithName(mts.name), WithShape(mts.xshape...), WithInit(RangedFrom(0))) 115 w := NewTensor(g, Float64, mts.wshape.Dims(), WithName("w"), WithShape(mts.wshape...), WithInit(RangedFrom(0))) 116 117 xw, err := Mul(x, w) 118 if err != nil { 119 t.Errorf("Error when testing %q. Err: %v", mts.name, err) 120 continue 121 } 122 123 if mts.xshape.IsVector() && mts.wshape.IsVector() { 124 125 } else { 126 if _, err = Sum(xw); err != nil { 127 t.Errorf("Error when summing %q. Err: %v", mts.name, err) 128 continue 129 } 130 } 131 132 m := NewLispMachine(g) 133 134 if err = m.RunAll(); err != nil { 135 // ioutil.WriteFile(fmt.Sprintf("fullGraph_%v.dot", mts.name), []byte(g.ToDot()), 0644) 136 t.Errorf("Error while executing %q. Err: %v", mts.name, err) 137 continue 138 } 139 140 gradX, err := x.Grad() 141 if err != nil { 142 t.Errorf("Error while getting gradient of x %q. Err: %v", mts.name, err) 143 } 144 145 gradW, err := w.Grad() 146 if err != nil { 147 t.Errorf("Error while getting gradient of w %q. Err: %v", mts.name, err) 148 } 149 150 assert.Equal(mts.gradX, gradX.Data().([]float64)) 151 assert.Equal(mts.gradW, gradW.Data().([]float64)) 152 assert.True(mts.xshape.Eq(gradX.Shape())) 153 assert.True(mts.wshape.Eq(gradW.Shape())) 154 m.Close() 155 } 156 } 157 158 var gtTests = []struct { 159 a, b Value 160 retSame bool 161 162 expected Value 163 err bool 164 }{ 165 // s-s 166 {NewF64(float64(1)), NewF64(float64(0)), true, NewF64(1.0), false}, 167 {NewF64(float64(0)), NewF64(float64(1)), true, NewF64(0.0), false}, 168 {NewF64(float64(1)), NewF64(float64(0)), false, NewB(true), false}, 169 {NewF32(float32(0)), NewF32(float32(1)), false, NewB(false), false}, 170 171 // s-t 172 { 173 NewF64(float64(1)), tensor.New(tensor.WithShape(2), tensor.WithBacking([]float64{0, 2})), 174 true, 175 tensor.New(tensor.WithShape(2), tensor.WithBacking([]float64{1, 0})), 176 false, 177 }, 178 179 { 180 NewF32(float32(1)), tensor.New(tensor.WithShape(2), tensor.WithBacking([]float32{0, 2})), 181 false, 182 tensor.New(tensor.WithShape(2), tensor.WithBacking([]bool{true, false})), 183 false, 184 }, 185 186 // t-s 187 { 188 tensor.New(tensor.WithShape(2), tensor.WithBacking([]float64{0, 2})), NewF64(float64(1)), 189 true, 190 tensor.New(tensor.WithShape(2), tensor.WithBacking([]float64{0, 1})), 191 false, 192 }, 193 194 { 195 tensor.New(tensor.WithShape(2), tensor.WithBacking([]float32{0, 2})), NewF32(float32(1)), 196 false, 197 tensor.New(tensor.WithShape(2), tensor.WithBacking([]bool{false, true})), 198 false, 199 }, 200 201 // t-t 202 { 203 tensor.New(tensor.WithShape(2, 3), tensor.WithBacking([]float64{0, 1, 2, 3, 4, 5})), 204 tensor.New(tensor.WithShape(2, 3), tensor.WithBacking([]float64{5, 4, 3, 2, 1, 0})), 205 true, 206 207 tensor.New(tensor.WithShape(2, 3), tensor.WithBacking([]float64{0, 0, 0, 1, 1, 1})), 208 false, 209 }, 210 211 { 212 tensor.New(tensor.WithShape(2, 3), tensor.WithBacking([]float64{0, 1, 2, 3, 4, 5})), 213 tensor.New(tensor.WithShape(2, 3), tensor.WithBacking([]float64{5, 4, 3, 2, 1, 0})), 214 false, 215 216 tensor.New(tensor.WithShape(2, 3), tensor.WithBacking([]bool{false, false, false, true, true, true})), 217 false, 218 }, 219 220 // stupids 221 222 // different shapes 223 { 224 tensor.New(tensor.Of(tensor.Float32), tensor.WithShape(2)), tensor.New(tensor.Of(tensor.Float32), tensor.WithShape(4)), 225 true, nil, true, 226 }, 227 228 // different dtypes 229 { 230 tensor.New(tensor.Of(tensor.Float64), tensor.WithShape(2)), tensor.New(tensor.Of(tensor.Float32), tensor.WithShape(2)), 231 true, nil, true, 232 }, 233 } 234 235 func TestGt(t *testing.T) { 236 defer runtime.GC() 237 for i, gtts := range gtTests { 238 // if i != 11 { 239 // continue 240 // } 241 g := NewGraph() 242 a := NodeFromAny(g, gtts.a, WithName("a")) 243 b := NodeFromAny(g, gtts.b, WithName("b")) 244 245 var ret *Node 246 var err error 247 ret, err = Gt(a, b, gtts.retSame) 248 249 switch { 250 case gtts.err: 251 if err == nil { 252 t.Errorf("Expected an error in Test %d", i) 253 } 254 continue 255 case !gtts.err && err != nil: 256 t.Errorf("Test %d: %+v", i, err) 257 continue 258 } 259 260 if gtts.retSame { 261 cost := Must(Sum(ret)) 262 Grad(cost, a, b) 263 } 264 265 m1 := NewTapeMachine(g) 266 if err = m1.RunAll(); err != nil { 267 ioutil.WriteFile("fail.dot", []byte(g.ToDot()), 0644) 268 t.Errorf("%v", m1.Prog()) 269 t.Errorf("Test %d: %+v", i, err) 270 continue 271 } 272 273 if !ValueEq(gtts.expected, ret.Value()) { 274 t.Errorf("Test %d Expected %v. Got %v", i, gtts.expected, ret.Value()) 275 } 276 277 // Test LispMachine implementation 278 h := NewGraph() 279 x := NodeFromAny(h, gtts.a, WithName("x")) 280 y := NodeFromAny(h, gtts.b, WithName("y")) 281 ret2, _ := Gt(x, y, gtts.retSame) 282 283 var m2 VM 284 if gtts.retSame { 285 Must(Sum(ret2)) 286 m2 = NewLispMachine(h) 287 } else { 288 m2 = NewLispMachine(h, ExecuteFwdOnly()) 289 } 290 if err = m2.RunAll(); err != nil { 291 t.Errorf("Test %d LispMachine: %+v", i, err) 292 continue 293 } 294 295 if !ValueEq(ret.Value(), ret2.Value()) { 296 t.Errorf("Test %d. Expected %v. Got %v", i, ret.Value(), ret2.Value()) 297 } 298 m1.Close() 299 m2.Close() 300 runtime.GC() 301 } 302 303 // other special cases 304 g := NewGraph() 305 c := NewConstant(F64(1)) 306 // T := NewTensor(g, Float64, 1, WithShape(2), WithInit(RangedFrom(0))) 307 T := UniformRandomNode(g, Float64, 0, 1, 2) 308 309 var gt *Node 310 var err error 311 if gt, err = Gt(c, T, true); err != nil { 312 t.Error(err) 313 } 314 cost := Must(Sum(gt)) 315 Grad(cost, T) 316 317 m1 := NewTapeMachine(g) 318 defer m1.Close() 319 if err = m1.RunAll(); err != nil { 320 t.Error(err) 321 } 322 323 if (TensorType{Dims: 1, Of: Float64}) != TypeOf(gt.Value()) { 324 t.Error("Expected a tensor type of float64") 325 } 326 327 // Same test as above, but using *lispMachine 328 329 h := NewGraph() 330 d := NewConstant(F64(1)) 331 U := UniformRandomNode(h, Float64, 0, 1, 2) 332 var gt2 *Node 333 if gt2, err = Gt(d, U, true); err != nil { 334 t.Error(err) 335 } 336 Must(Sum(gt2)) 337 338 m2 := NewLispMachine(h) 339 defer m2.Close() 340 if err = m2.RunAll(); err != nil { 341 t.Error(err) 342 } 343 344 if (TensorType{Dims: 1, Of: Float64}) != TypeOf(gt2.Value()) { 345 t.Error("Expected a tensor type of float64") 346 } 347 348 t.Logf("%v", gt2.Value()) 349 runtime.GC() 350 351 } 352 353 func TestMisha(t *testing.T) { 354 defer runtime.GC() 355 assert := assert.New(t) 356 g := NewGraph() 357 var err error 358 var x0, x1, x2, f0, f1, f2 *Node 359 var grad0, grad1, grad2 Nodes 360 361 x0 = NewScalar(g, Float64, WithName("x0")) 362 x1 = NewScalar(g, Float64, WithName("x1")) 363 x2 = NewScalar(g, Float64, WithName("x2")) 364 365 Let(x0, -2.5) 366 Let(x1, -2.2) 367 Let(x2, 1.0) 368 369 f0 = Must(Mish(x0)) 370 f1 = Must(Mish(x1)) 371 f2 = Must(Mish(x2)) 372 373 if grad0, err = Grad(f0, x0); err != nil { 374 t.Error(err) 375 } 376 if grad1, err = Grad(f1, x1); err != nil { 377 t.Error(err) 378 } 379 if grad2, err = Grad(f2, x2); err != nil { 380 t.Error(err) 381 } 382 383 machine := NewTapeMachine(g) 384 defer machine.Close() 385 if err = machine.RunAll(); err != nil { 386 t.Error(err) 387 } 388 389 // assert non-monotonicity of Mish 390 // x0 < x1 < x2 && f0 > f1 < f2 391 assert.Less(extractF64(x0.Value()), extractF64(x1.Value())) 392 assert.Less(extractF64(x1.Value()), extractF64(x2.Value())) 393 assert.Greater(extractF64(f0.Value()), extractF64(f1.Value())) 394 assert.Less(extractF64(f1.Value()), extractF64(f2.Value())) 395 396 // assert non-monotonocity of Mish' 397 assert.Greater(extractF64(grad0[0].Value()), extractF64(grad1[0].Value())) 398 assert.Less(extractF64(grad1[0].Value()), extractF64(grad2[0].Value())) 399 } 400 401 func TestSoftMax(t *testing.T) { 402 defer runtime.GC() 403 g := NewGraph() 404 xT := tensor.New(tensor.WithBacking([]float64{0.1, 0.2, -0.3, 0.4, 0.5})) 405 x := NewVector(g, Float64, WithShape(5), WithValue(xT)) 406 sm := Must(SoftMax(x)) 407 logsm := Must(Neg(Must(Log(sm)))) 408 cost := Must(Slice(logsm, S(2))) 409 410 if _, err := Grad(cost, x); err != nil { 411 t.Error(err) 412 } 413 414 m := NewTapeMachine(g, TraceExec()) 415 defer m.Close() 416 if err := m.RunAll(); err != nil { 417 t.Error(err) 418 } 419 ioutil.WriteFile("fullGraph.dot", []byte(g.ToDot()), 0644) 420 var xG Value 421 var err error 422 if xG, err = x.Grad(); err != nil { 423 t.Error(err) 424 } 425 426 // machine 2, graph 2 427 h := NewGraph() 428 xT2 := tensor.New(tensor.WithBacking([]float64{0.1, 0.2, -0.3, 0.4, 0.5})) 429 x2 := NewVector(h, Float64, WithShape(5), WithValue(xT2)) 430 sm2 := Must(SoftMax(x2)) 431 logsm2 := Must(Neg(Must(Log(sm2)))) 432 Must(Slice(logsm2, S(2))) 433 434 m2 := NewLispMachine(h) 435 defer m2.Close() 436 if err = m2.RunAll(); err != nil { 437 log.Printf("ERR %v", err) 438 t.Error(err) 439 } 440 441 var x2G Value 442 if x2G, err = x2.Grad(); err != nil { 443 t.Error(err) 444 } 445 446 if !floatsEqual64(xG.Data().([]float64), x2G.Data().([]float64)) { 447 t.Errorf("Expected both gradients of X to be the same.") 448 } 449 t.Logf("\n%v\n%v\n%v", sm.Value(), logsm.Value(), cost.Value()) 450 correctXGrad := []float64{ 451 0.178025447751409, 0.1967485475322529, -0.8806659736677602, 0.24030921861990098, 0.2655827597641975, 452 } 453 454 if !floatsEqual64(correctXGrad, x2G.Data().([]float64)) { 455 t.Errorf("Expected results to be %v. Got %v.", correctXGrad, x2G.Data()) 456 } 457 if !floatsEqual64(correctXGrad, xG.Data().([]float64)) { 458 t.Errorf("Expected results to be %v. Got %v.", correctXGrad, xG.Data()) 459 } 460 } 461 462 var sliceTests = []struct { 463 name string 464 shape tensor.Shape 465 slices []tensor.Slice 466 467 expected tensor.Shape 468 data interface{} 469 err bool 470 }{ 471 {"vec[0]", tensor.Shape{2}, []tensor.Slice{S(0)}, scalarShape, float64(0), false}, 472 {"vec[0:2]", tensor.Shape{2}, []tensor.Slice{S(0, 2)}, tensor.Shape{2}, []float64{0, 1}, false}, 473 {"Mat[0]", tensor.Shape{2, 3}, []tensor.Slice{S(0)}, tensor.Shape{3}, []float64{0, 1, 2}, false}, 474 {"Mat[:, 0]", tensor.Shape{2, 3}, []tensor.Slice{nil, S(0)}, tensor.Shape{2}, []float64{0, 3}, false}, 475 {"3Tensor[0]", tensor.Shape{2, 3, 4}, []tensor.Slice{S(0)}, tensor.Shape{3, 4}, tensor.Range(tensor.Float64, 0, 12), false}, 476 {"3Tensor[0:2]", tensor.Shape{2, 3, 4}, []tensor.Slice{S(0, 2)}, tensor.Shape{2, 3, 4}, tensor.Range(tensor.Float64, 0, 24), false}, 477 {"3Tensor[:, 0]", tensor.Shape{2, 3, 4}, []tensor.Slice{nil, S(0)}, tensor.Shape{2, 4}, []float64{0, 1, 2, 3, 12, 13, 14, 15}, false}, 478 {"3Tensor[0, :, 0]", tensor.Shape{2, 3, 4}, []tensor.Slice{S(0), nil, S(0)}, tensor.Shape{3}, []float64{0, 4, 8}, false}, 479 480 {"vec[:, 0]", tensor.Shape{2}, []tensor.Slice{nil, S(0)}, nil, nil, true}, 481 } 482 483 func TestSlice(t *testing.T) { 484 defer runtime.GC() 485 for _, sts := range sliceTests { 486 g := NewGraph() 487 x := NewTensor(g, Float64, len(sts.shape), WithShape(sts.shape...), WithInit(RangedFrom(0))) 488 sliced, err := Slice(x, sts.slices...) 489 switch { 490 case sts.err: 491 if err == nil { 492 t.Errorf("Expected an error while running test %q", sts.name) 493 } 494 continue 495 case !sts.err && err != nil: 496 t.Errorf("Error in %q: %+v", sts.name, err) 497 continue 498 } 499 500 // test expected shapes: 501 if !sts.expected.Eq(sliced.shape) { 502 t.Errorf("Test %q - Expected %v. Got %v instead", sts.name, sts.expected, sliced.shape) 503 continue 504 } 505 506 // test forwards and backwards prop 507 cost := Must(Sum(sliced)) 508 if _, err := Grad(cost, x); err != nil { 509 t.Errorf("Test %q failed to backprop: %+v", sts.name, err) 510 continue 511 } 512 513 m1 := NewTapeMachine(g) 514 if err = m1.RunAll(); err != nil { 515 t.Errorf("Test %q Runtime error %+v ", sts.name, err) 516 continue 517 } 518 519 sV := sliced.Value() 520 if !sts.expected.Eq(sV.Shape()) { 521 t.Errorf("Test %q For TapeMachine. Expected sliced value to have the shape %v. Got %v instead", sts.name, sts.expected, sV.Shape()) 522 } 523 524 assert.Equal(t, sts.data, sV.Data(), "Test %q For TapeMachine data expected %v, Got %v instead. Formatted:\n %+v", sts.name, sts.data, sV.Data(), sV) 525 526 // Test Lisp Machine for equivalence of gradients 527 528 h := NewGraph() 529 a := NewTensor(g, Float64, len(sts.shape), WithShape(sts.shape...), WithInit(RangedFrom(0))) 530 sliced2 := Must(Slice(a, sts.slices...)) 531 Must(Sum(sliced2)) 532 533 m2 := NewLispMachine(h) 534 if err = m2.RunAll(); err != nil { 535 t.Errorf("Test %q Lispmachine Runtime error: %+v", sts.name, err) 536 continue 537 } 538 539 s2V := sliced2.Value() 540 if !sts.expected.Eq(s2V.Shape()) { 541 t.Errorf("Test %q For LispMachine. Expected sliced value to have the shape %v. Got %v instead", sts.name, sts.expected, s2V.Shape()) 542 } 543 544 assert.Equal(t, sts.data, s2V.Data(), "Test %q For TapeMachine data expected %v, Got %v instead. Formatted:\n %+v", sts.name, sts.data, s2V.Data(), s2V) 545 546 sG, err := sliced.Grad() 547 if err != nil { 548 t.Errorf("Test %q sliced has no grad: %+v", sts.name, err) 549 continue 550 } 551 552 s2G, err := sliced2.Grad() 553 if err != nil { 554 t.Errorf("Test %q sliced2 has no grad: %+v", sts.name, err) 555 continue 556 } 557 558 if !ValueEq(sG, s2G) { 559 t.Errorf("Test %q - Expected sG and s2G to have the same value", sts.name) 560 } 561 562 m1.Close() 563 m2.Close() 564 565 // For visual checks 566 // xG, err := x.Grad() 567 // t.Logf("Test %q x: \n%+v,\n%+v", sts.name, x.Value(), xG) 568 } 569 570 // special cases with UnsafeLet 571 g := NewGraph() 572 x := NewTensor(g, Float64, 2, WithShape(2, 3), WithInit(RangedFrom(0))) 573 sliced, _ := Slice(x, S(0)) 574 cost := Must(Slice(sliced, S(0))) 575 Grad(cost, x) 576 577 m := NewTapeMachine(g) 578 defer m.Close() 579 // mutate the graph before running 580 UnsafeLet(sliced, S(1)) 581 UnsafeLet(cost, S(2)) 582 if err := m.RunAll(); err != nil { 583 t.Fatal(err) 584 } 585 586 xG, err := x.Grad() 587 if err != nil { 588 t.Fatal(err) 589 } 590 591 // ioutil.WriteFile("blah.dot", []byte(g.ToDot()), 0644) 592 assert.Equal(t, []float64{0, 0, 0, 0, 0, 1}, xG.Data()) 593 // visual inspection 594 // t.Logf("x: \n%+v,\n%+v", x.Value(), xG) 595 596 } 597 598 var sumTests = []struct { 599 name string 600 shape tensor.Shape 601 along []int 602 603 expectedShape tensor.Shape 604 expectedVal Value 605 expectedGrad Value 606 err bool 607 }{ 608 {"Sum(vec)", tensor.Shape{2}, nil, scalarShape, NewF64(1.0), NewF64(1.0), false}, 609 {"Sum(vec, 0)", tensor.Shape{2}, []int{0}, scalarShape, NewF64(1), NewF64(1.0), false}, 610 {"Sum(Mat)", tensor.Shape{2, 3}, nil, scalarShape, NewF64(15.0), tensor.New(tensor.WithShape(2, 3), tensor.WithBacking([]float64{1, 1, 1, 1, 1, 1})), false}, 611 {"Sum(Mat, 0)", tensor.Shape{2, 3}, []int{0}, tensor.Shape{3}, 612 tensor.New(tensor.WithShape(3), tensor.WithBacking([]float64{3, 5, 7})), 613 tensor.New(tensor.WithShape(2, 3), tensor.WithBacking([]float64{1, 1, 1, 1, 1, 1})), false, 614 }, 615 {"Sum(Mat, 1)", tensor.Shape{2, 3}, []int{1}, tensor.Shape{2}, 616 tensor.New(tensor.WithShape(2), tensor.WithBacking([]float64{3, 12})), 617 tensor.New(tensor.WithShape(2, 3), tensor.WithBacking([]float64{1, 1, 1, 1, 1, 1})), false, 618 }, 619 620 // TODO: tests for 3-Tensors 621 // TODO: negative and stupids cases. 622 } 623 624 func TestSum(t *testing.T) { 625 defer runtime.GC() 626 for _, sts := range sumTests { 627 g := NewGraph() 628 x := NewTensor(g, Float64, len(sts.shape), WithShape(sts.shape...), WithInit(RangedFrom(0))) 629 var s *Node 630 var err error 631 632 if len(sts.along) == 0 { 633 s, err = Sum(x) 634 } else { 635 s, err = Sum(x, sts.along...) 636 } 637 638 switch { 639 case sts.err: 640 if err == nil { 641 t.Errorf("Expected an error in %q", sts.name) 642 } 643 continue 644 case !sts.err && err != nil: 645 t.Errorf("Test %q errored while Sum() %+v", sts.name, err) 646 continue 647 } 648 649 if !sts.expectedShape.Eq(s.shape) { 650 t.Errorf("Test %q has wrong shape. Want %v, got %v instead", sts.name, sts.expectedShape, s.shape) 651 continue 652 } 653 654 cost := s 655 if len(sts.along) < len(sts.shape) && len(sts.along) > 0 { 656 cost = Must(Sum(s)) 657 } 658 659 if _, err = Grad(cost, x); err != nil { 660 t.Errorf("Test %q - Unable to back prop. Err : %+v", sts.name, err) 661 continue 662 } 663 664 m := NewTapeMachine(g) 665 if err = m.RunAll(); err != nil { 666 t.Errorf("Test %q - Runtime error: %v", sts.name, err) 667 continue 668 } 669 670 if !ValueEq(sts.expectedVal, s.Value()) { 671 t.Errorf("Test %q Expected %v. Got %v", sts.name, sts.expectedVal, s.Value()) 672 } 673 674 sG, err := s.Grad() 675 if err != nil { 676 t.Errorf("Test %q Grad() error: %+v", sts.name, err) 677 continue 678 } 679 680 // LISP MACHINE TO TEST GRAD EQUIVALENCE 681 h := NewGraph() 682 a := NewTensor(h, Float64, len(sts.shape), WithShape(sts.shape...), WithInit(RangedFrom(0))) 683 var b *Node 684 if len(sts.along) == 0 { 685 b = Must(Sum(a)) 686 } else { 687 b = Must(Sum(a, sts.along...)) 688 } 689 690 if len(sts.along) < len(sts.shape) && len(sts.along) > 0 { 691 Must(Sum(b)) 692 } 693 694 m2 := NewLispMachine(h) 695 if err = m2.RunAll(); err != nil { 696 t.Errorf("Test %q Lisp machine runtime error %+v", sts.name, err) 697 continue 698 } 699 700 if !ValueEq(sts.expectedVal, b.Value()) { 701 t.Errorf("Test %q LispMachine Run. Expected %v. Got %v instead", sts.name, sts.expectedVal, b.Value()) 702 } 703 704 bG, err := b.Grad() 705 if err != nil { 706 t.Errorf("Test %q Grad() err in lispmachine run %+v", sts.name, err) 707 continue 708 } 709 710 if !ValueEq(sG, bG) { 711 t.Errorf("Expected the values of the partial derivatives of both machines to be the same") 712 } 713 714 m.Close() 715 m2.Close() 716 } 717 } 718 719 func TestNorm(t *testing.T) { 720 assert := assert.New(t) 721 g := NewGraph() 722 x := NewMatrix(g, Float64, WithShape(3, 3)) 723 norm, err := Norm(x, 0, 2) 724 if err != nil { 725 t.Error(err) 726 return 727 } 728 m := NewLispMachine(g, ExecuteFwdOnly()) 729 defer m.Close() 730 731 xT := tensor.New(tensor.WithShape(3, 3), tensor.WithBacking(tensor.Range(tensor.Float64, 0, 9))) 732 Let(x, xT) 733 m.RunAll() 734 735 correct := []float64{6.708203932499369, 8.12403840463596, 9.643650760992955} 736 assert.Equal(correct, extractF64s(norm.Value())) 737 738 } 739 740 func TestMean(t *testing.T) { 741 g := NewGraph() 742 x := NewMatrix(g, Float64, WithShape(3, 3)) 743 m, err := Mean(x) 744 if err != nil { 745 t.Fatal(err) 746 } 747 748 if !m.IsScalar() { 749 t.Error("Expected result to be scalar") 750 } 751 } 752 753 func TestTensordot(t *testing.T) { 754 assert := assert.New(t) 755 756 // Scalars 757 g := NewGraph() 758 759 a := NewTensor(g, Float64, 0, WithName("a"), WithShape(1), WithInit(RangedFrom(2))) 760 b := NewTensor(g, Float64, 0, WithName("b"), WithShape(1), WithInit(RangedFrom(21))) 761 c := NewTensor(g, Float64, 0, WithName("c"), WithShape(1), WithInit(ValuesOf(1.0))) 762 763 tensordot, err := Tensordot([]int{0}, []int{0}, a, b) 764 if err == nil { 765 t.Fatal("Expected scalars to fail") 766 } 767 768 // Scalar-like 769 g = NewGraph() 770 a = NewTensor(g, Float64, 1, WithName("a"), WithShape(1), WithInit(RangedFrom(2))) 771 b = NewTensor(g, Float64, 1, WithName("b"), WithShape(1), WithInit(RangedFrom(21))) 772 c = NewTensor(g, Float64, 1, WithName("c"), WithShape(1), WithInit(ValuesOf(1.0))) 773 774 tensordot, err = Tensordot([]int{0}, []int{0}, a, b) 775 if err != nil { 776 t.Fatal(err) 777 } 778 log.Printf("SHAPE a %v b %v c %v tensordot %v", a.Shape(), b.Shape(), c.Shape(), tensordot.Shape()) 779 780 dtensordot, err := Backpropagate(Nodes{tensordot}, Nodes{c}, Nodes{a, b}) 781 782 if err != nil { 783 t.Fatalf("%+v", err) 784 } 785 786 m := NewTapeMachine(g) 787 defer m.Close() 788 if err = m.RunAll(); err != nil { 789 t.Fatal(err) 790 } 791 792 correctScalarlike := []float64{42.0} 793 value := tensordot.Value().Data() 794 assert.Equal(correctScalarlike, value) 795 796 dtensordotCorrectScalarlike0 := []float64{21} 797 dtensordotCorrectScalarlike1 := []float64{2} 798 799 assert.Equal(dtensordotCorrectScalarlike0, dtensordot[0].Value().Data()) 800 assert.Equal(dtensordotCorrectScalarlike1, dtensordot[1].Value().Data()) 801 802 // Vectors 803 804 g = NewGraph() 805 a = NewTensor(g, Float64, 1, WithName("a"), WithShape(2), WithInit(RangedFrom(1))) 806 b = NewTensor(g, Float64, 1, WithName("b"), WithShape(2), WithInit(RangedFrom(3))) 807 c = NewTensor(g, Float64, 0, WithName("c"), WithShape(), WithInit(ValuesOf(1.0))) 808 809 if tensordot, err = Tensordot([]int{0}, []int{0}, a, b); err != nil { 810 t.Fatal(err) 811 } 812 813 if dtensordot, err = Backpropagate(Nodes{tensordot}, Nodes{c}, Nodes{a, b}); err != nil { 814 t.Fatalf("%+v", err) 815 } 816 817 // Need to multiply dtensordot with identiy matrix, otherwise the transpose action in symdiff is not performed 818 id := NewConstant(tensor.I(Float64, 2, 2, 0)) 819 820 dtensordot0 := Must(Mul(id, dtensordot[0])) 821 dtensordot1 := Must(Mul(id, dtensordot[1])) 822 823 m = NewTapeMachine(g) 824 defer m.Close() 825 if err = m.RunAll(); err != nil { 826 t.Fatal(err) 827 } 828 829 log.Printf("TensorDot %v | %v", tensordot.Value().Shape(), tensordot.Type()) 830 correctScalarlike = []float64{11} 831 assert.Equal(correctScalarlike, tensordot.Value().Data()) 832 833 dcorrect0 := []float64{3, 4} 834 dcorrect1 := []float64{1, 2} 835 836 assert.Equal(dcorrect0, extractF64s(dtensordot[0].Value())) 837 assert.Equal(dcorrect1, extractF64s(dtensordot[1].Value())) 838 839 // Vector and Matrix 840 g = NewGraph() 841 a = NewTensor(g, Float64, 2, WithName("a"), WithShape(2, 2), WithInit(RangedFrom(0))) 842 b = NewTensor(g, Float64, 1, WithName("b"), WithShape(2), WithInit(RangedFrom(0))) 843 844 c = NewTensor(g, Float64, 1, WithName("c"), WithShape(2), WithInit(ValuesOf(1.0))) 845 846 if tensordot, err = Tensordot([]int{1}, []int{0}, a, b); err != nil { 847 t.Fatal(err) 848 } 849 850 if dtensordot, err = Backpropagate(Nodes{tensordot}, Nodes{c}, Nodes{a, b}); err != nil { 851 t.Fatal(err) 852 } 853 854 // Need to multiply dtensordot with identiy matrix, otherwise the transpose action in symdiff is not performed 855 id = NewConstant(tensor.I(Float64, 2, 2, 0)) 856 857 if dtensordot0, err = Mul(id, dtensordot[0]); err != nil { 858 t.Fatal(err) 859 } 860 if dtensordot1, err = Mul(id, dtensordot[1]); err != nil { 861 t.Fatal(err) 862 } 863 864 m = NewTapeMachine(g) 865 defer m.Close() 866 if err = m.RunAll(); err != nil { 867 t.Fatal(err) 868 } 869 870 correct := []float64{1, 3} 871 assert.Equal(correct, extractF64s(tensordot.Value())) 872 873 dcorrect0 = []float64{0, 1, 0, 1} 874 dcorrect1 = []float64{2, 4} 875 876 assert.Equal(dcorrect0, extractF64s(dtensordot0.Value())) 877 assert.Equal(dcorrect1, extractF64s(dtensordot1.Value())) 878 879 // Matrices 880 g = NewGraph() 881 882 a = NewTensor(g, Float64, 2, WithName("a"), WithShape(2, 2), WithInit(RangedFrom(0))) 883 b = NewTensor(g, Float64, 2, WithName("b"), WithShape(2, 2), WithInit(RangedFrom(0))) 884 885 c = NewTensor(g, Float64, 2, WithName("c"), WithShape(2, 2), WithInit(ValuesOf(1.0))) 886 887 if tensordot, err = Tensordot([]int{1}, []int{1}, a, b); err != nil { 888 t.Fatal(err) 889 } 890 891 if dtensordot, err = Backpropagate(Nodes{tensordot}, Nodes{c}, Nodes{a, b}); err != nil { 892 t.Fatal(err) 893 } 894 895 // Need to multiply dtensordot with identiy matrix, otherwise the transpose action in symdiff is not performed 896 id = NewConstant(tensor.I(Float64, 2, 2, 0)) 897 898 if dtensordot0, err = Mul(id, dtensordot[0]); err != nil { 899 t.Fatal(err) 900 } 901 if dtensordot1, err = Mul(id, dtensordot[1]); err != nil { 902 t.Fatal(err) 903 } 904 905 m = NewTapeMachine(g) 906 if err = m.RunAll(); err != nil { 907 t.Fatal(err) 908 } 909 910 correct = []float64{1, 3, 3, 13} 911 assert.Equal(correct, extractF64s(tensordot.Value())) 912 913 dcorrect := []float64{2, 4, 2, 4} 914 assert.Equal(dcorrect, extractF64s(dtensordot0.Value())) 915 assert.Equal(dcorrect, extractF64s(dtensordot1.Value())) 916 917 // Total matrix contraction 918 g = NewGraph() 919 920 a = NewTensor(g, Float64, 2, WithName("a"), WithShape(2, 2), WithInit(RangedFrom(0))) 921 b = NewTensor(g, Float64, 2, WithName("b"), WithShape(2, 2), WithInit(RangedFrom(0))) 922 923 c = NewTensor(g, Float64, 0, WithName("c"), WithShape(), WithInit(ValuesOf(1.0))) 924 925 if tensordot, err = Tensordot([]int{0, 1}, []int{0, 1}, a, b); err != nil { 926 t.Fatal(err) 927 } 928 929 if dtensordot, err = Backpropagate(Nodes{tensordot}, Nodes{c}, Nodes{a, b}); err != nil { 930 t.Fatal(err) 931 } 932 933 // Need to multiply dtensordot with identiy matrix, otherwise the transpose action in symdiff is not performed 934 id = NewConstant(tensor.I(Float64, 2, 2, 0)) 935 936 if dtensordot0, err = Mul(id, dtensordot[0]); err != nil { 937 t.Fatal(err) 938 } 939 if dtensordot1, err = Mul(id, dtensordot[1]); err != nil { 940 t.Fatal(err) 941 } 942 943 m = NewTapeMachine(g) 944 defer m.Close() 945 if err = m.RunAll(); err != nil { 946 t.Fatal(err) 947 } 948 949 correctScalarlike = []float64{14} 950 assert.Equal(correctScalarlike, tensordot.Value().Data()) 951 952 dcorrect = []float64{0, 1, 2, 3} 953 assert.Equal(dcorrect, extractF64s(dtensordot0.Value())) 954 assert.Equal(dcorrect, extractF64s(dtensordot1.Value())) 955 956 } 957 958 var reshapeTests = []struct { 959 testName string 960 input tensor.Shape 961 to tensor.Shape 962 output tensor.Shape 963 err bool 964 }{ 965 {"simple", tensor.Shape{2, 2}, tensor.Shape{4}, tensor.Shape{4}, false}, 966 {"simple big tensor", tensor.Shape{200, 200}, tensor.Shape{200 * 200}, tensor.Shape{200 * 200}, false}, 967 {"negative dim1 1", tensor.Shape{3, 2}, tensor.Shape{6, -1}, tensor.Shape{6, 1}, false}, 968 {"negative dim1 2", tensor.Shape{3, 2}, tensor.Shape{2, -1}, tensor.Shape{2, 3}, false}, 969 {"negative dim0 1", tensor.Shape{3, 2}, tensor.Shape{-1, 3}, tensor.Shape{2, 3}, false}, 970 {"negative dims0.1 with error", tensor.Shape{3, 2}, tensor.Shape{-1, -1}, nil, true}, 971 {"devative dim0 with error", tensor.Shape{3, 2}, tensor.Shape{4, -1}, nil, true}, 972 } 973 974 func TestReshape(t *testing.T) { 975 for _, rst := range reshapeTests { 976 g := NewGraph() 977 T := NewTensor(g, Float64, len(rst.input), WithShape(rst.input.Clone()...)) 978 T2, err := Reshape(T, rst.to.Clone()) 979 t.Log(T2) 980 switch { 981 case rst.err && err == nil: 982 t.Fatalf("Expected Error when testing %v", rst) 983 case rst.err: 984 continue 985 case err != nil: 986 t.Fatal(err) 987 default: 988 assert.True(t, rst.output.Eq(T2.Shape()), "expected both to be the same") 989 } 990 991 } 992 } 993 func TestReshape_Dense(t *testing.T) { 994 for _, rst := range reshapeTests { 995 g := NewGraph() 996 tT := tensor.New(tensor.Of(tensor.Float64), tensor.WithShape(rst.input.Clone()...)) 997 T := NodeFromAny(g, tT) 998 T2, err := Reshape(T, rst.to.Clone()) 999 switch { 1000 case rst.err && err == nil: 1001 t.Fatalf("Expected Error when testing %v", rst) 1002 case rst.err: 1003 continue 1004 case err != nil: 1005 t.Fatal(err) 1006 default: 1007 assert.True(t, rst.output.Eq(T2.Shape()), "expected both to be the same") 1008 } 1009 m := NewTapeMachine(g) 1010 if err := m.RunAll(); err != nil { 1011 t.Errorf("Error while executing %q. Err: %v", rst.testName, err) 1012 continue 1013 } 1014 1015 } 1016 } 1017 1018 func TestReshapeRuntime(t *testing.T) { 1019 g := NewGraph() 1020 x := NewMatrix(g, tensor.Float64, WithName("x"), WithShape(28, 28), WithInit(GlorotU(1))) 1021 w := NewMatrix(g, tensor.Float64, WithName("W"), WithShape(50, 784), WithInit(GlorotU(1))) 1022 x2 := Must(Reshape(x, tensor.Shape{784})) 1023 wx := Must(Mul(w, x2)) 1024 wx2 := Must(Reshape(wx, tensor.Shape{5, 10})) 1025 1026 cost := Must(Sum(wx2)) 1027 if _, err := Grad(cost, w); err != nil { 1028 t.Fatal(err) 1029 } 1030 m := NewTapeMachine(g) 1031 if err := m.RunAll(); err != nil { 1032 t.Fatal(err) 1033 } 1034 1035 if !x.Value().Shape().Eq(tensor.Shape{28, 28}) { 1036 t.Errorf("A mutation of shape has occurred") 1037 } 1038 } 1039 1040 var ravelTests = []struct { 1041 input tensor.Shape 1042 output tensor.Shape 1043 }{ 1044 { 1045 tensor.Shape{3, 3}, 1046 tensor.Shape{9}, 1047 }, 1048 { 1049 tensor.Shape{2, 3}, 1050 tensor.Shape{6}, 1051 }, 1052 { 1053 tensor.Shape{2, 1, 3}, 1054 tensor.Shape{6}, 1055 }, 1056 { 1057 tensor.Shape{1, 1, 1}, 1058 tensor.Shape{1}, 1059 }, 1060 } 1061 1062 func TestRavel(t *testing.T) { 1063 c := require.New(t) 1064 1065 for i, rst := range ravelTests { 1066 g := NewGraph() 1067 t := NewTensor(g, Float64, len(rst.input), WithShape(rst.input...)) 1068 t2, err := Ravel(t) 1069 1070 c.NoError(err) 1071 c.Equal(rst.output, t2.Shape(), "expected to be flatten in test case: %d", i) 1072 } 1073 } 1074 1075 func TestAuto(t *testing.T) { 1076 testCases := []struct { 1077 desc string 1078 shapeA tensor.Shape 1079 shapeB tensor.Shape 1080 expectedShape tensor.Shape 1081 expectedErr string 1082 }{ 1083 { 1084 desc: "Example 0", 1085 shapeA: tensor.Shape{12}, 1086 shapeB: tensor.Shape{1, 11}, 1087 expectedErr: "shapes (12) and (1, 11) should have the same dimensions", 1088 }, 1089 { 1090 desc: "Example 1", 1091 shapeA: tensor.Shape{12, 1}, 1092 shapeB: tensor.Shape{12, 11}, 1093 expectedShape: tensor.Shape{12, 11}, 1094 expectedErr: "", 1095 }, 1096 { 1097 desc: "Example 2", 1098 shapeA: tensor.Shape{1, 12}, 1099 shapeB: tensor.Shape{11, 12}, 1100 expectedShape: tensor.Shape{11, 12}, 1101 expectedErr: "", 1102 }, 1103 { 1104 desc: "Example 3", 1105 shapeA: tensor.Shape{2, 3, 5}, 1106 shapeB: tensor.Shape{2, 3, 1}, 1107 expectedShape: tensor.Shape{2, 3, 5}, 1108 expectedErr: "", 1109 }, 1110 { 1111 desc: "Example 4", 1112 shapeA: tensor.Shape{2, 1, 5}, 1113 shapeB: tensor.Shape{2, 3, 5}, 1114 expectedShape: tensor.Shape{2, 3, 5}, 1115 expectedErr: "", 1116 }, 1117 { 1118 desc: "Example 5", 1119 shapeA: tensor.Shape{2, 1, 1}, 1120 shapeB: tensor.Shape{2, 5, 3}, 1121 expectedShape: tensor.Shape{2, 5, 3}, 1122 expectedErr: "", 1123 }, 1124 } 1125 for _, tC := range testCases { 1126 t.Run(tC.desc, func(t *testing.T) { 1127 c := require.New(t) 1128 1129 g := NewGraph() 1130 a := NewTensor(g, Float64, tC.shapeA.Dims(), WithShape(tC.shapeA...), WithInit(RangedFrom(0))) 1131 b := NewTensor(g, Float64, tC.shapeB.Dims(), WithShape(tC.shapeB...), WithInit(RangedFrom(0))) 1132 1133 out, err := Auto(BroadcastHadamardProd, a, b) 1134 1135 if tC.expectedErr != "" { 1136 c.Error(err) 1137 c.Equal(tC.expectedErr, err.Error()) 1138 return 1139 } else { 1140 c.NoError(err) 1141 } 1142 1143 c.Equal(tC.expectedShape, out.Shape()) 1144 1145 out, err = Auto(BroadcastHadamardProd, b, a) 1146 c.NoError(err) 1147 c.Equal(tC.expectedShape, out.Shape()) 1148 }) 1149 } 1150 }