gorgonia.org/gorgonia@v0.9.17/op_reduction_test.go (about) 1 package gorgonia 2 3 import ( 4 "fmt" 5 "runtime" 6 "testing" 7 8 "github.com/stretchr/testify/assert" 9 "gorgonia.org/tensor" 10 ) 11 12 func TestSumOpGrad(t *testing.T) { 13 t.SkipNow() 14 assert := assert.New(t) 15 // var g *ExprGraph 16 var z, sz *Node 17 var grads Nodes 18 var err error 19 var op sumOp 20 21 _, _, _, z = simpleVecEqn() 22 sz = Must(Sum(z)) 23 // t.Logf(" %v %v %v %v", g, x, y, z) 24 25 diffWRT := sz.diffWRT() 26 assert.Equal([]bool{true}, diffWRT) 27 28 op = sz.op.(sumOp) 29 grads, err = op.SymDiff(Nodes{z}, sz, onef64) 30 assert.Nilf(err, "Got %+v", err) 31 assert.Equal(1, len(grads)) 32 t.Logf("%v", grads[0]) 33 } 34 35 func TestSumOpFakeVec(t *testing.T) { 36 g := NewGraph() 37 38 xv := tensor.New(tensor.WithBacking([]float64{1, 2}), tensor.WithShape(2, 1)) 39 yv := tensor.New(tensor.WithBacking([]float64{10, 20}), tensor.WithShape(1, 2)) 40 x := NewMatrix(g, Float64, WithName("x"), WithShape(2, 1), WithValue(xv)) 41 y := NewMatrix(g, Float64, WithName("y"), WithShape(1, 2), WithValue(yv)) 42 sx, _ := Sum(x) 43 sy, _ := Sum(y) 44 45 assert.True(t, sx.Shape().Eq(tensor.ScalarShape())) 46 assert.True(t, sy.Shape().Eq(tensor.ScalarShape())) 47 48 sx2, _ := Sum(x, 1) 49 assert.True(t, sx2.Shape().Eq(tensor.Shape{2})) 50 51 vm := NewTapeMachine(g) 52 vm.RunAll() 53 54 assert.Equal(t, 3.0, sx.Value().Data(), "Expected sx to be 3.0") 55 assert.Equal(t, 30.0, sy.Value().Data(), "Expected sy to be 30.0") 56 assert.Equal(t, []float64{1, 2}, sx2.Value().Data(), "sx2 should be a flat array") 57 } 58 59 func TestSumOpDiff(t *testing.T) { 60 defer runtime.GC() 61 assert := assert.New(t) 62 var g, g2 *ExprGraph 63 var x, y, z, a, b, c *Node 64 // var x, y, a, b *Node 65 var xG, yG, aG, bG Value 66 // var xG, aG Value 67 // var prog *program 68 // var locMap map[*Node]register 69 var m *tapeMachine 70 var m2 *lispMachine 71 var err error 72 73 // Basic Test case: a vector is summed 74 75 g = NewGraph() 76 x = NewVector(g, Float64, WithName("x"), WithShape(5), WithInit(RangedFrom(0))) 77 y = Must(Sum(x)) 78 WithName("y")(y) 79 80 Grad(y, x) 81 82 // ioutil.WriteFile("SumOp.dot", []byte(g.ToDot()), 0644) 83 84 m = NewTapeMachine(g) 85 defer m.Close() 86 if err = m.RunAll(); err != nil { 87 t.Error(err) 88 } 89 90 g2 = NewGraph() 91 a = NewVector(g2, Float64, WithShape(5), WithInit(RangedFrom(0))) 92 b = Must(Sum(a)) 93 94 m2 = NewLispMachine(g2, WithWatchlist()) 95 defer m2.Close() 96 if err = m2.RunAll(); err != nil { 97 t.Error(err) 98 } 99 100 if aG, err = a.Grad(); err != nil { 101 t.Error(err) 102 } 103 104 if xG, err = x.Grad(); err != nil { 105 t.Error(err) 106 } 107 108 if bG, err = b.Grad(); err != nil { 109 t.Error(err) 110 } 111 112 if yG, err = y.Grad(); err != nil { 113 t.Error(err) 114 } 115 116 assert.True(ValueEq(x.Value(), a.Value())) 117 assert.True(ValueEq(xG, aG)) 118 assert.True(ValueEq(y.Value(), b.Value())) 119 assert.True(ValueEq(yG, bG)) 120 121 // long standing bug: sometimes the derivation will get executed in the machine first 122 // for example, the deriv of y is 1, and occasionally, the machine will choose to 123 // execute const 1 into register 0 124 // It would then fail to bind to y's boundTo, because at that point in time, y is still unknown. 125 126 // assert.Equal(y.Grad(), b.Grad()) 127 128 // Slightly more advanced test case: A matrix is summed 129 g = NewGraph() 130 x = NewMatrix(g, Float64, WithName("x"), WithShape(11, 7), WithInit(RangedFrom(0))) 131 y = Must(Sum(x)) 132 WithName("y")(y) 133 134 Grad(y, x) 135 136 m = NewTapeMachine(g) 137 defer m.Close() 138 if err = m.RunAll(); err != nil { 139 t.Error(err) 140 } 141 142 g2 = NewGraph() 143 a = NewMatrix(g2, Float64, WithName("x"), WithShape(11, 7), WithInit(RangedFrom(0))) 144 b = Must(Sum(a)) 145 146 m2 = NewLispMachine(g2) 147 defer m2.Close() 148 if err = m2.RunAll(); err != nil { 149 t.Error(err) 150 } 151 152 if aG, err = a.Grad(); err != nil { 153 t.Error(err) 154 } 155 156 if xG, err = x.Grad(); err != nil { 157 t.Error(err) 158 } 159 if bG, err = b.Grad(); err != nil { 160 t.Error(err) 161 } 162 163 if yG, err = y.Grad(); err != nil { 164 t.Error(err) 165 } 166 assert.True(ValueEq(x.Value(), a.Value())) 167 assert.True(ValueEq(xG, aG)) 168 assert.True(ValueEq(y.Value(), b.Value())) 169 assert.True(ValueEq(yG, bG)) 170 171 /* Sum is not the root node */ 172 173 g = NewGraph() 174 x = NewMatrix(g, Float64, WithName("x"), WithShape(11, 7), WithInit(RangedFrom(0))) 175 y = Must(Sum(x)) 176 z = Must(Add(y, twof64)) 177 178 if _, err = Grad(z, x); err != nil { 179 t.Fatal(err) 180 } 181 182 m = NewTapeMachine(g) 183 defer m.Close() 184 if err = m.RunAll(); err != nil { 185 t.Errorf("%v", m.Prog()) 186 t.Error(err) 187 } 188 189 g2 = NewGraph() 190 a = NewMatrix(g2, Float64, WithName("x"), WithShape(11, 7), WithInit(RangedFrom(0))) 191 b = Must(Sum(a)) 192 c = Must(Add(b, twof64)) 193 194 m2 = NewLispMachine(g2) 195 defer m2.Close() 196 if err = m2.RunAll(); err != nil { 197 t.Fatalf("%+v", err) 198 } 199 200 if aG, err = a.Grad(); err != nil { 201 t.Error(err) 202 } 203 204 if xG, err = x.Grad(); err != nil { 205 t.Error(err) 206 } 207 208 if bG, err = b.Grad(); err != nil { 209 t.Error(err) 210 } 211 212 if yG, err = b.Grad(); err != nil { 213 t.Error(err) 214 } 215 216 assert.True(ValueEq(x.Value(), a.Value())) 217 assert.True(ValueEq(xG, aG)) 218 assert.True(ValueEq(y.Value(), b.Value())) 219 assert.True(ValueEq(yG, bG)) 220 assert.True(ValueEq(z.Value(), c.Value())) 221 222 runtime.GC() 223 } 224 225 func TestMaxOp(t *testing.T) { 226 subTests := []reductionTest{ 227 {dt: Float32, inShape: []int{3, 2}, inData: []float32{1, 2, 3, 4, 5, 6}, op: Max, along: []int{0}, wantShape: []int{2}, wantData: []float32{5, 6}}, 228 {dt: Float32, inShape: []int{3, 2}, inData: []float32{1, 2, 3, 4, 5, 6}, op: Max, along: []int{1}, wantShape: []int{3}, wantData: []float32{2, 4, 6}}, 229 {dt: Float32, inShape: []int{3, 2}, inData: []float32{1, 2, 3, 4, 5, 6}, op: Max, along: []int{}, wantShape: []int{}, wantData: float32(6)}, 230 {dt: Float32, inShape: []int{3, 2}, inData: []float32{1, 2, 3, 4, 5, 6}, op: Max, along: []int{0, 1}, wantShape: []int{}, wantData: float32(6)}, 231 {dt: Float32, inShape: []int{3, 2}, inData: []float32{1, 2, 3, 4, 5, 6}, op: Max, along: []int{1, 0}, wantShape: []int{}, wantData: float32(6)}, 232 //{dt: Float32, inShape: []int{1, 6}, inData: []float32{1, 2, 3, 4, 5, 6}, op: Max, along: []int{1}, wantShape: []int{}, wantData: float32(6)}, 233 { 234 dt: Float32, 235 inShape: []int{2, 2, 2, 2}, 236 inData: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, 237 op: Max, 238 along: []int{0, 1, 2, 3}, 239 wantShape: []int{}, 240 wantData: float32(16), 241 }, 242 { 243 dt: Float32, 244 inShape: []int{2, 2, 2, 2}, 245 inData: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, 246 op: Max, 247 along: []int{}, 248 wantShape: []int{}, 249 wantData: float32(16), 250 }, 251 { 252 dt: Float32, 253 inShape: []int{2, 2, 2, 2}, 254 inData: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, 255 op: Max, 256 along: []int{0}, 257 wantShape: []int{2, 2, 2}, 258 wantData: []float32{9, 10, 11, 12, 13, 14, 15, 16}, 259 }, 260 { 261 dt: Float32, 262 inShape: []int{2, 2, 2, 2}, 263 inData: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, 264 op: Max, 265 along: []int{1}, 266 wantShape: []int{2, 2, 2}, 267 wantData: []float32{5, 6, 7, 8, 13, 14, 15, 16}, 268 }, 269 { 270 dt: Float32, 271 inShape: []int{2, 2, 2, 2}, 272 inData: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, 273 op: Max, 274 along: []int{2}, 275 wantShape: []int{2, 2, 2}, 276 wantData: []float32{3, 4, 7, 8, 11, 12, 15, 16}, 277 }, 278 { 279 dt: Float32, 280 inShape: []int{2, 2, 2, 2}, 281 inData: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, 282 op: Max, 283 along: []int{3}, 284 wantShape: []int{2, 2, 2}, 285 wantData: []float32{2, 4, 6, 8, 10, 12, 14, 16}, 286 }, 287 { 288 dt: Float32, 289 inShape: []int{2, 2, 2, 2}, 290 inData: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, 291 op: Max, 292 along: []int{1, 3}, 293 wantShape: []int{2, 2}, 294 wantData: []float32{6, 8, 14, 16}, 295 }, 296 { 297 dt: Float32, 298 inShape: []int{2, 2, 2, 2}, 299 inData: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, 300 op: Max, 301 along: []int{0, 2, 3}, 302 wantShape: []int{2}, 303 wantData: []float32{12, 16}, 304 }, 305 } 306 307 for _, subTest := range subTests { 308 t.Run(fmt.Sprintf("along %v", subTest.along), func(t *testing.T) { 309 testReductionOp(t, subTest) 310 }) 311 } 312 } 313 314 func TestSumOp(t *testing.T) { 315 subTests := []reductionTest{ 316 {dt: Float32, inShape: []int{3, 2}, inData: []float32{1, 2, 3, 4, 5, 6}, op: Sum, along: []int{0}, wantShape: []int{2}, wantData: []float32{9, 12}}, 317 {dt: Float32, inShape: []int{3, 2}, inData: []float32{1, 2, 3, 4, 5, 6}, op: Sum, along: []int{1}, wantShape: []int{3}, wantData: []float32{3, 7, 11}}, 318 {dt: Float32, inShape: []int{3, 2}, inData: []float32{1, 2, 3, 4, 5, 6}, op: Sum, along: []int{}, wantShape: []int{}, wantData: float32(21)}, 319 {dt: Float32, inShape: []int{3, 2}, inData: []float32{1, 2, 3, 4, 5, 6}, op: Sum, along: []int{0, 1}, wantShape: []int{}, wantData: float32(21)}, 320 {dt: Float32, inShape: []int{3, 2}, inData: []float32{1, 2, 3, 4, 5, 6}, op: Sum, along: []int{1, 0}, wantShape: []int{}, wantData: float32(21)}, 321 { 322 dt: Float32, 323 inShape: []int{2, 2, 2, 2}, 324 inData: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, 325 op: Sum, 326 along: []int{0, 1, 2, 3}, 327 wantShape: []int{}, 328 wantData: float32(136), 329 }, 330 { 331 dt: Float32, 332 inShape: []int{2, 2, 2, 2}, 333 inData: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, 334 op: Sum, 335 along: []int{}, 336 wantShape: []int{}, 337 wantData: float32(136), 338 }, 339 { 340 dt: Float32, 341 inShape: []int{2, 2, 2, 2}, 342 inData: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, 343 op: Sum, 344 along: []int{0}, 345 wantShape: []int{2, 2, 2}, 346 wantData: []float32{10, 12, 14, 16, 18, 20, 22, 24}, 347 }, 348 { 349 dt: Float32, 350 inShape: []int{2, 2, 2, 2}, 351 inData: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, 352 op: Sum, 353 along: []int{1}, 354 wantShape: []int{2, 2, 2}, 355 wantData: []float32{6, 8, 10, 12, 22, 24, 26, 28}, 356 }, 357 { 358 dt: Float32, 359 inShape: []int{2, 2, 2, 2}, 360 inData: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, 361 op: Sum, 362 along: []int{2}, 363 wantShape: []int{2, 2, 2}, 364 wantData: []float32{4, 6, 12, 14, 20, 22, 28, 30}, 365 }, 366 { 367 dt: Float32, 368 inShape: []int{2, 2, 2, 2}, 369 inData: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, 370 op: Sum, 371 along: []int{3}, 372 wantShape: []int{2, 2, 2}, 373 wantData: []float32{3, 7, 11, 15, 19, 23, 27, 31}, 374 }, 375 { 376 dt: Float32, 377 inShape: []int{2, 2, 2, 2}, 378 inData: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, 379 op: Sum, 380 along: []int{1, 3}, 381 wantShape: []int{2, 2}, 382 wantData: []float32{14, 22, 46, 54}, 383 }, 384 { 385 dt: Float32, 386 inShape: []int{2, 2, 2, 2}, 387 inData: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, 388 op: Sum, 389 along: []int{0, 2, 3}, 390 wantShape: []int{2}, 391 wantData: []float32{52, 84}, 392 }, 393 } 394 395 for _, subTest := range subTests { 396 t.Run(fmt.Sprintf("along %v", subTest.along), func(t *testing.T) { 397 testReductionOp(t, subTest) 398 }) 399 } 400 } 401 402 type reductionTest struct { 403 dt tensor.Dtype 404 inShape tensor.Shape 405 inData interface{} 406 op func(*Node, ...int) (*Node, error) 407 along []int 408 wantShape tensor.Shape 409 wantData interface{} 410 } 411 412 func testReductionOp(t *testing.T, test reductionTest) { 413 g := NewGraph() 414 Xn := NewTensor(g, test.dt, len(test.inShape), WithShape(test.inShape...)) 415 got := Must(test.op(Xn, test.along...)) 416 417 xT := tensor.New(tensor.WithShape(test.inShape...), tensor.WithBacking(test.inData)) 418 vm := NewTapeMachine(g) 419 defer vm.Close() 420 vm.Let(Xn, xT) 421 err := vm.RunAll() 422 if err != nil { 423 t.Fatal(err) 424 } 425 assert := assert.New(t) 426 assert.Equal(test.wantShape, got.Value().Shape(), "shape mismatch") 427 assert.Equal(test.wantData, got.Value().Data(), "data mismatch") 428 } 429 430 func TestMaxOpGrad(t *testing.T) { 431 subTests := []reductionGradTest{ 432 { 433 dt: Float64, 434 inShape: tensor.Shape{6}, 435 inData: []float64{1, 2, 3, 4, 5, 6}, 436 op: Max, 437 along: []int{}, 438 outGradShape: tensor.Shape{1}, 439 outGrad: []float64{1}, 440 wantInGrad: []float64{0, 0, 0, 0, 0, 1}, 441 }, 442 { 443 dt: Float32, 444 inShape: tensor.Shape{6}, 445 inData: []float32{1, 2, 3, 4, 5, 6}, 446 op: Max, 447 along: []int{0}, 448 outGradShape: tensor.Shape{1}, 449 outGrad: []float32{1}, 450 wantInGrad: []float32{0, 0, 0, 0, 0, 1}, 451 }, 452 { 453 dt: Float32, 454 inShape: tensor.Shape{6}, 455 inData: []float32{1, 2, 3, 4, 5, 6}, 456 op: Max, 457 along: []int{}, 458 outGradShape: tensor.Shape{1}, 459 outGrad: []float32{1}, 460 wantInGrad: []float32{0, 0, 0, 0, 0, 1}, 461 }, 462 { 463 dt: Float32, 464 inShape: tensor.Shape{3, 2}, 465 inData: []float32{1, 2, 3, 4, 5, 6}, 466 op: Max, 467 along: []int{0}, 468 outGradShape: tensor.Shape{2}, 469 outGrad: []float32{0.2, 0.8}, 470 wantInGrad: []float32{0, 0, 0, 0, 0.2, 0.8}, 471 }, 472 { 473 dt: Float32, 474 inShape: tensor.Shape{3, 2}, 475 inData: []float32{1, 2, 3, 4, 5, 6}, 476 op: Max, 477 along: []int{1}, 478 outGradShape: tensor.Shape{3}, 479 outGrad: []float32{0.1, 0.3, 0.6}, 480 wantInGrad: []float32{0, 0.1, 0, 0.3, 0, 0.6}, 481 }, 482 { 483 dt: Float32, 484 inShape: tensor.Shape{3, 2}, 485 inData: []float32{1, 2, 3, 4, 5, 6}, 486 op: Max, 487 along: []int{0, 1}, 488 outGradShape: tensor.Shape{1}, 489 outGrad: []float32{1}, 490 wantInGrad: []float32{0, 0, 0, 0, 0, 1}, 491 }, 492 //{ 493 // dt: Float32, 494 // inShape: tensor.Shape{1, 6}, 495 // inData: []float32{1, 2, 3, 4, 5, 6}, 496 // op: Max, 497 // along: []int{1}, 498 // outGradShape: tensor.Shape{6}, 499 // outGrad: []float32{1}, 500 // wantInGrad: []float32{0, 0, 0, 0, 0, 1}, 501 //}, 502 } 503 504 for _, subTest := range subTests { 505 t.Run(fmt.Sprintf("%v along %v %v", subTest.inShape, subTest.along, subTest.dt), func(t *testing.T) { 506 testReductionOpGrad(t, subTest) 507 }) 508 } 509 } 510 511 type reductionGradTest struct { 512 dt tensor.Dtype 513 inShape tensor.Shape 514 inData interface{} 515 op func(*Node, ...int) (*Node, error) 516 along []int 517 outGradShape tensor.Shape 518 outGrad interface{} 519 wantInGrad interface{} 520 } 521 522 func testReductionOpGrad(t *testing.T, test reductionGradTest) { 523 assert := assert.New(t) 524 525 var xG Value 526 var err error 527 528 // Run op 529 g := NewGraph() 530 xN := NewTensor(g, test.dt, len(test.inShape), WithShape(test.inShape...)) 531 y := Must(test.op(xN, test.along...)) 532 533 outGrad := NewTensor(g, test.dt, len(test.outGradShape), WithValue(tensor.New(tensor.WithShape(test.outGradShape...), tensor.WithBacking(test.outGrad)))) 534 if _, err = Backpropagate(Nodes{y}, Nodes{outGrad}, Nodes{xN}); err != nil { 535 t.Fatal(err) 536 } 537 538 xT := tensor.New(tensor.WithShape(test.inShape...), tensor.WithBacking(test.inData)) 539 vm := NewTapeMachine(g) 540 defer vm.Close() 541 vm.Let(xN, xT) 542 if err = vm.RunAll(); err != nil { 543 t.Fatal(err) 544 } 545 546 // Test grad functions 547 diffWRT := y.diffWRT() 548 assert.Equal([]bool{true}, diffWRT) 549 550 if xG, err = xN.Grad(); err != nil { 551 t.Fatal(err) 552 } 553 assert.Equal(test.inShape, xG.Shape(), "grad shape mismatch") 554 assert.Equal(test.wantInGrad, xG.Data(), "grad data mismatch") 555 } 556 557 // TestFollowupOp confirms that an element-wise binary op will work as expected after a sum/max. 558 // The underlying reduction on the tensor changes the number of dimensions, but the gorgonia node does not. 559 // We therefore confirm that the resulting nodes actually work. 560 func TestFollowupOp(t *testing.T) { 561 g := NewGraph() 562 Xn := NewTensor(g, tensor.Float64, 4, WithShape(2, 2, 2, 2), WithInit(RangedFrom(1))) 563 mx := Must(Max(Xn, 1, 2)) 564 sx := Must(Sum(Xn, 1, 2)) 565 y := NewTensor(g, tensor.Float64, 2, WithShape(2, 2), WithInit(RangedFrom(1))) 566 567 amx := Must(Add(mx, y)) 568 asx := Must(Add(sx, y)) 569 assert.Equal(t, amx.Shape(), tensor.Shape{2, 2}) 570 assert.Equal(t, asx.Shape(), tensor.Shape{2, 2}) 571 vm := NewTapeMachine(g) 572 defer vm.Close() 573 err := vm.RunAll() 574 if err != nil { 575 t.Error(err) 576 } 577 assert.Equal(t, []float64{8, 10, 18, 20}, amx.Value().Data(), "data mismatch") 578 assert.Equal(t, []float64{17, 22, 51, 56}, asx.Value().Data(), "data mismatch") 579 }