gorgonia.org/gorgonia@v0.9.17/op_reduction.go (about) 1 package gorgonia 2 3 /* 4 This file holds code for ndarray related reduction Ops. 5 What this means is we take a ndarray, and reduce the dimensions down - typically to 1. 6 For example, summing all the values in a matrix, or finding the max value. 7 There is an additional field in each of these Ops - the 'along' field. This is because it's not always we want to reduce a ndarray down to a single scalar number 8 */ 9 10 import ( 11 "encoding/binary" 12 "fmt" 13 "hash" 14 "strings" 15 16 "github.com/chewxy/hm" 17 "github.com/pkg/errors" 18 "gorgonia.org/tensor" 19 ) 20 21 func reductionType(d int, along []int) hm.Type { 22 a := hm.TypeVariable('a') 23 t := makeTensorType(d-len(along), a) 24 25 axes := make(map[int]bool) 26 for _, axis := range along { 27 if axis < d { 28 axes[axis] = true 29 } 30 } 31 32 if d == 1 || len(axes) == 0 || len(axes) == d { 33 // then it reduces down 34 return hm.NewFnType(t, a) 35 } 36 37 var retType hm.Type 38 if len(axes) == d-1 { // Only 1 non-reduced dim, so we can reduce to a vector as before. 39 retType = makeTensorType(1, a) 40 } else { 41 retType = t 42 } 43 return hm.NewFnType(t, retType) 44 } 45 46 func reductionInferShape(along []int, in tensor.Shape) (tensor.Shape, error) { 47 if len(along) == 0 { 48 return tensor.ScalarShape(), nil 49 } 50 shape := in.Clone() 51 for _, d := range along { 52 if d >= shape.Dims() { 53 return nil, fmt.Errorf("shape error, along %d is not a valid axis for shape %v", d, in) 54 } 55 shape[d] = 0 56 } 57 58 var dims []int 59 for _, d := range shape { 60 if d != 0 { 61 dims = append(dims, d) 62 } 63 } 64 if len(dims) == 0 { 65 return tensor.ScalarShape(), nil 66 } 67 return tensor.Shape(dims), nil 68 } 69 70 func reductionDo(op Op, s string, f func(*tensor.Dense, ...int) (*tensor.Dense, error), along []int, inputs ...Value) (retVal Value, err error) { 71 if err = checkArity(op, len(inputs)); err != nil { 72 return 73 } 74 at := inputs[0].(tensor.Tensor) 75 switch t := at.(type) { 76 case *tensor.Dense: 77 var ret *tensor.Dense 78 if ret, err = f(t, along...); err == nil { 79 if ret.IsScalar() { 80 retVal, _ = anyToScalar(ret.ScalarValue()) 81 } else { 82 // the tensor reduction ops remove collapsed dimensions, but here we preserve them except in special cases. 83 // so we reshape the return to ensure the dimensions match. 84 var sh tensor.Shape 85 if sh, err = reductionInferShape(along, t.Shape()); err == nil { 86 if err = ret.Reshape(sh...); err == nil { 87 retVal = ret 88 } 89 } 90 } 91 } else { 92 return nil, errors.Wrap(err, fmt.Sprintf("failed to apply *tensor.Dense.%s()", strings.Title(s))) 93 } 94 default: 95 return nil, errors.Errorf(nyiFail, fmt.Sprintf("%sOp.Do()", s), at) 96 } 97 return 98 99 } 100 101 type maxOp struct { 102 along axes 103 d int 104 } 105 106 func newMaxOp(along axes, dim int) *maxOp { 107 return &maxOp{ 108 along: along, 109 d: dim, 110 } 111 } 112 113 func (op maxOp) Arity() int { return 1 } 114 115 func (op maxOp) Type() hm.Type { 116 return reductionType(op.d, op.along) 117 } 118 119 func (op maxOp) InferShape(dimsizers ...DimSizer) (tensor.Shape, error) { 120 if len(dimsizers) != 1 { 121 return nil, errors.Errorf("maxOp only takes one input shape to infer ") 122 } 123 return reductionInferShape(op.along, dimsizers[0].(tensor.Shape)) 124 } 125 func (op maxOp) DiffWRT(i int) []bool { return []bool{true} } 126 127 func (op maxOp) SymDiff(inputs Nodes, output, gradNode *Node) (retVal Nodes, err error) { 128 if err = checkArity(op, len(inputs)); err != nil { 129 return 130 } 131 132 t := inputs[0] 133 opDim := len(t.Shape()) 134 135 var leftAxes []byte 136 for i := 0; i < opDim; i++ { 137 for _, ax := range op.along { 138 if i == ax { 139 leftAxes = append(leftAxes, byte(i)) 140 break 141 } 142 } 143 } 144 145 var a, b, a2, b2, eq *Node 146 bcpat := NewBroadcastPattern(leftAxes, nil) 147 if a, b, err = Broadcast(output, t, bcpat); err != nil { 148 return nil, errors.Wrap(err, operationError) 149 } 150 if eq, err = Eq(a, b, true); err != nil { 151 return nil, errors.Wrap(err, operationError) 152 } 153 154 if a2, b2, err = Broadcast(gradNode, eq, bcpat); err != nil { 155 return nil, errors.Wrap(err, operationError) 156 } 157 retVal = make(Nodes, 1) 158 if retVal[0], err = HadamardProd(a2, b2); err != nil { 159 return nil, errors.Wrap(err, operationError) 160 } 161 return 162 } 163 164 func (op maxOp) Do(inputs ...Value) (retVal Value, err error) { 165 return reductionDo(op, "max", (*tensor.Dense).Max, op.along, inputs...) 166 } 167 168 func (op maxOp) ReturnsPtr() bool { return true } 169 func (op maxOp) OverwritesInput() int { return 0 } 170 func (op maxOp) CallsExtern() bool { return false } 171 172 func (op maxOp) WriteHash(h hash.Hash) { 173 h.Write([]byte("max")) 174 if err := binary.Write(h, binary.LittleEndian, byte(op.d)); err != nil { 175 panic(err) 176 } 177 fmt.Fprintf(h, "%v->%v", op.d, op.along) 178 } 179 180 func (op maxOp) Hashcode() uint32 { return simpleHash(op) } 181 182 func (op maxOp) String() string { return fmt.Sprintf("MaxAlong%v", op.along) } 183 func (op maxOp) isUnary() bool { return true } 184 185 /* ARGMAX OP */ 186 // type argmaxOp struct { 187 // along int // axis 188 // } 189 190 // func (op argmaxOp) Type() hm.Type { 191 // a := hm.TypeVariable('a') 192 193 // } 194 195 /* SUM OP */ 196 197 type sumOp struct { 198 along axes 199 d int 200 inputShape tensor.Shape 201 } 202 203 func newSumOp(along axes, s tensor.Shape, d int) sumOp { 204 return sumOp{ 205 along: along, 206 d: d, 207 inputShape: s, 208 } 209 } 210 211 func (op sumOp) Arity() int { return 1 } 212 213 // sumOp is a function with this type: 214 // sumOp :: (Summable a) ⇒ Tensor d a → Tensor d-1 a 215 func (op sumOp) Type() hm.Type { 216 return reductionType(op.d, op.along) 217 } 218 219 // InferShape infers the shape of a sumOp. It's purpose is to fulfil the Op interface. Only one input is expected, and the type is expected to be a tensor.Shape 220 func (op sumOp) InferShape(inputs ...DimSizer) (shape tensor.Shape, err error) { 221 return reductionInferShape(op.along, inputs[0].(tensor.Shape)) 222 } 223 224 func (op sumOp) DiffWRT(i int) []bool { return []bool{true} } 225 226 func (op sumOp) SymDiff(inputs Nodes, output, gradNode *Node) (retVal Nodes, err error) { 227 if err = checkArity(op, len(inputs)); err != nil { 228 return 229 } 230 231 newShape := calcBroadcastShape(gradNode, op.d, op.along) 232 if gradNode, err = Reshape(gradNode, newShape); err != nil { 233 return nil, errors.Wrapf(err, "Unable to reshape grad node to %v", newShape) 234 } 235 gradNode.setGroup(gradClust) 236 237 children := make(Nodes, len(op.along)+1) 238 children[0] = gradNode 239 240 for i, a := range op.along { 241 var n *Node 242 if n, err = SizeOf(a, inputs[0]); err != nil { 243 return nil, errors.Wrap(err, operationError) 244 } 245 WithGroupName(gradClust)(n) 246 children[i+1] = n 247 } 248 249 retVal = make(Nodes, 1) 250 if retVal[0], err = repeatedApply(op.along, children); err != nil { 251 return nil, errors.Wrap(err, applyOpFail) 252 } 253 retVal[0].setGroup(gradClust) 254 return 255 } 256 257 func (op sumOp) DoDiff(ctx ExecutionContext, inputs Nodes, output *Node) (err error) { 258 if err = checkArity(op, len(inputs)); err != nil { 259 return 260 } 261 262 x := inputs[0] 263 xdv, ydv := getDV(x, output) 264 xShape := xdv.Value.Shape() 265 266 var T tensor.Tensor 267 switch ydvd := ydv.d.(type) { 268 case Scalar: 269 dt := ydvd.Dtype() 270 T = tensor.New(tensor.Of(dt), tensor.WithShape(xdv.d.Shape().Clone()...)) 271 T.Memset(ydvd.Data()) 272 case tensor.Tensor: 273 // handle broadcasting 274 if ydvd.Shape().Dims() == xdv.d.Shape().Dims()-len(op.along) { 275 newShape := xdv.d.Shape().Clone() 276 for _, a := range op.along { 277 newShape[a] = 1 278 } 279 ydvd.Reshape(newShape...) 280 } 281 282 T = ydvd 283 default: 284 err = errors.Errorf(nyiTypeFail, "sumOp.DoDiff()", ydv.d) 285 return 286 } 287 288 var val Value 289 if !T.Shape().Eq(xdv.d.Shape()) { 290 // TO DO: Optimize: figure out a way to bunch it all up so you can repeat in one call 291 for _, a := range op.along { 292 if xShape[a] == 1 { 293 continue // don't need to repeat 294 } 295 296 if T, err = tensor.Repeat(T, a, xShape[a]); err != nil { 297 return errors.Wrapf(err, repFail, a, xShape[a]) 298 } 299 } 300 val = T 301 } else { 302 val = T 303 } 304 305 // then just add the two 306 add := newEBOByType(addOpType, TypeOf(xdv.d), TypeOf(val)) 307 addOp := NewExternalOp(add, ctx, nil) 308 addOp.UseUnsafe = true 309 addOp.Device = x.Device() 310 311 dev := x.Device() 312 if output.Device() != dev && dev != CPU { 313 var valOnDev Value 314 if valOnDev, err = ctx.Transfer(dev, output.Device(), val, false); err != nil { 315 return 316 } 317 defer ctx.PutValue(dev, valOnDev) 318 val = valOnDev 319 320 // Copy(valOnDev, val) 321 } 322 var xd, d Value 323 var extra bool 324 if xd, extra, err = x.GradOnDevice(dev, ctx.External); err != nil { 325 return errors.Wrapf(err, gradOnDeviceFail, x, dev) 326 } 327 if extra { 328 defer ctx.PutValue(dev, xd) 329 } 330 if d, err = addOp.Do(xd, val); err != nil { 331 return errors.Wrapf(err, unsafeDoFail, add) 332 } 333 334 return xdv.SetDeriv(d) 335 336 // var d Value 337 // if d, err = add.UnsafeDo(xdv.d, val); err != nil { 338 // return errors.Wrapf(err, unsafeDoFail, add) 339 // } 340 } 341 342 func (op sumOp) Do(inputs ...Value) (retVal Value, err error) { 343 return reductionDo(op, "sum", (*tensor.Dense).Sum, op.along, inputs...) 344 } 345 346 func (op sumOp) ReturnsPtr() bool { return true } 347 func (op sumOp) OverwritesInput() int { return 0 } 348 func (op sumOp) CallsExtern() bool { return false } 349 func (op sumOp) WriteHash(h hash.Hash) { fmt.Fprintf(h, "sum%v->%v", op.along, op.inputShape) } 350 func (op sumOp) Hashcode() uint32 { return simpleHash(op) } 351 func (op sumOp) String() string { return fmt.Sprintf("Σ%v", op.along) } 352 func (op sumOp) isUnary() bool { return true }