gorgonia.org/gorgonia@v0.9.17/operatorPointwise_binary.go (about) 1 package gorgonia 2 3 import ( 4 "math" 5 6 "github.com/chewxy/math32" 7 "github.com/pkg/errors" 8 "gorgonia.org/tensor" 9 ) 10 11 type incrDoerBinOp interface { 12 IncrDo(v Value, retSame bool, inputs ...Value) error 13 } 14 type usePreallocDoerBinOp interface { 15 UsePreallocDo(v Value, retSame bool, inputs ...Value) (retVal Value, err error) 16 } 17 type unsafeDoerBinOp interface { 18 UnsafeDo(retSame bool, inputs ...Value) (Value, error) 19 } 20 21 /* BINARY OPERATOR */ 22 23 type ʘBinaryOperator interface { 24 isArith() bool 25 binOpType() ʘBinaryOperatorType 26 Do(bool, ...Value) (Value, error) 27 String() string 28 } 29 30 type scalarBinOp struct { 31 ʘBinaryOperatorType 32 t tensor.Dtype 33 } 34 35 func (o scalarBinOp) Arity() int { return 2 } 36 func (o scalarBinOp) binOpType() ʘBinaryOperatorType { return o.ʘBinaryOperatorType } 37 func (o scalarBinOp) isArith() bool { return o.ʘBinaryOperatorType.isArith() } 38 func (o scalarBinOp) String() string { return o.ʘBinaryOperatorType.String() } 39 40 func (o scalarBinOp) Do(same bool, vals ...Value) (retVal Value, err error) { 41 if err = checkArity(o, len(vals)); err != nil { 42 return 43 } 44 45 at := TypeOf(vals[0]) 46 bt := TypeOf(vals[1]) 47 if !at.Eq(bt) { 48 err = errors.Errorf("Type Mismatch: %v != %v", at, bt) 49 return 50 } 51 52 var r interface{} // float or bool only plz 53 switch a := vals[0].(type) { 54 case *F64: 55 b := vals[1].(*F64) 56 switch o.ʘBinaryOperatorType { 57 case addOpType: 58 r = NewF64(a.any() + b.any()) 59 case subOpType: 60 r = NewF64(a.any() - b.any()) 61 case mulOpType: 62 r = NewF64(a.any() * b.any()) 63 case divOpType: 64 r = NewF64(a.any() / b.any()) 65 case powOpType: 66 r = NewF64(math.Pow(a.any(), b.any())) 67 case ltOpType: 68 r = NewB(a.any() < b.any()) 69 case gtOpType: 70 r = NewB(a.any() > b.any()) 71 case lteOpType: 72 r = NewB(a.any() <= b.any()) 73 case gteOpType: 74 r = NewB(a.any() >= b.any()) 75 case eqOpType: 76 r = NewB(a.any() == b.any()) 77 case neOpType: 78 r = NewB(a.any() != b.any()) 79 default: 80 err = errors.Errorf(nyiFail, "scalarBinOp.Do() - Float64", o.ʘBinaryOperatorType) 81 } 82 83 if same && !o.isArith() { 84 if *(r.(*B)) { 85 r = NewF64(1.0) 86 } else { 87 r = NewF64(0.0) 88 } 89 } 90 91 case *F32: 92 b := vals[1].(*F32) 93 switch o.ʘBinaryOperatorType { 94 case addOpType: 95 r = NewF32(a.any() + b.any()) 96 case subOpType: 97 r = NewF32(a.any() - b.any()) 98 case mulOpType: 99 r = NewF32(a.any() * b.any()) 100 case divOpType: 101 r = NewF32(a.any() / b.any()) 102 case powOpType: 103 r = NewF32(math32.Pow(float32(a.any()), float32(b.any()))) 104 case ltOpType: 105 r = NewB(a.any() < b.any()) 106 case gtOpType: 107 r = NewB(a.any() > b.any()) 108 case lteOpType: 109 r = NewB(a.any() <= b.any()) 110 case gteOpType: 111 r = NewB(a.any() >= b.any()) 112 case eqOpType: 113 r = NewB(a.any() == b.any()) 114 case neOpType: 115 r = NewB(a.any() != b.any()) 116 default: 117 err = errors.Errorf(nyiFail, "scalarBinOp.Do() - Float32", o.ʘBinaryOperatorType) 118 } 119 120 if same && !o.isArith() { 121 if *(r.(*B)) { 122 r = NewF32(1) 123 } else { 124 r = NewF32(0) 125 } 126 } 127 128 case *I: 129 b := vals[1].(*I) 130 switch o.ʘBinaryOperatorType { 131 case addOpType: 132 r = NewI(a.any() + b.any()) 133 case subOpType: 134 r = NewI(a.any() - b.any()) 135 case mulOpType: 136 r = NewI(a.any() * b.any()) 137 case divOpType: 138 r = NewI(a.any() / b.any()) 139 // case powOpType: 140 // r = math.Pow(a, b) 141 case ltOpType: 142 r = NewB(a.any() < b.any()) 143 case gtOpType: 144 r = NewB(a.any() > b.any()) 145 case lteOpType: 146 r = NewB(a.any() <= b.any()) 147 case gteOpType: 148 r = NewB(a.any() >= b.any()) 149 case eqOpType: 150 r = NewB(a.any() == b.any()) 151 case neOpType: 152 r = NewB(a.any() != b.any()) 153 default: 154 err = errors.Errorf(nyiFail, "scalarBinOp.Do() - Int", o.ʘBinaryOperatorType) 155 } 156 157 if same && !o.isArith() { 158 if *(r.(*B)) { 159 r = NewI(1) 160 } else { 161 r = NewI(0) 162 } 163 } 164 case *I32: 165 b := vals[1].(*I32) 166 switch o.ʘBinaryOperatorType { 167 case addOpType: 168 r = NewI32(a.any() + b.any()) 169 case subOpType: 170 r = NewI32(a.any() - b.any()) 171 case mulOpType: 172 r = NewI32(a.any() * b.any()) 173 case divOpType: 174 r = NewI32(a.any() / b.any()) 175 // case powOpType: 176 // r = math.Pow(a, b) 177 case ltOpType: 178 r = NewB(a.any() < b.any()) 179 case gtOpType: 180 r = NewB(a.any() > b.any()) 181 case lteOpType: 182 r = NewB(a.any() <= b.any()) 183 case gteOpType: 184 r = NewB(a.any() >= b.any()) 185 case eqOpType: 186 r = NewB(a.any() == b.any()) 187 case neOpType: 188 r = NewB(a.any() != b.any()) 189 default: 190 err = errors.Errorf(nyiFail, "scalarBinOp.Do() - Int32", o.ʘBinaryOperatorType) 191 } 192 193 if same && !o.isArith() { 194 if *(r.(*B)) { 195 r = NewI32(1) 196 } else { 197 r = NewI32(0) 198 } 199 } 200 case *I64: 201 b := vals[1].(*I64) 202 switch o.ʘBinaryOperatorType { 203 case addOpType: 204 r = NewI64(a.any() + b.any()) 205 case subOpType: 206 r = NewI64(a.any() - b.any()) 207 case mulOpType: 208 r = NewI64(a.any() * b.any()) 209 case divOpType: 210 r = NewI64(a.any() / b.any()) 211 // case powOpType: 212 // r = math.Pow(a, b) 213 case ltOpType: 214 r = NewB(a.any() < b.any()) 215 case gtOpType: 216 r = NewB(a.any() > b.any()) 217 case lteOpType: 218 r = NewB(a.any() <= b.any()) 219 case gteOpType: 220 r = NewB(a.any() >= b.any()) 221 case eqOpType: 222 r = NewB(a.any() == b.any()) 223 case neOpType: 224 r = NewB(a.any() != b.any()) 225 default: 226 err = errors.Errorf(nyiFail, "scalarBinOp.Do() - Int64", o.ʘBinaryOperatorType) 227 } 228 229 if same && !o.isArith() { 230 if *(r.(*B)) { 231 r = NewI64(1) 232 } else { 233 r = NewI64(0) 234 } 235 } 236 case *U8: 237 b := vals[1].(*U8) 238 switch o.ʘBinaryOperatorType { 239 case addOpType: 240 r = NewU8(a.any() + b.any()) 241 case subOpType: 242 r = NewU8(a.any() - b.any()) 243 case mulOpType: 244 r = NewU8(a.any() * b.any()) 245 case divOpType: 246 r = NewU8(a.any() / b.any()) 247 // case powOpType: 248 // r = math.Pow(a, b) 249 case ltOpType: 250 r = NewB(a.any() < b.any()) 251 case gtOpType: 252 r = NewB(a.any() > b.any()) 253 case lteOpType: 254 r = NewB(a.any() <= b.any()) 255 case gteOpType: 256 r = NewB(a.any() >= b.any()) 257 case eqOpType: 258 r = NewB(a.any() == b.any()) 259 case neOpType: 260 r = NewB(a.any() != b.any()) 261 default: 262 err = errors.Errorf(nyiFail, "scalarBinOp.Do() - Byte", o.ʘBinaryOperatorType) 263 } 264 265 if same && !o.isArith() { 266 if *(r.(*B)) { 267 r = NewU8(1) 268 } else { 269 r = NewU8(0) 270 } 271 } 272 case *B: 273 b := vals[1].(*B) 274 switch o.ʘBinaryOperatorType { 275 case eqOpType: 276 r = NewB(a.any() == b.any()) 277 case neOpType: 278 r = NewB(a.any() != b.any()) 279 default: 280 err = errors.Errorf(nyiFail, "scalarBinOp.Do() - Bool", o.ʘBinaryOperatorType) 281 } 282 283 default: 284 err = errors.Errorf(nyiFail, "scalarBinOp.Do() - Unhandled Scalar Type", o.t) 285 } 286 287 if err != nil { 288 return 289 } 290 291 retVal, _ = anyToScalar(r) 292 return 293 } 294 295 type tBinOp struct { 296 ʘBinaryOperatorType 297 tensorLeft bool 298 } 299 300 func (o tBinOp) Arity() int { return 2 } 301 func (o tBinOp) binOpType() ʘBinaryOperatorType { return o.ʘBinaryOperatorType } 302 func (o tBinOp) String() string { return o.ʘBinaryOperatorType.String() } 303 func (o tBinOp) isArith() bool { return o.ʘBinaryOperatorType.isArith() } 304 305 func (o tBinOp) Do(same bool, inputs ...Value) (Value, error) { 306 if same { 307 return o.do(inputs, tensor.AsSameType()) 308 } 309 return o.do(inputs) 310 } 311 312 func (o tBinOp) UnsafeDo(retSame bool, inputs ...Value) (Value, error) { 313 if retSame { 314 return o.do(inputs, tensor.AsSameType(), tensor.UseUnsafe()) 315 } 316 return o.do(inputs, tensor.UseUnsafe()) 317 } 318 func (o tBinOp) UsePreallocDo(v Value, retSame bool, inputs ...Value) (retVal Value, err error) { 319 t, ok := v.(tensor.Tensor) 320 if !ok { 321 return nil, errors.Errorf("Expected Tensor as preallocated value. Got %v of %T instead", v, v) 322 } 323 324 reuse := t 325 if retSame { 326 return o.do(inputs, tensor.WithReuse(reuse), tensor.AsSameType()) 327 } 328 return o.do(inputs, tensor.WithReuse(reuse)) 329 } 330 331 func (o tBinOp) IncrDo(incr Value, retSame bool, inputs ...Value) (err error) { 332 reuse, ok := incr.(tensor.Tensor) 333 if ok { 334 _, err = o.do(inputs, tensor.WithIncr(reuse)) 335 return 336 } 337 338 var retVal Value 339 if retSame { 340 if retVal, err = o.do(inputs, tensor.AsSameType()); err != nil { 341 return errors.Wrapf(err, doFail, o) 342 } 343 } else { 344 if retVal, err = o.do(inputs); err != nil { 345 return errors.Wrapf(err, doFail, o) 346 } 347 348 } 349 350 add := newEBOByType(addOpType, TypeOf(incr), TypeOf(retVal)) 351 if retVal, err = add.UnsafeDo(incr, retVal); err != nil { 352 return errors.Wrapf(err, unsafeDoFail, add) 353 } 354 355 err = noIncrErr{retVal} 356 return 357 } 358 359 func (o tBinOp) do(vals []Value, opts ...tensor.FuncOpt) (retVal Value, err error) { 360 if err = checkArity(o, len(vals)); err != nil { 361 return 362 } 363 364 // typecheck the operands 365 d0 := vals[0].Dtype() 366 d1 := vals[1].Dtype() 367 368 if d0 != d1 { 369 return nil, errors.Errorf("Dtype mismatch for bin op: %v and %v", d0, d1) 370 } 371 372 // extract the goddamn values 373 var a, b interface{} 374 if o.tensorLeft { 375 t, ok := vals[0].(tensor.Tensor) 376 if !ok { 377 return nil, errors.Errorf("Expected left value to be Tensor. Got %v of %T instead", vals[0], vals[0]) 378 } 379 a = tensor.Materialize(t) 380 // a = t 381 382 switch other := vals[1].(type) { 383 case *F64: 384 b = other.any() 385 case *F32: 386 b = other.any() 387 case tensor.Tensor: 388 b = tensor.Materialize(other) 389 default: 390 return nil, errors.Errorf(nyiFail, "tBinOp.do()", vals[1]) 391 } 392 } else { 393 t, ok := vals[1].(tensor.Tensor) 394 if !ok { 395 return nil, errors.Errorf("Expected right value to be Tensor. Got %v of %T instead", vals[1], vals[1]) 396 } 397 b = tensor.Materialize(t) 398 399 switch other := vals[0].(type) { 400 case *F64: 401 a = other.any() 402 case *F32: 403 a = other.any() 404 case tensor.Tensor: 405 a = tensor.Materialize(other) 406 default: 407 return nil, errors.Errorf(nyiFail, "tBinOp.do()", vals[1]) 408 } 409 } 410 411 if o.isArith() { 412 fn := binOps[o.ʘBinaryOperatorType] 413 if fn == nil { 414 return nil, errors.Errorf("nil function returned for %v", o.ʘBinaryOperatorType) 415 } 416 retVal, err = (*fn)(a, b, opts...) 417 } else { 418 fn := cmpOps[o.ʘBinaryOperatorType] 419 if fn == nil { 420 return nil, errors.Errorf("nil function returned for %v", o.ʘBinaryOperatorType) 421 } 422 retVal, err = (*fn)(a, b, opts...) 423 424 } 425 return 426 } 427 428 // type binDiffFn func(x, y, z, gradZ *Node) (Nodes, err error) 429 430 func addDiffExpr(x, y, z, gradZ *Node) (retVal Nodes, err error) { 431 return Nodes{gradZ, gradZ}, nil 432 } 433 434 func addDiff(ctx ExecutionContext, x, y, z *Node) (err error) { 435 xdv, ydv := getDV(x, y) 436 437 // set up the op to be executed 438 op := NewAddOp(x, z, ctx) 439 op.Device = x.Device() 440 op.UseUnsafe = true 441 442 // we'll use the same device as the device the data from the node resides in 443 dev := op.Device 444 445 var d, xd, yd, zd Value 446 var extra bool 447 448 // allocate if necessary 449 if xd, extra, err = x.GradOnDevice(dev, ctx.External); err != nil { 450 return errors.Wrapf(err, gradOnDeviceFail, x, dev) 451 } 452 if extra { 453 defer ctx.PutValue(dev, xd) 454 } 455 456 if zd, extra, err = z.GradOnDevice(dev, ctx.External); err != nil { 457 return errors.Wrapf(err, gradOnDeviceFail, z, dev) 458 } 459 if extra { 460 defer ctx.PutValue(dev, xd) 461 } 462 463 // if x is scalar, an additional vector needs to be acquired 464 if x.IsScalar() && dev != CPU { 465 var mem tensor.Memory 466 var xd2 Value 467 memsize := calcMemSize(zd.Dtype(), zd.Shape()) 468 if mem, err = ctx.Get(dev, memsize); err != nil { 469 return 470 } 471 472 if xd2, err = makeValueFromMem(z.t, zd.Shape(), mem); err != nil { 473 return 474 } 475 476 op.Prealloc = xd2 477 defer ctx.Signal() 478 } 479 480 // xd += zd 481 if d, err = op.Do(xd, zd); err != nil { 482 return errors.Wrapf(err, doFail, op) 483 } 484 xdv.SetDeriv(d) 485 486 // set up the op to be executed for y 487 op = NewAddOp(y, z, ctx) 488 op.Device = y.Device() 489 op.UseUnsafe = true 490 491 dev = op.Device 492 493 if yd, extra, err = y.GradOnDevice(dev, ctx.External); err != nil { 494 return errors.Wrapf(err, gradOnDeviceFail, y, dev) 495 } 496 if extra { 497 defer ctx.PutValue(dev, yd) 498 } 499 500 if zd, extra, err = z.GradOnDevice(dev, ctx.External); err != nil { 501 return errors.Wrapf(err, gradOnDeviceFail, z, dev) 502 } 503 if extra { 504 defer ctx.PutValue(dev, zd) 505 } 506 507 // if y is scalar, an additional vector needs to be acquired 508 if y.IsScalar() && dev != CPU { 509 var mem tensor.Memory 510 var yd2 Value 511 memsize := calcMemSize(zd.Dtype(), zd.Shape()) 512 if mem, err = ctx.Get(dev, memsize); err != nil { 513 return 514 } 515 if yd2, err = makeValueFromMem(z.t, zd.Shape(), mem); err != nil { 516 return 517 } 518 519 op.Prealloc = yd2 520 defer ctx.Signal() 521 } 522 523 // yd += zd 524 if d, err = op.Do(yd, zd); err != nil { 525 return errors.Wrapf(err, doFail, op) 526 } 527 ydv.SetDeriv(d) // ignore errors on purpose 528 529 return nil 530 } 531 532 func subDiffExpr(x, y, z, gradZ *Node) (retVal Nodes, err error) { 533 var dzdy *Node 534 if dzdy, err = Neg(gradZ); err == nil { 535 WithGroupName(gradClust)(dzdy) 536 WithGroupName(gradClust)(gradZ) 537 retVal = Nodes{gradZ, dzdy} 538 } else { 539 return nil, errors.Wrap(err, "Failed to carry Neg()") 540 } 541 return 542 } 543 544 func subDiff(ctx ExecutionContext, x, y, z *Node) (err error) { 545 xdv, ydv := getDV(x, y) 546 547 add := NewAddOp(x, z, ctx) 548 sub := NewSubOp(y, z, ctx) 549 add.Device = x.Device() 550 sub.Device = y.Device() 551 sub.UseUnsafe = true 552 add.UseUnsafe = true 553 // sub := newEBOByType(subOpType, y.t, z.t) 554 // add := newEBOByType(addOpType, x.t, z.t) 555 556 dev := sub.Device 557 var xd, yd, zd, d Value 558 var extra bool 559 560 if zd, extra, err = z.GradOnDevice(dev, ctx.External); err != nil { 561 return errors.Wrapf(err, gradOnDeviceFail, z, dev) 562 } 563 if extra { 564 defer ctx.PutValue(dev, zd) 565 } 566 567 if yd, extra, err = y.GradOnDevice(dev, ctx.External); err != nil { 568 return errors.Wrapf(err, gradOnDeviceFail, y, dev) 569 } 570 if extra { 571 defer ctx.PutValue(dev, yd) 572 } 573 574 // if y is scalar an additional vector needs to be allocated for the prelloc 575 switch { 576 case y.IsScalar() && dev != CPU: 577 var mem tensor.Memory 578 var yd2 Value 579 memsize := calcMemSize(zd.Dtype(), zd.Shape()) 580 if mem, err = ctx.Get(dev, memsize); err != nil { 581 return errors.Wrapf(err, allocFail, memsize, dev) 582 } 583 if yd2, err = makeValueFromMem(z.t, zd.Shape(), mem); err != nil { 584 return errors.Wrapf(err, makeValueFail, z.t, zd.Shape()) 585 } 586 587 sub.Prealloc = yd2 588 defer ctx.Signal() 589 case y.IsScalar() && dev == CPU: 590 if sub.Prealloc, err = makeValue(z.t, zd.Shape()); err != nil { 591 return 592 } 593 } 594 595 // dz/dy 596 if d, err = sub.Do(yd, zd); err != nil { 597 return errors.Wrapf(err, doFail, sub) 598 } 599 ydv.SetDeriv(d) // errors are ignored on purpose 600 601 // handle x 602 603 dev = add.Device 604 if zd, extra, err = z.GradOnDevice(dev, ctx.External); err != nil { 605 return errors.Wrapf(err, gradOnDeviceFail, z, dev) 606 } 607 if extra { 608 defer ctx.PutValue(dev, zd) 609 } 610 611 if xd, extra, err = x.GradOnDevice(dev, ctx.External); err != nil { 612 return errors.Wrapf(err, gradOnDeviceFail, x, dev) 613 } 614 if extra { 615 defer ctx.PutValue(dev, xd) 616 } 617 618 switch { 619 case x.IsScalar() && dev != CPU: 620 var mem tensor.Memory 621 var xd2 Value 622 memsize := calcMemSize(zd.Dtype(), zd.Shape()) 623 if mem, err = ctx.Get(dev, memsize); err != nil { 624 return 625 } 626 627 if xd2, err = makeValueFromMem(z.t, zd.Shape(), mem); err != nil { 628 return 629 } 630 add.Prealloc = xd2 631 defer ctx.Signal() 632 case x.IsScalar() && dev == CPU: 633 if sub.Prealloc, err = makeValue(z.t, zd.Shape()); err != nil { 634 return 635 } 636 } 637 638 // dz/dx 639 if d, err = add.Do(xd, zd); err != nil { 640 return errors.Wrapf(err, doFail, add) 641 } 642 xdv.SetDeriv(d) // ignore errors on purpose 643 644 return nil 645 } 646 647 func hadamardProdDiffExpr(x, y, z, gradZ *Node) (retVal Nodes, err error) { 648 var dzdx, dzdy *Node 649 if dzdx, err = HadamardProd(y, gradZ); err == nil { 650 dzdy, err = HadamardProd(x, gradZ) 651 if err != nil { 652 return nil, errors.Wrap(err, "Failed to carry HadamardProd()") 653 } 654 WithGroupName(gradClust)(dzdx) 655 WithGroupName(gradClust)(dzdy) 656 retVal = Nodes{dzdx, dzdy} 657 return 658 } 659 return nil, errors.Wrap(err, "Failed to carry HadamardProd()") 660 } 661 662 func hadamardProdDiff(ctx ExecutionContext, x, y, z *Node) (err error) { 663 xdv, ydv := getDV(x, y) 664 665 var mul *ExternalOp 666 var dev Device 667 var xd, yd, zd, d Value 668 var extra bool 669 670 if x.isConstant() { 671 goto dzdy 672 } 673 674 //dzdx 675 mul = NewHadamardProdOp(y, z, ctx) 676 mul.Device = x.Device() 677 dev = mul.Device 678 679 if xd, extra, err = x.GradOnDevice(dev, ctx.External); err != nil { 680 return errors.Wrapf(err, gradOnDeviceFail, x, dev) 681 } 682 if extra { 683 defer ctx.PutValue(dev, xd) 684 } 685 686 if yd, extra, err = y.ValueOnDevice(dev, ctx.External); err != nil { 687 return errors.Wrapf(err, gradOnDeviceFail, y, dev) 688 } 689 if extra { 690 defer ctx.PutValue(dev, yd) 691 } 692 693 if zd, extra, err = z.GradOnDevice(dev, ctx.External); err != nil { 694 return errors.Wrapf(err, gradOnDeviceFail, z, dev) 695 } 696 if extra { 697 defer ctx.PutValue(dev, zd) 698 } 699 700 mul.Incr = xd 701 702 // if y is Scalar, then it needs to be broadcasted across to the 703 if x.IsScalar() && dev != CPU && !zd.Shape().IsScalar() { 704 var memIncr, mem2 tensor.Memory 705 var xdIncr, xd2 Value 706 memsize := calcMemSize(zd.Dtype(), zd.Shape()) 707 if mem2, err = ctx.Get(dev, memsize); err != nil { 708 return errors.Wrapf(err, allocFail, memsize, dev) 709 } 710 711 if xd2, err = makeValueFromMem(z.t, zd.Shape(), mem2); err != nil { 712 return errors.Wrapf(err, makeValueFail, z.t, zd.Shape()) 713 } 714 715 // "broadcast" x (in a very sloppy way) 716 if memIncr, err = ctx.Get(dev, memsize); err != nil { 717 return errors.Wrapf(err, allocFail, memsize, dev) 718 } 719 720 if xdIncr, err = makeValueFromMem(z.t, zd.Shape(), memIncr); err != nil { 721 return errors.Wrapf(err, makeValueFail, z.t, zd.Shape()) 722 } 723 xdIncr.(tensor.Tensor).Memset(xdv.d.Data()) 724 725 mul.Prealloc = xd2 726 mul.Incr = xdIncr 727 728 defer ctx.PutValue(dev, xd2) // xd2 is temporary, we need to dealloc it 729 defer ctx.Signal() // work needs to be done 730 } 731 732 if d, err = mul.Do(yd, zd); err != nil { 733 return errors.Wrapf(err, "IncrDo xd faile") 734 } 735 736 xdv.SetDeriv(d) 737 738 dzdy: 739 if y.isConstant() { 740 goto end 741 } 742 743 mul = NewHadamardProdOp(x, z, ctx) 744 mul.Device = y.Device() 745 dev = mul.Device 746 747 if xd, extra, err = x.ValueOnDevice(dev, ctx.External); err != nil { 748 return errors.Wrapf(err, gradOnDeviceFail, x, dev) 749 } 750 if extra { 751 defer ctx.PutValue(dev, xd) 752 } 753 754 if yd, extra, err = y.GradOnDevice(dev, ctx.External); err != nil { 755 return errors.Wrapf(err, gradOnDeviceFail, y, dev) 756 } 757 if extra { 758 defer ctx.PutValue(dev, yd) 759 } 760 761 if zd, extra, err = z.GradOnDevice(dev, ctx.External); err != nil { 762 return errors.Wrapf(err, gradOnDeviceFail, z, dev) 763 } 764 if extra { 765 defer ctx.PutValue(dev, zd) 766 } 767 768 mul.Incr = yd 769 770 // if y is Scalar, then it needs to be broadcasted across to the 771 if y.IsScalar() && dev != CPU && !zd.Shape().IsScalar() { 772 var memIncr, mem2 tensor.Memory 773 var ydIncr, yd2 Value 774 memsize := calcMemSize(zd.Dtype(), zd.Shape()) 775 if mem2, err = ctx.Get(dev, memsize); err != nil { 776 return errors.Wrapf(err, allocFail, memsize, dev) 777 } 778 779 if yd2, err = makeValueFromMem(z.t, zd.Shape(), mem2); err != nil { 780 return errors.Wrapf(err, makeValueFail, z.t, zd.Shape()) 781 } 782 783 // "broadcast" y (in a very sloppy way) 784 if memIncr, err = ctx.Get(dev, memsize); err != nil { 785 return errors.Wrapf(err, allocFail, memsize, dev) 786 } 787 788 if ydIncr, err = makeValueFromMem(z.t, zd.Shape(), memIncr); err != nil { 789 return errors.Wrapf(err, makeValueFail, z.t, zd.Shape()) 790 } 791 ydIncr.(tensor.Tensor).Memset(ydv.d.Data()) 792 793 mul.Prealloc = yd2 794 mul.Incr = ydIncr 795 796 defer ctx.PutValue(dev, yd2) // yd2 is temporary, we need to dealloc it 797 defer ctx.Signal() // work needs to be done 798 } 799 800 if d, err = mul.Do(xd, zd); err != nil { 801 return errors.Wrapf(err, "IncrDo yd failed") 802 } 803 ydv.SetDeriv(d) 804 805 end: 806 return nil 807 } 808 809 func hadamardDivDiffExpr(x, y, z, gradZ *Node) (retVal Nodes, err error) { 810 var dzdx, dzdy *Node 811 if dzdx, err = HadamardDiv(gradZ, y); err == nil { 812 WithGroupName(gradClust)(dzdx) 813 if dzdy, err = HadamardDiv(z, y); err == nil { 814 WithGroupName(gradClust)(dzdy) 815 if dzdy, err = Neg(dzdy); err == nil { 816 WithGroupName(gradClust)(dzdy) 817 if dzdy, err = HadamardProd(dzdy, gradZ); err == nil { 818 WithGroupName(gradClust)(dzdy) 819 retVal = Nodes{dzdx, dzdy} 820 return 821 } 822 return nil, errors.Wrap(err, "Failed to carry HadamardProd()") 823 } 824 return nil, errors.Wrap(err, "Failed to carry Neg()") 825 } 826 return nil, errors.Wrap(err, "Failed to carry HadamardProd()") 827 } 828 return nil, errors.Wrap(err, "Failed to carry HadamardProd()") 829 } 830 831 func hadamardDivDiff(ctx ExecutionContext, x, y, z *Node) (err error) { 832 xdv, ydv, zdv := getDV3(x, y, z) 833 834 // dzdx = 1/y * dz 835 div := newEBOByType(divOpType, TypeOf(zdv.d), TypeOf(ydv.Value)) 836 err = div.IncrDo(xdv.d, zdv.d, ydv.Value) 837 if err != nil { 838 var ver Valuer 839 var ok bool 840 if ver, ok = err.(Valuer); !ok { 841 return 842 } 843 844 xdv.SetDeriv(ver.Value()) // ignore errors on purpose 845 } 846 847 //dzdy = -x/y² 848 // TODO: investigate if this can be done (if no other node uses z): 849 // unsafe do : neg zdv.d 850 // unsafe do : mul zdv.d, zdv.Value 851 // incr do : <incr: ydv.d> div zdv.d, ydv.Value 852 var d Value 853 if d, err = div.Do(zdv.Value, ydv.Value); err != nil { 854 return errors.Wrapf(err, doFail, div) 855 } 856 857 neg := newElemUnaryOp(negOpType, y) 858 if d, err = neg.Do(d); err != nil { 859 return errors.Wrapf(err, doFail, neg) 860 } 861 862 mul := newElemBinOp(mulOpType, z, y) 863 err = mul.IncrDo(ydv.d, zdv.d, d) 864 if err != nil { 865 var ver Valuer 866 var ok bool 867 if ver, ok = err.(Valuer); !ok { 868 return 869 } 870 871 ydv.SetDeriv(ver.Value()) // ignore errors on purpose 872 } 873 874 return nil 875 } 876 877 // TODO: go back in time, pay more attention to calculus class in high school and learn how to differentiate x^y 878 func hadamardPowDiffExpr(x, y, z, grad *Node) (retVal Nodes, err error) { 879 var one *Node 880 var dt tensor.Dtype 881 882 if dt, err = dtypeOf(y.t); err != nil { 883 return nil, errors.Wrapf(err, dtypeExtractionFail, y.t) 884 } 885 886 switch dt { 887 case Float32: 888 one = onef32 889 case Float64: 890 one = onef64 891 default: 892 err = errors.Errorf(nyiTypeFail, "Hadamard Power Diff", y.t) 893 return 894 } 895 896 var ym1, pow *Node 897 if ym1, err = Sub(y, one); err != nil { 898 return 899 } 900 901 if pow, err = Pow(x, ym1); err != nil { 902 return 903 } 904 905 var dzdx *Node 906 if dzdx, err = HadamardProd(grad, y); err != nil { 907 return 908 } 909 if dzdx, err = HadamardProd(dzdx, pow); err != nil { 910 return 911 } 912 913 var logx *Node 914 if logx, err = Log(x); err != nil { 915 return 916 } 917 918 var dzdy *Node 919 if dzdy, err = HadamardProd(grad, z); err != nil { 920 return 921 } 922 if dzdy, err = HadamardProd(dzdy, logx); err != nil { 923 return 924 } 925 926 retVal = Nodes{dzdx, dzdy} 927 return 928 // return nil, errors.New("hadamardPowDiffExpr not yet implemented") 929 } 930 931 func hadamardPowDiff(ctx ExecutionContext, x, y, z *Node) (err error) { 932 xdv, ydv, zdv := getDV3(x, y, z) 933 934 var ym1 Value 935 switch ydvt := ydv.Value.(type) { 936 case *F64: 937 ym1 = NewF64(ydvt.any() - float64(1)) 938 case *F32: 939 ym1 = NewF32(ydvt.any() - float32(1)) 940 case *tensor.Dense: 941 var one interface{} 942 switch ydvt.Dtype() { 943 case tensor.Float64: 944 one = float64(1) 945 case tensor.Float32: 946 one = float32(1) 947 } 948 if ym1, err = tensor.Sub(ydvt, one); err != nil { 949 return 950 } 951 default: 952 err = errors.Errorf(nyiTypeFail, "hadamardPowDiff", ydv.Value) 953 return 954 } 955 956 // dzdx 957 var pow Value 958 powOp := newEBOByType(powOpType, TypeOf(xdv.Value), TypeOf(ym1)) 959 if pow, err = powOp.Do(xdv.Value, ym1); err != nil { 960 return 961 } 962 963 mul := newEBOByType(mulOpType, TypeOf(ydv.Value), TypeOf(xdv.Value)) 964 if pow, err = mul.UnsafeDo(pow, ydv.Value); err != nil { 965 return 966 } 967 968 if err = mul.IncrDo(xdv.d, pow, zdv.d); err != nil { 969 var ver Valuer 970 var ok bool 971 if ver, ok = err.(Valuer); !ok { 972 return 973 } 974 975 xdv.SetDeriv(ver.Value()) 976 } 977 978 // dzdy 979 var logx Value 980 logOp := newElemUnaryOp(lnOpType, x) 981 if logx, err = logOp.Do(xdv.Value); err != nil { 982 return 983 } 984 if logx, err = mul.Do(zdv.Value, logx); err != nil { 985 return 986 } 987 if err = mul.IncrDo(ydv.d, logx, zdv.d); err != nil { 988 var ver Valuer 989 var ok bool 990 if ver, ok = err.(Valuer); !ok { 991 return 992 } 993 994 ydv.SetDeriv(ver.Value()) 995 } 996 return nil 997 } 998 999 func nondiffBinOpExpr(x, y, z, grad *Node) (retVal Nodes, err error) { 1000 return nil, errors.New("Nondifferentiable") 1001 } 1002 1003 func nondiffBinOp(ctx ExecutionContext, x, y, z *Node) (err error) { 1004 return AutoDiffError{} 1005 }