gorgonia.org/gorgonia@v0.9.17/op_tensor.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"fmt"
     7  	"hash"
     8  	"sort"
     9  
    10  	"github.com/chewxy/hm"
    11  	"github.com/pkg/errors"
    12  	"gorgonia.org/tensor"
    13  )
    14  
    15  /* This file contains tensor related Ops */
    16  
    17  // atOp takes a Tensor and returns the value at the coordinates.
    18  type atOp struct {
    19  	coordinates coordinates
    20  	d           int
    21  }
    22  
    23  func (op atOp) Arity() int { return 1 }
    24  
    25  // atOp has this type
    26  //		op :: Tensor a → a
    27  func (op atOp) Type() hm.Type {
    28  	a := hm.TypeVariable('a')
    29  	tt := makeTensorType(op.d, a)
    30  
    31  	return hm.NewFnType(tt, a)
    32  }
    33  
    34  func (op atOp) ReturnsPtr() bool                                        { return false }
    35  func (op atOp) OverwritesInput() int                                    { return -1 }
    36  func (op atOp) CallsExtern() bool                                       { return false }
    37  func (op atOp) InferShape(...DimSizer) (retVal tensor.Shape, err error) { return scalarShape, nil }
    38  func (op atOp) DiffWRT(i int) []bool                                    { return make([]bool, i) }
    39  func (op atOp) SymDiff(Nodes, *Node, *Node) (Nodes, error)              { return nil, nondiffErr(op) }
    40  func (op atOp) String() string                                          { return fmt.Sprintf("At(%v)", op.coordinates) }
    41  
    42  func (op atOp) Do(inputs ...Value) (retVal Value, err error) {
    43  	if err = checkArity(op, len(inputs)); err != nil {
    44  		return
    45  	}
    46  
    47  	switch tt := inputs[0].(type) {
    48  	case *tensor.Dense:
    49  		var r interface{}
    50  		if r, err = tt.At(op.coordinates...); err != nil {
    51  			err = errors.Wrap(err, opDoFail)
    52  			return
    53  		}
    54  
    55  		retVal, _, _, err = anyToValue(r)
    56  	default:
    57  		err = errors.Errorf(nyiTypeFail, "atOp.Do()", tt)
    58  	}
    59  	return
    60  }
    61  
    62  func (op atOp) WriteHash(h hash.Hash) {
    63  	fmt.Fprintf(h, "atOp%v%v", op.d, op.coordinates)
    64  }
    65  
    66  func (op atOp) Hashcode() uint32 { return simpleHash(op) }
    67  
    68  func (op atOp) isStmt() bool { return true }
    69  
    70  type sizeOp struct {
    71  	axis, d int
    72  	val     int // if we know ahead of time what the size is...
    73  }
    74  
    75  func (op sizeOp) Arity() int { return 1 }
    76  
    77  // sizeOp is a function with this type:
    78  //		sizeOp :: Tensor d a → a
    79  func (op sizeOp) Type() hm.Type {
    80  	a := hm.TypeVariable('a')
    81  
    82  	// handle scalar cases
    83  	if op.d == 0 {
    84  		return hm.NewFnType(a, a)
    85  	}
    86  
    87  	tt := makeTensorType(op.d, a)
    88  	return hm.NewFnType(tt, a)
    89  }
    90  
    91  func (op sizeOp) ReturnsPtr() bool                             { return false }
    92  func (op sizeOp) OverwritesInput() int                         { return -1 }
    93  func (op sizeOp) CallsExtern() bool                            { return false }
    94  func (op sizeOp) InferShape(...DimSizer) (tensor.Shape, error) { return scalarShape, nil } // TODO: return error
    95  func (op sizeOp) DiffWRT(i int) []bool                         { return []bool{false} }
    96  func (op sizeOp) String() string {
    97  	if op.val != 0 {
    98  		return fmt.Sprintf("SizeOf=%d", op.val)
    99  	}
   100  	return fmt.Sprintf("SizeOf(%d)", op.axis)
   101  }
   102  
   103  func (op sizeOp) SymDiff(inputs Nodes, output, gradNode *Node) (Nodes, error) {
   104  	return nil, nondiffErr(op)
   105  }
   106  
   107  func (op sizeOp) Do(inputs ...Value) (retVal Value, err error) {
   108  	if err = checkArity(op, len(inputs)); err != nil {
   109  		return
   110  	}
   111  
   112  	switch t := inputs[0].(type) {
   113  	case Scalar:
   114  		retVal = one(t.Dtype())
   115  
   116  		// bools are special
   117  		if _, ok := t.(*B); ok {
   118  			retVal = NewI(1)
   119  		}
   120  	case tensor.Tensor:
   121  		sh := t.Shape()
   122  		if op.axis >= len(sh) {
   123  			return nil, errors.Errorf("Shape is %v. Want size of %d", sh, op.axis)
   124  		}
   125  		size := sh[op.axis]
   126  
   127  		// cast as ... types
   128  		switch t.Dtype() {
   129  		case tensor.Float64:
   130  			retVal = NewF64(float64(size))
   131  		case tensor.Float32:
   132  			retVal = NewF32(float32(size))
   133  		case tensor.Int:
   134  			retVal = NewI(size)
   135  		default:
   136  			return nil, errors.Errorf(nyiFail, "sizeOf.Do()", t.Dtype())
   137  		}
   138  	}
   139  
   140  	return
   141  }
   142  
   143  func (op sizeOp) WriteHash(h hash.Hash) {
   144  	h.Write([]byte("sizeOf"))
   145  	if err := binary.Write(h, binary.LittleEndian, byte(op.d)); err != nil {
   146  		panic(err)
   147  	}
   148  	h.Write([]byte("on"))
   149  	if err := binary.Write(h, binary.LittleEndian, byte(op.axis)); err != nil {
   150  		panic(err)
   151  	}
   152  }
   153  
   154  func (op sizeOp) Hashcode() uint32 { return simpleHash(op) }
   155  
   156  func (op sizeOp) DimSize(d int) (int, error) {
   157  	if d != op.axis {
   158  		return -1, errors.Errorf("Dimension mismatch. Size Op is for axis %d. Want Dim Size of %d", op.axis, d)
   159  	}
   160  	return op.val, nil
   161  }
   162  
   163  type repeatOp struct {
   164  	along      int
   165  	inputShape tensor.Shape
   166  }
   167  
   168  func newRepeatOp(along int, a *Node) *repeatOp {
   169  	return &repeatOp{
   170  		along:      along,
   171  		inputShape: a.Shape().Clone(),
   172  	}
   173  }
   174  
   175  func repeatedApply(along []int, children Nodes) (retVal *Node, err error) {
   176  	if len(children) != len(along)+1 {
   177  		return nil, errors.Errorf("Expected %v children. Got %v instead (hint: along axes and number of children must match)", len(along)+1, len(children))
   178  	}
   179  
   180  	retVal = children[0]
   181  	for i := range along {
   182  		op := newRepeatOp(along[i], retVal)
   183  		if retVal, err = ApplyOp(op, retVal, children[i+1]); err != nil {
   184  			return nil, err
   185  		}
   186  	}
   187  	return
   188  }
   189  
   190  func (op repeatOp) Arity() int { return 2 }
   191  
   192  // repeat is defined as one of the following:
   193  //		repeat :: Tensor-n a → a → Tensor-n a
   194  //		repeat :: a → Vector a
   195  // The end result must have the same dimensions as the input
   196  func (op repeatOp) Type() hm.Type {
   197  
   198  	a := hm.TypeVariable('a')
   199  
   200  	d := op.inputShape.Dims()
   201  
   202  	var i0t hm.Type
   203  	var rt hm.Type
   204  
   205  	if d == 0 {
   206  		i0t = a
   207  		rt = makeTensorType(d+1, a)
   208  	} else {
   209  		i0t = makeTensorType(d, a)
   210  		rt = makeTensorType(d, a)
   211  	}
   212  
   213  	return hm.NewFnType(i0t, a, rt)
   214  }
   215  
   216  func (op repeatOp) ReturnsPtr() bool     { return true }
   217  func (op repeatOp) OverwritesInput() int { return 0 }
   218  func (op repeatOp) CallsExtern() bool    { return true } // set to true because we want to force the VM to use PreallocDo
   219  
   220  func (op repeatOp) InferShape(inputs ...DimSizer) (retVal tensor.Shape, err error) {
   221  	retVal = inputs[0].(tensor.Shape).Clone()
   222  	rep, err := inputs[1].DimSize(op.along)
   223  	if err != nil {
   224  		return nil, err
   225  	}
   226  
   227  	// TODO: switch stmt
   228  	if retVal.IsVector() && retVal.Dims() <= op.along {
   229  		// extend
   230  		retVal = append(retVal, make(tensor.Shape, op.along-retVal.Dims()+1)...)
   231  		for i := range retVal {
   232  			if retVal[i] == 0 {
   233  				retVal[i] = 1
   234  			}
   235  		}
   236  	}
   237  	if retVal.IsScalar() {
   238  		retVal = tensor.Shape{1}
   239  	}
   240  	retVal[op.along] *= rep
   241  
   242  	return
   243  }
   244  
   245  func (op repeatOp) DiffWRT(i int) []bool {
   246  	symdiffLogf("DiffWRT: %d", i)
   247  	retVal := make([]bool, i)
   248  	retVal[0] = true
   249  	return retVal
   250  }
   251  
   252  func (op repeatOp) SymDiff(inputs Nodes, output, gradNode *Node) (retVal Nodes, err error) {
   253  	var n *Node
   254  	if n, err = Sum(gradNode, op.along); err == nil {
   255  		n.setGroup(gradClust)
   256  	}
   257  	retVal = make(Nodes, len(inputs))
   258  	retVal[0] = n
   259  	return
   260  }
   261  
   262  func (op repeatOp) DoDiff(ctx ExecutionContext, inputs Nodes, output *Node) (err error) {
   263  	if err = checkArity(op, len(inputs)); err != nil {
   264  		return
   265  	}
   266  	xdv, ydv := getDV(inputs[0], output)
   267  
   268  	var reps []int
   269  	var repeats []Value
   270  	for _, r := range inputs[1:] {
   271  		repeats = append(repeats, r.Value())
   272  	}
   273  
   274  	if reps, err = valuesToInts(repeats); err != nil {
   275  		return
   276  	}
   277  
   278  	xshape := xdv.Shape()
   279  	var d Value
   280  	d = ydv.d
   281  
   282  	// we make it a colVec
   283  	if xshape.IsVector() && !xshape.IsColVec() && !xshape.IsRowVec() {
   284  		xshape = tensor.Shape{xshape[0], 1}
   285  	}
   286  
   287  	if xshape.IsScalar() {
   288  		sum := newSumOp([]int{op.along}, output.shape, output.Dims())
   289  		if d, err = sum.Do(d); err != nil {
   290  			err = errors.Wrapf(err, doFail, sum)
   291  			return
   292  		}
   293  	} else {
   294  		axis := op.along
   295  		if xshape[axis] == 1 {
   296  			sum := newSumOp([]int{op.along}, output.shape, output.Dims())
   297  			if d, err = sum.Do(d); err != nil {
   298  				err = errors.Wrapf(err, doFail, sum)
   299  				return
   300  			}
   301  		} else {
   302  			newShape := xshape.Clone()
   303  			newShape = newShape[0 : axis+1]
   304  			newShape = append(newShape, reps...)
   305  			if axis+1 < xshape.Dims() {
   306  				newShape = append(newShape, xshape[axis+1:]...)
   307  			}
   308  
   309  			along := []int{axis + 1}
   310  
   311  			// a scalar can never get to this path
   312  			t := d.(tensor.Tensor)
   313  			if err = t.Reshape(newShape...); err != nil {
   314  				err = errors.Wrapf(err, reshapeFail, newShape, t.DataSize())
   315  				return
   316  			}
   317  
   318  			sum := newSumOp(along, newShape, len(newShape))
   319  			if d, err = sum.Do(d); err != nil {
   320  				err = errors.Wrapf(err, doFail, sum)
   321  				return
   322  			}
   323  			// sum.Do leaves the dimension of size 1 behind, so reshape here.
   324  			t = d.(tensor.Tensor)
   325  			finalShape := newShape[:axis+1]
   326  			if axis+1 < newShape.Dims() {
   327  				finalShape = append(finalShape, newShape[axis+2:]...)
   328  			}
   329  			if err = t.Reshape(finalShape...); err != nil {
   330  				err = errors.Wrapf(err, reshapeFail, newShape, t.DataSize())
   331  				return
   332  			}
   333  		}
   334  
   335  	}
   336  
   337  	add := newEBOByType(addOpType, TypeOf(xdv.d), TypeOf(d))
   338  	if d, err = add.UnsafeDo(xdv.d, d); err != nil {
   339  		return
   340  	}
   341  
   342  	if !add.ReturnsPtr() || inputs[0].IsScalar() {
   343  		err = xdv.SetDeriv(d)
   344  	}
   345  
   346  	return
   347  
   348  }
   349  
   350  func (op repeatOp) String() string { return fmt.Sprintf("Repeat%v", op.along) }
   351  
   352  // Do performs a repeat on the value.
   353  // TODO(anyone): implement for other types
   354  func (op repeatOp) Do(inputs ...Value) (retVal Value, err error) {
   355  	if err = checkArity(op, len(inputs)); err != nil {
   356  		return
   357  	}
   358  
   359  	var rep int
   360  	if rep, err = valueToInt(inputs[1]); err != nil {
   361  		return nil, errors.Wrapf(err, "Cannot convert %v to an int", inputs[1])
   362  	}
   363  
   364  	// process inputs[0]
   365  	var t tensor.Tensor
   366  	switch iv := inputs[0].(type) {
   367  	case Scalar:
   368  		s := iv.Data()
   369  		t = tensor.New(tensor.FromScalar(s))
   370  	case tensor.Tensor:
   371  		// if iv.Shape().IsScalarEquiv() {
   372  		// 	log.Printf("SCALAR EQUIV %v", iv.Data())
   373  		// 	t = iv.Clone().(tensor.Tensor)
   374  		// 	retVal = t
   375  		// 	return
   376  		// }
   377  		t = iv
   378  	default:
   379  		err = errors.Errorf(nyiTypeFail, "repeatOp.Do()", inputs[0])
   380  		return
   381  	}
   382  
   383  	// actually do repeat
   384  	if rep == 1 {
   385  		goto fin
   386  	}
   387  	if t, err = tensor.Repeat(t, op.along, rep); err != nil {
   388  		err = errors.Wrapf(err, repFail, op.along, rep)
   389  		return
   390  	}
   391  fin:
   392  	retVal = t
   393  	return
   394  }
   395  
   396  func (op repeatOp) WriteHash(h hash.Hash) {
   397  	fmt.Fprintf(h, "repeat %v %v", op.along, op.inputShape)
   398  	var arg0Dim int
   399  	if !op.inputShape.Eq(tensor.ScalarShape()) {
   400  		arg0Dim = op.inputShape[0]
   401  	}
   402  	if arg0Dim == 0 {
   403  		h.Write([]byte{1})
   404  	} else {
   405  		h.Write([]byte{0})
   406  	}
   407  }
   408  
   409  func (op repeatOp) Hashcode() uint32 { return simpleHash(op) }
   410  
   411  func (op repeatOp) UsePreallocDo(prealloc Value, inputs ...Value) (retVal Value, err error) {
   412  	pt, ok := prealloc.(tensor.Tensor)
   413  	if !ok {
   414  		return nil, errors.Errorf("Expected Tensor as a preallocated value. Got %v of %T instead", prealloc, prealloc)
   415  	}
   416  
   417  	if err = checkArity(op, len(inputs)); err != nil {
   418  		return
   419  	}
   420  
   421  	var rep int
   422  	if rep, err = valueToInt(inputs[1]); err != nil {
   423  		return nil, errors.Wrapf(err, "Cannot convert %v to an int", inputs[1])
   424  	}
   425  
   426  	// process inputs[0]
   427  	var t tensor.Tensor
   428  	switch iv := inputs[0].(type) {
   429  	case Scalar:
   430  		s := iv.Data()
   431  		pt.Memset(s)
   432  		retVal = pt
   433  		return
   434  		t = tensor.New(tensor.FromScalar(s))
   435  	case tensor.Tensor:
   436  		if iv.Shape().IsScalarEquiv() {
   437  			data := iv.Data()
   438  			switch dt := data.(type) {
   439  			case float64:
   440  				ptd := pt.Data().([]float64)
   441  				for i := range ptd {
   442  					ptd[i] = dt
   443  				}
   444  			case float32:
   445  				ptd := pt.Data().([]float32)
   446  				for i := range ptd {
   447  					ptd[i] = dt
   448  				}
   449  			case []float64:
   450  				ptd := pt.Data().([]float64)
   451  				for i := range ptd {
   452  					ptd[i] = dt[0]
   453  				}
   454  			case []float32:
   455  				ptd := pt.Data().([]float32)
   456  				for i := range ptd {
   457  					ptd[i] = dt[0]
   458  				}
   459  			}
   460  			return pt, nil
   461  		}
   462  		t = iv
   463  	default:
   464  		err = errors.Errorf(nyiTypeFail, "repeatOp.Do()", inputs[0])
   465  		return
   466  	}
   467  	if rep == 1 {
   468  		return Copy(pt, t)
   469  	}
   470  
   471  	return tensor.RepeatReuse(t, pt, op.along, rep)
   472  }
   473  
   474  // sliceOp represents a slicing operation. If end ⩽ start, it means ":"
   475  type sliceOp struct {
   476  	tensor.Slice
   477  
   478  	along int // along which axis to slice?
   479  
   480  	a int // along which axis of the original tensor
   481  	d int // how many dimensions were the original tensor
   482  }
   483  
   484  func (op *sliceOp) IsSlice() tensor.Slice { return op.Slice }
   485  
   486  func newSliceOp(s tensor.Slice, along, d int) *sliceOp {
   487  	return &sliceOp{
   488  		Slice: s,
   489  		along: along,
   490  		d:     d,
   491  	}
   492  }
   493  
   494  func (op *sliceOp) Arity() int { return 1 }
   495  
   496  // slicing a tensor value T[:] has type
   497  // 		slice :: Tensor a → Tensor a
   498  // 		slice :: Tensor a → a
   499  //
   500  // The latter is in the case where the resulting dimensions is 0, returning a scalar
   501  func (op *sliceOp) Type() hm.Type {
   502  	a := hm.TypeVariable('a')
   503  	tt := makeTensorType(op.d, a)
   504  
   505  	var selection int
   506  
   507  	if op.Slice == nil {
   508  		selection = -1
   509  	} else {
   510  		selection = op.End() - op.Start()
   511  	}
   512  
   513  	if selection == 1 {
   514  		if op.d == 1 {
   515  			return hm.NewFnType(tt, a)
   516  		}
   517  
   518  		tt2 := makeTensorType(op.d-1, a)
   519  		return hm.NewFnType(tt, tt2)
   520  	}
   521  
   522  	return hm.NewFnType(tt, tt)
   523  }
   524  
   525  func (op *sliceOp) InferShape(inputs ...DimSizer) (s tensor.Shape, err error) {
   526  	input := inputs[0].(tensor.Shape)
   527  	slices := make([]tensor.Slice, op.along+1)
   528  	slices[op.along] = op.Slice
   529  
   530  	return input.S(slices...)
   531  
   532  	// return input.S(op.Slice)
   533  }
   534  
   535  func (op *sliceOp) DiffWRT(i int) []bool {
   536  	if i > 1 {
   537  		// error
   538  		err := errors.Errorf("sliceOp should only have one or more inputs. Got %v instead", i)
   539  		panic(err)
   540  	}
   541  
   542  	return []bool{true}
   543  }
   544  
   545  func (op *sliceOp) SymDiff(inputs Nodes, outputNode, gradNode *Node) (retVal Nodes, err error) {
   546  	if err = checkArity(op, len(inputs)); err != nil {
   547  		return
   548  	}
   549  
   550  	t := inputs[0]
   551  	incrOp := sliceIncrOp{op}
   552  
   553  	retVal = make(Nodes, 1)
   554  	retVal[0], err = ApplyOp(incrOp, t, gradNode)
   555  	return
   556  }
   557  
   558  func (op *sliceOp) DoDiff(ctx ExecutionContext, inputs Nodes, output *Node) (err error) {
   559  	if err = checkArity(op, len(inputs)); err != nil {
   560  		return
   561  	}
   562  	xdv, ydv := getDV(inputs[0], output)
   563  
   564  	// var d Value
   565  	incrOp := sliceIncrOp{op}
   566  	if _, err = incrOp.UsePreallocDo(xdv.d, xdv.d, ydv.d); err != nil {
   567  		return errors.Wrapf(err, doFail, incrOp)
   568  	}
   569  
   570  	// there is no need to handle scalars, because you can never slice a scalar
   571  	// add := newElemBinOp(addOpType, inputs[0], output)
   572  	// if _, err = add.UnsafeDo(xdv.d, d); err != nil {
   573  	// 	return errors.Wrapf(err, unsafeDoFail, add)
   574  	// }
   575  
   576  	return
   577  }
   578  
   579  func (op *sliceOp) Do(inputs ...Value) (retVal Value, err error) {
   580  	if err = checkArity(op, len(inputs)); err != nil {
   581  		return
   582  	}
   583  
   584  	t := inputs[0]
   585  	// prep the slices
   586  	var slices []tensor.Slice
   587  	slices = make([]tensor.Slice, len(t.Shape()))
   588  
   589  	if !op.all() {
   590  		slices[op.along] = op
   591  	}
   592  	switch T := t.(type) {
   593  	case tensor.Tensor:
   594  		var v tensor.Tensor
   595  		if v, err = T.Slice(slices...); err != nil {
   596  			return nil, errors.Wrapf(err, sliceFail, slices)
   597  		}
   598  		if v.IsScalar() {
   599  			retVal, _ = anyToScalar(v.ScalarValue())
   600  		} else {
   601  			retVal = v.(tensor.View).Materialize()
   602  		}
   603  	case Scalar:
   604  		return nil, errors.New("Cannot slice a scalar value")
   605  	default:
   606  		return nil, errors.Errorf(nyiFail, "sliceOp.Do()", t)
   607  	}
   608  	return
   609  }
   610  
   611  func (op *sliceOp) ReturnsPtr() bool     { return true }
   612  func (op *sliceOp) CallsExtern() bool    { return true }
   613  func (op *sliceOp) OverwritesInput() int { return -1 }
   614  func (op sliceOp) WriteHash(h hash.Hash) {
   615  	h.Write([]byte("slice"))
   616  	if err := binary.Write(h, binary.LittleEndian, byte(op.d)); err != nil {
   617  		panic(err)
   618  	}
   619  	fmt.Fprintf(h, "%v", op.along)
   620  	if op.Slice == nil {
   621  		fmt.Fprintf(h, ":")
   622  		return
   623  	}
   624  
   625  	if err := binary.Write(h, binary.LittleEndian, byte(op.Start())); err != nil {
   626  		panic(err)
   627  	}
   628  	if err := binary.Write(h, binary.LittleEndian, byte(op.End())); err != nil {
   629  		panic(err)
   630  	}
   631  	if err := binary.Write(h, binary.LittleEndian, byte(op.Step())); err != nil {
   632  		panic(err)
   633  	}
   634  
   635  }
   636  func (op sliceOp) Hashcode() uint32 { return simpleHash(op) }
   637  
   638  func (op sliceOp) String() string {
   639  	var buf bytes.Buffer
   640  	buf.WriteString("T[")
   641  	for i := 0; i < op.along; i++ {
   642  		buf.WriteString(":, ")
   643  	}
   644  
   645  	if op.all() {
   646  		buf.WriteString(":")
   647  	} else {
   648  		fmt.Fprintf(&buf, "%d:%d:%d", op.Start(), op.End(), op.Step())
   649  	}
   650  
   651  	buf.WriteString("...]")
   652  	return buf.String()
   653  }
   654  
   655  // func (op sliceOp) CUDADo(extern External, dev Device, prealloc Value, inputs ...Value) (retVal Value, err error) {
   656  // 	return op.Do(inputs...)
   657  // }
   658  
   659  // func (op sliceOp) CUDAFuncName() string { return "" }
   660  
   661  func (op sliceOp) all() bool { return op.Slice == nil || op.End() <= op.Start() }
   662  
   663  // T[:] +=incr
   664  // THIS IS AN UNSAFE OPERATION
   665  type sliceIncrOp struct {
   666  	*sliceOp
   667  }
   668  
   669  // slicing a tensor value T[:] has type
   670  // 		slice :: Tensor a → b → Tensor a
   671  //
   672  // b can be a or Vector a
   673  func (op sliceIncrOp) Type() hm.Type {
   674  	a := hm.TypeVariable('a')
   675  	b := hm.TypeVariable('c')
   676  	tt := makeTensorType(op.d, a)
   677  
   678  	retVal := hm.NewFnType(tt, b, tt)
   679  	return retVal
   680  }
   681  
   682  func (op sliceIncrOp) Arity() int { return 2 }
   683  
   684  func (op sliceIncrOp) InferShape(inputs ...DimSizer) (retVal tensor.Shape, err error) {
   685  	retVal = inputs[0].(tensor.Shape)
   686  	return
   687  }
   688  
   689  func (op sliceIncrOp) DiffWRT(i int) []bool {
   690  	if err := checkArity(op, i); err != nil {
   691  		panic(err)
   692  	}
   693  
   694  	return []bool{true, false}
   695  }
   696  
   697  func (op sliceIncrOp) SymDiff(inputs Nodes, outputNode, gradNode *Node) (retVal Nodes, err error) {
   698  	var slicedRes *Node
   699  	if slicedRes, err = ApplyOp(op.sliceOp, gradNode); err != nil {
   700  		return nil, errors.Wrap(err, operationError)
   701  	}
   702  	retVal = Nodes{gradNode, slicedRes}
   703  
   704  	return
   705  }
   706  
   707  func (op sliceIncrOp) DoDiff(ctx ExecutionContext, inputs Nodes, output *Node) (err error) {
   708  	xdv, ydv, zdv := getDV3(inputs[0], inputs[1], output)
   709  
   710  	// dzdx
   711  	add := newElemBinOp(addOpType, inputs[0], output)
   712  	if _, err = add.UnsafeDo(xdv.d, zdv.d); err != nil {
   713  		return errors.Wrapf(err, unsafeDoFail, add)
   714  	}
   715  
   716  	// dzdy
   717  	var d Value
   718  	if d, err = op.sliceOp.Do(zdv.d); err != nil {
   719  		return errors.Wrapf(err, doFail, op)
   720  	}
   721  
   722  	add = newElemBinOp(addOpType, inputs[1], output)
   723  	if _, err = add.UnsafeDo(ydv.d, d); err != nil {
   724  		return errors.Wrapf(err, doFail, add)
   725  	}
   726  	return
   727  }
   728  
   729  func (op sliceIncrOp) Do(inputs ...Value) (retVal Value, err error) {
   730  	machineLogf("Doing %v", op)
   731  	enterLogScope()
   732  	defer leaveLogScope()
   733  
   734  	if err = checkArity(op, len(inputs)); err != nil {
   735  		return
   736  	}
   737  
   738  	t := inputs[0]
   739  	incr := inputs[1]
   740  
   741  	// prep the slices
   742  	slices := make([]tensor.Slice, op.d)
   743  	if !op.all() {
   744  		slices[op.along] = op
   745  	}
   746  
   747  	switch T := t.(type) {
   748  	case *tensor.Dense:
   749  		grad := tensor.NewDense(T.Dtype(), T.Shape().Clone())
   750  		var v tensor.Tensor
   751  		if v, err = grad.Slice(slices...); err != nil {
   752  			return nil, errors.Wrapf(err, sliceFail, slices)
   753  		}
   754  		switch i := incr.(type) {
   755  		case *F64:
   756  			tensor.Add(v, i.any(), tensor.UseUnsafe())
   757  		case *F32:
   758  			tensor.Add(v, i.any(), tensor.UseUnsafe())
   759  		case *tensor.Dense:
   760  			tensor.Add(v, i, tensor.UseUnsafe())
   761  		}
   762  		retVal = grad
   763  	case Scalar:
   764  		return nil, errors.New("Cannot slice a scalar value")
   765  	default:
   766  		return nil, errors.Errorf(nyiFail, "sliceIncrOp()", t)
   767  	}
   768  	return
   769  }
   770  
   771  func (op sliceIncrOp) UsePreallocDo(prealloc Value, inputs ...Value) (retVal Value, err error) {
   772  	machineLogf("Doing %v", op)
   773  	enterLogScope()
   774  	defer leaveLogScope()
   775  
   776  	if err = checkArity(op, len(inputs)); err != nil {
   777  		return
   778  	}
   779  	incr := inputs[1]
   780  
   781  	// prep the slices
   782  	slices := make([]tensor.Slice, op.d)
   783  	if !op.all() {
   784  		slices[op.along] = op
   785  	}
   786  
   787  	switch T := prealloc.(type) {
   788  	case *tensor.Dense:
   789  		var v tensor.Tensor
   790  		if v, err = T.Slice(slices...); err != nil {
   791  			return nil, errors.Wrapf(err, sliceFail, slices)
   792  		}
   793  		switch i := incr.(type) {
   794  		case *F64:
   795  			tensor.Add(v, i.any(), tensor.UseUnsafe())
   796  		case *F32:
   797  			tensor.Add(v, i.any(), tensor.UseUnsafe())
   798  		case *tensor.Dense:
   799  			tensor.Add(v, i, tensor.UseUnsafe())
   800  		}
   801  		retVal = T
   802  	case Scalar:
   803  		return nil, errors.New("Cannot slice a scalar value")
   804  	default:
   805  		return nil, errors.Errorf(nyiFail, "sliceIncrOp()", prealloc)
   806  	}
   807  	return
   808  }
   809  
   810  func (op sliceIncrOp) OverwritesInput() int { return 0 }
   811  
   812  func (op sliceIncrOp) WriteHash(h hash.Hash) {
   813  	h.Write([]byte("sliceIncr"))
   814  	if err := binary.Write(h, binary.LittleEndian, byte(op.d)); err != nil {
   815  		panic(err)
   816  	}
   817  	if err := binary.Write(h, binary.LittleEndian, byte(op.along)); err != nil {
   818  		panic(err)
   819  	}
   820  
   821  	if op.Slice == nil {
   822  		fmt.Fprintf(h, ":")
   823  		return
   824  	}
   825  
   826  	if err := binary.Write(h, binary.LittleEndian, byte(op.Start())); err != nil {
   827  		panic(err)
   828  	}
   829  	if err := binary.Write(h, binary.LittleEndian, byte(op.End())); err != nil {
   830  		panic(err)
   831  	}
   832  	if err := binary.Write(h, binary.LittleEndian, byte(op.Step())); err != nil {
   833  		panic(err)
   834  	}
   835  }
   836  
   837  func (op sliceIncrOp) Hashcode() uint32 { return simpleHash(op) }
   838  
   839  func (op sliceIncrOp) String() string {
   840  	var buf bytes.Buffer
   841  	buf.WriteString("T[")
   842  
   843  	for i := 0; i < op.along; i++ {
   844  		buf.WriteString(":, ")
   845  	}
   846  
   847  	if op.all() {
   848  		buf.WriteString(":")
   849  	} else {
   850  		fmt.Fprintf(&buf, "%d:%d:%d", op.Start(), op.End(), op.Step())
   851  	}
   852  
   853  	buf.WriteString("...]+=...")
   854  	return buf.String()
   855  }
   856  
   857  // func (op sliceIncrOp) UsePreallocDo(val Value, inputs ...Value) (Value, error) {
   858  
   859  // }
   860  
   861  type transposeOp struct {
   862  	pattern []int
   863  	d       int
   864  }
   865  
   866  func (op transposeOp) Arity() int { return 1 }
   867  
   868  // transposing a tensor has type
   869  // 		transpose :: Tensor a → Tensor a
   870  func (op transposeOp) Type() hm.Type {
   871  	a := hm.TypeVariable('a')
   872  	tt := makeTensorType(op.d, a)
   873  
   874  	return hm.NewFnType(tt, tt)
   875  }
   876  
   877  func (op transposeOp) InferShape(inputs ...DimSizer) (retVal tensor.Shape, err error) {
   878  	input := inputs[0].(tensor.Shape)
   879  	if input.IsScalar() {
   880  		return nil, errors.Errorf(undefinedOnShape, op, input)
   881  	}
   882  
   883  	retVal = make(tensor.Shape, len(input))
   884  	copy(retVal, input)
   885  	err = tensor.UnsafePermute(op.pattern, retVal)
   886  	return
   887  }
   888  
   889  func (op transposeOp) DiffWRT(i int) []bool {
   890  	if err := checkArity(op, i); err != nil {
   891  		panic(err)
   892  	}
   893  
   894  	return []bool{true}
   895  }
   896  
   897  func (op transposeOp) SymDiff(inputs Nodes, outputNode, gradNode *Node) (retVal Nodes, err error) {
   898  	newPattern := make([]int, len(op.pattern))
   899  	for i, p := range op.pattern {
   900  		newPattern[p] = i
   901  	}
   902  	op2 := transposeOp{pattern: newPattern, d: op.d}
   903  
   904  	retVal = make(Nodes, 1)
   905  	retVal[0], err = ApplyOp(op2, gradNode)
   906  	return
   907  }
   908  
   909  func (op transposeOp) DoDiff(ctx ExecutionContext, inputs Nodes, output *Node) (err error) {
   910  	xdv, zdv := getDV(inputs[0], output)
   911  
   912  	newPattern := make([]int, len(op.pattern))
   913  	for i, p := range op.pattern {
   914  		newPattern[p] = i
   915  	}
   916  
   917  	var zdvdT tensor.Tensor
   918  	var ok bool
   919  	if zdvdT, ok = zdv.d.(tensor.Tensor); !ok {
   920  		return errors.Errorf("Expected the gradient of the output node to be a Tensor. Got %v instead", zdv.d)
   921  	}
   922  
   923  	if err = zdvdT.T(newPattern...); err != nil {
   924  		return errors.Wrap(err, "Failed to T()")
   925  	}
   926  
   927  	d := tensor.Materialize(zdvdT)
   928  	zdvdT.UT()
   929  
   930  	add := newEBOByType(addOpType, inputs[0].t, TypeOf(zdvdT))
   931  	if _, err = add.UnsafeDo(xdv.d, d); err != nil {
   932  		err = errors.Wrapf(err, doFail, add)
   933  	}
   934  	return
   935  }
   936  
   937  func (op transposeOp) Do(inputs ...Value) (retVal Value, err error) {
   938  	machineLogf("Doing %v", op)
   939  	enterLogScope()
   940  	defer leaveLogScope()
   941  
   942  	if err = checkArity(op, len(inputs)); err != nil {
   943  		return
   944  	}
   945  
   946  	t := inputs[0].(tensor.Tensor)
   947  
   948  	throwaway := tensor.BorrowInts(len(op.pattern))
   949  	copy(throwaway, op.pattern)
   950  	// return tensor.T(t, throwaway...)
   951  
   952  	return tensor.Transpose(t, throwaway...)
   953  
   954  	// DEPRECATED
   955  	// the reason for this is because the .T() method of a Tensor
   956  	// will use the axes in the .transposedWith field
   957  	// Later when .UT() is called, the .transposedWith field is recycled into the pool
   958  	// throwaway := tensor.BorrowInts(len(op.pattern))
   959  	// copy(throwaway, op.pattern)
   960  
   961  	// t.T(throwaway...)
   962  	// ret := t.Materialize()
   963  	// t.UT()
   964  }
   965  
   966  func (op transposeOp) ReturnsPtr() bool     { return true }
   967  func (op transposeOp) CallsExtern() bool    { return false }
   968  func (op transposeOp) OverwritesInput() int { return 0 }
   969  
   970  func (op transposeOp) WriteHash(h hash.Hash) {
   971  	h.Write([]byte("transposeOp"))
   972  	fmt.Fprintf(h, "%v", op.pattern)
   973  	if err := binary.Write(h, binary.LittleEndian, byte(op.d)); err != nil {
   974  		panic(err)
   975  	}
   976  }
   977  
   978  func (op transposeOp) Hashcode() uint32 { return simpleHash(op) }
   979  
   980  func (op transposeOp) String() string {
   981  	var buf bytes.Buffer
   982  	buf.WriteString("Aᵀ{")
   983  	for i, ax := range op.pattern {
   984  		fmt.Fprintf(&buf, "%d", ax)
   985  		if i < len(op.pattern)-1 {
   986  			buf.WriteString(", ")
   987  		}
   988  	}
   989  
   990  	buf.WriteString("}")
   991  	return buf.String()
   992  }
   993  
   994  type concatOp struct {
   995  	axis     int
   996  	d        int
   997  	children int
   998  }
   999  
  1000  func (op concatOp) Arity() int { return -1 }
  1001  
  1002  // concat only works for Tensor types
  1003  //		concat :: Tensor a → Tensor a → ... → Tensor a
  1004  func (op concatOp) Type() hm.Type {
  1005  	tt := makeTensorType(op.d, hm.TypeVariable('a'))
  1006  	fnt := make([]hm.Type, op.children+1)
  1007  	for i := range fnt {
  1008  		fnt[i] = tt
  1009  	}
  1010  
  1011  	return hm.NewFnType(fnt...)
  1012  }
  1013  
  1014  func (op concatOp) InferShape(ds ...DimSizer) (tensor.Shape, error) {
  1015  	if len(ds) == 0 {
  1016  		return nil, errors.Errorf("No shapes passed in!")
  1017  	}
  1018  	shapes, err := DimSizersToShapes(ds)
  1019  	if err != nil {
  1020  		return nil, err
  1021  	}
  1022  
  1023  	return shapes[0].Concat(op.axis, shapes[1:]...)
  1024  }
  1025  
  1026  func (op concatOp) Do(vals ...Value) (Value, error) {
  1027  	if len(vals) == 1 {
  1028  		return vals[0], nil
  1029  	}
  1030  
  1031  	ts, err := valuesToTensors(vals)
  1032  	if err != nil {
  1033  		return nil, err
  1034  	}
  1035  
  1036  	return tensor.Concat(op.axis, ts[0], ts[1:]...)
  1037  }
  1038  
  1039  func (op concatOp) ReturnsPtr() bool     { return true }
  1040  func (op concatOp) CallsExtern() bool    { return false }
  1041  func (op concatOp) OverwritesInput() int { return -1 }
  1042  
  1043  func (op concatOp) WriteHash(h hash.Hash) {
  1044  	h.Write([]byte("concatOp"))
  1045  	fmt.Fprintf(h, "axis: %d, dims: %d", op.axis, op.d)
  1046  }
  1047  
  1048  func (op concatOp) Hashcode() uint32 { return simpleHash(op) }
  1049  
  1050  func (op concatOp) String() string {
  1051  	return fmt.Sprintf("Concat(axis=%d)", op.axis)
  1052  }
  1053  
  1054  func (op concatOp) DiffWRT(inputs int) []bool {
  1055  	retVal := make([]bool, inputs)
  1056  	for i := range retVal {
  1057  		retVal[i] = true
  1058  	}
  1059  	return retVal
  1060  }
  1061  
  1062  func (op concatOp) SymDiff(inputs Nodes, output *Node, grad *Node) (retVal Nodes, err error) {
  1063  	var start int
  1064  
  1065  	retVal = make(Nodes, len(inputs))
  1066  	for i, in := range inputs {
  1067  		if op.axis >= len(in.shape) {
  1068  			return nil, errors.Errorf("Wanted dimension %d is larger than the shape %v", op.axis, in.shape)
  1069  		}
  1070  		end := in.shape[op.axis] + start
  1071  
  1072  		s := newSliceOp(S(start, end), op.axis, op.d)
  1073  		if retVal[i], err = ApplyOp(s, grad); err != nil {
  1074  			return
  1075  		}
  1076  		start = end
  1077  	}
  1078  	return
  1079  }
  1080  
  1081  func (op concatOp) DoDiff(ctx ExecutionContext, inputs Nodes, output *Node) error {
  1082  	odv := output.boundTo.(*dualValue)
  1083  	odvd := odv.d.(tensor.Tensor)
  1084  
  1085  	var start int
  1086  	for _, in := range inputs {
  1087  		if op.axis >= len(in.shape) {
  1088  			return errors.Errorf("Wanted dimension %d is larger than the shape %v", op.axis, in.shape)
  1089  		}
  1090  		end := in.shape[op.axis] + start
  1091  
  1092  		idv := in.boundTo.(*dualValue)
  1093  		idvd := idv.d.(tensor.Tensor)
  1094  
  1095  		sliced, err := odvd.Slice(S(start, end))
  1096  		if err != nil {
  1097  			return err
  1098  		}
  1099  
  1100  		// TODO: fix VAdd hack
  1101  		// add to odvd
  1102  		switch st := sliced.(type) {
  1103  		case *tensor.Dense:
  1104  			d := idvd.(*tensor.Dense)
  1105  			d.Add(st, tensor.UseUnsafe())
  1106  		default:
  1107  			return errors.Errorf(nyiTypeFail, "DoDiff (hack) ", st)
  1108  		}
  1109  
  1110  		start = end
  1111  	}
  1112  	return nil
  1113  }
  1114  
  1115  type reshapeOp struct {
  1116  	from, to tensor.Shape
  1117  }
  1118  
  1119  func (op reshapeOp) Arity() int { return 1 }
  1120  func (op reshapeOp) Type() hm.Type {
  1121  	if op.from.Dims() != op.to.Dims() {
  1122  		fr := op.from.Dims()
  1123  		var frT hm.Type
  1124  		frT = newTensorType(fr, hm.TypeVariable('a'))
  1125  		if fr == 0 {
  1126  			frT = hm.TypeVariable('a')
  1127  		}
  1128  
  1129  		to := op.to.Dims()
  1130  		var toT hm.Type
  1131  		toT = newTensorType(to, hm.TypeVariable('a'))
  1132  		if to == 0 {
  1133  			toT = hm.TypeVariable('a')
  1134  		}
  1135  		return hm.NewFnType(frT, toT)
  1136  	}
  1137  	return hm.NewFnType(hm.TypeVariable('a'), hm.TypeVariable('a'))
  1138  }
  1139  func (op reshapeOp) InferShape(ds ...DimSizer) (tensor.Shape, error) { return op.to.Clone(), nil }
  1140  
  1141  func (op reshapeOp) Do(vals ...Value) (Value, error) {
  1142  	if err := checkArity(op, len(vals)); err != nil {
  1143  		return nil, err
  1144  	}
  1145  	var val Value
  1146  	var err error
  1147  	switch vals[0].(type) {
  1148  	case tensor.Tensor:
  1149  		if v, ok := vals[0].(*tensor.Dense); ok {
  1150  			if v.IsView() {
  1151  				val = v.Materialize()
  1152  			} else {
  1153  				val = v.ShallowClone()
  1154  			}
  1155  		} else {
  1156  			if val, err = CloneValue(vals[0]); err != nil {
  1157  				return nil, errors.Wrapf(err, cloneFail, vals[0])
  1158  			}
  1159  		}
  1160  		if val.Shape().TotalSize() != op.from.TotalSize() {
  1161  			return nil, errors.Errorf("Shape mismatch. Input shape is %v. Expected %v", val.Shape(), op.from)
  1162  		}
  1163  
  1164  		if err := val.(tensor.Tensor).Reshape(op.to...); err != nil {
  1165  			return nil, err
  1166  		}
  1167  		return val, nil
  1168  	case Scalar:
  1169  		v0 := ScalarAsTensor(vals[0], op.to.Dims(), nil)
  1170  		if err := v0.(tensor.Tensor).Reshape(op.to...); err != nil {
  1171  			return nil, err
  1172  		}
  1173  		return v0, nil
  1174  	default:
  1175  		return nil, errors.Errorf(nyiTypeFail, "reshape.Do", vals[0])
  1176  	}
  1177  }
  1178  
  1179  func (op reshapeOp) ReturnsPtr() bool     { return true }
  1180  func (op reshapeOp) CallsExtern() bool    { return false }
  1181  func (op reshapeOp) OverwritesInput() int { return 0 }
  1182  func (op reshapeOp) WriteHash(h hash.Hash) {
  1183  	h.Write([]byte("reshapeOp"))
  1184  	fmt.Fprintf(h, "from: %v, dims: %v", op.from, op.to)
  1185  }
  1186  
  1187  func (op reshapeOp) Hashcode() uint32 { return simpleHash(op) }
  1188  
  1189  func (op reshapeOp) String() string { return fmt.Sprintf("Reshape%v", op.to) }
  1190  
  1191  func (op reshapeOp) UnsafeDo(vals ...Value) (Value, error) {
  1192  	if err := checkArity(op, len(vals)); err != nil {
  1193  		return nil, err
  1194  	}
  1195  	var val Value
  1196  	var err error
  1197  	switch vals[0].(type) {
  1198  	case tensor.Tensor:
  1199  		val = vals[0]
  1200  		err = val.(tensor.Tensor).Reshape(op.to...)
  1201  
  1202  		return val, err
  1203  	case Scalar:
  1204  		v0 := ScalarAsTensor(vals[0], op.to.Dims(), nil)
  1205  		if err := v0.(tensor.Tensor).Reshape(op.to...); err != nil {
  1206  			return nil, err
  1207  		}
  1208  		return v0, nil
  1209  	default:
  1210  		return nil, errors.Errorf(nyiTypeFail, "reshape.Do", vals[0])
  1211  	}
  1212  }
  1213  
  1214  func (op reshapeOp) CUDADo(extern External, dev Device, prealloc Value, vals ...Value) (retVal Value, err error) {
  1215  	if err := checkArity(op, len(vals)); err != nil {
  1216  		return nil, err
  1217  	}
  1218  	val := vals[0]
  1219  	switch v := val.(type) {
  1220  	case tensor.Tensor:
  1221  		if err := v.Reshape(op.to...); err != nil {
  1222  			return nil, err
  1223  		}
  1224  		return v, nil
  1225  	case Scalar:
  1226  		vT := ScalarAsTensor(v, op.to.Dims(), nil)
  1227  		if err := vT.(tensor.Tensor).Reshape(op.to...); err != nil {
  1228  
  1229  			return nil, errors.Errorf(nyiTypeFail, "reshape.Do", "Scalar")
  1230  		}
  1231  		return vT, nil
  1232  	}
  1233  
  1234  	panic("Unreachable")
  1235  }
  1236  
  1237  func (op reshapeOp) DiffWRT(i int) []bool { return []bool{true} }
  1238  
  1239  func (op reshapeOp) SymDiff(inputs Nodes, output *Node, grad *Node) (retVal Nodes, err error) {
  1240  	var ret *Node
  1241  	if ret, err = Reshape(grad, op.from); err != nil {
  1242  		return
  1243  	}
  1244  	ret.setGroup(gradClust)
  1245  	return Nodes{ret}, nil
  1246  }
  1247  
  1248  func (op reshapeOp) DoDiff(ctx ExecutionContext, inputs Nodes, output *Node) (err error) {
  1249  	var grad Value
  1250  	if grad, err = output.Grad(); err != nil {
  1251  		return
  1252  	}
  1253  	T := grad.(tensor.Tensor)
  1254  	if err = T.Reshape(op.from...); err != nil {
  1255  		return
  1256  	}
  1257  	input := inputs[0]
  1258  	dv := input.boundTo.(*dualValue)
  1259  	return dv.SetDeriv(T)
  1260  }
  1261  
  1262  /* PRIVATE FUNCTIONS */
  1263  
  1264  // if value is contained in slice, contains returns the corresp. index in slice, -1 otherwise
  1265  func contains(slice []int, value int) int {
  1266  	if nil == slice {
  1267  		return -1
  1268  	}
  1269  
  1270  	for sliceIndex, sliceValue := range slice {
  1271  		if value == sliceValue {
  1272  			return sliceIndex
  1273  		}
  1274  	}
  1275  
  1276  	return -1
  1277  }
  1278  
  1279  // TODO: This function is an overkill for a small number of axes...
  1280  func sortUniqueIntWithImitator(toBeSorted, imitator []int) {
  1281  	toBeSortedBackup := make([]int, len(toBeSorted))
  1282  	for index, value := range toBeSorted {
  1283  		toBeSortedBackup[index] = value
  1284  	}
  1285  
  1286  	imitatorBackup := make([]int, len(imitator))
  1287  	for index, value := range imitator {
  1288  		imitatorBackup[index] = value
  1289  	}
  1290  
  1291  	sort.Ints(toBeSorted)
  1292  
  1293  	// Permutate the imitator accordingly
  1294  	for originalIndex, originalValue := range toBeSortedBackup {
  1295  		sortedIndex := sort.SearchInts(toBeSorted, originalValue)
  1296  
  1297  		imitator[sortedIndex] = imitatorBackup[originalIndex]
  1298  	}
  1299  
  1300  	return
  1301  }