gorgonia.org/gorgonia@v0.9.17/op_reduction.go (about)

     1  package gorgonia
     2  
     3  /*
     4  This file holds code for ndarray related reduction Ops.
     5  What this means is we take a ndarray, and reduce the dimensions down - typically to 1.
     6  For example, summing all the values in a matrix, or finding the max value.
     7  There is an additional field in each of these Ops - the 'along' field. This is because it's not always we want to reduce a ndarray down to a single scalar number
     8  */
     9  
    10  import (
    11  	"encoding/binary"
    12  	"fmt"
    13  	"hash"
    14  	"strings"
    15  
    16  	"github.com/chewxy/hm"
    17  	"github.com/pkg/errors"
    18  	"gorgonia.org/tensor"
    19  )
    20  
    21  func reductionType(d int, along []int) hm.Type {
    22  	a := hm.TypeVariable('a')
    23  	t := makeTensorType(d-len(along), a)
    24  
    25  	axes := make(map[int]bool)
    26  	for _, axis := range along {
    27  		if axis < d {
    28  			axes[axis] = true
    29  		}
    30  	}
    31  
    32  	if d == 1 || len(axes) == 0 || len(axes) == d {
    33  		// then it reduces down
    34  		return hm.NewFnType(t, a)
    35  	}
    36  
    37  	var retType hm.Type
    38  	if len(axes) == d-1 { // Only 1 non-reduced dim, so we can reduce to a vector as before.
    39  		retType = makeTensorType(1, a)
    40  	} else {
    41  		retType = t
    42  	}
    43  	return hm.NewFnType(t, retType)
    44  }
    45  
    46  func reductionInferShape(along []int, in tensor.Shape) (tensor.Shape, error) {
    47  	if len(along) == 0 {
    48  		return tensor.ScalarShape(), nil
    49  	}
    50  	shape := in.Clone()
    51  	for _, d := range along {
    52  		if d >= shape.Dims() {
    53  			return nil, fmt.Errorf("shape error, along %d is not a valid axis for shape %v", d, in)
    54  		}
    55  		shape[d] = 0
    56  	}
    57  
    58  	var dims []int
    59  	for _, d := range shape {
    60  		if d != 0 {
    61  			dims = append(dims, d)
    62  		}
    63  	}
    64  	if len(dims) == 0 {
    65  		return tensor.ScalarShape(), nil
    66  	}
    67  	return tensor.Shape(dims), nil
    68  }
    69  
    70  func reductionDo(op Op, s string, f func(*tensor.Dense, ...int) (*tensor.Dense, error), along []int, inputs ...Value) (retVal Value, err error) {
    71  	if err = checkArity(op, len(inputs)); err != nil {
    72  		return
    73  	}
    74  	at := inputs[0].(tensor.Tensor)
    75  	switch t := at.(type) {
    76  	case *tensor.Dense:
    77  		var ret *tensor.Dense
    78  		if ret, err = f(t, along...); err == nil {
    79  			if ret.IsScalar() {
    80  				retVal, _ = anyToScalar(ret.ScalarValue())
    81  			} else {
    82  				// the tensor reduction ops remove collapsed dimensions, but here we preserve them except in special cases.
    83  				// so we reshape the return to ensure the dimensions match.
    84  				var sh tensor.Shape
    85  				if sh, err = reductionInferShape(along, t.Shape()); err == nil {
    86  					if err = ret.Reshape(sh...); err == nil {
    87  						retVal = ret
    88  					}
    89  				}
    90  			}
    91  		} else {
    92  			return nil, errors.Wrap(err, fmt.Sprintf("failed to apply *tensor.Dense.%s()", strings.Title(s)))
    93  		}
    94  	default:
    95  		return nil, errors.Errorf(nyiFail, fmt.Sprintf("%sOp.Do()", s), at)
    96  	}
    97  	return
    98  
    99  }
   100  
   101  type maxOp struct {
   102  	along axes
   103  	d     int
   104  }
   105  
   106  func newMaxOp(along axes, dim int) *maxOp {
   107  	return &maxOp{
   108  		along: along,
   109  		d:     dim,
   110  	}
   111  }
   112  
   113  func (op maxOp) Arity() int { return 1 }
   114  
   115  func (op maxOp) Type() hm.Type {
   116  	return reductionType(op.d, op.along)
   117  }
   118  
   119  func (op maxOp) InferShape(dimsizers ...DimSizer) (tensor.Shape, error) {
   120  	if len(dimsizers) != 1 {
   121  		return nil, errors.Errorf("maxOp only takes one input shape to infer ")
   122  	}
   123  	return reductionInferShape(op.along, dimsizers[0].(tensor.Shape))
   124  }
   125  func (op maxOp) DiffWRT(i int) []bool { return []bool{true} }
   126  
   127  func (op maxOp) SymDiff(inputs Nodes, output, gradNode *Node) (retVal Nodes, err error) {
   128  	if err = checkArity(op, len(inputs)); err != nil {
   129  		return
   130  	}
   131  
   132  	t := inputs[0]
   133  	opDim := len(t.Shape())
   134  
   135  	var leftAxes []byte
   136  	for i := 0; i < opDim; i++ {
   137  		for _, ax := range op.along {
   138  			if i == ax {
   139  				leftAxes = append(leftAxes, byte(i))
   140  				break
   141  			}
   142  		}
   143  	}
   144  
   145  	var a, b, a2, b2, eq *Node
   146  	bcpat := NewBroadcastPattern(leftAxes, nil)
   147  	if a, b, err = Broadcast(output, t, bcpat); err != nil {
   148  		return nil, errors.Wrap(err, operationError)
   149  	}
   150  	if eq, err = Eq(a, b, true); err != nil {
   151  		return nil, errors.Wrap(err, operationError)
   152  	}
   153  
   154  	if a2, b2, err = Broadcast(gradNode, eq, bcpat); err != nil {
   155  		return nil, errors.Wrap(err, operationError)
   156  	}
   157  	retVal = make(Nodes, 1)
   158  	if retVal[0], err = HadamardProd(a2, b2); err != nil {
   159  		return nil, errors.Wrap(err, operationError)
   160  	}
   161  	return
   162  }
   163  
   164  func (op maxOp) Do(inputs ...Value) (retVal Value, err error) {
   165  	return reductionDo(op, "max", (*tensor.Dense).Max, op.along, inputs...)
   166  }
   167  
   168  func (op maxOp) ReturnsPtr() bool     { return true }
   169  func (op maxOp) OverwritesInput() int { return 0 }
   170  func (op maxOp) CallsExtern() bool    { return false }
   171  
   172  func (op maxOp) WriteHash(h hash.Hash) {
   173  	h.Write([]byte("max"))
   174  	if err := binary.Write(h, binary.LittleEndian, byte(op.d)); err != nil {
   175  		panic(err)
   176  	}
   177  	fmt.Fprintf(h, "%v->%v", op.d, op.along)
   178  }
   179  
   180  func (op maxOp) Hashcode() uint32 { return simpleHash(op) }
   181  
   182  func (op maxOp) String() string { return fmt.Sprintf("MaxAlong%v", op.along) }
   183  func (op maxOp) isUnary() bool  { return true }
   184  
   185  /* ARGMAX OP */
   186  // type argmaxOp struct {
   187  // 	along int // axis
   188  // }
   189  
   190  // func (op argmaxOp) Type() hm.Type {
   191  // 	a := hm.TypeVariable('a')
   192  
   193  // }
   194  
   195  /* SUM OP */
   196  
   197  type sumOp struct {
   198  	along      axes
   199  	d          int
   200  	inputShape tensor.Shape
   201  }
   202  
   203  func newSumOp(along axes, s tensor.Shape, d int) sumOp {
   204  	return sumOp{
   205  		along:      along,
   206  		d:          d,
   207  		inputShape: s,
   208  	}
   209  }
   210  
   211  func (op sumOp) Arity() int { return 1 }
   212  
   213  // sumOp is a function with this type:
   214  //		sumOp :: (Summable a) ⇒ Tensor d a → Tensor d-1 a
   215  func (op sumOp) Type() hm.Type {
   216  	return reductionType(op.d, op.along)
   217  }
   218  
   219  // InferShape infers the shape of a sumOp. It's purpose is to fulfil the Op interface. Only one input is expected, and the type is expected to be a tensor.Shape
   220  func (op sumOp) InferShape(inputs ...DimSizer) (shape tensor.Shape, err error) {
   221  	return reductionInferShape(op.along, inputs[0].(tensor.Shape))
   222  }
   223  
   224  func (op sumOp) DiffWRT(i int) []bool { return []bool{true} }
   225  
   226  func (op sumOp) SymDiff(inputs Nodes, output, gradNode *Node) (retVal Nodes, err error) {
   227  	if err = checkArity(op, len(inputs)); err != nil {
   228  		return
   229  	}
   230  
   231  	newShape := calcBroadcastShape(gradNode, op.d, op.along)
   232  	if gradNode, err = Reshape(gradNode, newShape); err != nil {
   233  		return nil, errors.Wrapf(err, "Unable to reshape grad node to %v", newShape)
   234  	}
   235  	gradNode.setGroup(gradClust)
   236  
   237  	children := make(Nodes, len(op.along)+1)
   238  	children[0] = gradNode
   239  
   240  	for i, a := range op.along {
   241  		var n *Node
   242  		if n, err = SizeOf(a, inputs[0]); err != nil {
   243  			return nil, errors.Wrap(err, operationError)
   244  		}
   245  		WithGroupName(gradClust)(n)
   246  		children[i+1] = n
   247  	}
   248  
   249  	retVal = make(Nodes, 1)
   250  	if retVal[0], err = repeatedApply(op.along, children); err != nil {
   251  		return nil, errors.Wrap(err, applyOpFail)
   252  	}
   253  	retVal[0].setGroup(gradClust)
   254  	return
   255  }
   256  
   257  func (op sumOp) DoDiff(ctx ExecutionContext, inputs Nodes, output *Node) (err error) {
   258  	if err = checkArity(op, len(inputs)); err != nil {
   259  		return
   260  	}
   261  
   262  	x := inputs[0]
   263  	xdv, ydv := getDV(x, output)
   264  	xShape := xdv.Value.Shape()
   265  
   266  	var T tensor.Tensor
   267  	switch ydvd := ydv.d.(type) {
   268  	case Scalar:
   269  		dt := ydvd.Dtype()
   270  		T = tensor.New(tensor.Of(dt), tensor.WithShape(xdv.d.Shape().Clone()...))
   271  		T.Memset(ydvd.Data())
   272  	case tensor.Tensor:
   273  		// handle broadcasting
   274  		if ydvd.Shape().Dims() == xdv.d.Shape().Dims()-len(op.along) {
   275  			newShape := xdv.d.Shape().Clone()
   276  			for _, a := range op.along {
   277  				newShape[a] = 1
   278  			}
   279  			ydvd.Reshape(newShape...)
   280  		}
   281  
   282  		T = ydvd
   283  	default:
   284  		err = errors.Errorf(nyiTypeFail, "sumOp.DoDiff()", ydv.d)
   285  		return
   286  	}
   287  
   288  	var val Value
   289  	if !T.Shape().Eq(xdv.d.Shape()) {
   290  		// TO DO: Optimize: figure out a way to bunch it all up so you can repeat in one call
   291  		for _, a := range op.along {
   292  			if xShape[a] == 1 {
   293  				continue // don't need to repeat
   294  			}
   295  
   296  			if T, err = tensor.Repeat(T, a, xShape[a]); err != nil {
   297  				return errors.Wrapf(err, repFail, a, xShape[a])
   298  			}
   299  		}
   300  		val = T
   301  	} else {
   302  		val = T
   303  	}
   304  
   305  	// then just add the two
   306  	add := newEBOByType(addOpType, TypeOf(xdv.d), TypeOf(val))
   307  	addOp := NewExternalOp(add, ctx, nil)
   308  	addOp.UseUnsafe = true
   309  	addOp.Device = x.Device()
   310  
   311  	dev := x.Device()
   312  	if output.Device() != dev && dev != CPU {
   313  		var valOnDev Value
   314  		if valOnDev, err = ctx.Transfer(dev, output.Device(), val, false); err != nil {
   315  			return
   316  		}
   317  		defer ctx.PutValue(dev, valOnDev)
   318  		val = valOnDev
   319  
   320  		// Copy(valOnDev, val)
   321  	}
   322  	var xd, d Value
   323  	var extra bool
   324  	if xd, extra, err = x.GradOnDevice(dev, ctx.External); err != nil {
   325  		return errors.Wrapf(err, gradOnDeviceFail, x, dev)
   326  	}
   327  	if extra {
   328  		defer ctx.PutValue(dev, xd)
   329  	}
   330  	if d, err = addOp.Do(xd, val); err != nil {
   331  		return errors.Wrapf(err, unsafeDoFail, add)
   332  	}
   333  
   334  	return xdv.SetDeriv(d)
   335  
   336  	// var d Value
   337  	// if d, err = add.UnsafeDo(xdv.d, val); err != nil {
   338  	// 	return errors.Wrapf(err, unsafeDoFail, add)
   339  	// }
   340  }
   341  
   342  func (op sumOp) Do(inputs ...Value) (retVal Value, err error) {
   343  	return reductionDo(op, "sum", (*tensor.Dense).Sum, op.along, inputs...)
   344  }
   345  
   346  func (op sumOp) ReturnsPtr() bool      { return true }
   347  func (op sumOp) OverwritesInput() int  { return 0 }
   348  func (op sumOp) CallsExtern() bool     { return false }
   349  func (op sumOp) WriteHash(h hash.Hash) { fmt.Fprintf(h, "sum%v->%v", op.along, op.inputShape) }
   350  func (op sumOp) Hashcode() uint32      { return simpleHash(op) }
   351  func (op sumOp) String() string        { return fmt.Sprintf("Σ%v", op.along) }
   352  func (op sumOp) isUnary() bool         { return true }