gorgonia.org/gorgonia@v0.9.17/op_math.go (about) 1 package gorgonia 2 3 /* 4 This file holds all the Ops that are related to doing math-related work. Due to the numerousness of 5 mathematical operations, they're classified into 3 main types: 6 elemBinOp - a representation of a binary mathematical operation that is performed elementwise (example: +, *, -, or >, <) 7 elemUnaryOp - a representation of a mathematical operation that is performed elmentwise 8 linAlgBinOp - a representation of a binary mathematical operation that is performed on matrices 9 10 The individual operators are further exanded on operator*.go files. Their datatypes are often embedded in the datatypes here. 11 12 For all data type, the methods are standardized by arrangement in the order the Op interface is defined. 13 Any additional interfaces that the data type fulfils will be declared AFTER the Op interface methods. 14 */ 15 16 import ( 17 "bytes" 18 "encoding/binary" 19 "fmt" 20 "hash" 21 22 "github.com/chewxy/hm" 23 "github.com/pkg/errors" 24 "gorgonia.org/tensor" 25 ) 26 27 /* ELEMENTWISE BINARY OPERATION */ 28 29 // elemBinOp is the representation of an operation that is to be performed elementwise 30 type elemBinOp struct { 31 ʘBinaryOperator 32 arg0, arg1 hm.Type // pruned types only plz 33 retSame bool // for comparison ops, return same type? 34 } 35 36 func newEBOByType(ot ʘBinaryOperatorType, at, bt hm.Type) elemBinOp { 37 var binOp ʘBinaryOperator 38 switch att := at.(type) { 39 case tensor.Dtype: 40 switch bt.(type) { 41 case tensor.Dtype: 42 binOp = scalarBinOp{ 43 ʘBinaryOperatorType: ot, 44 t: att, 45 } 46 case TensorType: 47 binOp = tBinOp{ 48 ʘBinaryOperatorType: ot, 49 tensorLeft: false, 50 } 51 default: 52 panic(fmt.Sprintf("Unsupported type of b %v!", bt)) 53 } 54 case TensorType: 55 binOp = tBinOp{ 56 ʘBinaryOperatorType: ot, 57 tensorLeft: true, 58 } 59 default: 60 panic(fmt.Sprintf("Unsupported type of a %v!", at)) 61 } 62 return elemBinOp{ 63 ʘBinaryOperator: binOp, 64 arg0: at, 65 arg1: bt, 66 } 67 } 68 69 func newElemBinOp(ot ʘBinaryOperatorType, a, b *Node) elemBinOp { 70 // at := hm.Prune(a.t) 71 // bt := hm.Prune(b.t) 72 73 return newEBOByType(ot, a.t, b.t) 74 } 75 76 func (op elemBinOp) Arity() int { return 2 } 77 78 // elemBinOp has either of these types: 79 // elemBinOp :: (Floats a) ⇒ Tensor a → Tensor a → Tensor a 80 // elemBinOp :: (Floats a) ⇒ Tensor a → a → Tensor a 81 // elemBinOp :: (Floats a) ⇒ a → Tensor a → a 82 // elemBinOp :: (Floats a) ⇒ a → a → a 83 // elemBinOp :: (Floats a) ⇒ a → a → Bool 84 // elemBinOp :: (Floats a) ⇒ Tensor a → Tensor a → Tensor Bool 85 // elemBinOp :: (Floats a) ⇒ Tensor a → a → Tensor Bool 86 // elemBinOp :: (Floats a) ⇒ a → Tensor a → Bool 87 // 88 // To make things clearer, it helps to consider elemBinOp to be the representation of 89 // a dispatch table for different functions. In a sense it's "overloading" functions. 90 // 91 // At the moment, due to my refusal to create a sum type (which requires more finnicking with data constructors) 92 // Type() happens pretty much at close to run time 93 func (op elemBinOp) Type() hm.Type { 94 a := hm.TypeVariable('a') 95 96 var a0, a1, retType hm.Type 97 var arg0Dims int 98 switch arg0 := op.arg0.(type) { 99 case TensorType: 100 arg0Dims = arg0.Dims 101 a0 = makeFromTensorType(arg0, a) 102 retType = makeFromTensorType(arg0, a) 103 default: 104 a0 = a 105 retType = a 106 } 107 108 switch arg1 := op.arg1.(type) { 109 case TensorType: 110 if arg1.Dims >= arg0Dims { 111 retType = makeFromTensorType(arg1, a) 112 } 113 a1 = makeFromTensorType(arg1, a) 114 default: 115 a1 = a 116 } 117 118 if op.isArith() || (!op.isArith() && op.retSame) { 119 return hm.NewFnType(a0, a1, retType) 120 } 121 122 switch rt := retType.(type) { 123 case TensorType: 124 rt.Of = Bool 125 retType = rt 126 default: 127 retType = Bool 128 } 129 130 return hm.NewFnType(a0, a1, retType) 131 } 132 133 // elemBinOp has these allowed shapes: 134 // op :: () → () → () 135 // op :: () → (...) → (...) 136 // op :: (...) → () → (...) 137 func (op elemBinOp) InferShape(inputs ...DimSizer) (retVal tensor.Shape, err error) { 138 shapeLogf("Inferring shape of %v", op) 139 enterLogScope() 140 defer leaveLogScope() 141 142 if inputs[0] == nil || inputs[1] == nil { 143 return nil, errors.Errorf(nyiFail, "elemBinOp.inferShape", "runtime impl") 144 } 145 146 switch x := inputs[0].(type) { 147 case tensor.Shape: 148 switch y := inputs[1].(type) { 149 case tensor.Shape: 150 switch { 151 case x.IsScalarEquiv() && y.IsScalarEquiv(): 152 // preserve ambiguous scalar shape 153 switch { 154 case len(x) > 0 && x[0] == 1: 155 retVal = x 156 case len(y) > 0 && y[0] == 1: 157 retVal = y 158 case x.IsScalar() && y.IsScalar(): 159 retVal = scalarShape 160 default: 161 retVal = scalarShape 162 } 163 case x.IsScalar() && !y.IsScalar(): 164 retVal = y 165 case !x.IsScalar() && y.IsScalar(): 166 retVal = x 167 case !x.IsScalar() && !y.IsScalar(): 168 if !x.Eq(y) { 169 return nil, errors.Errorf("Shape mismatch: %v and %v", x, y) 170 } 171 if x.Dims() > y.Dims() { 172 retVal = x 173 } else { 174 retVal = y 175 } 176 } 177 default: 178 retVal = x 179 } 180 default: 181 switch y := inputs[1].(type) { 182 case tensor.Shape: 183 retVal = y 184 default: 185 retVal = scalarShape 186 } 187 } 188 return 189 } 190 191 // DiffWRT gives info on whether or not the operation is actually differentiable 192 // For example, this is differentiable: 193 // c = a ** b 194 // The result of the differentiation wrt to a and b would be: 195 // dc/da = b * a ** (b-1) 196 // dc/db = a ** b * ln(a) 197 // 198 // However, operators like < and > are NOT differentiable 199 // 200 // This method returns a slice of bools, indicating whether differentiation with regards to its operands 201 // can be done. Since binOp has 2 operands, we'll return a slice 202 func (op elemBinOp) DiffWRT(inputs int) []bool { 203 if inputs != 2 { 204 panic(fmt.Sprintf(binOpFail, inputs)) 205 } 206 207 b := op.ʘBinaryOperator.binOpType() 208 209 if b >= maxʘBinaryOpType { 210 panic("Unsupported unary operator is not differentiable") 211 } 212 213 if b.isArith() { 214 return []bool{true, true} 215 } 216 return []bool{false, false} 217 } 218 219 func (op elemBinOp) SymDiff(inputs Nodes, output, gradNode *Node) (retVal Nodes, err error) { 220 if err = checkArity(op, len(inputs)); err != nil { 221 return 222 } 223 224 b := op.ʘBinaryOperator.binOpType() 225 226 if retVal, err = ʘBinOpDiffExprs[b](inputs[0], inputs[1], output, gradNode); err == nil { 227 for _, n := range retVal { 228 n.setGroup(gradClust) 229 } 230 } 231 232 // needed to handle scalar gradients such as b in the logit regression example 233 for i, grad := range retVal { 234 if inputs[i].IsScalar() && !grad.IsScalar() { 235 if retVal[i], err = Sum(grad); err != nil { 236 err = errors.Wrap(err, operationError) 237 return 238 } 239 } 240 } 241 242 return 243 } 244 245 func (op elemBinOp) Do(values ...Value) (Value, error) { 246 return op.ʘBinaryOperator.Do(op.retSame, values...) 247 } 248 249 func (op elemBinOp) DoDiff(ctx ExecutionContext, inputs Nodes, output *Node) (err error) { 250 if err = checkArity(op, len(inputs)); err != nil { 251 return 252 } 253 254 b := op.ʘBinaryOperator.binOpType() 255 if err = ʘBinOpDiffFns[b](ctx, inputs[0], inputs[1], output); err != nil { 256 if _, ok := err.(AutoDiffError); !ok { 257 return errors.Wrapf(err, autodiffFail, b) 258 } 259 err = nil 260 } 261 262 //handle scalar gradients 263 for _, in := range inputs { 264 indv := in.boundTo.(*dualValue) 265 if _, ok := indv.d.(Scalar); in.IsScalar() && !ok { 266 indvdT := indv.d.(tensor.Tensor) 267 defer returnTensor(indvdT) 268 269 var d Value 270 var t tensor.Tensor 271 if t, err = tensor.Sum(indvdT); err != nil { 272 return errors.Wrap(err, operationError) 273 } 274 defer returnTensor(t) 275 276 d, _ = anyToScalar(t.ScalarValue()) 277 indv.SetDeriv(d) 278 } 279 } 280 return 281 } 282 283 func (op elemBinOp) ReturnsPtr() bool { return true } 284 285 func (op elemBinOp) OverwritesInput() int { 286 if _, ok := op.arg0.(TensorType); ok { 287 return 0 288 } 289 290 if _, ok := op.arg1.(TensorType); ok { 291 return 1 292 } 293 return -1 294 } 295 296 func (op elemBinOp) WriteHash(h hash.Hash) { 297 if err := binary.Write(h, binary.LittleEndian, op.binOpType()); err != nil { 298 panic(err) 299 } 300 301 fmt.Fprintf(h, "%v,%v", op.arg0, op.arg1) 302 } 303 304 func (op elemBinOp) Hashcode() uint32 { return simpleHash(op) } 305 306 // Fulfils UsePreallocDoer interface 307 func (op elemBinOp) UsePreallocDo(prealloc Value, inputs ...Value) (retVal Value, err error) { 308 if !op.ReturnsPtr() { 309 return op.Do(inputs...) 310 } 311 312 if pd, ok := op.ʘBinaryOperator.(usePreallocDoerBinOp); ok { 313 return pd.UsePreallocDo(prealloc, op.retSame, inputs...) 314 } 315 316 if retVal, err = op.Do(inputs...); err != nil { 317 return 318 } 319 return Copy(prealloc, retVal) 320 } 321 322 // Fulfils UnsafeDoer interface 323 func (op elemBinOp) UnsafeDo(inputs ...Value) (retVal Value, err error) { 324 if !op.ReturnsPtr() { 325 return op.Do(inputs...) 326 } 327 328 if ud, ok := op.ʘBinaryOperator.(unsafeDoerBinOp); ok { 329 return ud.UnsafeDo(op.retSame, inputs...) 330 } 331 return op.Do(inputs...) 332 } 333 334 // Fulfils the IncrDoer interface 335 func (op elemBinOp) IncrDo(incr Value, inputs ...Value) (err error) { 336 if id, ok := op.ʘBinaryOperator.(incrDoerBinOp); ok { 337 return id.IncrDo(incr, op.retSame, inputs...) 338 } 339 340 // if !op.ReturnsPtr() { 341 var retVal Value 342 if retVal, err = op.Do(inputs...); err != nil { 343 return errors.Wrapf(err, doFail, op) 344 } 345 346 add := newEBOByType(addOpType, TypeOf(incr), TypeOf(retVal)) 347 if retVal, err = add.UnsafeDo(incr, retVal); err != nil { 348 return errors.Wrapf(err, unsafeDoFail, add) 349 } 350 err = noIncrErr{retVal} 351 return 352 // } 353 } 354 355 func (op elemBinOp) String() string { return fmt.Sprintf("%v %t", op.ʘBinaryOperator, op.retSame) } 356 357 // Fulfils the BinaryOp interface 358 func (op elemBinOp) IsBinary() bool { return true } 359 360 /* ELEMENTWISE UNARY OP */ 361 362 type elemUnaryOp struct { 363 ʘUnaryOperator 364 365 argTensor bool 366 numericResult bool // indicate if boolean results should be converted to 1 and 0 in the respective Dtype 367 } 368 369 func newElemUnaryOp(op ʘUnaryOperatorType, a *Node) elemUnaryOp { 370 dt, err := dtypeOf(a.t) 371 if err != nil { 372 panic(err) 373 } 374 375 _, isTensor := a.t.(TensorType) 376 377 var operator ʘUnaryOperator 378 switch dt { 379 case Float32: 380 operator = sf32UnaryOperators[op] 381 case Float64: 382 operator = sf64UnaryOperators[op] 383 } 384 385 return elemUnaryOp{ 386 ʘUnaryOperator: operator, 387 argTensor: isTensor, 388 } 389 } 390 391 func (op elemUnaryOp) Arity() int { return 1 } 392 393 // all pointwise unary operations have this type: 394 // op :: (Arithable a) ⇒ a → a 395 func (op elemUnaryOp) Type() hm.Type { 396 a := hm.TypeVariable('a') 397 return hm.NewFnType(a, a) 398 } 399 400 func (op elemUnaryOp) InferShape(inputs ...DimSizer) (retVal tensor.Shape, err error) { 401 if inputs[0] == nil { 402 return nil, errors.Errorf(nyiFail, "inferShape", "nil shape") 403 } 404 405 return inputs[0].(tensor.Shape), nil 406 } 407 408 // diffWRT gives info on whether or not the operation is actually differentiable wrt to its inputs 409 // 410 // some operations, such as ceil(), sign(), floor cannot be differentiated wrt to its inputs (or I don't actually know how to do them) 411 func (op elemUnaryOp) DiffWRT(inputs int) []bool { 412 if inputs != 1 { 413 panic(fmt.Sprintf("unary operator only supports one input, got %d instead", inputs)) 414 } 415 416 u := op.ʘUnaryOperator.unaryOpType() 417 418 if u >= maxʘUnaryOperator { 419 panic("Unsupported unary operator is not differentiable") 420 } 421 return []bool{ʘUnaryOpDifferentiable[u]} 422 } 423 424 func (op elemUnaryOp) SymDiff(inputs Nodes, output, gradNode *Node) (retVal Nodes, err error) { 425 if err = checkArity(op, len(inputs)); err != nil { 426 return 427 } 428 429 u := op.ʘUnaryOperator.unaryOpType() 430 431 var n *Node 432 if n, err = ʘUnaryOpDiffExprs[u](inputs[0], output, gradNode); err == nil { 433 n.setGroup(gradClust) 434 retVal = Nodes{n} 435 } 436 return 437 } 438 439 func (op elemUnaryOp) DoDiff(ctx ExecutionContext, inputs Nodes, output *Node) (err error) { 440 if err = checkArity(op, len(inputs)); err != nil { 441 return 442 } 443 444 u := op.ʘUnaryOperator.unaryOpType() 445 return ʘUnaryOpDiffFns[u](inputs[0], output) 446 } 447 448 func (op elemUnaryOp) Do(inputs ...Value) (retVal Value, err error) { 449 if err = checkArity(op, len(inputs)); err != nil { 450 return 451 } 452 return op.do(inputs[0]) 453 } 454 455 func (op elemUnaryOp) ReturnsPtr() bool { return true } 456 457 func (op elemUnaryOp) OverwritesInput() int { 458 if op.argTensor { 459 return 0 460 } 461 return -1 462 } 463 464 func (op elemUnaryOp) WriteHash(h hash.Hash) { 465 if err := binary.Write(h, binary.LittleEndian, op.unaryOpType()); err != nil { 466 panic(err) 467 } 468 469 if op.argTensor { 470 h.Write([]byte{1}) 471 } else { 472 h.Write([]byte{0}) 473 } 474 } 475 476 func (op elemUnaryOp) Hashcode() uint32 { return simpleHash(op) } 477 478 // fulfils UnsafeDoer interface 479 func (op elemUnaryOp) UnsafeDo(inputs ...Value) (Value, error) { 480 if err := checkArity(op, len(inputs)); err != nil { 481 return nil, err 482 } 483 return op.do(inputs[0], tensor.UseUnsafe()) 484 } 485 486 // fulfils UnaryOp interface 487 488 func (op elemUnaryOp) isUnary() bool { return true } 489 490 // misc private methods 491 492 func (op elemUnaryOp) do(a Value, opts ...tensor.FuncOpt) (retVal Value, err error) { 493 switch v := a.(type) { 494 case tensor.Tensor: 495 return unaryCheckApply(op.ʘUnaryOperator, v, opts...) 496 case Scalar: 497 vt := v.Dtype() 498 switch vt { 499 case tensor.Float32: 500 vs := v.(*F32) 501 f := float32(*vs) 502 opFn := op.ʘUnaryOperator.(*sf32UnaryOperator) 503 retVal, _ = anyToScalar((*opFn)(f)) 504 case tensor.Float64: 505 vs := v.(*F64) 506 f := float64(*vs) 507 opFn := op.ʘUnaryOperator.(*sf64UnaryOperator) 508 retVal, _ = anyToScalar((*opFn)(f)) 509 default: 510 return nil, errors.Errorf(nyiFail, "elemUnaryOp.do", vt) 511 } 512 } 513 return 514 } 515 516 /* LINEAR ALGEBRA RELATED OPERATIONS */ 517 518 type linAlgBinOp struct { 519 āBinaryOperator 520 transA, transB bool 521 } 522 523 func (op linAlgBinOp) Arity() int { return 2 } 524 525 func (op linAlgBinOp) InferShape(inputs ...DimSizer) (retVal tensor.Shape, err error) { 526 shapeLogf("Inferring shape of %v", op) 527 enterLogScope() 528 defer leaveLogScope() 529 530 if inputs[0] == nil || inputs[1] == nil { 531 return nil, nyi("InferShape for linalgBinOp", "runtime impl") 532 } 533 534 x, y := inputs[0].(tensor.Shape), inputs[1].(tensor.Shape) 535 if x == nil || y == nil { 536 return nil, errors.Errorf("Cannot infer shape from %v %v", x, y) 537 } 538 539 shapeLogf("x.shape: %v; y.shape: %v", x, y) 540 // TODO: add checks for tensors greater than 2 d 541 542 switch op.āBinaryOperator { 543 case matMulOperator: 544 if op.transA { 545 x = transpose2D(x) 546 defer tensor.ReturnInts(x) 547 } 548 if op.transB { 549 y = transpose2D(y) 550 defer tensor.ReturnInts(y) 551 } 552 553 if x[1] != y[0] { 554 return nil, errors.Errorf("Inner dimensions do not match up") 555 } 556 557 retVal = tensor.Shape{x[0], y[1]} 558 case matVecMulOperator: 559 if op.transA { 560 x = transpose2D(x) 561 defer tensor.ReturnInts(x) 562 } 563 if x[0] != y[0] && x[1] != y[0] { 564 return nil, errors.Errorf("Incompatible shapes: %v and %v", x, y) 565 } 566 567 switch { 568 case x[0] == y[0]: 569 retVal = tensor.Shape{x[1]} 570 case x[1] == y[0]: 571 retVal = tensor.Shape{x[0]} 572 } 573 574 case vecDotOperator: 575 retVal = scalarShape 576 case outerProdOperator: 577 // outerprods only handles vec x vec for now 578 retVal = tensor.Shape{x.TotalSize(), y.TotalSize()} 579 case batchedMatMulOperator: 580 x = x.Clone() 581 y = y.Clone() 582 innerX := x[len(x)-2:] 583 outerX := x[:len(x)-2] 584 innerY := y[len(y)-2:] 585 outerY := y[:len(y)-2] 586 if !outerX.Eq(outerY) { 587 return nil, errors.Errorf("Expected outer dimensions of %v and %v to match. Got %v and %v", x, y, outerX, outerY) 588 } 589 590 // batchSize := outerX.TotalSize() 591 if op.transA { 592 innerX = transpose2D(innerX) 593 defer tensor.ReturnInts(innerX) 594 } 595 if op.transB { 596 innerY = transpose2D(innerY) 597 defer tensor.ReturnInts(innerY) 598 } 599 retVal = append(outerX, innerX[0], innerY[1]) 600 } 601 return 602 } 603 604 func (op linAlgBinOp) SymDiff(inputs Nodes, output, gradNode *Node) (retVal Nodes, err error) { 605 if err = checkArity(op, len(inputs)); err != nil { 606 return 607 } 608 609 o := op.āBinaryOperator 610 611 if retVal, err = āBinOpDiffExprs[o](op.transA, op.transB, inputs[0], inputs[1], output, gradNode); err != nil { 612 return nil, errors.Wrap(err, "Failed to differentiate expressions") 613 } 614 615 for _, n := range retVal { 616 n.setGroup(gradClust) 617 } 618 return 619 } 620 621 func (op linAlgBinOp) DoDiff(ctx ExecutionContext, inputs Nodes, output *Node) (err error) { 622 if err = checkArity(op, len(inputs)); err != nil { 623 return 624 } 625 626 o := op.āBinaryOperator 627 return āBinOpDiffs[o](ctx, op.transA, op.transB, inputs[0], inputs[1], output) 628 } 629 630 func (op linAlgBinOp) Do(inputs ...Value) (retVal Value, err error) { return op.do(inputs) } 631 func (op linAlgBinOp) ReturnsPtr() bool { return true } 632 func (op linAlgBinOp) OverwritesInput() int { return -1 } 633 634 func (op linAlgBinOp) WriteHash(h hash.Hash) { 635 if err := binary.Write(h, binary.LittleEndian, op.āBinaryOperator); err != nil { 636 panic(err) 637 } 638 639 if op.transA { 640 h.Write([]byte{1}) 641 } else { 642 h.Write([]byte{0}) 643 } 644 645 if op.transB { 646 h.Write([]byte{1}) 647 } else { 648 h.Write([]byte{0}) 649 } 650 } 651 652 func (op linAlgBinOp) Hashcode() uint32 { return simpleHash(op) } 653 654 func (op linAlgBinOp) String() string { 655 var buf bytes.Buffer 656 657 switch op.āBinaryOperator { 658 case matMulOperator, matVecMulOperator, batchedMatMulOperator: 659 buf.WriteString("A") 660 case vecDotOperator, outerProdOperator: 661 buf.WriteString("a") 662 } 663 664 if op.transA { 665 buf.WriteString("ᵀ") 666 } 667 668 switch op.āBinaryOperator { 669 case matMulOperator, batchedMatMulOperator: 670 fmt.Fprintf(&buf, " %v B", op.āBinaryOperator) 671 case matVecMulOperator, vecDotOperator, outerProdOperator: 672 fmt.Fprintf(&buf, " %v b", op.āBinaryOperator) 673 } 674 675 if op.transB { 676 buf.WriteString("ᵀ") 677 } 678 679 return buf.String() 680 } 681 682 // fulfils IncrDoer 683 func (op linAlgBinOp) IncrDo(incr Value, inputs ...Value) (err error) { 684 t, ok := incr.(tensor.Tensor) 685 686 switch { 687 case ok && op.āBinaryOperator != batchedMatMulOperator: 688 _, err = op.do(inputs, tensor.WithIncr(t)) 689 return 690 case ok && op.āBinaryOperator == batchedMatMulOperator: 691 _, err = op.preallocBatchMatMul(true, incr, inputs...) 692 return 693 } 694 695 var retVal Value 696 if retVal, err = op.do(inputs); err != nil { 697 return errors.Wrapf(err, doFail, op) 698 } 699 700 add := newEBOByType(addOpType, TypeOf(incr), TypeOf(retVal)) 701 if retVal, err = add.UnsafeDo(incr, retVal); err != nil { 702 return errors.Wrapf(err, unsafeDoFail, add) 703 } 704 705 err = noIncrErr{retVal} 706 return 707 } 708 709 // fulfils UsePreallocDoer 710 func (op linAlgBinOp) UsePreallocDo(prealloc Value, inputs ...Value) (retVal Value, err error) { 711 t, ok := prealloc.(tensor.Tensor) 712 if !ok { 713 return nil, errors.Errorf("Expected Tensor as preallocated value. Got %v of %T instead", prealloc, prealloc) 714 } 715 if op.āBinaryOperator == batchedMatMulOperator { 716 return op.preallocBatchMatMul(false, prealloc, inputs...) 717 } 718 return op.do(inputs, tensor.WithReuse(t)) 719 } 720 721 // fulfils BinaryOp 722 func (op linAlgBinOp) IsBinary() bool { return true } 723 724 /* PRIVATE METHODS */ 725 726 func (op linAlgBinOp) do(inputs []Value, opts ...tensor.FuncOpt) (retVal Value, err error) { 727 if err = checkArity(op, len(inputs)); err != nil { 728 return 729 } 730 731 a, b := inputs[0].(tensor.Tensor), inputs[1].(tensor.Tensor) 732 733 if op.transA && op.āBinaryOperator != batchedMatMulOperator { 734 if err = a.T(); err != nil { 735 return nil, errors.Wrap(err, tFail) 736 } 737 738 // untranspose 739 defer a.T() 740 } 741 742 if op.transB && op.āBinaryOperator != batchedMatMulOperator { 743 if err = b.T(); err != nil { 744 return nil, errors.Wrap(err, tFail) 745 } 746 747 // untranspose 748 defer b.T() 749 } 750 751 switch op.āBinaryOperator { 752 case matMulOperator: 753 retVal, err = tensor.MatMul(a, b, opts...) 754 case matVecMulOperator: 755 retVal, err = tensor.MatVecMul(a, b, opts...) 756 case vecDotOperator: 757 var ret interface{} 758 759 if ret, err = tensor.Inner(a, b); err != nil { 760 return nil, errors.Wrapf(err, "Failed to carry out linalgBinOp operation %v", op) 761 } 762 763 retVal, _ = anyToScalar(ret) 764 case outerProdOperator: 765 retVal, err = tensor.Outer(a, b, opts...) 766 case batchedMatMulOperator: 767 // checks were done when the op was created 768 retVal, err = batchedMatMul(a, b, nil, op.transA, op.transB, false) 769 } 770 771 if err != nil { 772 return nil, fmt.Errorf("linAlgBinOp %v %s %v error: %w", a.Shape(), op.āBinaryOperator, b.Shape(), err) 773 } 774 775 return retVal, nil 776 } 777 778 func (op linAlgBinOp) preallocBatchMatMul(incr bool, prealloc Value, inputs ...Value) (retVal Value, err error) { 779 if err = checkArity(op, len(inputs)); err != nil { 780 return 781 } 782 a, b := inputs[0].(tensor.Tensor), inputs[1].(tensor.Tensor) 783 c := prealloc.(tensor.Tensor) 784 return batchedMatMul(a, b, c, op.transA, op.transB, incr) 785 } 786 787 type tensordotOp struct { 788 aAxes []int 789 bAxes []int 790 aDims int 791 bDims int 792 retDims int // Dimension of the tensor resulting from operation 793 } 794 795 func makeTensordotOp(a, b *Node, aAxes, bAxes []int) tensordotOp { 796 aDims := a.Shape().Dims() 797 bDims := b.Shape().Dims() 798 retDims := a.Shape().Dims() + b.Shape().Dims() - 2*len(aAxes) 799 if retDims < 0 { 800 retDims = 0 801 } 802 return tensordotOp{ 803 aAxes: aAxes, 804 bAxes: bAxes, 805 aDims: aDims, 806 bDims: bDims, 807 retDims: retDims, 808 } 809 } 810 811 func (op tensordotOp) Arity() int { return 2 } 812 813 func (op tensordotOp) Type() hm.Type { 814 var tRet hm.Type 815 if op.retDims == 0 { 816 tRet = hm.TypeVariable('a') 817 } else { 818 tRet = newTensorType(op.retDims, hm.TypeVariable('a')) 819 } 820 ta := newTensorType(op.aDims, hm.TypeVariable('a')) 821 tb := newTensorType(op.bDims, hm.TypeVariable('a')) 822 823 return hm.NewFnType(ta, tb, tRet) 824 } 825 826 func (op tensordotOp) InferShape(ds ...DimSizer) (tensor.Shape, error) { 827 if err := checkArity(op, len(ds)); err != nil { 828 return nil, errors.Wrap(err, "tensordot") 829 } 830 831 shapes, err := DimSizersToShapes(ds) 832 if err != nil { 833 return nil, err 834 } 835 836 aShape := shapes[0] 837 bShape := shapes[1] 838 839 aAxes := op.aAxes 840 bAxes := op.bAxes 841 842 shapeBackingLen := op.retDims 843 844 shapeBacking := make([]int, shapeBackingLen, shapeBackingLen) 845 846 shapeBackingPos := 0 847 848 for aShapeIndex, aShapeValue := range aShape { 849 if 0 > contains(aAxes, aShapeIndex) { 850 shapeBacking[shapeBackingPos] = aShapeValue 851 shapeBackingPos++ 852 } 853 } 854 855 for bShapeIndex, bShapeValue := range bShape { 856 if 0 > contains(bAxes, bShapeIndex) { 857 shapeBacking[shapeBackingPos] = bShapeValue 858 shapeBackingPos++ 859 } 860 } 861 862 return tensor.Shape(shapeBacking), nil 863 } 864 865 func (op tensordotOp) Do(vals ...Value) (Value, error) { 866 if err := checkArity(op, len(vals)); err != nil { 867 return nil, errors.Wrap(err, "tensordot") 868 } 869 870 ts, err := valuesToTensors(vals) 871 if err != nil { 872 return nil, errors.Wrap(err, "tensordot - valuesToTensors failed") 873 } 874 875 return tensor.Contract(ts[0], ts[1], op.aAxes, op.bAxes) 876 } 877 878 func (op tensordotOp) ReturnsPtr() bool { return true } 879 880 func (op tensordotOp) CallsExtern() bool { return false } 881 882 func (op tensordotOp) OverwritesInput() int { return -1 } 883 884 func (op tensordotOp) WriteHash(h hash.Hash) { 885 h.Write([]byte("tensordotOp")) 886 fmt.Fprintf(h, "aAxes: %d, bAxes: %d, dims: %d", op.aAxes, op.bAxes, op.retDims) 887 888 return 889 } 890 891 func (op tensordotOp) Hashcode() uint32 { return simpleHash(op) } 892 893 func (op tensordotOp) String() string { 894 return fmt.Sprintf("Tensordot(aAxes=%d, bAxes=%d)", op.aAxes, op.bAxes) 895 } 896 897 func (op tensordotOp) DoDiff(ctx ExecutionContext, inputs Nodes, output *Node) error { 898 odv := output.boundTo.(*dualValue) 899 odvd := odv.d.(tensor.Tensor) 900 901 for inNr, in := range inputs { 902 // abuse of language below: "i" up front will refer to current "in" 903 // "other" for the other input (there are only two) 904 905 // Who's derivative are we calculating? 906 var iAxes []int 907 var otherAxes []int 908 var otherdv *dualValue 909 var iWasFirstArgument bool 910 911 if 0 == inNr { 912 iAxes = op.aAxes 913 otherAxes = op.bAxes 914 otherdv = inputs[1].boundTo.(*dualValue) 915 iWasFirstArgument = true 916 } else { 917 iAxes = op.bAxes 918 otherAxes = op.aAxes 919 otherdv = inputs[0].boundTo.(*dualValue) 920 iWasFirstArgument = false 921 } 922 923 idv := in.boundTo.(*dualValue) 924 idvd := idv.d.(tensor.Tensor) 925 926 otherdvv := otherdv.Value.(tensor.Tensor) 927 928 // Below a tensordot will be performed: Its output axes will be in the wrong order w.r.t to the input. 929 // What is the correct permutation/pattern? 930 iAxesCoSorted := make([]int, len(iAxes)) 931 for index, value := range iAxes { 932 iAxesCoSorted[index] = value 933 } 934 935 otherAxesSorted := make([]int, len(otherAxes)) 936 for index, value := range otherAxes { 937 otherAxesSorted[index] = value 938 } 939 940 sortUniqueIntWithImitator(otherAxesSorted, iAxesCoSorted) 941 pattern := make([]int, len(in.Shape())) 942 counter := len(iAxes) 943 944 for patternIndex := 0; patternIndex < len(pattern); patternIndex++ { 945 iAxesCoSortedIndex := contains(iAxesCoSorted, patternIndex) 946 if 0 <= iAxesCoSortedIndex { 947 pattern[patternIndex] = iAxesCoSortedIndex 948 } else { 949 pattern[patternIndex] = counter 950 counter++ 951 } 952 } 953 // if the shape is scalar equivalent, then we'll not have any transforms 954 if in.Shape().IsScalarEquiv() { 955 pattern = pattern[:0] 956 } 957 958 // Which axes of the other tensor and the output should be contracted? 959 // Other tensor: All axes that weren't contracted (with i ;-) ) in the original tensordot 960 // With the exception of scalars 961 dOtherAxes := make([]int, otherdvv.Dims()) 962 963 if !otherdvv.Shape().IsScalarEquiv() { 964 var dOtherAxesIndex int 965 966 for axis := 0; axis < otherdvv.Dims(); axis++ { 967 if 0 > contains(otherAxes, axis) { 968 dOtherAxes[dOtherAxesIndex] = axis 969 dOtherAxesIndex++ 970 } 971 } 972 973 dOtherAxes = dOtherAxes[0:dOtherAxesIndex] 974 } 975 976 // Output: All axes which belong to other in the output of original tensordot, so this depends on input ordering 977 dOutputAxes := make([]int, len(dOtherAxes)) 978 if iWasFirstArgument { 979 outputOtherAxesStart := odvd.Dims() - len(dOtherAxes) 980 981 for axis := 0; axis < len(dOtherAxes); axis++ { 982 dOutputAxes[axis] = outputOtherAxesStart + axis 983 } 984 } else { 985 for axis := 0; axis < len(dOtherAxes); axis++ { 986 dOutputAxes[axis] = axis 987 } 988 } 989 990 // perform tensordot 991 switch st := odvd.(type) { 992 case *tensor.Dense: 993 994 otherdvvDense := otherdvv.(*tensor.Dense) 995 odvdDense := odvd.(*tensor.Dense) 996 var tensordot *tensor.Dense 997 var err error 998 999 switch { 1000 case odvdDense.Shape().IsScalarEquiv(): 1001 tensordot, err = otherdvvDense.MulScalar(odvdDense, true) 1002 case otherdvvDense.IsVector() && odvdDense.IsVector() && 0 == len(dOtherAxes): // TensorMul does not support creating matrix from two vectors 1003 // Reformat vectors, so that MatMul will create a matrix from them 1004 var otherdvvDenseShapeOld tensor.Shape 1005 var odvdDenseShapeOld tensor.Shape 1006 1007 otherdvvDenseReshaped := false 1008 if !otherdvvDense.IsColVec() { 1009 otherdvvDenseShapeOld = otherdvvDense.Shape().Clone() 1010 1011 otherdvvVecDims, err := (otherdvvDense.AP.Shape()).DimSize(0) 1012 if err != nil { 1013 return err 1014 } 1015 1016 otherdvvDenseReshaped = true 1017 otherdvvDense.Reshape(otherdvvVecDims, 1) 1018 } 1019 1020 odvdDenseReshaped := false 1021 if !odvdDense.IsRowVec() { 1022 odvdDenseShapeOld = odvdDense.Shape().Clone() 1023 odvdDenseVecDims, err := (odvdDense.AP.Shape()).DimSize(0) 1024 1025 if err != nil { 1026 return err 1027 } 1028 1029 odvdDenseReshaped = true 1030 odvdDense.Reshape(1, odvdDenseVecDims) 1031 } 1032 1033 tensordot, err = otherdvvDense.MatMul(odvdDense) 1034 1035 // Undo Reshape 1036 if otherdvvDenseReshaped { 1037 otherdvvDense.Reshape(otherdvvDenseShapeOld...) 1038 } 1039 1040 if odvdDenseReshaped { 1041 odvdDense.Reshape(odvdDenseShapeOld...) 1042 } 1043 1044 default: 1045 tensordot, err = otherdvvDense.TensorMul(odvdDense, dOtherAxes, dOutputAxes) 1046 1047 } 1048 1049 if err != nil { 1050 return err 1051 } 1052 tensordotPerm, err := tensor.T(tensordot, pattern...) 1053 if err != nil { 1054 return err 1055 } 1056 1057 tensordotPermDense := tensordotPerm.(*tensor.Dense) 1058 1059 d := idvd.(*tensor.Dense) 1060 d.Add(tensordotPermDense, tensor.UseUnsafe()) // TODO: Should output directly into d and save the add 1061 1062 default: 1063 return errors.Errorf(nyiTypeFail, "Do Diff (hack)", st) 1064 } 1065 } 1066 1067 return nil 1068 } 1069 1070 func (op tensordotOp) DiffWRT(inputs int) []bool { 1071 retVal := make([]bool, inputs) 1072 for i := range retVal { 1073 retVal[i] = true 1074 } 1075 return retVal 1076 } 1077 1078 func (op tensordotOp) SymDiff(inputs Nodes, output *Node, grad *Node) (retVal Nodes, err error) { 1079 if err = checkArity(op, len(inputs)); err != nil { 1080 return 1081 } 1082 1083 retVal = make(Nodes, len(inputs)) 1084 1085 for inNr, in := range inputs { 1086 // abuse of language below: "i" up front will refer to current "in" 1087 // "other" for the other input (there are only two) 1088 1089 // Who's derivative are we calculating? 1090 var iAxes []int 1091 var otherAxes []int 1092 var iWasFirstArgument bool 1093 var other *Node 1094 1095 if 0 == inNr { 1096 iAxes = op.aAxes 1097 otherAxes = op.bAxes 1098 other = inputs[1] 1099 iWasFirstArgument = true 1100 } else { 1101 iAxes = op.bAxes 1102 otherAxes = op.aAxes 1103 other = inputs[0] 1104 iWasFirstArgument = false 1105 } 1106 1107 // Below a tensordot will be performed: Its output axes will be in the wrong order w.r.t to the input. 1108 // What is the correct permutation/pattern? 1109 iAxesCoSorted := make([]int, len(iAxes)) 1110 for index, value := range iAxes { 1111 iAxesCoSorted[index] = value 1112 } 1113 1114 otherAxesSorted := make([]int, len(otherAxes)) 1115 for index, value := range otherAxes { 1116 otherAxesSorted[index] = value 1117 } 1118 1119 sortUniqueIntWithImitator(otherAxesSorted, iAxesCoSorted) 1120 1121 pattern := make([]int, len(in.shape)) 1122 counter := len(iAxes) 1123 1124 for patternIndex := 0; patternIndex < len(pattern); patternIndex++ { 1125 iAxesCoSortedIndex := contains(iAxesCoSorted, patternIndex) 1126 if 0 <= iAxesCoSortedIndex { 1127 pattern[patternIndex] = iAxesCoSortedIndex 1128 } else { 1129 pattern[patternIndex] = counter 1130 counter++ 1131 } 1132 } 1133 1134 // Which axes of the other tensor and the output should be contracted? 1135 // Other tensor: All axes that weren't contracted (with i ;-) ) in the original tensordot 1136 // With the exception of scalars 1137 dOtherAxes := make([]int, other.Dims()) 1138 if !other.Shape().IsScalarEquiv() { 1139 var dOtherAxesIndex int 1140 1141 for axis := 0; axis < other.Dims(); axis++ { 1142 if 0 > contains(otherAxes, axis) { 1143 dOtherAxes[dOtherAxesIndex] = axis 1144 dOtherAxesIndex++ 1145 } 1146 } 1147 dOtherAxes = dOtherAxes[0:dOtherAxesIndex] 1148 } 1149 1150 // Grad: All axes which belong to other in the output of original tensordot, so this depends on input ordering 1151 dGradAxes := make([]int, len(dOtherAxes)) 1152 if iWasFirstArgument { 1153 gradAxesStart := grad.Dims() - len(dOtherAxes) 1154 1155 for axis := 0; axis < len(dOtherAxes); axis++ { 1156 dGradAxes[axis] = gradAxesStart + axis 1157 } 1158 } else { 1159 for axis := 0; axis < len(dOtherAxes); axis++ { 1160 dGradAxes[axis] = axis 1161 } 1162 } 1163 1164 // perform tensordot 1165 var tensordot *Node 1166 switch { 1167 case grad.Shape().IsScalarEquiv(): 1168 if tensordot, err = HadamardProd(other, grad); err != nil { 1169 err = SymDiffError{ 1170 nodes: inputs, 1171 single: other, 1172 grad: grad, 1173 err: errors.Wrap(err, "While performing tensordot of (other × grad) in SymDiff of `tensordotOp`. Nodes() returns the inputs. Node() returns the `other`, Grad() returns grad`"), 1174 } 1175 return nil, err 1176 } 1177 1178 case other.Shape().IsVector() && grad.Shape().IsVector() && 0 == len(dOtherAxes): // TensorMul does not support creating matrix from two vectors 1179 // Reformat vectors, so that MatMul will create a matrix from them 1180 otherCorrectShape := other 1181 if !other.IsColVec() { 1182 otherVecDims, err := (other.Shape()).DimSize(0) 1183 if err != nil { 1184 err = SymDiffError{ 1185 nodes: inputs, 1186 single: other, 1187 err: errors.Wrap(err, "While getting .DimSize(0) of other, while SymDiff-ing. Nodes() returns the inputs, Node() returns `other`. There is no Grad or Grad map."), 1188 } 1189 return nil, err 1190 } 1191 1192 if otherCorrectShape, err = Reshape(other, tensor.Shape{otherVecDims, 1}); err != nil { 1193 return nil, err 1194 } 1195 } 1196 1197 gradCorrectShape := grad 1198 if !grad.IsRowVec() { 1199 gradVecDims, err := (grad.Shape()).DimSize(0) 1200 1201 if err != nil { 1202 return nil, err 1203 } 1204 1205 if gradCorrectShape, err = Reshape(grad, tensor.Shape{1, gradVecDims}); err != nil { 1206 return nil, err 1207 } 1208 } 1209 1210 op := linAlgBinOp{āBinaryOperator: matMulOperator} 1211 if tensordot, err = binOpNode(op, otherCorrectShape, gradCorrectShape); err != nil { 1212 return nil, err 1213 } 1214 1215 default: 1216 tensordot, err = Tensordot(dOtherAxes, dGradAxes, other, grad) 1217 } 1218 1219 if err != nil { 1220 return nil, err 1221 } 1222 1223 ret, err := Transpose(tensordot, pattern...) 1224 1225 if err != nil { 1226 return nil, err 1227 } 1228 1229 retVal[inNr] = ret 1230 } 1231 1232 return retVal, nil 1233 }