gorgonia.org/gorgonia@v0.9.17/operatorPointwise_binary.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"math"
     5  
     6  	"github.com/chewxy/math32"
     7  	"github.com/pkg/errors"
     8  	"gorgonia.org/tensor"
     9  )
    10  
    11  type incrDoerBinOp interface {
    12  	IncrDo(v Value, retSame bool, inputs ...Value) error
    13  }
    14  type usePreallocDoerBinOp interface {
    15  	UsePreallocDo(v Value, retSame bool, inputs ...Value) (retVal Value, err error)
    16  }
    17  type unsafeDoerBinOp interface {
    18  	UnsafeDo(retSame bool, inputs ...Value) (Value, error)
    19  }
    20  
    21  /* BINARY OPERATOR */
    22  
    23  type ʘBinaryOperator interface {
    24  	isArith() bool
    25  	binOpType() ʘBinaryOperatorType
    26  	Do(bool, ...Value) (Value, error)
    27  	String() string
    28  }
    29  
    30  type scalarBinOp struct {
    31  	ʘBinaryOperatorType
    32  	t tensor.Dtype
    33  }
    34  
    35  func (o scalarBinOp) Arity() int                     { return 2 }
    36  func (o scalarBinOp) binOpType() ʘBinaryOperatorType { return o.ʘBinaryOperatorType }
    37  func (o scalarBinOp) isArith() bool                  { return o.ʘBinaryOperatorType.isArith() }
    38  func (o scalarBinOp) String() string                 { return o.ʘBinaryOperatorType.String() }
    39  
    40  func (o scalarBinOp) Do(same bool, vals ...Value) (retVal Value, err error) {
    41  	if err = checkArity(o, len(vals)); err != nil {
    42  		return
    43  	}
    44  
    45  	at := TypeOf(vals[0])
    46  	bt := TypeOf(vals[1])
    47  	if !at.Eq(bt) {
    48  		err = errors.Errorf("Type Mismatch: %v != %v", at, bt)
    49  		return
    50  	}
    51  
    52  	var r interface{} // float or bool only plz
    53  	switch a := vals[0].(type) {
    54  	case *F64:
    55  		b := vals[1].(*F64)
    56  		switch o.ʘBinaryOperatorType {
    57  		case addOpType:
    58  			r = NewF64(a.any() + b.any())
    59  		case subOpType:
    60  			r = NewF64(a.any() - b.any())
    61  		case mulOpType:
    62  			r = NewF64(a.any() * b.any())
    63  		case divOpType:
    64  			r = NewF64(a.any() / b.any())
    65  		case powOpType:
    66  			r = NewF64(math.Pow(a.any(), b.any()))
    67  		case ltOpType:
    68  			r = NewB(a.any() < b.any())
    69  		case gtOpType:
    70  			r = NewB(a.any() > b.any())
    71  		case lteOpType:
    72  			r = NewB(a.any() <= b.any())
    73  		case gteOpType:
    74  			r = NewB(a.any() >= b.any())
    75  		case eqOpType:
    76  			r = NewB(a.any() == b.any())
    77  		case neOpType:
    78  			r = NewB(a.any() != b.any())
    79  		default:
    80  			err = errors.Errorf(nyiFail, "scalarBinOp.Do() - Float64", o.ʘBinaryOperatorType)
    81  		}
    82  
    83  		if same && !o.isArith() {
    84  			if *(r.(*B)) {
    85  				r = NewF64(1.0)
    86  			} else {
    87  				r = NewF64(0.0)
    88  			}
    89  		}
    90  
    91  	case *F32:
    92  		b := vals[1].(*F32)
    93  		switch o.ʘBinaryOperatorType {
    94  		case addOpType:
    95  			r = NewF32(a.any() + b.any())
    96  		case subOpType:
    97  			r = NewF32(a.any() - b.any())
    98  		case mulOpType:
    99  			r = NewF32(a.any() * b.any())
   100  		case divOpType:
   101  			r = NewF32(a.any() / b.any())
   102  		case powOpType:
   103  			r = NewF32(math32.Pow(float32(a.any()), float32(b.any())))
   104  		case ltOpType:
   105  			r = NewB(a.any() < b.any())
   106  		case gtOpType:
   107  			r = NewB(a.any() > b.any())
   108  		case lteOpType:
   109  			r = NewB(a.any() <= b.any())
   110  		case gteOpType:
   111  			r = NewB(a.any() >= b.any())
   112  		case eqOpType:
   113  			r = NewB(a.any() == b.any())
   114  		case neOpType:
   115  			r = NewB(a.any() != b.any())
   116  		default:
   117  			err = errors.Errorf(nyiFail, "scalarBinOp.Do() - Float32", o.ʘBinaryOperatorType)
   118  		}
   119  
   120  		if same && !o.isArith() {
   121  			if *(r.(*B)) {
   122  				r = NewF32(1)
   123  			} else {
   124  				r = NewF32(0)
   125  			}
   126  		}
   127  
   128  	case *I:
   129  		b := vals[1].(*I)
   130  		switch o.ʘBinaryOperatorType {
   131  		case addOpType:
   132  			r = NewI(a.any() + b.any())
   133  		case subOpType:
   134  			r = NewI(a.any() - b.any())
   135  		case mulOpType:
   136  			r = NewI(a.any() * b.any())
   137  		case divOpType:
   138  			r = NewI(a.any() / b.any())
   139  		// case powOpType:
   140  		// 	r = math.Pow(a, b)
   141  		case ltOpType:
   142  			r = NewB(a.any() < b.any())
   143  		case gtOpType:
   144  			r = NewB(a.any() > b.any())
   145  		case lteOpType:
   146  			r = NewB(a.any() <= b.any())
   147  		case gteOpType:
   148  			r = NewB(a.any() >= b.any())
   149  		case eqOpType:
   150  			r = NewB(a.any() == b.any())
   151  		case neOpType:
   152  			r = NewB(a.any() != b.any())
   153  		default:
   154  			err = errors.Errorf(nyiFail, "scalarBinOp.Do() - Int", o.ʘBinaryOperatorType)
   155  		}
   156  
   157  		if same && !o.isArith() {
   158  			if *(r.(*B)) {
   159  				r = NewI(1)
   160  			} else {
   161  				r = NewI(0)
   162  			}
   163  		}
   164  	case *I32:
   165  		b := vals[1].(*I32)
   166  		switch o.ʘBinaryOperatorType {
   167  		case addOpType:
   168  			r = NewI32(a.any() + b.any())
   169  		case subOpType:
   170  			r = NewI32(a.any() - b.any())
   171  		case mulOpType:
   172  			r = NewI32(a.any() * b.any())
   173  		case divOpType:
   174  			r = NewI32(a.any() / b.any())
   175  		// case powOpType:
   176  		// 	r = math.Pow(a, b)
   177  		case ltOpType:
   178  			r = NewB(a.any() < b.any())
   179  		case gtOpType:
   180  			r = NewB(a.any() > b.any())
   181  		case lteOpType:
   182  			r = NewB(a.any() <= b.any())
   183  		case gteOpType:
   184  			r = NewB(a.any() >= b.any())
   185  		case eqOpType:
   186  			r = NewB(a.any() == b.any())
   187  		case neOpType:
   188  			r = NewB(a.any() != b.any())
   189  		default:
   190  			err = errors.Errorf(nyiFail, "scalarBinOp.Do() - Int32", o.ʘBinaryOperatorType)
   191  		}
   192  
   193  		if same && !o.isArith() {
   194  			if *(r.(*B)) {
   195  				r = NewI32(1)
   196  			} else {
   197  				r = NewI32(0)
   198  			}
   199  		}
   200  	case *I64:
   201  		b := vals[1].(*I64)
   202  		switch o.ʘBinaryOperatorType {
   203  		case addOpType:
   204  			r = NewI64(a.any() + b.any())
   205  		case subOpType:
   206  			r = NewI64(a.any() - b.any())
   207  		case mulOpType:
   208  			r = NewI64(a.any() * b.any())
   209  		case divOpType:
   210  			r = NewI64(a.any() / b.any())
   211  		// case powOpType:
   212  		// 	r = math.Pow(a, b)
   213  		case ltOpType:
   214  			r = NewB(a.any() < b.any())
   215  		case gtOpType:
   216  			r = NewB(a.any() > b.any())
   217  		case lteOpType:
   218  			r = NewB(a.any() <= b.any())
   219  		case gteOpType:
   220  			r = NewB(a.any() >= b.any())
   221  		case eqOpType:
   222  			r = NewB(a.any() == b.any())
   223  		case neOpType:
   224  			r = NewB(a.any() != b.any())
   225  		default:
   226  			err = errors.Errorf(nyiFail, "scalarBinOp.Do() - Int64", o.ʘBinaryOperatorType)
   227  		}
   228  
   229  		if same && !o.isArith() {
   230  			if *(r.(*B)) {
   231  				r = NewI64(1)
   232  			} else {
   233  				r = NewI64(0)
   234  			}
   235  		}
   236  	case *U8:
   237  		b := vals[1].(*U8)
   238  		switch o.ʘBinaryOperatorType {
   239  		case addOpType:
   240  			r = NewU8(a.any() + b.any())
   241  		case subOpType:
   242  			r = NewU8(a.any() - b.any())
   243  		case mulOpType:
   244  			r = NewU8(a.any() * b.any())
   245  		case divOpType:
   246  			r = NewU8(a.any() / b.any())
   247  		// case powOpType:
   248  		// 	r = math.Pow(a, b)
   249  		case ltOpType:
   250  			r = NewB(a.any() < b.any())
   251  		case gtOpType:
   252  			r = NewB(a.any() > b.any())
   253  		case lteOpType:
   254  			r = NewB(a.any() <= b.any())
   255  		case gteOpType:
   256  			r = NewB(a.any() >= b.any())
   257  		case eqOpType:
   258  			r = NewB(a.any() == b.any())
   259  		case neOpType:
   260  			r = NewB(a.any() != b.any())
   261  		default:
   262  			err = errors.Errorf(nyiFail, "scalarBinOp.Do() - Byte", o.ʘBinaryOperatorType)
   263  		}
   264  
   265  		if same && !o.isArith() {
   266  			if *(r.(*B)) {
   267  				r = NewU8(1)
   268  			} else {
   269  				r = NewU8(0)
   270  			}
   271  		}
   272  	case *B:
   273  		b := vals[1].(*B)
   274  		switch o.ʘBinaryOperatorType {
   275  		case eqOpType:
   276  			r = NewB(a.any() == b.any())
   277  		case neOpType:
   278  			r = NewB(a.any() != b.any())
   279  		default:
   280  			err = errors.Errorf(nyiFail, "scalarBinOp.Do() - Bool", o.ʘBinaryOperatorType)
   281  		}
   282  
   283  	default:
   284  		err = errors.Errorf(nyiFail, "scalarBinOp.Do() - Unhandled Scalar Type", o.t)
   285  	}
   286  
   287  	if err != nil {
   288  		return
   289  	}
   290  
   291  	retVal, _ = anyToScalar(r)
   292  	return
   293  }
   294  
   295  type tBinOp struct {
   296  	ʘBinaryOperatorType
   297  	tensorLeft bool
   298  }
   299  
   300  func (o tBinOp) Arity() int                     { return 2 }
   301  func (o tBinOp) binOpType() ʘBinaryOperatorType { return o.ʘBinaryOperatorType }
   302  func (o tBinOp) String() string                 { return o.ʘBinaryOperatorType.String() }
   303  func (o tBinOp) isArith() bool                  { return o.ʘBinaryOperatorType.isArith() }
   304  
   305  func (o tBinOp) Do(same bool, inputs ...Value) (Value, error) {
   306  	if same {
   307  		return o.do(inputs, tensor.AsSameType())
   308  	}
   309  	return o.do(inputs)
   310  }
   311  
   312  func (o tBinOp) UnsafeDo(retSame bool, inputs ...Value) (Value, error) {
   313  	if retSame {
   314  		return o.do(inputs, tensor.AsSameType(), tensor.UseUnsafe())
   315  	}
   316  	return o.do(inputs, tensor.UseUnsafe())
   317  }
   318  func (o tBinOp) UsePreallocDo(v Value, retSame bool, inputs ...Value) (retVal Value, err error) {
   319  	t, ok := v.(tensor.Tensor)
   320  	if !ok {
   321  		return nil, errors.Errorf("Expected Tensor as preallocated value. Got %v of %T instead", v, v)
   322  	}
   323  
   324  	reuse := t
   325  	if retSame {
   326  		return o.do(inputs, tensor.WithReuse(reuse), tensor.AsSameType())
   327  	}
   328  	return o.do(inputs, tensor.WithReuse(reuse))
   329  }
   330  
   331  func (o tBinOp) IncrDo(incr Value, retSame bool, inputs ...Value) (err error) {
   332  	reuse, ok := incr.(tensor.Tensor)
   333  	if ok {
   334  		_, err = o.do(inputs, tensor.WithIncr(reuse))
   335  		return
   336  	}
   337  
   338  	var retVal Value
   339  	if retSame {
   340  		if retVal, err = o.do(inputs, tensor.AsSameType()); err != nil {
   341  			return errors.Wrapf(err, doFail, o)
   342  		}
   343  	} else {
   344  		if retVal, err = o.do(inputs); err != nil {
   345  			return errors.Wrapf(err, doFail, o)
   346  		}
   347  
   348  	}
   349  
   350  	add := newEBOByType(addOpType, TypeOf(incr), TypeOf(retVal))
   351  	if retVal, err = add.UnsafeDo(incr, retVal); err != nil {
   352  		return errors.Wrapf(err, unsafeDoFail, add)
   353  	}
   354  
   355  	err = noIncrErr{retVal}
   356  	return
   357  }
   358  
   359  func (o tBinOp) do(vals []Value, opts ...tensor.FuncOpt) (retVal Value, err error) {
   360  	if err = checkArity(o, len(vals)); err != nil {
   361  		return
   362  	}
   363  
   364  	// typecheck the operands
   365  	d0 := vals[0].Dtype()
   366  	d1 := vals[1].Dtype()
   367  
   368  	if d0 != d1 {
   369  		return nil, errors.Errorf("Dtype mismatch for bin op: %v and %v", d0, d1)
   370  	}
   371  
   372  	// extract the goddamn values
   373  	var a, b interface{}
   374  	if o.tensorLeft {
   375  		t, ok := vals[0].(tensor.Tensor)
   376  		if !ok {
   377  			return nil, errors.Errorf("Expected left value to be Tensor. Got %v of %T instead", vals[0], vals[0])
   378  		}
   379  		a = tensor.Materialize(t)
   380  		// a = t
   381  
   382  		switch other := vals[1].(type) {
   383  		case *F64:
   384  			b = other.any()
   385  		case *F32:
   386  			b = other.any()
   387  		case tensor.Tensor:
   388  			b = tensor.Materialize(other)
   389  		default:
   390  			return nil, errors.Errorf(nyiFail, "tBinOp.do()", vals[1])
   391  		}
   392  	} else {
   393  		t, ok := vals[1].(tensor.Tensor)
   394  		if !ok {
   395  			return nil, errors.Errorf("Expected right value to be Tensor. Got %v of %T instead", vals[1], vals[1])
   396  		}
   397  		b = tensor.Materialize(t)
   398  
   399  		switch other := vals[0].(type) {
   400  		case *F64:
   401  			a = other.any()
   402  		case *F32:
   403  			a = other.any()
   404  		case tensor.Tensor:
   405  			a = tensor.Materialize(other)
   406  		default:
   407  			return nil, errors.Errorf(nyiFail, "tBinOp.do()", vals[1])
   408  		}
   409  	}
   410  
   411  	if o.isArith() {
   412  		fn := binOps[o.ʘBinaryOperatorType]
   413  		if fn == nil {
   414  			return nil, errors.Errorf("nil function returned for %v", o.ʘBinaryOperatorType)
   415  		}
   416  		retVal, err = (*fn)(a, b, opts...)
   417  	} else {
   418  		fn := cmpOps[o.ʘBinaryOperatorType]
   419  		if fn == nil {
   420  			return nil, errors.Errorf("nil function returned for %v", o.ʘBinaryOperatorType)
   421  		}
   422  		retVal, err = (*fn)(a, b, opts...)
   423  
   424  	}
   425  	return
   426  }
   427  
   428  // type binDiffFn func(x, y, z, gradZ *Node) (Nodes, err error)
   429  
   430  func addDiffExpr(x, y, z, gradZ *Node) (retVal Nodes, err error) {
   431  	return Nodes{gradZ, gradZ}, nil
   432  }
   433  
   434  func addDiff(ctx ExecutionContext, x, y, z *Node) (err error) {
   435  	xdv, ydv := getDV(x, y)
   436  
   437  	// set up the op to be executed
   438  	op := NewAddOp(x, z, ctx)
   439  	op.Device = x.Device()
   440  	op.UseUnsafe = true
   441  
   442  	// we'll use the same device as the device the data from the node resides in
   443  	dev := op.Device
   444  
   445  	var d, xd, yd, zd Value
   446  	var extra bool
   447  
   448  	// allocate if necessary
   449  	if xd, extra, err = x.GradOnDevice(dev, ctx.External); err != nil {
   450  		return errors.Wrapf(err, gradOnDeviceFail, x, dev)
   451  	}
   452  	if extra {
   453  		defer ctx.PutValue(dev, xd)
   454  	}
   455  
   456  	if zd, extra, err = z.GradOnDevice(dev, ctx.External); err != nil {
   457  		return errors.Wrapf(err, gradOnDeviceFail, z, dev)
   458  	}
   459  	if extra {
   460  		defer ctx.PutValue(dev, xd)
   461  	}
   462  
   463  	// if x is scalar, an additional vector needs to be acquired
   464  	if x.IsScalar() && dev != CPU {
   465  		var mem tensor.Memory
   466  		var xd2 Value
   467  		memsize := calcMemSize(zd.Dtype(), zd.Shape())
   468  		if mem, err = ctx.Get(dev, memsize); err != nil {
   469  			return
   470  		}
   471  
   472  		if xd2, err = makeValueFromMem(z.t, zd.Shape(), mem); err != nil {
   473  			return
   474  		}
   475  
   476  		op.Prealloc = xd2
   477  		defer ctx.Signal()
   478  	}
   479  
   480  	// xd += zd
   481  	if d, err = op.Do(xd, zd); err != nil {
   482  		return errors.Wrapf(err, doFail, op)
   483  	}
   484  	xdv.SetDeriv(d)
   485  
   486  	// set up the op to be executed for y
   487  	op = NewAddOp(y, z, ctx)
   488  	op.Device = y.Device()
   489  	op.UseUnsafe = true
   490  
   491  	dev = op.Device
   492  
   493  	if yd, extra, err = y.GradOnDevice(dev, ctx.External); err != nil {
   494  		return errors.Wrapf(err, gradOnDeviceFail, y, dev)
   495  	}
   496  	if extra {
   497  		defer ctx.PutValue(dev, yd)
   498  	}
   499  
   500  	if zd, extra, err = z.GradOnDevice(dev, ctx.External); err != nil {
   501  		return errors.Wrapf(err, gradOnDeviceFail, z, dev)
   502  	}
   503  	if extra {
   504  		defer ctx.PutValue(dev, zd)
   505  	}
   506  
   507  	// if y is scalar, an additional vector needs to be acquired
   508  	if y.IsScalar() && dev != CPU {
   509  		var mem tensor.Memory
   510  		var yd2 Value
   511  		memsize := calcMemSize(zd.Dtype(), zd.Shape())
   512  		if mem, err = ctx.Get(dev, memsize); err != nil {
   513  			return
   514  		}
   515  		if yd2, err = makeValueFromMem(z.t, zd.Shape(), mem); err != nil {
   516  			return
   517  		}
   518  
   519  		op.Prealloc = yd2
   520  		defer ctx.Signal()
   521  	}
   522  
   523  	// yd += zd
   524  	if d, err = op.Do(yd, zd); err != nil {
   525  		return errors.Wrapf(err, doFail, op)
   526  	}
   527  	ydv.SetDeriv(d) // ignore errors on purpose
   528  
   529  	return nil
   530  }
   531  
   532  func subDiffExpr(x, y, z, gradZ *Node) (retVal Nodes, err error) {
   533  	var dzdy *Node
   534  	if dzdy, err = Neg(gradZ); err == nil {
   535  		WithGroupName(gradClust)(dzdy)
   536  		WithGroupName(gradClust)(gradZ)
   537  		retVal = Nodes{gradZ, dzdy}
   538  	} else {
   539  		return nil, errors.Wrap(err, "Failed to carry Neg()")
   540  	}
   541  	return
   542  }
   543  
   544  func subDiff(ctx ExecutionContext, x, y, z *Node) (err error) {
   545  	xdv, ydv := getDV(x, y)
   546  
   547  	add := NewAddOp(x, z, ctx)
   548  	sub := NewSubOp(y, z, ctx)
   549  	add.Device = x.Device()
   550  	sub.Device = y.Device()
   551  	sub.UseUnsafe = true
   552  	add.UseUnsafe = true
   553  	// sub := newEBOByType(subOpType, y.t, z.t)
   554  	// add := newEBOByType(addOpType, x.t, z.t)
   555  
   556  	dev := sub.Device
   557  	var xd, yd, zd, d Value
   558  	var extra bool
   559  
   560  	if zd, extra, err = z.GradOnDevice(dev, ctx.External); err != nil {
   561  		return errors.Wrapf(err, gradOnDeviceFail, z, dev)
   562  	}
   563  	if extra {
   564  		defer ctx.PutValue(dev, zd)
   565  	}
   566  
   567  	if yd, extra, err = y.GradOnDevice(dev, ctx.External); err != nil {
   568  		return errors.Wrapf(err, gradOnDeviceFail, y, dev)
   569  	}
   570  	if extra {
   571  		defer ctx.PutValue(dev, yd)
   572  	}
   573  
   574  	// if y is scalar an additional vector needs to be allocated for the prelloc
   575  	switch {
   576  	case y.IsScalar() && dev != CPU:
   577  		var mem tensor.Memory
   578  		var yd2 Value
   579  		memsize := calcMemSize(zd.Dtype(), zd.Shape())
   580  		if mem, err = ctx.Get(dev, memsize); err != nil {
   581  			return errors.Wrapf(err, allocFail, memsize, dev)
   582  		}
   583  		if yd2, err = makeValueFromMem(z.t, zd.Shape(), mem); err != nil {
   584  			return errors.Wrapf(err, makeValueFail, z.t, zd.Shape())
   585  		}
   586  
   587  		sub.Prealloc = yd2
   588  		defer ctx.Signal()
   589  	case y.IsScalar() && dev == CPU:
   590  		if sub.Prealloc, err = makeValue(z.t, zd.Shape()); err != nil {
   591  			return
   592  		}
   593  	}
   594  
   595  	// dz/dy
   596  	if d, err = sub.Do(yd, zd); err != nil {
   597  		return errors.Wrapf(err, doFail, sub)
   598  	}
   599  	ydv.SetDeriv(d) // errors are ignored on purpose
   600  
   601  	//	handle x
   602  
   603  	dev = add.Device
   604  	if zd, extra, err = z.GradOnDevice(dev, ctx.External); err != nil {
   605  		return errors.Wrapf(err, gradOnDeviceFail, z, dev)
   606  	}
   607  	if extra {
   608  		defer ctx.PutValue(dev, zd)
   609  	}
   610  
   611  	if xd, extra, err = x.GradOnDevice(dev, ctx.External); err != nil {
   612  		return errors.Wrapf(err, gradOnDeviceFail, x, dev)
   613  	}
   614  	if extra {
   615  		defer ctx.PutValue(dev, xd)
   616  	}
   617  
   618  	switch {
   619  	case x.IsScalar() && dev != CPU:
   620  		var mem tensor.Memory
   621  		var xd2 Value
   622  		memsize := calcMemSize(zd.Dtype(), zd.Shape())
   623  		if mem, err = ctx.Get(dev, memsize); err != nil {
   624  			return
   625  		}
   626  
   627  		if xd2, err = makeValueFromMem(z.t, zd.Shape(), mem); err != nil {
   628  			return
   629  		}
   630  		add.Prealloc = xd2
   631  		defer ctx.Signal()
   632  	case x.IsScalar() && dev == CPU:
   633  		if sub.Prealloc, err = makeValue(z.t, zd.Shape()); err != nil {
   634  			return
   635  		}
   636  	}
   637  
   638  	// dz/dx
   639  	if d, err = add.Do(xd, zd); err != nil {
   640  		return errors.Wrapf(err, doFail, add)
   641  	}
   642  	xdv.SetDeriv(d) // ignore errors on purpose
   643  
   644  	return nil
   645  }
   646  
   647  func hadamardProdDiffExpr(x, y, z, gradZ *Node) (retVal Nodes, err error) {
   648  	var dzdx, dzdy *Node
   649  	if dzdx, err = HadamardProd(y, gradZ); err == nil {
   650  		dzdy, err = HadamardProd(x, gradZ)
   651  		if err != nil {
   652  			return nil, errors.Wrap(err, "Failed to carry HadamardProd()")
   653  		}
   654  		WithGroupName(gradClust)(dzdx)
   655  		WithGroupName(gradClust)(dzdy)
   656  		retVal = Nodes{dzdx, dzdy}
   657  		return
   658  	}
   659  	return nil, errors.Wrap(err, "Failed to carry HadamardProd()")
   660  }
   661  
   662  func hadamardProdDiff(ctx ExecutionContext, x, y, z *Node) (err error) {
   663  	xdv, ydv := getDV(x, y)
   664  
   665  	var mul *ExternalOp
   666  	var dev Device
   667  	var xd, yd, zd, d Value
   668  	var extra bool
   669  
   670  	if x.isConstant() {
   671  		goto dzdy
   672  	}
   673  
   674  	//dzdx
   675  	mul = NewHadamardProdOp(y, z, ctx)
   676  	mul.Device = x.Device()
   677  	dev = mul.Device
   678  
   679  	if xd, extra, err = x.GradOnDevice(dev, ctx.External); err != nil {
   680  		return errors.Wrapf(err, gradOnDeviceFail, x, dev)
   681  	}
   682  	if extra {
   683  		defer ctx.PutValue(dev, xd)
   684  	}
   685  
   686  	if yd, extra, err = y.ValueOnDevice(dev, ctx.External); err != nil {
   687  		return errors.Wrapf(err, gradOnDeviceFail, y, dev)
   688  	}
   689  	if extra {
   690  		defer ctx.PutValue(dev, yd)
   691  	}
   692  
   693  	if zd, extra, err = z.GradOnDevice(dev, ctx.External); err != nil {
   694  		return errors.Wrapf(err, gradOnDeviceFail, z, dev)
   695  	}
   696  	if extra {
   697  		defer ctx.PutValue(dev, zd)
   698  	}
   699  
   700  	mul.Incr = xd
   701  
   702  	// if y is Scalar, then it needs to be broadcasted across to the
   703  	if x.IsScalar() && dev != CPU && !zd.Shape().IsScalar() {
   704  		var memIncr, mem2 tensor.Memory
   705  		var xdIncr, xd2 Value
   706  		memsize := calcMemSize(zd.Dtype(), zd.Shape())
   707  		if mem2, err = ctx.Get(dev, memsize); err != nil {
   708  			return errors.Wrapf(err, allocFail, memsize, dev)
   709  		}
   710  
   711  		if xd2, err = makeValueFromMem(z.t, zd.Shape(), mem2); err != nil {
   712  			return errors.Wrapf(err, makeValueFail, z.t, zd.Shape())
   713  		}
   714  
   715  		// "broadcast" x (in a very sloppy way)
   716  		if memIncr, err = ctx.Get(dev, memsize); err != nil {
   717  			return errors.Wrapf(err, allocFail, memsize, dev)
   718  		}
   719  
   720  		if xdIncr, err = makeValueFromMem(z.t, zd.Shape(), memIncr); err != nil {
   721  			return errors.Wrapf(err, makeValueFail, z.t, zd.Shape())
   722  		}
   723  		xdIncr.(tensor.Tensor).Memset(xdv.d.Data())
   724  
   725  		mul.Prealloc = xd2
   726  		mul.Incr = xdIncr
   727  
   728  		defer ctx.PutValue(dev, xd2) // xd2 is temporary, we need to dealloc it
   729  		defer ctx.Signal()           // work needs to be done
   730  	}
   731  
   732  	if d, err = mul.Do(yd, zd); err != nil {
   733  		return errors.Wrapf(err, "IncrDo xd faile")
   734  	}
   735  
   736  	xdv.SetDeriv(d)
   737  
   738  dzdy:
   739  	if y.isConstant() {
   740  		goto end
   741  	}
   742  
   743  	mul = NewHadamardProdOp(x, z, ctx)
   744  	mul.Device = y.Device()
   745  	dev = mul.Device
   746  
   747  	if xd, extra, err = x.ValueOnDevice(dev, ctx.External); err != nil {
   748  		return errors.Wrapf(err, gradOnDeviceFail, x, dev)
   749  	}
   750  	if extra {
   751  		defer ctx.PutValue(dev, xd)
   752  	}
   753  
   754  	if yd, extra, err = y.GradOnDevice(dev, ctx.External); err != nil {
   755  		return errors.Wrapf(err, gradOnDeviceFail, y, dev)
   756  	}
   757  	if extra {
   758  		defer ctx.PutValue(dev, yd)
   759  	}
   760  
   761  	if zd, extra, err = z.GradOnDevice(dev, ctx.External); err != nil {
   762  		return errors.Wrapf(err, gradOnDeviceFail, z, dev)
   763  	}
   764  	if extra {
   765  		defer ctx.PutValue(dev, zd)
   766  	}
   767  
   768  	mul.Incr = yd
   769  
   770  	// if y is Scalar, then it needs to be broadcasted across to the
   771  	if y.IsScalar() && dev != CPU && !zd.Shape().IsScalar() {
   772  		var memIncr, mem2 tensor.Memory
   773  		var ydIncr, yd2 Value
   774  		memsize := calcMemSize(zd.Dtype(), zd.Shape())
   775  		if mem2, err = ctx.Get(dev, memsize); err != nil {
   776  			return errors.Wrapf(err, allocFail, memsize, dev)
   777  		}
   778  
   779  		if yd2, err = makeValueFromMem(z.t, zd.Shape(), mem2); err != nil {
   780  			return errors.Wrapf(err, makeValueFail, z.t, zd.Shape())
   781  		}
   782  
   783  		// "broadcast" y (in a very sloppy way)
   784  		if memIncr, err = ctx.Get(dev, memsize); err != nil {
   785  			return errors.Wrapf(err, allocFail, memsize, dev)
   786  		}
   787  
   788  		if ydIncr, err = makeValueFromMem(z.t, zd.Shape(), memIncr); err != nil {
   789  			return errors.Wrapf(err, makeValueFail, z.t, zd.Shape())
   790  		}
   791  		ydIncr.(tensor.Tensor).Memset(ydv.d.Data())
   792  
   793  		mul.Prealloc = yd2
   794  		mul.Incr = ydIncr
   795  
   796  		defer ctx.PutValue(dev, yd2) // yd2 is temporary, we need to dealloc it
   797  		defer ctx.Signal()           // work needs to be done
   798  	}
   799  
   800  	if d, err = mul.Do(xd, zd); err != nil {
   801  		return errors.Wrapf(err, "IncrDo yd failed")
   802  	}
   803  	ydv.SetDeriv(d)
   804  
   805  end:
   806  	return nil
   807  }
   808  
   809  func hadamardDivDiffExpr(x, y, z, gradZ *Node) (retVal Nodes, err error) {
   810  	var dzdx, dzdy *Node
   811  	if dzdx, err = HadamardDiv(gradZ, y); err == nil {
   812  		WithGroupName(gradClust)(dzdx)
   813  		if dzdy, err = HadamardDiv(z, y); err == nil {
   814  			WithGroupName(gradClust)(dzdy)
   815  			if dzdy, err = Neg(dzdy); err == nil {
   816  				WithGroupName(gradClust)(dzdy)
   817  				if dzdy, err = HadamardProd(dzdy, gradZ); err == nil {
   818  					WithGroupName(gradClust)(dzdy)
   819  					retVal = Nodes{dzdx, dzdy}
   820  					return
   821  				}
   822  				return nil, errors.Wrap(err, "Failed to carry HadamardProd()")
   823  			}
   824  			return nil, errors.Wrap(err, "Failed to carry Neg()")
   825  		}
   826  		return nil, errors.Wrap(err, "Failed to carry HadamardProd()")
   827  	}
   828  	return nil, errors.Wrap(err, "Failed to carry HadamardProd()")
   829  }
   830  
   831  func hadamardDivDiff(ctx ExecutionContext, x, y, z *Node) (err error) {
   832  	xdv, ydv, zdv := getDV3(x, y, z)
   833  
   834  	// dzdx = 1/y * dz
   835  	div := newEBOByType(divOpType, TypeOf(zdv.d), TypeOf(ydv.Value))
   836  	err = div.IncrDo(xdv.d, zdv.d, ydv.Value)
   837  	if err != nil {
   838  		var ver Valuer
   839  		var ok bool
   840  		if ver, ok = err.(Valuer); !ok {
   841  			return
   842  		}
   843  
   844  		xdv.SetDeriv(ver.Value()) // ignore errors on purpose
   845  	}
   846  
   847  	//dzdy = -x/y²
   848  	// TODO: investigate if this can be done (if no other node uses z):
   849  	//		unsafe do : neg zdv.d
   850  	// 		unsafe do : mul zdv.d, zdv.Value
   851  	//		incr do   : <incr: ydv.d> div zdv.d, ydv.Value
   852  	var d Value
   853  	if d, err = div.Do(zdv.Value, ydv.Value); err != nil {
   854  		return errors.Wrapf(err, doFail, div)
   855  	}
   856  
   857  	neg := newElemUnaryOp(negOpType, y)
   858  	if d, err = neg.Do(d); err != nil {
   859  		return errors.Wrapf(err, doFail, neg)
   860  	}
   861  
   862  	mul := newElemBinOp(mulOpType, z, y)
   863  	err = mul.IncrDo(ydv.d, zdv.d, d)
   864  	if err != nil {
   865  		var ver Valuer
   866  		var ok bool
   867  		if ver, ok = err.(Valuer); !ok {
   868  			return
   869  		}
   870  
   871  		ydv.SetDeriv(ver.Value()) // ignore errors on purpose
   872  	}
   873  
   874  	return nil
   875  }
   876  
   877  // TODO: go back in time, pay more attention to calculus class in high school and learn how to differentiate x^y
   878  func hadamardPowDiffExpr(x, y, z, grad *Node) (retVal Nodes, err error) {
   879  	var one *Node
   880  	var dt tensor.Dtype
   881  
   882  	if dt, err = dtypeOf(y.t); err != nil {
   883  		return nil, errors.Wrapf(err, dtypeExtractionFail, y.t)
   884  	}
   885  
   886  	switch dt {
   887  	case Float32:
   888  		one = onef32
   889  	case Float64:
   890  		one = onef64
   891  	default:
   892  		err = errors.Errorf(nyiTypeFail, "Hadamard Power Diff", y.t)
   893  		return
   894  	}
   895  
   896  	var ym1, pow *Node
   897  	if ym1, err = Sub(y, one); err != nil {
   898  		return
   899  	}
   900  
   901  	if pow, err = Pow(x, ym1); err != nil {
   902  		return
   903  	}
   904  
   905  	var dzdx *Node
   906  	if dzdx, err = HadamardProd(grad, y); err != nil {
   907  		return
   908  	}
   909  	if dzdx, err = HadamardProd(dzdx, pow); err != nil {
   910  		return
   911  	}
   912  
   913  	var logx *Node
   914  	if logx, err = Log(x); err != nil {
   915  		return
   916  	}
   917  
   918  	var dzdy *Node
   919  	if dzdy, err = HadamardProd(grad, z); err != nil {
   920  		return
   921  	}
   922  	if dzdy, err = HadamardProd(dzdy, logx); err != nil {
   923  		return
   924  	}
   925  
   926  	retVal = Nodes{dzdx, dzdy}
   927  	return
   928  	// return nil, errors.New("hadamardPowDiffExpr not yet implemented")
   929  }
   930  
   931  func hadamardPowDiff(ctx ExecutionContext, x, y, z *Node) (err error) {
   932  	xdv, ydv, zdv := getDV3(x, y, z)
   933  
   934  	var ym1 Value
   935  	switch ydvt := ydv.Value.(type) {
   936  	case *F64:
   937  		ym1 = NewF64(ydvt.any() - float64(1))
   938  	case *F32:
   939  		ym1 = NewF32(ydvt.any() - float32(1))
   940  	case *tensor.Dense:
   941  		var one interface{}
   942  		switch ydvt.Dtype() {
   943  		case tensor.Float64:
   944  			one = float64(1)
   945  		case tensor.Float32:
   946  			one = float32(1)
   947  		}
   948  		if ym1, err = tensor.Sub(ydvt, one); err != nil {
   949  			return
   950  		}
   951  	default:
   952  		err = errors.Errorf(nyiTypeFail, "hadamardPowDiff", ydv.Value)
   953  		return
   954  	}
   955  
   956  	// dzdx
   957  	var pow Value
   958  	powOp := newEBOByType(powOpType, TypeOf(xdv.Value), TypeOf(ym1))
   959  	if pow, err = powOp.Do(xdv.Value, ym1); err != nil {
   960  		return
   961  	}
   962  
   963  	mul := newEBOByType(mulOpType, TypeOf(ydv.Value), TypeOf(xdv.Value))
   964  	if pow, err = mul.UnsafeDo(pow, ydv.Value); err != nil {
   965  		return
   966  	}
   967  
   968  	if err = mul.IncrDo(xdv.d, pow, zdv.d); err != nil {
   969  		var ver Valuer
   970  		var ok bool
   971  		if ver, ok = err.(Valuer); !ok {
   972  			return
   973  		}
   974  
   975  		xdv.SetDeriv(ver.Value())
   976  	}
   977  
   978  	// dzdy
   979  	var logx Value
   980  	logOp := newElemUnaryOp(lnOpType, x)
   981  	if logx, err = logOp.Do(xdv.Value); err != nil {
   982  		return
   983  	}
   984  	if logx, err = mul.Do(zdv.Value, logx); err != nil {
   985  		return
   986  	}
   987  	if err = mul.IncrDo(ydv.d, logx, zdv.d); err != nil {
   988  		var ver Valuer
   989  		var ok bool
   990  		if ver, ok = err.(Valuer); !ok {
   991  			return
   992  		}
   993  
   994  		ydv.SetDeriv(ver.Value())
   995  	}
   996  	return nil
   997  }
   998  
   999  func nondiffBinOpExpr(x, y, z, grad *Node) (retVal Nodes, err error) {
  1000  	return nil, errors.New("Nondifferentiable")
  1001  }
  1002  
  1003  func nondiffBinOp(ctx ExecutionContext, x, y, z *Node) (err error) {
  1004  	return AutoDiffError{}
  1005  }