gorgonia.org/gorgonia@v0.9.17/operatorPointwise_unary_test.go (about) 1 package gorgonia 2 3 import ( 4 "math" 5 "math/rand" 6 "testing" 7 8 "github.com/chewxy/math32" 9 "github.com/stretchr/testify/assert" 10 "gorgonia.org/dawson" 11 "gorgonia.org/tensor" 12 ) 13 14 func unaryOpTest(t *testing.T, dt tensor.Dtype, shape tensor.Shape, fn func(*Node) (*Node, error)) (x, y, a, b *Node, v Value, err error) { 15 var xV, aV Value 16 var any interface{} 17 if shape.IsScalar() { 18 if dt == tensor.Float64 { 19 any = rand.ExpFloat64() 20 } else { 21 any = float32(rand.ExpFloat64()) 22 } 23 } else { 24 any = tensor.New(tensor.WithBacking(tensor.Random(dt, shape.TotalSize()))) 25 } 26 if v, _, _, err = anyToValue(any); err != nil { 27 t.Errorf("anyToValue failed %v", err) 28 return 29 } 30 if xV, err = CloneValue(v); err != nil { 31 t.Errorf("Clone to xV failed %v", err) 32 return 33 } 34 35 g := NewGraph() 36 x = NodeFromAny(g, xV, WithName("x")) 37 y = Must(fn(x)) 38 Must(Sum(y)) 39 40 var grads Nodes 41 h := NewGraph() 42 a = NodeFromAny(h, xV, WithName("x")) 43 b = Must(fn(a)) 44 cost := Must(Sum(b)) 45 if grads, err = Grad(cost, a); err != nil { 46 t.Errorf("Unable to get gradient %v", err) 47 return 48 } 49 50 if aV, err = CloneValue(v); err != nil { 51 t.Errorf("Clone to aV failed: %v", err) 52 return 53 } 54 55 m0 := NewLispMachine(g) 56 m1 := NewTapeMachine(h) 57 defer m1.Close() 58 defer m0.Close() 59 60 Let(x, xV) 61 if err = m0.RunAll(); err != nil { 62 t.Errorf("m0 failed: %v", err) 63 return 64 } 65 66 Let(a, aV) 67 if err = m1.RunAll(); err != nil { 68 t.Errorf("m1 failed: %v", err) 69 return 70 } 71 72 var yV, xG, bV, aG Value 73 yV = y.Value() 74 if xG, err = x.Grad(); err != nil { 75 t.Errorf("x has no grad: %v", err) 76 return 77 } 78 79 bV = b.Value() 80 if aG, err = a.Grad(); err != nil { 81 t.Errorf("a has no grad: %v", err) 82 t.Logf("a.deriv %p | %p", a.deriv, grads[0]) 83 return 84 } 85 86 if !ValueClose(yV, bV) { 87 t.Errorf("Expected yV and bV to be close. yV: %v, bV: %v", yV, bV) 88 } 89 90 if !ValueClose(aG, xG) { 91 t.Errorf("Expected aG and xG to be close. aG: %v, xG %v", aG, xG) 92 } 93 94 return 95 } 96 97 func unaryOpDiffTest(op ʘUnaryOperatorType) (xRandVal float64, x, y, xT, yT *Node, err error) { 98 _, x, y = simpleUnaryEqn() 99 100 xRandVal = rand.ExpFloat64() 101 fn := *(sf64UnaryOperators[op]) 102 diff := ʘUnaryOpDiffFns[op] 103 104 // let the first stone be cast! 105 Let(x, xRandVal) 106 v, _, _, _ := anyToValue(fn(xRandVal)) // as if the graph has been executed upon 107 ydv := variableDV(v) 108 109 if err = y.bind(ydv); err != nil { 110 return 111 } 112 113 if err = x.bind(dvUnit(x.boundTo)); err != nil { 114 return 115 } 116 117 if err = diff(x, y); err != nil { 118 return 119 } 120 121 // Tensor edition 122 _, xT, yT = simpleUnaryVecEqn() 123 124 xBack := []float64{-xRandVal, xRandVal} 125 yBack := []float64{fn(-xRandVal), fn(xRandVal)} 126 Let(xT, tensor.New(tensor.WithShape(2, 1), tensor.WithBacking(xBack))) 127 vT, _, _, _ := anyToValue(tensor.New(tensor.WithShape(2, 1), tensor.WithBacking(yBack))) 128 yTdv := variableDV(vT) 129 130 if err = yT.bind(yTdv); err != nil { 131 return 132 } 133 134 if err = xT.bind(dvUnit(xT.boundTo)); err != nil { 135 return 136 } 137 138 if err = diff(xT, yT); err != nil { 139 return 140 } 141 return 142 } 143 144 func TestAbs(t *testing.T) { 145 assert := assert.New(t) 146 147 var x, y, a, b *Node 148 var v Value 149 var yV, xG, bV, aG Value 150 var err error 151 152 /* FLOAT 64 Scalar */ 153 154 x, y, a, b, v, err = unaryOpTest(t, Float64, tensor.Shape{}, Abs) 155 if err != nil { 156 t.Fatal(err) 157 } 158 159 yV = y.Value() 160 if xG, err = x.Grad(); err != nil { 161 t.Errorf("x has no grad: %v", err) 162 return 163 } 164 165 bV = b.Value() 166 if aG, err = a.Grad(); err != nil { 167 t.Errorf("a has no grad: %v", err) 168 } 169 170 correctF64 := math.Abs(v.Data().(float64)) 171 assert.True(ValueClose(NewF64(correctF64), yV)) 172 assert.True(ValueClose(NewF64(correctF64), bV)) 173 assert.True(ValueClose(NewF64(1.0), xG)) 174 assert.True(ValueClose(NewF64(1.0), aG)) 175 176 /* FLOAT 32 Scalar */ 177 178 x, y, a, b, v, err = unaryOpTest(t, Float32, tensor.Shape{}, Abs) 179 if err != nil { 180 t.Fatal(err) 181 } 182 183 yV = y.Value() 184 if xG, err = x.Grad(); err != nil { 185 t.Errorf("x has no grad: %v", err) 186 return 187 } 188 189 bV = b.Value() 190 if aG, err = a.Grad(); err != nil { 191 t.Errorf("a has no grad: %v", err) 192 } 193 194 correctF32 := math32.Abs(v.Data().(float32)) 195 assert.True(ValueClose(NewF32(correctF32), yV)) 196 assert.True(ValueClose(NewF32(correctF32), bV)) 197 assert.True(ValueClose(NewF32(1.0), xG)) 198 assert.True(ValueClose(NewF32(1.0), aG)) 199 200 /* FLOAT64 Vector */ 201 202 x, y, a, b, v, err = unaryOpTest(t, Float64, tensor.Shape{10}, Abs) 203 if err != nil { 204 t.Fatal(err) 205 } 206 207 yV = y.Value() 208 if xG, err = x.Grad(); err != nil { 209 t.Errorf("x has no grad: %v", err) 210 return 211 } 212 213 bV = b.Value() 214 if aG, err = a.Grad(); err != nil { 215 t.Errorf("a has no grad: %v", err) 216 } 217 218 absF64s := v.Data().([]float64) 219 backingGrad64 := make([]float64, len(absF64s)) 220 for i, v := range absF64s { 221 absF64s[i] = math.Abs(v) 222 if v > 0 { 223 backingGrad64[i] = 1 224 } else { 225 backingGrad64[i] = -1 226 } 227 } 228 correctVecF64 := tensor.New(tensor.WithBacking(absF64s)) 229 gradF64s := tensor.New(tensor.WithBacking(backingGrad64)) 230 231 assert.True(ValueClose(correctVecF64, yV)) 232 assert.True(ValueClose(correctVecF64, bV)) 233 assert.True(ValueClose(gradF64s, xG), "xG %v", xG) 234 assert.True(ValueClose(gradF64s, aG), "aG %v", aG) 235 236 /* FLOAT32 Vector */ 237 238 x, y, a, b, v, err = unaryOpTest(t, Float32, tensor.Shape{10}, Abs) 239 if err != nil { 240 t.Fatal(err) 241 } 242 243 yV = y.Value() 244 if xG, err = x.Grad(); err != nil { 245 t.Errorf("x has no grad: %v", err) 246 return 247 } 248 249 bV = b.Value() 250 if aG, err = a.Grad(); err != nil { 251 t.Errorf("a has no grad: %v", err) 252 } 253 254 absF32s := v.Data().([]float32) 255 backingGrad32 := make([]float32, len(absF32s)) 256 for i, v := range absF32s { 257 absF32s[i] = math32.Abs(v) 258 if v > 0 { 259 backingGrad32[i] = 1 260 } else { 261 backingGrad32[i] = -1 262 } 263 } 264 correctVecF32 := tensor.New(tensor.WithBacking(absF32s)) 265 gradF32s := tensor.New(tensor.WithBacking(backingGrad32)) 266 267 assert.True(ValueClose(correctVecF32, yV)) 268 assert.True(ValueClose(correctVecF32, bV)) 269 assert.True(ValueClose(gradF32s, xG), "xG %v", xG) 270 assert.True(ValueClose(gradF32s, aG), "aG %v", aG) 271 272 } 273 274 func TestSinDiff(t *testing.T) { 275 assert := assert.New(t) 276 v, x, _, xT, _, err := unaryOpDiffTest(sinOpType) 277 if err != nil { 278 t.Error(err) 279 } 280 281 correct := math.Cos(v) 282 assert.Equal(correct, x.boundTo.(*dualValue).d.Data()) 283 284 // Tensor edition 285 xdvd := xT.boundTo.(*dualValue).d.(*tensor.Dense) 286 correctT := []float64{math.Cos(-v), math.Cos(v)} 287 assert.Equal(correctT, xdvd.Data()) 288 } 289 290 func TestCosDiff(t *testing.T) { 291 assert := assert.New(t) 292 293 v, x, _, xT, _, err := unaryOpDiffTest(cosOpType) 294 if err != nil { 295 t.Error(err) 296 } 297 298 assert.Equal(-math.Sin(v), x.boundTo.(*dualValue).d.Data()) 299 300 // Tensor edition 301 xdvd := xT.boundTo.(*dualValue).d.(*tensor.Dense) 302 correct := []float64{-math.Sin(-v), -math.Sin(v)} 303 assert.Equal(correct, xdvd.Data()) 304 } 305 306 func TestExpDiff(t *testing.T) { 307 assert := assert.New(t) 308 _, x, y, xT, yT, err := unaryOpDiffTest(expOpType) 309 if err != nil { 310 t.Error(err) 311 } 312 313 assert.Equal(y.boundTo.(*dualValue).Value, x.boundTo.(*dualValue).d) 314 315 // Tensor edition 316 xdvd := xT.boundTo.(*dualValue).d.(*tensor.Dense) 317 ydvd := yT.boundTo.(*dualValue).Value.(*tensor.Dense) 318 assert.Equal(ydvd.Data(), xdvd.Data()) 319 } 320 321 func TestLnDiff(t *testing.T) { 322 assert := assert.New(t) 323 var err error 324 v, x, _, xT, _, err := unaryOpDiffTest(lnOpType) 325 if err != nil { 326 t.Error(err) 327 } 328 correct := 1.0 / v 329 assert.Equal(correct, x.boundTo.(*dualValue).d.Data(), "v was %v", v) 330 331 // Tensor edition 332 xdvd := xT.boundTo.(*dualValue).d.(*tensor.Dense) 333 correctT := []float64{1.0 / -v, 1.0 / v} 334 assert.Equal(correctT, xdvd.Data()) 335 } 336 337 func TestLog2Diff(t *testing.T) { 338 assert := assert.New(t) 339 v, x, _, xT, _, err := unaryOpDiffTest(log2OpType) 340 if err != nil { 341 t.Error(err) 342 } 343 correct := 1.0 / (v * math.Ln2) 344 assert.Equal(correct, x.boundTo.(*dualValue).d.Data()) 345 346 // Tensor edition 347 xdvd := xT.boundTo.(*dualValue).d.(*tensor.Dense) 348 correctT := []float64{1.0 / (-v * math.Ln2), 1.0 / (v * math.Ln2)} 349 assert.Equal(correctT, xdvd.Data()) 350 } 351 352 func TestSquareDiff(t *testing.T) { 353 assert := assert.New(t) 354 var err error 355 v, x, _, xT, _, err := unaryOpDiffTest(squareOpType) 356 if err != nil { 357 t.Error(err) 358 } 359 360 assert.Equal(2*v, x.boundTo.(*dualValue).d.Data()) 361 362 // Tensor edition 363 xdvd := xT.boundTo.(*dualValue).d.(*tensor.Dense) 364 correct := []float64{2 * -v, 2 * v} 365 assert.Equal(correct, xdvd.Data()) 366 } 367 368 func TestSqrtDiff(t *testing.T) { 369 assert := assert.New(t) 370 v, x, _, xT, _, err := unaryOpDiffTest(sqrtOpType) 371 if err != nil { 372 t.Error(err) 373 } 374 375 assert.Equal(1.0/(2*math.Sqrt(v)), x.boundTo.(*dualValue).d.Data()) 376 377 // Tensor edition 378 xdvd := xT.boundTo.(*dualValue).d.(*tensor.Dense) 379 correct := []float64{1.0 / (2 * math.Sqrt(-v)), 1.0 / (2 * math.Sqrt(v))} 380 got := xdvd.Data().([]float64) 381 if !math.IsNaN(got[0]) && math.IsNaN(correct[0]) { 382 t.Error("Expected NaN for the first value") 383 } 384 if got[1] != correct[1] { 385 t.Error("Different second values") 386 } 387 } 388 389 func TestInverseDiff(t *testing.T) { 390 assert := assert.New(t) 391 v, x, _, xT, _, err := unaryOpDiffTest(inverseOpType) 392 if err != nil { 393 t.Error(err) 394 } 395 396 correct := -((1 / v) * (1 / v)) 397 assert.Equal(correct, x.boundTo.(*dualValue).d.Data()) 398 399 // Tensor edition 400 xdvd := xT.boundTo.(*dualValue).d.(*tensor.Dense) 401 correctT := []float64{correct, correct} 402 assert.Equal(correctT, xdvd.Data()) 403 } 404 405 func TestCubeDiff(t *testing.T) { 406 assert := assert.New(t) 407 v, x, _, xT, _, err := unaryOpDiffTest(cubeOpType) 408 if err != nil { 409 t.Error(err) 410 } 411 412 correct := 3 * v * v 413 xG, err := x.Grad() 414 if err != nil { 415 t.Error(err) 416 } 417 418 assert.True(dawson.CloseF64(correct, extractF64(xG)), "%v != %v", xG, correct) 419 420 // Tensor edition 421 xdvd := xT.boundTo.(*dualValue).d 422 correctT := []float64{correct, correct} 423 assert.True(floatsEqual64(correctT, extractF64s(xdvd))) 424 } 425 426 func TestTanhDiff(t *testing.T) { 427 assert := assert.New(t) 428 v, x, _, xT, _, err := unaryOpDiffTest(tanhOpType) 429 if err != nil { 430 t.Error(err) 431 } 432 433 // NOTE: there are not guarantees of identical behaviours across architectures, 434 // in this case arm64 gives different results than amd64 for Tanh. 435 // See https://github.com/golang/go/issues/18354#issuecomment-267705645 436 correct := 1.0 - (float64(math.Tanh(v)) * float64(math.Tanh(v))) // I'm surprised Golang doesn't have a secant function! 437 assert.InDeltaf(correct, x.boundTo.(*dualValue).d.Data(), 1e-14, "") 438 439 // Tensor edition 440 xdvd := xT.boundTo.(*dualValue).d.(*tensor.Dense) 441 assert.InDeltaSlicef([]float64{correct, correct}, xdvd.Data(), 1e-14, "") 442 } 443 444 func TestSigmoidDiff(t *testing.T) { 445 assert := assert.New(t) 446 v, x, _, xT, _, err := unaryOpDiffTest(sigmoidOpType) 447 if err != nil { 448 t.Error(err) 449 } 450 451 correct := math.Exp(-v) / ((1 + math.Exp(-v)) * (1 + math.Exp(-v))) 452 xG := x.boundTo.(*dualValue).d 453 assert.True(dawson.CloseF64(correct, extractF64(xG))) 454 455 // Tensor edition 456 xdvd := xT.boundTo.(*dualValue).d 457 negCorrect := math.Exp(v) / ((1 + math.Exp(v)) * (1 + math.Exp(v))) 458 corrects := []float64{negCorrect, correct} 459 assert.True(floatsEqual64(corrects, extractF64s(xdvd))) 460 } 461 462 func TestLog1pDiff(t *testing.T) { 463 assert := assert.New(t) 464 v, x, _, xT, _, err := unaryOpDiffTest(log1pOpType) 465 if err != nil { 466 t.Error(err) 467 } 468 469 correct := 1 / (1.0 + v) 470 assert.Equal(correct, x.boundTo.(*dualValue).d.Data()) 471 472 // Tensor edition 473 xdvd := xT.boundTo.(*dualValue).d.(*tensor.Dense) 474 correct0 := 1 / (1.0 - v) 475 assert.Equal([]float64{correct0, correct}, xdvd.Data()) 476 } 477 478 func TestExpm1Diff(t *testing.T) { 479 assert := assert.New(t) 480 v, x, _, xT, _, err := unaryOpDiffTest(expm1OpType) 481 if err != nil { 482 t.Error(err) 483 } 484 485 correct := math.Exp(v) 486 assert.Equal(correct, x.boundTo.(*dualValue).d.Data()) 487 488 // Tensor edition 489 xdvd := xT.boundTo.(*dualValue).d.(*tensor.Dense) 490 correct0 := math.Exp(-v) 491 assert.Equal([]float64{correct0, correct}, xdvd.Data()) 492 } 493 494 func TestSoftplus(t *testing.T) { 495 assert := assert.New(t) 496 497 var x, y, a, b *Node 498 var v Value 499 var xV, yV, xG, bV, aG Value 500 var err error 501 502 /* FLOAT64 SCALAR */ 503 504 if x, y, a, b, v, err = unaryOpTest(t, Float64, tensor.Shape{}, Softplus); err != nil { 505 t.Fatal(err) 506 } 507 508 xV = x.Value() 509 yV = y.Value() 510 if xG, err = x.Grad(); err != nil { 511 t.Errorf("x has no grad: %v", err) 512 return 513 } 514 515 bV = b.Value() 516 if aG, err = a.Grad(); err != nil { 517 t.Errorf("a has no grad: %v", err) 518 } 519 520 correctVF64 := softplusf64(v.Data().(float64)) 521 correctDF64 := sigmoidf64(xV.Data().(float64)) 522 assert.True(ValueClose(NewF64(correctVF64), yV)) 523 assert.True(ValueClose(NewF64(correctVF64), bV)) 524 assert.True(ValueClose(NewF64(correctDF64), xG)) 525 assert.True(ValueClose(NewF64(correctDF64), aG)) 526 527 /* FLOAT32 SCALAR */ 528 529 if x, y, a, b, v, err = unaryOpTest(t, Float32, tensor.Shape{}, Softplus); err != nil { 530 t.Fatal(err) 531 } 532 533 xV = x.Value() 534 yV = y.Value() 535 if xG, err = x.Grad(); err != nil { 536 t.Errorf("x has no grad: %v", err) 537 return 538 } 539 540 bV = b.Value() 541 if aG, err = a.Grad(); err != nil { 542 t.Errorf("a has no grad: %v", err) 543 } 544 545 correctVF32 := softplusf32(v.Data().(float32)) 546 correctDF32 := sigmoidf32(xV.Data().(float32)) 547 assert.True(ValueClose(NewF32(correctVF32), yV)) 548 assert.True(ValueClose(NewF32(correctVF32), bV)) 549 assert.True(ValueClose(NewF32(correctDF32), xG)) 550 assert.True(ValueClose(NewF32(correctDF32), aG)) 551 552 /* FLOAT64 Vector */ 553 554 if x, y, a, b, v, err = unaryOpTest(t, Float64, tensor.Shape{10}, Softplus); err != nil { 555 t.Fatal(err) 556 } 557 558 xV = x.Value() 559 yV = y.Value() 560 if xG, err = x.Grad(); err != nil { 561 t.Errorf("x has no grad: %v", err) 562 return 563 } 564 565 bV = b.Value() 566 if aG, err = a.Grad(); err != nil { 567 t.Errorf("a has no grad: %v", err) 568 } 569 570 correctVF64s := v.Data().([]float64) 571 correctDF64s := xV.Data().([]float64) 572 573 for i, v := range correctVF64s { 574 correctVF64s[i] = softplusf64(v) 575 correctDF64s[i] = sigmoidf64(correctDF64s[i]) 576 } 577 assert.True(floatsEqual64(correctVF64s, yV.Data().([]float64))) 578 assert.True(floatsEqual64(correctVF64s, bV.Data().([]float64))) 579 assert.True(floatsEqual64(correctDF64s, xG.Data().([]float64))) 580 assert.True(floatsEqual64(correctDF64s, aG.Data().([]float64))) 581 582 /* FLOAT32 Vector */ 583 584 if x, y, a, b, v, err = unaryOpTest(t, Float32, tensor.Shape{10}, Softplus); err != nil { 585 t.Fatal(err) 586 } 587 588 xV = x.Value() 589 yV = y.Value() 590 if xG, err = x.Grad(); err != nil { 591 t.Errorf("x has no grad: %v", err) 592 return 593 } 594 595 bV = b.Value() 596 if aG, err = a.Grad(); err != nil { 597 t.Errorf("a has no grad: %v", err) 598 } 599 600 correctVF32s := v.Data().([]float32) 601 correctDF32s := xV.Data().([]float32) 602 603 for i, v := range correctVF32s { 604 correctVF32s[i] = softplusf32(v) 605 correctDF32s[i] = sigmoidf32(correctDF32s[i]) 606 } 607 assert.True(floatsEqual32(correctVF32s, yV.Data().([]float32))) 608 assert.True(floatsEqual32(correctVF32s, bV.Data().([]float32))) 609 assert.True(floatsEqual32(correctDF32s, xG.Data().([]float32))) 610 assert.True(floatsEqual32(correctDF32s, aG.Data().([]float32))) 611 }