gorgonia.org/gorgonia@v0.9.17/op_tensor.go (about) 1 package gorgonia 2 3 import ( 4 "bytes" 5 "encoding/binary" 6 "fmt" 7 "hash" 8 "sort" 9 10 "github.com/chewxy/hm" 11 "github.com/pkg/errors" 12 "gorgonia.org/tensor" 13 ) 14 15 /* This file contains tensor related Ops */ 16 17 // atOp takes a Tensor and returns the value at the coordinates. 18 type atOp struct { 19 coordinates coordinates 20 d int 21 } 22 23 func (op atOp) Arity() int { return 1 } 24 25 // atOp has this type 26 // op :: Tensor a → a 27 func (op atOp) Type() hm.Type { 28 a := hm.TypeVariable('a') 29 tt := makeTensorType(op.d, a) 30 31 return hm.NewFnType(tt, a) 32 } 33 34 func (op atOp) ReturnsPtr() bool { return false } 35 func (op atOp) OverwritesInput() int { return -1 } 36 func (op atOp) CallsExtern() bool { return false } 37 func (op atOp) InferShape(...DimSizer) (retVal tensor.Shape, err error) { return scalarShape, nil } 38 func (op atOp) DiffWRT(i int) []bool { return make([]bool, i) } 39 func (op atOp) SymDiff(Nodes, *Node, *Node) (Nodes, error) { return nil, nondiffErr(op) } 40 func (op atOp) String() string { return fmt.Sprintf("At(%v)", op.coordinates) } 41 42 func (op atOp) Do(inputs ...Value) (retVal Value, err error) { 43 if err = checkArity(op, len(inputs)); err != nil { 44 return 45 } 46 47 switch tt := inputs[0].(type) { 48 case *tensor.Dense: 49 var r interface{} 50 if r, err = tt.At(op.coordinates...); err != nil { 51 err = errors.Wrap(err, opDoFail) 52 return 53 } 54 55 retVal, _, _, err = anyToValue(r) 56 default: 57 err = errors.Errorf(nyiTypeFail, "atOp.Do()", tt) 58 } 59 return 60 } 61 62 func (op atOp) WriteHash(h hash.Hash) { 63 fmt.Fprintf(h, "atOp%v%v", op.d, op.coordinates) 64 } 65 66 func (op atOp) Hashcode() uint32 { return simpleHash(op) } 67 68 func (op atOp) isStmt() bool { return true } 69 70 type sizeOp struct { 71 axis, d int 72 val int // if we know ahead of time what the size is... 73 } 74 75 func (op sizeOp) Arity() int { return 1 } 76 77 // sizeOp is a function with this type: 78 // sizeOp :: Tensor d a → a 79 func (op sizeOp) Type() hm.Type { 80 a := hm.TypeVariable('a') 81 82 // handle scalar cases 83 if op.d == 0 { 84 return hm.NewFnType(a, a) 85 } 86 87 tt := makeTensorType(op.d, a) 88 return hm.NewFnType(tt, a) 89 } 90 91 func (op sizeOp) ReturnsPtr() bool { return false } 92 func (op sizeOp) OverwritesInput() int { return -1 } 93 func (op sizeOp) CallsExtern() bool { return false } 94 func (op sizeOp) InferShape(...DimSizer) (tensor.Shape, error) { return scalarShape, nil } // TODO: return error 95 func (op sizeOp) DiffWRT(i int) []bool { return []bool{false} } 96 func (op sizeOp) String() string { 97 if op.val != 0 { 98 return fmt.Sprintf("SizeOf=%d", op.val) 99 } 100 return fmt.Sprintf("SizeOf(%d)", op.axis) 101 } 102 103 func (op sizeOp) SymDiff(inputs Nodes, output, gradNode *Node) (Nodes, error) { 104 return nil, nondiffErr(op) 105 } 106 107 func (op sizeOp) Do(inputs ...Value) (retVal Value, err error) { 108 if err = checkArity(op, len(inputs)); err != nil { 109 return 110 } 111 112 switch t := inputs[0].(type) { 113 case Scalar: 114 retVal = one(t.Dtype()) 115 116 // bools are special 117 if _, ok := t.(*B); ok { 118 retVal = NewI(1) 119 } 120 case tensor.Tensor: 121 sh := t.Shape() 122 if op.axis >= len(sh) { 123 return nil, errors.Errorf("Shape is %v. Want size of %d", sh, op.axis) 124 } 125 size := sh[op.axis] 126 127 // cast as ... types 128 switch t.Dtype() { 129 case tensor.Float64: 130 retVal = NewF64(float64(size)) 131 case tensor.Float32: 132 retVal = NewF32(float32(size)) 133 case tensor.Int: 134 retVal = NewI(size) 135 default: 136 return nil, errors.Errorf(nyiFail, "sizeOf.Do()", t.Dtype()) 137 } 138 } 139 140 return 141 } 142 143 func (op sizeOp) WriteHash(h hash.Hash) { 144 h.Write([]byte("sizeOf")) 145 if err := binary.Write(h, binary.LittleEndian, byte(op.d)); err != nil { 146 panic(err) 147 } 148 h.Write([]byte("on")) 149 if err := binary.Write(h, binary.LittleEndian, byte(op.axis)); err != nil { 150 panic(err) 151 } 152 } 153 154 func (op sizeOp) Hashcode() uint32 { return simpleHash(op) } 155 156 func (op sizeOp) DimSize(d int) (int, error) { 157 if d != op.axis { 158 return -1, errors.Errorf("Dimension mismatch. Size Op is for axis %d. Want Dim Size of %d", op.axis, d) 159 } 160 return op.val, nil 161 } 162 163 type repeatOp struct { 164 along int 165 inputShape tensor.Shape 166 } 167 168 func newRepeatOp(along int, a *Node) *repeatOp { 169 return &repeatOp{ 170 along: along, 171 inputShape: a.Shape().Clone(), 172 } 173 } 174 175 func repeatedApply(along []int, children Nodes) (retVal *Node, err error) { 176 if len(children) != len(along)+1 { 177 return nil, errors.Errorf("Expected %v children. Got %v instead (hint: along axes and number of children must match)", len(along)+1, len(children)) 178 } 179 180 retVal = children[0] 181 for i := range along { 182 op := newRepeatOp(along[i], retVal) 183 if retVal, err = ApplyOp(op, retVal, children[i+1]); err != nil { 184 return nil, err 185 } 186 } 187 return 188 } 189 190 func (op repeatOp) Arity() int { return 2 } 191 192 // repeat is defined as one of the following: 193 // repeat :: Tensor-n a → a → Tensor-n a 194 // repeat :: a → Vector a 195 // The end result must have the same dimensions as the input 196 func (op repeatOp) Type() hm.Type { 197 198 a := hm.TypeVariable('a') 199 200 d := op.inputShape.Dims() 201 202 var i0t hm.Type 203 var rt hm.Type 204 205 if d == 0 { 206 i0t = a 207 rt = makeTensorType(d+1, a) 208 } else { 209 i0t = makeTensorType(d, a) 210 rt = makeTensorType(d, a) 211 } 212 213 return hm.NewFnType(i0t, a, rt) 214 } 215 216 func (op repeatOp) ReturnsPtr() bool { return true } 217 func (op repeatOp) OverwritesInput() int { return 0 } 218 func (op repeatOp) CallsExtern() bool { return true } // set to true because we want to force the VM to use PreallocDo 219 220 func (op repeatOp) InferShape(inputs ...DimSizer) (retVal tensor.Shape, err error) { 221 retVal = inputs[0].(tensor.Shape).Clone() 222 rep, err := inputs[1].DimSize(op.along) 223 if err != nil { 224 return nil, err 225 } 226 227 // TODO: switch stmt 228 if retVal.IsVector() && retVal.Dims() <= op.along { 229 // extend 230 retVal = append(retVal, make(tensor.Shape, op.along-retVal.Dims()+1)...) 231 for i := range retVal { 232 if retVal[i] == 0 { 233 retVal[i] = 1 234 } 235 } 236 } 237 if retVal.IsScalar() { 238 retVal = tensor.Shape{1} 239 } 240 retVal[op.along] *= rep 241 242 return 243 } 244 245 func (op repeatOp) DiffWRT(i int) []bool { 246 symdiffLogf("DiffWRT: %d", i) 247 retVal := make([]bool, i) 248 retVal[0] = true 249 return retVal 250 } 251 252 func (op repeatOp) SymDiff(inputs Nodes, output, gradNode *Node) (retVal Nodes, err error) { 253 var n *Node 254 if n, err = Sum(gradNode, op.along); err == nil { 255 n.setGroup(gradClust) 256 } 257 retVal = make(Nodes, len(inputs)) 258 retVal[0] = n 259 return 260 } 261 262 func (op repeatOp) DoDiff(ctx ExecutionContext, inputs Nodes, output *Node) (err error) { 263 if err = checkArity(op, len(inputs)); err != nil { 264 return 265 } 266 xdv, ydv := getDV(inputs[0], output) 267 268 var reps []int 269 var repeats []Value 270 for _, r := range inputs[1:] { 271 repeats = append(repeats, r.Value()) 272 } 273 274 if reps, err = valuesToInts(repeats); err != nil { 275 return 276 } 277 278 xshape := xdv.Shape() 279 var d Value 280 d = ydv.d 281 282 // we make it a colVec 283 if xshape.IsVector() && !xshape.IsColVec() && !xshape.IsRowVec() { 284 xshape = tensor.Shape{xshape[0], 1} 285 } 286 287 if xshape.IsScalar() { 288 sum := newSumOp([]int{op.along}, output.shape, output.Dims()) 289 if d, err = sum.Do(d); err != nil { 290 err = errors.Wrapf(err, doFail, sum) 291 return 292 } 293 } else { 294 axis := op.along 295 if xshape[axis] == 1 { 296 sum := newSumOp([]int{op.along}, output.shape, output.Dims()) 297 if d, err = sum.Do(d); err != nil { 298 err = errors.Wrapf(err, doFail, sum) 299 return 300 } 301 } else { 302 newShape := xshape.Clone() 303 newShape = newShape[0 : axis+1] 304 newShape = append(newShape, reps...) 305 if axis+1 < xshape.Dims() { 306 newShape = append(newShape, xshape[axis+1:]...) 307 } 308 309 along := []int{axis + 1} 310 311 // a scalar can never get to this path 312 t := d.(tensor.Tensor) 313 if err = t.Reshape(newShape...); err != nil { 314 err = errors.Wrapf(err, reshapeFail, newShape, t.DataSize()) 315 return 316 } 317 318 sum := newSumOp(along, newShape, len(newShape)) 319 if d, err = sum.Do(d); err != nil { 320 err = errors.Wrapf(err, doFail, sum) 321 return 322 } 323 // sum.Do leaves the dimension of size 1 behind, so reshape here. 324 t = d.(tensor.Tensor) 325 finalShape := newShape[:axis+1] 326 if axis+1 < newShape.Dims() { 327 finalShape = append(finalShape, newShape[axis+2:]...) 328 } 329 if err = t.Reshape(finalShape...); err != nil { 330 err = errors.Wrapf(err, reshapeFail, newShape, t.DataSize()) 331 return 332 } 333 } 334 335 } 336 337 add := newEBOByType(addOpType, TypeOf(xdv.d), TypeOf(d)) 338 if d, err = add.UnsafeDo(xdv.d, d); err != nil { 339 return 340 } 341 342 if !add.ReturnsPtr() || inputs[0].IsScalar() { 343 err = xdv.SetDeriv(d) 344 } 345 346 return 347 348 } 349 350 func (op repeatOp) String() string { return fmt.Sprintf("Repeat%v", op.along) } 351 352 // Do performs a repeat on the value. 353 // TODO(anyone): implement for other types 354 func (op repeatOp) Do(inputs ...Value) (retVal Value, err error) { 355 if err = checkArity(op, len(inputs)); err != nil { 356 return 357 } 358 359 var rep int 360 if rep, err = valueToInt(inputs[1]); err != nil { 361 return nil, errors.Wrapf(err, "Cannot convert %v to an int", inputs[1]) 362 } 363 364 // process inputs[0] 365 var t tensor.Tensor 366 switch iv := inputs[0].(type) { 367 case Scalar: 368 s := iv.Data() 369 t = tensor.New(tensor.FromScalar(s)) 370 case tensor.Tensor: 371 // if iv.Shape().IsScalarEquiv() { 372 // log.Printf("SCALAR EQUIV %v", iv.Data()) 373 // t = iv.Clone().(tensor.Tensor) 374 // retVal = t 375 // return 376 // } 377 t = iv 378 default: 379 err = errors.Errorf(nyiTypeFail, "repeatOp.Do()", inputs[0]) 380 return 381 } 382 383 // actually do repeat 384 if rep == 1 { 385 goto fin 386 } 387 if t, err = tensor.Repeat(t, op.along, rep); err != nil { 388 err = errors.Wrapf(err, repFail, op.along, rep) 389 return 390 } 391 fin: 392 retVal = t 393 return 394 } 395 396 func (op repeatOp) WriteHash(h hash.Hash) { 397 fmt.Fprintf(h, "repeat %v %v", op.along, op.inputShape) 398 var arg0Dim int 399 if !op.inputShape.Eq(tensor.ScalarShape()) { 400 arg0Dim = op.inputShape[0] 401 } 402 if arg0Dim == 0 { 403 h.Write([]byte{1}) 404 } else { 405 h.Write([]byte{0}) 406 } 407 } 408 409 func (op repeatOp) Hashcode() uint32 { return simpleHash(op) } 410 411 func (op repeatOp) UsePreallocDo(prealloc Value, inputs ...Value) (retVal Value, err error) { 412 pt, ok := prealloc.(tensor.Tensor) 413 if !ok { 414 return nil, errors.Errorf("Expected Tensor as a preallocated value. Got %v of %T instead", prealloc, prealloc) 415 } 416 417 if err = checkArity(op, len(inputs)); err != nil { 418 return 419 } 420 421 var rep int 422 if rep, err = valueToInt(inputs[1]); err != nil { 423 return nil, errors.Wrapf(err, "Cannot convert %v to an int", inputs[1]) 424 } 425 426 // process inputs[0] 427 var t tensor.Tensor 428 switch iv := inputs[0].(type) { 429 case Scalar: 430 s := iv.Data() 431 pt.Memset(s) 432 retVal = pt 433 return 434 t = tensor.New(tensor.FromScalar(s)) 435 case tensor.Tensor: 436 if iv.Shape().IsScalarEquiv() { 437 data := iv.Data() 438 switch dt := data.(type) { 439 case float64: 440 ptd := pt.Data().([]float64) 441 for i := range ptd { 442 ptd[i] = dt 443 } 444 case float32: 445 ptd := pt.Data().([]float32) 446 for i := range ptd { 447 ptd[i] = dt 448 } 449 case []float64: 450 ptd := pt.Data().([]float64) 451 for i := range ptd { 452 ptd[i] = dt[0] 453 } 454 case []float32: 455 ptd := pt.Data().([]float32) 456 for i := range ptd { 457 ptd[i] = dt[0] 458 } 459 } 460 return pt, nil 461 } 462 t = iv 463 default: 464 err = errors.Errorf(nyiTypeFail, "repeatOp.Do()", inputs[0]) 465 return 466 } 467 if rep == 1 { 468 return Copy(pt, t) 469 } 470 471 return tensor.RepeatReuse(t, pt, op.along, rep) 472 } 473 474 // sliceOp represents a slicing operation. If end ⩽ start, it means ":" 475 type sliceOp struct { 476 tensor.Slice 477 478 along int // along which axis to slice? 479 480 a int // along which axis of the original tensor 481 d int // how many dimensions were the original tensor 482 } 483 484 func (op *sliceOp) IsSlice() tensor.Slice { return op.Slice } 485 486 func newSliceOp(s tensor.Slice, along, d int) *sliceOp { 487 return &sliceOp{ 488 Slice: s, 489 along: along, 490 d: d, 491 } 492 } 493 494 func (op *sliceOp) Arity() int { return 1 } 495 496 // slicing a tensor value T[:] has type 497 // slice :: Tensor a → Tensor a 498 // slice :: Tensor a → a 499 // 500 // The latter is in the case where the resulting dimensions is 0, returning a scalar 501 func (op *sliceOp) Type() hm.Type { 502 a := hm.TypeVariable('a') 503 tt := makeTensorType(op.d, a) 504 505 var selection int 506 507 if op.Slice == nil { 508 selection = -1 509 } else { 510 selection = op.End() - op.Start() 511 } 512 513 if selection == 1 { 514 if op.d == 1 { 515 return hm.NewFnType(tt, a) 516 } 517 518 tt2 := makeTensorType(op.d-1, a) 519 return hm.NewFnType(tt, tt2) 520 } 521 522 return hm.NewFnType(tt, tt) 523 } 524 525 func (op *sliceOp) InferShape(inputs ...DimSizer) (s tensor.Shape, err error) { 526 input := inputs[0].(tensor.Shape) 527 slices := make([]tensor.Slice, op.along+1) 528 slices[op.along] = op.Slice 529 530 return input.S(slices...) 531 532 // return input.S(op.Slice) 533 } 534 535 func (op *sliceOp) DiffWRT(i int) []bool { 536 if i > 1 { 537 // error 538 err := errors.Errorf("sliceOp should only have one or more inputs. Got %v instead", i) 539 panic(err) 540 } 541 542 return []bool{true} 543 } 544 545 func (op *sliceOp) SymDiff(inputs Nodes, outputNode, gradNode *Node) (retVal Nodes, err error) { 546 if err = checkArity(op, len(inputs)); err != nil { 547 return 548 } 549 550 t := inputs[0] 551 incrOp := sliceIncrOp{op} 552 553 retVal = make(Nodes, 1) 554 retVal[0], err = ApplyOp(incrOp, t, gradNode) 555 return 556 } 557 558 func (op *sliceOp) DoDiff(ctx ExecutionContext, inputs Nodes, output *Node) (err error) { 559 if err = checkArity(op, len(inputs)); err != nil { 560 return 561 } 562 xdv, ydv := getDV(inputs[0], output) 563 564 // var d Value 565 incrOp := sliceIncrOp{op} 566 if _, err = incrOp.UsePreallocDo(xdv.d, xdv.d, ydv.d); err != nil { 567 return errors.Wrapf(err, doFail, incrOp) 568 } 569 570 // there is no need to handle scalars, because you can never slice a scalar 571 // add := newElemBinOp(addOpType, inputs[0], output) 572 // if _, err = add.UnsafeDo(xdv.d, d); err != nil { 573 // return errors.Wrapf(err, unsafeDoFail, add) 574 // } 575 576 return 577 } 578 579 func (op *sliceOp) Do(inputs ...Value) (retVal Value, err error) { 580 if err = checkArity(op, len(inputs)); err != nil { 581 return 582 } 583 584 t := inputs[0] 585 // prep the slices 586 var slices []tensor.Slice 587 slices = make([]tensor.Slice, len(t.Shape())) 588 589 if !op.all() { 590 slices[op.along] = op 591 } 592 switch T := t.(type) { 593 case tensor.Tensor: 594 var v tensor.Tensor 595 if v, err = T.Slice(slices...); err != nil { 596 return nil, errors.Wrapf(err, sliceFail, slices) 597 } 598 if v.IsScalar() { 599 retVal, _ = anyToScalar(v.ScalarValue()) 600 } else { 601 retVal = v.(tensor.View).Materialize() 602 } 603 case Scalar: 604 return nil, errors.New("Cannot slice a scalar value") 605 default: 606 return nil, errors.Errorf(nyiFail, "sliceOp.Do()", t) 607 } 608 return 609 } 610 611 func (op *sliceOp) ReturnsPtr() bool { return true } 612 func (op *sliceOp) CallsExtern() bool { return true } 613 func (op *sliceOp) OverwritesInput() int { return -1 } 614 func (op sliceOp) WriteHash(h hash.Hash) { 615 h.Write([]byte("slice")) 616 if err := binary.Write(h, binary.LittleEndian, byte(op.d)); err != nil { 617 panic(err) 618 } 619 fmt.Fprintf(h, "%v", op.along) 620 if op.Slice == nil { 621 fmt.Fprintf(h, ":") 622 return 623 } 624 625 if err := binary.Write(h, binary.LittleEndian, byte(op.Start())); err != nil { 626 panic(err) 627 } 628 if err := binary.Write(h, binary.LittleEndian, byte(op.End())); err != nil { 629 panic(err) 630 } 631 if err := binary.Write(h, binary.LittleEndian, byte(op.Step())); err != nil { 632 panic(err) 633 } 634 635 } 636 func (op sliceOp) Hashcode() uint32 { return simpleHash(op) } 637 638 func (op sliceOp) String() string { 639 var buf bytes.Buffer 640 buf.WriteString("T[") 641 for i := 0; i < op.along; i++ { 642 buf.WriteString(":, ") 643 } 644 645 if op.all() { 646 buf.WriteString(":") 647 } else { 648 fmt.Fprintf(&buf, "%d:%d:%d", op.Start(), op.End(), op.Step()) 649 } 650 651 buf.WriteString("...]") 652 return buf.String() 653 } 654 655 // func (op sliceOp) CUDADo(extern External, dev Device, prealloc Value, inputs ...Value) (retVal Value, err error) { 656 // return op.Do(inputs...) 657 // } 658 659 // func (op sliceOp) CUDAFuncName() string { return "" } 660 661 func (op sliceOp) all() bool { return op.Slice == nil || op.End() <= op.Start() } 662 663 // T[:] +=incr 664 // THIS IS AN UNSAFE OPERATION 665 type sliceIncrOp struct { 666 *sliceOp 667 } 668 669 // slicing a tensor value T[:] has type 670 // slice :: Tensor a → b → Tensor a 671 // 672 // b can be a or Vector a 673 func (op sliceIncrOp) Type() hm.Type { 674 a := hm.TypeVariable('a') 675 b := hm.TypeVariable('c') 676 tt := makeTensorType(op.d, a) 677 678 retVal := hm.NewFnType(tt, b, tt) 679 return retVal 680 } 681 682 func (op sliceIncrOp) Arity() int { return 2 } 683 684 func (op sliceIncrOp) InferShape(inputs ...DimSizer) (retVal tensor.Shape, err error) { 685 retVal = inputs[0].(tensor.Shape) 686 return 687 } 688 689 func (op sliceIncrOp) DiffWRT(i int) []bool { 690 if err := checkArity(op, i); err != nil { 691 panic(err) 692 } 693 694 return []bool{true, false} 695 } 696 697 func (op sliceIncrOp) SymDiff(inputs Nodes, outputNode, gradNode *Node) (retVal Nodes, err error) { 698 var slicedRes *Node 699 if slicedRes, err = ApplyOp(op.sliceOp, gradNode); err != nil { 700 return nil, errors.Wrap(err, operationError) 701 } 702 retVal = Nodes{gradNode, slicedRes} 703 704 return 705 } 706 707 func (op sliceIncrOp) DoDiff(ctx ExecutionContext, inputs Nodes, output *Node) (err error) { 708 xdv, ydv, zdv := getDV3(inputs[0], inputs[1], output) 709 710 // dzdx 711 add := newElemBinOp(addOpType, inputs[0], output) 712 if _, err = add.UnsafeDo(xdv.d, zdv.d); err != nil { 713 return errors.Wrapf(err, unsafeDoFail, add) 714 } 715 716 // dzdy 717 var d Value 718 if d, err = op.sliceOp.Do(zdv.d); err != nil { 719 return errors.Wrapf(err, doFail, op) 720 } 721 722 add = newElemBinOp(addOpType, inputs[1], output) 723 if _, err = add.UnsafeDo(ydv.d, d); err != nil { 724 return errors.Wrapf(err, doFail, add) 725 } 726 return 727 } 728 729 func (op sliceIncrOp) Do(inputs ...Value) (retVal Value, err error) { 730 machineLogf("Doing %v", op) 731 enterLogScope() 732 defer leaveLogScope() 733 734 if err = checkArity(op, len(inputs)); err != nil { 735 return 736 } 737 738 t := inputs[0] 739 incr := inputs[1] 740 741 // prep the slices 742 slices := make([]tensor.Slice, op.d) 743 if !op.all() { 744 slices[op.along] = op 745 } 746 747 switch T := t.(type) { 748 case *tensor.Dense: 749 grad := tensor.NewDense(T.Dtype(), T.Shape().Clone()) 750 var v tensor.Tensor 751 if v, err = grad.Slice(slices...); err != nil { 752 return nil, errors.Wrapf(err, sliceFail, slices) 753 } 754 switch i := incr.(type) { 755 case *F64: 756 tensor.Add(v, i.any(), tensor.UseUnsafe()) 757 case *F32: 758 tensor.Add(v, i.any(), tensor.UseUnsafe()) 759 case *tensor.Dense: 760 tensor.Add(v, i, tensor.UseUnsafe()) 761 } 762 retVal = grad 763 case Scalar: 764 return nil, errors.New("Cannot slice a scalar value") 765 default: 766 return nil, errors.Errorf(nyiFail, "sliceIncrOp()", t) 767 } 768 return 769 } 770 771 func (op sliceIncrOp) UsePreallocDo(prealloc Value, inputs ...Value) (retVal Value, err error) { 772 machineLogf("Doing %v", op) 773 enterLogScope() 774 defer leaveLogScope() 775 776 if err = checkArity(op, len(inputs)); err != nil { 777 return 778 } 779 incr := inputs[1] 780 781 // prep the slices 782 slices := make([]tensor.Slice, op.d) 783 if !op.all() { 784 slices[op.along] = op 785 } 786 787 switch T := prealloc.(type) { 788 case *tensor.Dense: 789 var v tensor.Tensor 790 if v, err = T.Slice(slices...); err != nil { 791 return nil, errors.Wrapf(err, sliceFail, slices) 792 } 793 switch i := incr.(type) { 794 case *F64: 795 tensor.Add(v, i.any(), tensor.UseUnsafe()) 796 case *F32: 797 tensor.Add(v, i.any(), tensor.UseUnsafe()) 798 case *tensor.Dense: 799 tensor.Add(v, i, tensor.UseUnsafe()) 800 } 801 retVal = T 802 case Scalar: 803 return nil, errors.New("Cannot slice a scalar value") 804 default: 805 return nil, errors.Errorf(nyiFail, "sliceIncrOp()", prealloc) 806 } 807 return 808 } 809 810 func (op sliceIncrOp) OverwritesInput() int { return 0 } 811 812 func (op sliceIncrOp) WriteHash(h hash.Hash) { 813 h.Write([]byte("sliceIncr")) 814 if err := binary.Write(h, binary.LittleEndian, byte(op.d)); err != nil { 815 panic(err) 816 } 817 if err := binary.Write(h, binary.LittleEndian, byte(op.along)); err != nil { 818 panic(err) 819 } 820 821 if op.Slice == nil { 822 fmt.Fprintf(h, ":") 823 return 824 } 825 826 if err := binary.Write(h, binary.LittleEndian, byte(op.Start())); err != nil { 827 panic(err) 828 } 829 if err := binary.Write(h, binary.LittleEndian, byte(op.End())); err != nil { 830 panic(err) 831 } 832 if err := binary.Write(h, binary.LittleEndian, byte(op.Step())); err != nil { 833 panic(err) 834 } 835 } 836 837 func (op sliceIncrOp) Hashcode() uint32 { return simpleHash(op) } 838 839 func (op sliceIncrOp) String() string { 840 var buf bytes.Buffer 841 buf.WriteString("T[") 842 843 for i := 0; i < op.along; i++ { 844 buf.WriteString(":, ") 845 } 846 847 if op.all() { 848 buf.WriteString(":") 849 } else { 850 fmt.Fprintf(&buf, "%d:%d:%d", op.Start(), op.End(), op.Step()) 851 } 852 853 buf.WriteString("...]+=...") 854 return buf.String() 855 } 856 857 // func (op sliceIncrOp) UsePreallocDo(val Value, inputs ...Value) (Value, error) { 858 859 // } 860 861 type transposeOp struct { 862 pattern []int 863 d int 864 } 865 866 func (op transposeOp) Arity() int { return 1 } 867 868 // transposing a tensor has type 869 // transpose :: Tensor a → Tensor a 870 func (op transposeOp) Type() hm.Type { 871 a := hm.TypeVariable('a') 872 tt := makeTensorType(op.d, a) 873 874 return hm.NewFnType(tt, tt) 875 } 876 877 func (op transposeOp) InferShape(inputs ...DimSizer) (retVal tensor.Shape, err error) { 878 input := inputs[0].(tensor.Shape) 879 if input.IsScalar() { 880 return nil, errors.Errorf(undefinedOnShape, op, input) 881 } 882 883 retVal = make(tensor.Shape, len(input)) 884 copy(retVal, input) 885 err = tensor.UnsafePermute(op.pattern, retVal) 886 return 887 } 888 889 func (op transposeOp) DiffWRT(i int) []bool { 890 if err := checkArity(op, i); err != nil { 891 panic(err) 892 } 893 894 return []bool{true} 895 } 896 897 func (op transposeOp) SymDiff(inputs Nodes, outputNode, gradNode *Node) (retVal Nodes, err error) { 898 newPattern := make([]int, len(op.pattern)) 899 for i, p := range op.pattern { 900 newPattern[p] = i 901 } 902 op2 := transposeOp{pattern: newPattern, d: op.d} 903 904 retVal = make(Nodes, 1) 905 retVal[0], err = ApplyOp(op2, gradNode) 906 return 907 } 908 909 func (op transposeOp) DoDiff(ctx ExecutionContext, inputs Nodes, output *Node) (err error) { 910 xdv, zdv := getDV(inputs[0], output) 911 912 newPattern := make([]int, len(op.pattern)) 913 for i, p := range op.pattern { 914 newPattern[p] = i 915 } 916 917 var zdvdT tensor.Tensor 918 var ok bool 919 if zdvdT, ok = zdv.d.(tensor.Tensor); !ok { 920 return errors.Errorf("Expected the gradient of the output node to be a Tensor. Got %v instead", zdv.d) 921 } 922 923 if err = zdvdT.T(newPattern...); err != nil { 924 return errors.Wrap(err, "Failed to T()") 925 } 926 927 d := tensor.Materialize(zdvdT) 928 zdvdT.UT() 929 930 add := newEBOByType(addOpType, inputs[0].t, TypeOf(zdvdT)) 931 if _, err = add.UnsafeDo(xdv.d, d); err != nil { 932 err = errors.Wrapf(err, doFail, add) 933 } 934 return 935 } 936 937 func (op transposeOp) Do(inputs ...Value) (retVal Value, err error) { 938 machineLogf("Doing %v", op) 939 enterLogScope() 940 defer leaveLogScope() 941 942 if err = checkArity(op, len(inputs)); err != nil { 943 return 944 } 945 946 t := inputs[0].(tensor.Tensor) 947 948 throwaway := tensor.BorrowInts(len(op.pattern)) 949 copy(throwaway, op.pattern) 950 // return tensor.T(t, throwaway...) 951 952 return tensor.Transpose(t, throwaway...) 953 954 // DEPRECATED 955 // the reason for this is because the .T() method of a Tensor 956 // will use the axes in the .transposedWith field 957 // Later when .UT() is called, the .transposedWith field is recycled into the pool 958 // throwaway := tensor.BorrowInts(len(op.pattern)) 959 // copy(throwaway, op.pattern) 960 961 // t.T(throwaway...) 962 // ret := t.Materialize() 963 // t.UT() 964 } 965 966 func (op transposeOp) ReturnsPtr() bool { return true } 967 func (op transposeOp) CallsExtern() bool { return false } 968 func (op transposeOp) OverwritesInput() int { return 0 } 969 970 func (op transposeOp) WriteHash(h hash.Hash) { 971 h.Write([]byte("transposeOp")) 972 fmt.Fprintf(h, "%v", op.pattern) 973 if err := binary.Write(h, binary.LittleEndian, byte(op.d)); err != nil { 974 panic(err) 975 } 976 } 977 978 func (op transposeOp) Hashcode() uint32 { return simpleHash(op) } 979 980 func (op transposeOp) String() string { 981 var buf bytes.Buffer 982 buf.WriteString("Aᵀ{") 983 for i, ax := range op.pattern { 984 fmt.Fprintf(&buf, "%d", ax) 985 if i < len(op.pattern)-1 { 986 buf.WriteString(", ") 987 } 988 } 989 990 buf.WriteString("}") 991 return buf.String() 992 } 993 994 type concatOp struct { 995 axis int 996 d int 997 children int 998 } 999 1000 func (op concatOp) Arity() int { return -1 } 1001 1002 // concat only works for Tensor types 1003 // concat :: Tensor a → Tensor a → ... → Tensor a 1004 func (op concatOp) Type() hm.Type { 1005 tt := makeTensorType(op.d, hm.TypeVariable('a')) 1006 fnt := make([]hm.Type, op.children+1) 1007 for i := range fnt { 1008 fnt[i] = tt 1009 } 1010 1011 return hm.NewFnType(fnt...) 1012 } 1013 1014 func (op concatOp) InferShape(ds ...DimSizer) (tensor.Shape, error) { 1015 if len(ds) == 0 { 1016 return nil, errors.Errorf("No shapes passed in!") 1017 } 1018 shapes, err := DimSizersToShapes(ds) 1019 if err != nil { 1020 return nil, err 1021 } 1022 1023 return shapes[0].Concat(op.axis, shapes[1:]...) 1024 } 1025 1026 func (op concatOp) Do(vals ...Value) (Value, error) { 1027 if len(vals) == 1 { 1028 return vals[0], nil 1029 } 1030 1031 ts, err := valuesToTensors(vals) 1032 if err != nil { 1033 return nil, err 1034 } 1035 1036 return tensor.Concat(op.axis, ts[0], ts[1:]...) 1037 } 1038 1039 func (op concatOp) ReturnsPtr() bool { return true } 1040 func (op concatOp) CallsExtern() bool { return false } 1041 func (op concatOp) OverwritesInput() int { return -1 } 1042 1043 func (op concatOp) WriteHash(h hash.Hash) { 1044 h.Write([]byte("concatOp")) 1045 fmt.Fprintf(h, "axis: %d, dims: %d", op.axis, op.d) 1046 } 1047 1048 func (op concatOp) Hashcode() uint32 { return simpleHash(op) } 1049 1050 func (op concatOp) String() string { 1051 return fmt.Sprintf("Concat(axis=%d)", op.axis) 1052 } 1053 1054 func (op concatOp) DiffWRT(inputs int) []bool { 1055 retVal := make([]bool, inputs) 1056 for i := range retVal { 1057 retVal[i] = true 1058 } 1059 return retVal 1060 } 1061 1062 func (op concatOp) SymDiff(inputs Nodes, output *Node, grad *Node) (retVal Nodes, err error) { 1063 var start int 1064 1065 retVal = make(Nodes, len(inputs)) 1066 for i, in := range inputs { 1067 if op.axis >= len(in.shape) { 1068 return nil, errors.Errorf("Wanted dimension %d is larger than the shape %v", op.axis, in.shape) 1069 } 1070 end := in.shape[op.axis] + start 1071 1072 s := newSliceOp(S(start, end), op.axis, op.d) 1073 if retVal[i], err = ApplyOp(s, grad); err != nil { 1074 return 1075 } 1076 start = end 1077 } 1078 return 1079 } 1080 1081 func (op concatOp) DoDiff(ctx ExecutionContext, inputs Nodes, output *Node) error { 1082 odv := output.boundTo.(*dualValue) 1083 odvd := odv.d.(tensor.Tensor) 1084 1085 var start int 1086 for _, in := range inputs { 1087 if op.axis >= len(in.shape) { 1088 return errors.Errorf("Wanted dimension %d is larger than the shape %v", op.axis, in.shape) 1089 } 1090 end := in.shape[op.axis] + start 1091 1092 idv := in.boundTo.(*dualValue) 1093 idvd := idv.d.(tensor.Tensor) 1094 1095 sliced, err := odvd.Slice(S(start, end)) 1096 if err != nil { 1097 return err 1098 } 1099 1100 // TODO: fix VAdd hack 1101 // add to odvd 1102 switch st := sliced.(type) { 1103 case *tensor.Dense: 1104 d := idvd.(*tensor.Dense) 1105 d.Add(st, tensor.UseUnsafe()) 1106 default: 1107 return errors.Errorf(nyiTypeFail, "DoDiff (hack) ", st) 1108 } 1109 1110 start = end 1111 } 1112 return nil 1113 } 1114 1115 type reshapeOp struct { 1116 from, to tensor.Shape 1117 } 1118 1119 func (op reshapeOp) Arity() int { return 1 } 1120 func (op reshapeOp) Type() hm.Type { 1121 if op.from.Dims() != op.to.Dims() { 1122 fr := op.from.Dims() 1123 var frT hm.Type 1124 frT = newTensorType(fr, hm.TypeVariable('a')) 1125 if fr == 0 { 1126 frT = hm.TypeVariable('a') 1127 } 1128 1129 to := op.to.Dims() 1130 var toT hm.Type 1131 toT = newTensorType(to, hm.TypeVariable('a')) 1132 if to == 0 { 1133 toT = hm.TypeVariable('a') 1134 } 1135 return hm.NewFnType(frT, toT) 1136 } 1137 return hm.NewFnType(hm.TypeVariable('a'), hm.TypeVariable('a')) 1138 } 1139 func (op reshapeOp) InferShape(ds ...DimSizer) (tensor.Shape, error) { return op.to.Clone(), nil } 1140 1141 func (op reshapeOp) Do(vals ...Value) (Value, error) { 1142 if err := checkArity(op, len(vals)); err != nil { 1143 return nil, err 1144 } 1145 var val Value 1146 var err error 1147 switch vals[0].(type) { 1148 case tensor.Tensor: 1149 if v, ok := vals[0].(*tensor.Dense); ok { 1150 if v.IsView() { 1151 val = v.Materialize() 1152 } else { 1153 val = v.ShallowClone() 1154 } 1155 } else { 1156 if val, err = CloneValue(vals[0]); err != nil { 1157 return nil, errors.Wrapf(err, cloneFail, vals[0]) 1158 } 1159 } 1160 if val.Shape().TotalSize() != op.from.TotalSize() { 1161 return nil, errors.Errorf("Shape mismatch. Input shape is %v. Expected %v", val.Shape(), op.from) 1162 } 1163 1164 if err := val.(tensor.Tensor).Reshape(op.to...); err != nil { 1165 return nil, err 1166 } 1167 return val, nil 1168 case Scalar: 1169 v0 := ScalarAsTensor(vals[0], op.to.Dims(), nil) 1170 if err := v0.(tensor.Tensor).Reshape(op.to...); err != nil { 1171 return nil, err 1172 } 1173 return v0, nil 1174 default: 1175 return nil, errors.Errorf(nyiTypeFail, "reshape.Do", vals[0]) 1176 } 1177 } 1178 1179 func (op reshapeOp) ReturnsPtr() bool { return true } 1180 func (op reshapeOp) CallsExtern() bool { return false } 1181 func (op reshapeOp) OverwritesInput() int { return 0 } 1182 func (op reshapeOp) WriteHash(h hash.Hash) { 1183 h.Write([]byte("reshapeOp")) 1184 fmt.Fprintf(h, "from: %v, dims: %v", op.from, op.to) 1185 } 1186 1187 func (op reshapeOp) Hashcode() uint32 { return simpleHash(op) } 1188 1189 func (op reshapeOp) String() string { return fmt.Sprintf("Reshape%v", op.to) } 1190 1191 func (op reshapeOp) UnsafeDo(vals ...Value) (Value, error) { 1192 if err := checkArity(op, len(vals)); err != nil { 1193 return nil, err 1194 } 1195 var val Value 1196 var err error 1197 switch vals[0].(type) { 1198 case tensor.Tensor: 1199 val = vals[0] 1200 err = val.(tensor.Tensor).Reshape(op.to...) 1201 1202 return val, err 1203 case Scalar: 1204 v0 := ScalarAsTensor(vals[0], op.to.Dims(), nil) 1205 if err := v0.(tensor.Tensor).Reshape(op.to...); err != nil { 1206 return nil, err 1207 } 1208 return v0, nil 1209 default: 1210 return nil, errors.Errorf(nyiTypeFail, "reshape.Do", vals[0]) 1211 } 1212 } 1213 1214 func (op reshapeOp) CUDADo(extern External, dev Device, prealloc Value, vals ...Value) (retVal Value, err error) { 1215 if err := checkArity(op, len(vals)); err != nil { 1216 return nil, err 1217 } 1218 val := vals[0] 1219 switch v := val.(type) { 1220 case tensor.Tensor: 1221 if err := v.Reshape(op.to...); err != nil { 1222 return nil, err 1223 } 1224 return v, nil 1225 case Scalar: 1226 vT := ScalarAsTensor(v, op.to.Dims(), nil) 1227 if err := vT.(tensor.Tensor).Reshape(op.to...); err != nil { 1228 1229 return nil, errors.Errorf(nyiTypeFail, "reshape.Do", "Scalar") 1230 } 1231 return vT, nil 1232 } 1233 1234 panic("Unreachable") 1235 } 1236 1237 func (op reshapeOp) DiffWRT(i int) []bool { return []bool{true} } 1238 1239 func (op reshapeOp) SymDiff(inputs Nodes, output *Node, grad *Node) (retVal Nodes, err error) { 1240 var ret *Node 1241 if ret, err = Reshape(grad, op.from); err != nil { 1242 return 1243 } 1244 ret.setGroup(gradClust) 1245 return Nodes{ret}, nil 1246 } 1247 1248 func (op reshapeOp) DoDiff(ctx ExecutionContext, inputs Nodes, output *Node) (err error) { 1249 var grad Value 1250 if grad, err = output.Grad(); err != nil { 1251 return 1252 } 1253 T := grad.(tensor.Tensor) 1254 if err = T.Reshape(op.from...); err != nil { 1255 return 1256 } 1257 input := inputs[0] 1258 dv := input.boundTo.(*dualValue) 1259 return dv.SetDeriv(T) 1260 } 1261 1262 /* PRIVATE FUNCTIONS */ 1263 1264 // if value is contained in slice, contains returns the corresp. index in slice, -1 otherwise 1265 func contains(slice []int, value int) int { 1266 if nil == slice { 1267 return -1 1268 } 1269 1270 for sliceIndex, sliceValue := range slice { 1271 if value == sliceValue { 1272 return sliceIndex 1273 } 1274 } 1275 1276 return -1 1277 } 1278 1279 // TODO: This function is an overkill for a small number of axes... 1280 func sortUniqueIntWithImitator(toBeSorted, imitator []int) { 1281 toBeSortedBackup := make([]int, len(toBeSorted)) 1282 for index, value := range toBeSorted { 1283 toBeSortedBackup[index] = value 1284 } 1285 1286 imitatorBackup := make([]int, len(imitator)) 1287 for index, value := range imitator { 1288 imitatorBackup[index] = value 1289 } 1290 1291 sort.Ints(toBeSorted) 1292 1293 // Permutate the imitator accordingly 1294 for originalIndex, originalValue := range toBeSortedBackup { 1295 sortedIndex := sort.SearchInts(toBeSorted, originalValue) 1296 1297 imitator[sortedIndex] = imitatorBackup[originalIndex] 1298 } 1299 1300 return 1301 }