gorgonia.org/gorgonia@v0.9.17/op_nn.go (about) 1 package gorgonia 2 3 import ( 4 "fmt" 5 "hash" 6 "time" 7 8 "github.com/chewxy/hm" 9 rng "github.com/leesper/go_rng" 10 "github.com/pkg/errors" 11 "gonum.org/v1/gonum/blas" 12 "gorgonia.org/tensor" 13 "gorgonia.org/vecf32" 14 "gorgonia.org/vecf64" 15 ) 16 17 // Sanity checks 18 var ( 19 _ SDOp = im2colOp{} 20 _ Op = col2imOp{} 21 _ Op = &maxPoolOp{} 22 _ Op = &maxPoolDiffOp{} 23 _ Op = &BatchNormOp{} 24 _ Op = &batchnormDiffOp{} 25 _ Op = &globalAveragePoolOp{} 26 ) 27 28 /* 29 This file contains all the Ops related to building a neural network. 30 31 Bear in mind that not all things that are related to a neural network are here, as not everything 32 are encoded as Ops the way theano does it. 33 34 See also: nn.go for functions that relate to neural networks 35 */ 36 37 type randomness byte 38 39 const ( 40 uniform randomness = iota 41 gaussian 42 binomial 43 ) 44 45 type randomOp struct { 46 which randomness 47 shape tensor.Shape 48 dt tensor.Dtype 49 50 a, b float64 // when uniform, a,b = low, high; when gaussian, a,b = mean, stdev 51 } 52 53 func makeRandomOp(which randomness, dt tensor.Dtype, a, b float64, shape ...int) randomOp { 54 return randomOp{ 55 which: which, 56 shape: tensor.Shape(shape), 57 dt: dt, 58 a: a, 59 b: b, 60 } 61 } 62 63 func (op randomOp) Arity() int { return 0 } 64 65 // randomOp :: a 66 // randomOp :: Tensor a 67 func (op randomOp) Type() hm.Type { 68 if op.shape.IsScalar() { 69 return op.dt 70 } 71 tt := newTensorType(op.shape.Dims(), op.dt) 72 return tt 73 } 74 75 func (op randomOp) InferShape(...DimSizer) (tensor.Shape, error) { return op.shape, nil } 76 77 func (op randomOp) Do(...Value) (retVal Value, err error) { 78 if op.shape.IsScalar() { 79 var v interface{} 80 switch op.dt { 81 case Float64: 82 switch op.which { 83 case uniform: 84 rand := rng.NewUniformGenerator(time.Now().UnixNano()) 85 v = rand.Float64Range(op.a, op.b) 86 case gaussian: 87 rand := rng.NewGaussianGenerator(time.Now().UnixNano()) 88 v = rand.Gaussian(op.a, op.b) 89 case binomial: 90 rand := rng.NewBinomialGenerator(time.Now().UnixNano()) 91 v = float64(rand.Binomial(int64(op.a), op.b)) 92 } 93 case Float32: 94 switch op.which { 95 case uniform: 96 rand := rng.NewUniformGenerator(time.Now().UnixNano()) 97 v = rand.Float32Range(float32(op.a), float32(op.b)) 98 case gaussian: 99 rand := rng.NewGaussianGenerator(time.Now().UnixNano()) 100 v = float32(rand.Gaussian(op.a, op.b)) 101 case binomial: 102 rand := rng.NewBinomialGenerator(time.Now().UnixNano()) 103 v = float32(rand.Binomial(int64(op.a), op.b)) 104 } 105 default: 106 return nil, errors.Errorf(nyiFail, "randomOp.do()", op.dt) 107 } 108 109 retVal, _ = anyToScalar(v) 110 return 111 } 112 113 switch op.dt { 114 case Float64: 115 switch op.which { 116 case uniform: 117 backing := Uniform64(op.a, op.b, op.shape...) 118 retVal = tensor.New(tensor.WithBacking(backing), tensor.WithShape(op.shape...)) 119 case gaussian: 120 backing := Gaussian64(op.a, op.b, op.shape...) 121 retVal = tensor.New(tensor.WithBacking(backing), tensor.WithShape(op.shape...)) 122 case binomial: 123 backing := Binomial64(op.a, op.b, op.shape...) 124 retVal = tensor.New(tensor.WithBacking(backing), tensor.WithShape(op.shape...)) 125 } 126 return 127 case Float32: 128 switch op.which { 129 case uniform: 130 backing := Uniform32(op.a, op.b, op.shape...) 131 retVal = tensor.New(tensor.WithBacking(backing), tensor.WithShape(op.shape...)) 132 case gaussian: 133 backing := Gaussian32(op.a, op.b, op.shape...) 134 retVal = tensor.New(tensor.WithBacking(backing), tensor.WithShape(op.shape...)) 135 case binomial: 136 backing := Binomial32(op.a, op.b, op.shape...) 137 retVal = tensor.New(tensor.WithBacking(backing), tensor.WithShape(op.shape...)) 138 } 139 return 140 default: 141 return nil, errors.Errorf(nyiFail, "randomOp.do() for non-scalar", op.dt) 142 } 143 } 144 145 func (op randomOp) ReturnsPtr() bool { return false } 146 func (op randomOp) CallsExtern() bool { return false } 147 func (op randomOp) OverwritesInput() int { return -1 } 148 func (op randomOp) WriteHash(h hash.Hash) { 149 fmt.Fprintf(h, "%d%v%f%f", op.which, op.shape, op.a, op.b) 150 } 151 152 func (op randomOp) Hashcode() uint32 { return simpleHash(op) } 153 154 func (op randomOp) String() string { 155 return fmt.Sprintf("%v(%v, %v) - %v", op.which, op.a, op.b, op.shape) 156 } 157 158 type im2colOp struct { 159 h, w int // kernel height and width 160 padH, padW int 161 strideH, strideW int 162 dilationH, dilationW int 163 } 164 165 func makeIm2ColOp(kernelHeight, kernelWidth, padHeight, padWidth, strideHeight, strideWidth, dilationHeight, dilationWidth int) im2colOp { 166 return im2colOp{ 167 h: kernelHeight, 168 w: kernelWidth, 169 padH: padHeight, 170 padW: padWidth, 171 strideH: strideHeight, 172 strideW: strideWidth, 173 dilationH: dilationHeight, 174 dilationW: dilationWidth, 175 } 176 } 177 178 func (op im2colOp) Arity() int { return 1 } 179 180 // im2col :: (Floats a) ⇒ Tensor a → Tensor a 181 func (op im2colOp) Type() hm.Type { 182 t := makeTensorType(4, hm.TypeVariable('a')) 183 return hm.NewFnType(t, t) 184 } 185 186 func (op im2colOp) InferShape(shapes ...DimSizer) (retVal tensor.Shape, err error) { 187 if err = checkArity(op, len(shapes)); err != nil { 188 return 189 } 190 191 if s, ok := shapes[0].(tensor.Shape); ok { 192 return op.calcShape(s), nil 193 } 194 return nil, errors.Errorf("expected tensor.Shape. got %T instead", shapes[0]) 195 } 196 197 func (op im2colOp) Do(inputs ...Value) (retVal Value, err error) { 198 if err = checkArity(op, len(inputs)); err != nil { 199 return 200 } 201 202 im := inputs[0] 203 204 // todo type check values 205 // todo shape check values 206 207 retShape := op.calcShape(im.Shape()) 208 prealloc := tensor.New(tensor.Of(im.Dtype()), tensor.WithShape(retShape...)) 209 210 return op.do(prealloc, im) 211 } 212 213 func (op im2colOp) ReturnsPtr() bool { return false } 214 func (op im2colOp) CallsExtern() bool { return false } 215 func (op im2colOp) OverwritesInput() int { return -1 } 216 217 func (op im2colOp) WriteHash(h hash.Hash) { 218 fmt.Fprintf(h, "im2col:%d-%d-%d-%d-%d-%d", op.h, op.w, op.padH, op.padW, op.strideH, op.strideW) 219 } 220 221 func (op im2colOp) Hashcode() uint32 { return simpleHash(op) } 222 223 func (op im2colOp) String() string { 224 return fmt.Sprintf("im2col<(%d,%d), (%d, %d), (%d,%d) (%d, %d)>", op.h, op.w, op.padH, op.padW, op.strideH, op.strideW, op.dilationH, op.dilationW) 225 } 226 227 func (op im2colOp) DiffWRT(i int) []bool { return []bool{true} } 228 229 func (op im2colOp) SymDiff(inputs Nodes, output, grad *Node) (retVal Nodes, err error) { 230 if err = checkArity(op, len(inputs)); err != nil { 231 return 232 } 233 im := inputs[0] 234 s := im.Shape() 235 if s.Dims() != 4 { 236 return nil, errors.Errorf("Expected input to have a shape with 4 dims") 237 } 238 var unpaddedB, unpaddedC, unpaddedH, unpaddedW int 239 unpaddedB, unpaddedC, unpaddedH, unpaddedW = s[0], s[1], s[2], s[3] 240 diffOp := col2imOp{ 241 unpaddedB: unpaddedB, 242 unpaddedC: unpaddedC, 243 unpaddedH: unpaddedH, 244 unpaddedW: unpaddedW, 245 246 im2colOp: op, 247 } 248 249 var ret *Node 250 if ret, err = ApplyOp(diffOp, grad); err != nil { 251 return 252 } 253 retVal = Nodes{ret} 254 return 255 } 256 257 func (op im2colOp) DoDiff(ctx ExecutionContext, inputs Nodes, output *Node) (err error) { 258 if err = checkArity(op, len(inputs)); err != nil { 259 return 260 } 261 262 im := inputs[0] 263 s := im.Shape() 264 imv, colv := getDV(im, output) 265 266 var unpaddedB, unpaddedC, unpaddedH, unpaddedW int 267 unpaddedB, unpaddedC, unpaddedH, unpaddedW = s[0], s[1], s[2], s[3] 268 diffOp := col2imOp{ 269 unpaddedB: unpaddedB, 270 unpaddedC: unpaddedC, 271 unpaddedH: unpaddedH, 272 unpaddedW: unpaddedW, 273 274 im2colOp: op, 275 } 276 277 if _, err = diffOp.UsePreallocDo(imv.d, colv.d); err != nil { 278 return errors.Wrapf(err, doFail, diffOp) 279 } 280 return 281 } 282 283 func (op im2colOp) calcShape(s tensor.Shape) (retVal tensor.Shape) { 284 b := s[0] 285 c := s[1] 286 h := s[2] 287 w := s[3] 288 289 retHeight, retWidth := op.retHW(h, w) 290 retVal = tensor.Shape(tensor.BorrowInts(4)) 291 292 // todo: double check this with tests 293 retVal[0] = b 294 retVal[1] = retHeight 295 retVal[2] = retWidth 296 retVal[3] = c * op.w * op.h 297 298 return 299 } 300 301 func (op im2colOp) retHW(h, w int) (retHeight, retWidth int) { 302 retHeight = (h+2*op.padH-(op.dilationH*(op.h-1)+1))/op.strideH + 1 303 retWidth = (w+2*op.padW-(op.dilationW*(op.w-1)+1))/op.strideW + 1 304 return 305 } 306 307 func (op im2colOp) do(prealloc, input Value) (retVal Value, err error) { 308 inputT := input.(*tensor.Dense) 309 outputT := prealloc.(*tensor.Dense) 310 311 // extract bchw - this bit can be expanded in the future, but for now we only support bchw 312 s := inputT.Shape() 313 b := s[0] 314 c := s[1] 315 h := s[2] 316 w := s[3] 317 318 inputStrides := inputT.Strides() 319 retHeight, retWidth := op.retHW(h, w) 320 batchStrideIm := inputStrides[0] 321 batchStrideCol := outputT.Strides()[0] 322 chanStride := h * w 323 inRowStride := inputStrides[2] 324 325 switch input.Dtype() { 326 case tensor.Float64: 327 imData := input.Data().([]float64) 328 colData := prealloc.Data().([]float64) 329 for i := 0; i < b; i++ { 330 imStart := i * batchStrideIm 331 colStart := i * batchStrideCol 332 imEnd := imStart + batchStrideIm 333 colEnd := colStart + batchStrideCol 334 335 if imEnd >= len(imData) { 336 imEnd = len(imData) 337 } 338 if colEnd >= len(colData) { 339 colEnd = len(colData) 340 } 341 342 op.f64s(c, h, w, chanStride, inRowStride, retHeight, retWidth, imData[imStart:imEnd], colData[colStart:colEnd]) 343 } 344 case tensor.Float32: 345 imData := input.Data().([]float32) 346 colData := prealloc.Data().([]float32) 347 for i := 0; i < b; i++ { 348 imStart := i * batchStrideIm 349 colStart := i * batchStrideCol 350 imEnd := imStart + batchStrideIm 351 colEnd := colStart + batchStrideCol 352 353 if imEnd >= len(imData) { 354 imEnd = len(imData) 355 } 356 if colEnd >= len(colData) { 357 colEnd = len(colData) 358 } 359 360 op.f32s(c, h, w, chanStride, inRowStride, retHeight, retWidth, imData[imStart:imEnd], colData[colStart:colEnd]) 361 } 362 default: 363 return nil, errors.Errorf(nyiFail, "im2col", input.Dtype()) 364 } 365 return prealloc, nil 366 } 367 368 func (op im2colOp) f64s(chans, height, width, chanStride, inRowStride, retHeight, retWidth int, im, col []float64) { 369 colIdx := 0 370 var inputRow int 371 var inputCol int 372 for outputRow := 0; outputRow < retHeight; outputRow++ { 373 for outputCol := 0; outputCol < retWidth; outputCol++ { 374 for ch := 0; ch < chans; ch++ { 375 for kernelRow := 0; kernelRow < op.h; kernelRow++ { 376 inputRow = -op.padH + kernelRow*op.dilationH + outputRow*op.strideH 377 for kernelCol := 0; kernelCol < op.w; kernelCol++ { 378 if inputRow < 0 || inputRow >= height { 379 col[colIdx] = 0 380 colIdx++ 381 continue 382 } 383 inputCol = -op.padW + kernelCol*op.dilationW + outputCol*op.strideW 384 if inputCol < 0 || inputCol >= width { 385 col[colIdx] = 0 386 colIdx++ 387 } else { 388 imIdx := chanStride*ch + inputRow*width + inputCol 389 col[colIdx] = im[imIdx] 390 colIdx++ 391 } 392 } 393 } 394 } 395 } 396 } 397 } 398 399 func (op im2colOp) f32s(chans, height, width, chanStride, inRowStride, retHeight, retWidth int, im, col []float32) { 400 colIdx := 0 401 var inputRow int 402 var inputCol int 403 for outputRow := 0; outputRow < retHeight; outputRow++ { 404 for outputCol := 0; outputCol < retWidth; outputCol++ { 405 for ch := 0; ch < chans; ch++ { 406 for kernelRow := 0; kernelRow < op.h; kernelRow++ { 407 inputRow = -op.padH + kernelRow*op.dilationH + outputRow*op.strideH 408 for kernelCol := 0; kernelCol < op.w; kernelCol++ { 409 if inputRow < 0 || inputRow >= height { 410 col[colIdx] = 0 411 colIdx++ 412 continue 413 } 414 inputCol = -op.padW + kernelCol*op.dilationW + outputCol*op.strideW 415 if inputCol < 0 || inputCol >= width { 416 col[colIdx] = 0 417 colIdx++ 418 } else { 419 imIdx := chanStride*ch + inputRow*width + inputCol 420 col[colIdx] = im[imIdx] 421 colIdx++ 422 } 423 } 424 } 425 } 426 } 427 } 428 } 429 430 type col2imOp struct { 431 // input shapes of im2col 432 unpaddedB int 433 unpaddedC int 434 unpaddedH int 435 unpaddedW int 436 437 im2colOp 438 } 439 440 func (op col2imOp) Arity() int { return 1 } 441 442 // im2col :: (Floats a) ⇒ a → a 443 func (op col2imOp) Type() hm.Type { 444 return hm.NewFnType(hm.TypeVariable('a'), hm.TypeVariable('a')) 445 } 446 447 func (op col2imOp) InferShape(shapes ...DimSizer) (retVal tensor.Shape, err error) { 448 return tensor.Shape{op.unpaddedB, op.unpaddedC, op.unpaddedH, op.unpaddedW}, nil 449 } 450 451 func (op col2imOp) Do(inputs ...Value) (retVal Value, err error) { 452 if err = checkArity(op, len(inputs)); err != nil { 453 return 454 } 455 456 im := inputs[0] 457 458 // todo type check values 459 // todo shape check values 460 461 retShape := tensor.Shape{op.unpaddedB, op.unpaddedC, op.unpaddedH, op.unpaddedW} 462 prealloc := tensor.New(tensor.Of(im.Dtype()), tensor.WithShape(retShape...)) 463 464 return op.do(prealloc, im) 465 } 466 467 func (op col2imOp) ReturnsPtr() bool { return false } 468 func (op col2imOp) CallsExtern() bool { return false } 469 func (op col2imOp) OverwritesInput() int { return -1 } 470 471 func (op col2imOp) WriteHash(h hash.Hash) { 472 fmt.Fprintf(h, "col2im:%d-%d-%d-%d-%d-%d", op.h, op.w, op.padH, op.padW, op.strideH, op.strideW) 473 } 474 475 func (op col2imOp) Hashcode() uint32 { return simpleHash(op) } 476 477 func (op col2imOp) String() string { 478 return fmt.Sprintf("col2im<(%d,%d), (%d, %d), (%d,%d)>", op.h, op.w, op.padH, op.padW, op.strideH, op.strideW) 479 } 480 481 func (op col2imOp) UsePreallocDo(prealloc Value, inputs ...Value) (Value, error) { 482 if err := checkArity(op, len(inputs)); err != nil { 483 return nil, err 484 } 485 return op.do(prealloc, inputs[0]) 486 } 487 488 func (op col2imOp) do(prealloc, input Value) (retVal Value, err error) { 489 b := op.unpaddedB 490 c := op.unpaddedC 491 retHeight := op.unpaddedH 492 retWidth := op.unpaddedW 493 batchStrideIm := c * retHeight * retWidth 494 495 s := input.Shape() 496 h := s[1] 497 w := s[2] 498 chanStride := retHeight * retWidth 499 batchStrideCol := h * w * s[3] 500 501 var imStart, imEnd, colStart, colEnd int 502 imEnd = imStart + batchStrideIm 503 colEnd = colStart + batchStrideCol 504 505 switch input.Dtype() { 506 case tensor.Float64: 507 colData := input.Data().([]float64) 508 imData := prealloc.Data().([]float64) 509 for i := 0; i < b; i++ { 510 op.f64s(c, retHeight, retWidth, chanStride, h, w, colData[colStart:colEnd], imData[imStart:imEnd]) 511 512 colStart += batchStrideCol 513 colEnd += batchStrideCol 514 515 imStart += batchStrideIm 516 imEnd += batchStrideIm 517 518 if imEnd > len(imData) { 519 imEnd = len(imData) 520 } 521 if colEnd > len(colData) { 522 colEnd = len(colData) 523 } 524 } 525 case tensor.Float32: 526 colData := input.Data().([]float32) 527 imData := prealloc.Data().([]float32) 528 for i := 0; i < b; i++ { 529 op.f32s(c, retHeight, retWidth, chanStride, h, w, colData[colStart:colEnd], imData[imStart:imEnd]) 530 531 colStart += batchStrideCol 532 colEnd += batchStrideCol 533 534 imStart += batchStrideIm 535 imEnd += batchStrideIm 536 537 if imEnd > len(imData) { 538 imEnd = len(imData) 539 } 540 if colEnd > len(colData) { 541 colEnd = len(colData) 542 } 543 } 544 default: 545 return nil, errors.Errorf(nyiFail, "col2im", input.Dtype()) 546 } 547 548 return prealloc, nil 549 } 550 551 func (op col2imOp) f64s(chans, height, width, chanStride, retHeight, retWidth int, col, im []float64) { 552 // memset im to 0 553 for i := 0; i < len(im); i++ { 554 im[i] = 0 555 } 556 colIdx := 0 557 var inputRow int 558 var inputCol int 559 for outputRow := 0; outputRow < retHeight; outputRow++ { 560 for outputCol := 0; outputCol < retWidth; outputCol++ { 561 for ch := 0; ch < chans; ch++ { 562 for kernelRow := 0; kernelRow < op.h; kernelRow++ { 563 inputRow = -op.padH + kernelRow*op.dilationH + outputRow*op.strideH 564 for kernelCol := 0; kernelCol < op.w; kernelCol++ { 565 if inputRow < 0 || inputRow >= height { 566 colIdx++ 567 continue 568 } 569 inputCol = -op.padW + kernelCol*op.dilationW + outputCol*op.strideW 570 if inputCol >= 0 && inputCol < width { 571 imIdx := chanStride*ch + inputRow*width + inputCol 572 im[imIdx] += col[colIdx] 573 } 574 colIdx++ 575 } 576 } 577 } 578 } 579 } 580 } 581 582 func (op col2imOp) f32s(chans, height, width, chanStride, retHeight, retWidth int, col, im []float32) { 583 // memset im to 0 584 for i := 0; i < len(im); i++ { 585 im[i] = 0 586 } 587 colIdx := 0 588 var inputRow int 589 var inputCol int 590 for outputRow := 0; outputRow < retHeight; outputRow++ { 591 for outputCol := 0; outputCol < retWidth; outputCol++ { 592 for ch := 0; ch < chans; ch++ { 593 for kernelRow := 0; kernelRow < op.h; kernelRow++ { 594 inputRow = -op.padH + kernelRow*op.dilationH + outputRow*op.strideH 595 for kernelCol := 0; kernelCol < op.w; kernelCol++ { 596 if inputRow < 0 || inputRow >= height { 597 colIdx++ 598 continue 599 } 600 inputCol = -op.padW + kernelCol*op.dilationW + outputCol*op.strideW 601 if inputCol >= 0 && inputCol < width { 602 imIdx := chanStride*ch + inputRow*width + inputCol 603 im[imIdx] += col[colIdx] 604 } 605 colIdx++ 606 } 607 } 608 } 609 } 610 } 611 } 612 613 // It's important to note that this op actually produces TWO values - one argmax, which will be used 614 // as a mask, and the actual pooled value. 615 // 616 // The argmax is stored as an internal state and is not exposed to anything outside the op. 617 // There are alternative ways of designing this op, but they all don't particularly seem nice. 618 // Caffe's technique seemed the nicest. 619 type maxPoolOp struct { 620 // Shape of Input 621 unpaddedB int 622 unpaddedC int 623 unpaddedH int 624 unpaddedW int 625 626 h, w int // patch height and width 627 padNorth, padWest int 628 padSouth, padEast int 629 explicitPadding bool 630 strideH, strideW int 631 632 // execution state 633 // the mask is only filled at execution time 634 mask tensor.Tensor 635 } 636 637 func newMaxPoolOp(inputShape, kernel tensor.Shape, pad, stride []int) *maxPoolOp { 638 padNorth := pad[0] 639 padWest := pad[1] 640 padSouth := pad[0] 641 padEast := pad[1] 642 explicitPadding := false 643 if len(pad) == 4 { 644 explicitPadding = true 645 padNorth = pad[0] 646 padSouth = pad[1] 647 padWest = pad[2] 648 padEast = pad[3] 649 } 650 maxpoolOp := &maxPoolOp{ 651 // Shape of Input 652 unpaddedB: inputShape[0], 653 unpaddedC: inputShape[1], 654 unpaddedH: inputShape[2], 655 unpaddedW: inputShape[3], 656 657 h: kernel[0], 658 w: kernel[1], 659 padNorth: padNorth, 660 padWest: padWest, 661 padSouth: padSouth, 662 padEast: padEast, 663 explicitPadding: explicitPadding, 664 strideH: stride[0], 665 strideW: stride[1], 666 } 667 maxpoolOp.mask = tensor.New(tensor.Of(tensor.Int), tensor.WithShape(maxpoolOp.calcShape(inputShape)...)) 668 return maxpoolOp 669 } 670 671 func (op *maxPoolOp) Arity() int { return 1 } 672 673 // maxPoolOp has this type: 674 // op :: (...) → (...) 675 func (op *maxPoolOp) Type() hm.Type { 676 a := hm.TypeVariable('a') 677 t := newTensorType(4, a) 678 return hm.NewFnType(t, t) 679 } 680 func (op *maxPoolOp) InferShape(inputs ...DimSizer) (tensor.Shape, error) { 681 if s, ok := inputs[0].(tensor.Shape); ok { 682 return op.calcShape(s), nil 683 } 684 return nil, errors.Errorf("Expected a shape") 685 } 686 687 func (op *maxPoolOp) Do(inputs ...Value) (retVal Value, err error) { 688 var in, out tensor.Tensor 689 if in, err = op.checkInput(inputs...); err != nil { 690 return nil, err 691 } 692 inShp := in.Shape() 693 out = tensor.New(tensor.Of(in.Dtype()), tensor.WithShape(op.calcShape(inShp)...), tensor.WithEngine(in.Engine())) 694 op.do(out, in) 695 return out, nil 696 } 697 698 func (op *maxPoolOp) ReturnsPtr() bool { return false } 699 func (op *maxPoolOp) CallsExtern() bool { return false } 700 func (op *maxPoolOp) OverwritesInput() int { return -1 } 701 func (op *maxPoolOp) WriteHash(h hash.Hash) { 702 fmt.Fprintf(h, "MaxPool{%d, %d, %d, %d}(kernel: (%d, %d), pad: (%d, %d), stride: (%d, %d))", 703 op.unpaddedB, op.unpaddedC, op.unpaddedH, op.unpaddedW, 704 op.h, op.w, op.padNorth, op.padWest, op.strideH, op.strideW) 705 } 706 707 func (op *maxPoolOp) Hashcode() uint32 { return simpleHash(op) } 708 709 func (op *maxPoolOp) String() string { 710 return fmt.Sprintf("MaxPool{%d, %d, %d, %d}(kernel: (%d, %d), pad: (%d, %d), stride: (%d, %d))", 711 op.unpaddedB, op.unpaddedC, op.unpaddedH, op.unpaddedW, 712 op.h, op.w, op.padNorth, op.padWest, op.strideH, op.strideW) 713 } 714 715 func (op *maxPoolOp) UsePreallocDo(prealloc Value, inputs ...Value) (Value, error) { 716 var in tensor.Tensor 717 var err error 718 if in, err = op.checkInput(inputs...); err != nil { 719 return nil, err 720 } 721 722 if p, ok := prealloc.(tensor.Tensor); ok { 723 op.do(p, in) 724 return p, nil 725 } 726 return nil, errors.Errorf("Expected prealloc to be a tensor") 727 } 728 729 func (op *maxPoolOp) DiffWRT(inputs int) []bool { return []bool{true} } 730 731 func (op *maxPoolOp) SymDiff(inputs Nodes, output, grad *Node) (retVal Nodes, err error) { 732 if err = checkArity(op, len(inputs)); err != nil { 733 return 734 } 735 input := inputs[0] 736 737 var op2 maxPoolOp 738 op2 = *op 739 diff := &maxPoolDiffOp{op2} 740 741 var ret *Node 742 if ret, err = ApplyOp(diff, input, output, grad); err != nil { 743 return nil, err 744 } 745 return Nodes{ret}, nil 746 } 747 748 func (op *maxPoolOp) DoDiff(ctx ExecutionContext, inputs Nodes, output *Node) (err error) { 749 if err = checkArity(op, len(inputs)); err != nil { 750 return 751 } 752 input := inputs[0] 753 inputDV, outDV := getDV(input, output) 754 755 var op2 maxPoolOp 756 op2 = *op 757 diff := &maxPoolDiffOp{op2} 758 759 if _, err = diff.UsePreallocDo(inputDV.d, inputDV.Value, outDV.Value, outDV.d); err != nil { 760 return errors.Wrapf(err, doFail, diff) 761 } 762 return 763 } 764 765 func (op *maxPoolOp) checkInput(inputs ...Value) (tensor.Tensor, error) { 766 if err := checkArity(op, len(inputs)); err != nil { 767 return nil, err 768 } 769 770 var in tensor.Tensor 771 var ok bool 772 if in, ok = inputs[0].(tensor.Tensor); !ok { 773 return nil, errors.Errorf("Expected input to be a tensor") 774 } 775 776 if in.Shape().Dims() != 4 { 777 return nil, errors.Errorf("Expected input to have 4 dimensions") 778 } 779 return in, nil 780 } 781 782 // calcShape calculates the output shape given an input shape 783 func (op *maxPoolOp) calcShape(s tensor.Shape) tensor.Shape { 784 b, c, h, w := s[0], s[1], s[2], s[3] 785 786 pooledH := (h+op.padSouth+op.padNorth-(op.h-1)-1)/op.strideH + 1 787 pooledW := (w+op.padEast+op.padWest-(op.w-1)-1)/op.strideW + 1 788 return tensor.Shape{b, c, pooledH, pooledW} 789 } 790 791 // do prepares the data, and then dispatches it to the correct (computation) kernel. 792 // out is the preallocated tensor 793 func (op *maxPoolOp) do(out, in tensor.Tensor) { 794 outShape := out.Shape() 795 outStride := out.Strides()[1] 796 inShape := in.Shape() 797 inStride := in.Strides()[1] 798 maskStride := op.mask.Strides()[1] 799 800 b, c, h, w := outShape[0], outShape[1], outShape[2], outShape[3] 801 inH, inW := inShape[2], inShape[3] 802 803 if op.mask == nil { 804 op.mask = tensor.New(tensor.Of(tensor.Int), tensor.WithShape(op.calcShape(inShape)...)) 805 } 806 807 maskData := op.mask.Data().([]int) 808 809 switch in.Dtype() { 810 case tensor.Float64: 811 op.f64s(b, c, h, w, inH, inW, 812 outStride, inStride, maskStride, 813 out.Data().([]float64), in.Data().([]float64), 814 maskData) 815 case tensor.Float32: 816 op.f32s(b, c, h, w, inH, inW, 817 outStride, inStride, maskStride, 818 out.Data().([]float32), in.Data().([]float32), 819 maskData) 820 } 821 } 822 823 func (op *maxPoolOp) f32s(batches, channels, outH, outW, inH, inW, 824 outStride, inStride, maskStride int, 825 outData, inData []float32, 826 maskData []int) { 827 828 // set values 829 for i := range outData { 830 outData[i] = -maxFloat32 831 maskData[i] = -1 832 } 833 padH := op.padNorth 834 padW := op.padWest 835 if op.explicitPadding { 836 padH = op.padSouth 837 padW = op.padEast 838 } 839 840 for b := 0; b < batches; b++ { 841 for c := 0; c < channels; c++ { 842 for ph := 0; ph < outH; ph++ { 843 for pw := 0; pw < outW; pw++ { 844 845 hStart := ph*op.strideH - padH 846 wStart := pw*op.strideW - padW 847 hEnd := minInt(hStart+op.h, inH) 848 wEnd := minInt(wStart+op.w, inW) 849 hStart = maxInt(hStart, 0) 850 wStart = maxInt(wStart, 0) 851 852 poolIndex := ph*outW + pw 853 for hi := hStart; hi < hEnd; hi++ { 854 for wi := wStart; wi < wEnd; wi++ { 855 i := hi*inW + wi 856 if inData[i] > outData[poolIndex] { 857 outData[poolIndex] = inData[i] 858 maskData[poolIndex] = i 859 } 860 } 861 } 862 } 863 } 864 // skip by strides 865 inData = inData[inStride:] 866 outData = outData[outStride:] 867 maskData = maskData[maskStride:] 868 } 869 } 870 } 871 872 func (op *maxPoolOp) f64s(batches, channels, outH, outW, inH, inW, 873 outStride, inStride, maskStride int, 874 outData, inData []float64, 875 maskData []int) { 876 877 // set values 878 for i := range outData { 879 outData[i] = -maxFloat64 880 maskData[i] = -1 881 } 882 padH := op.padNorth 883 padW := op.padWest 884 if op.explicitPadding { 885 padH = op.padSouth 886 padW = op.padEast 887 } 888 889 for b := 0; b < batches; b++ { 890 for c := 0; c < channels; c++ { 891 for ph := 0; ph < outH; ph++ { 892 for pw := 0; pw < outW; pw++ { 893 hStart := ph*op.strideH - padH 894 wStart := pw*op.strideW - padW 895 hEnd := minInt(hStart+op.h, inH) 896 wEnd := minInt(wStart+op.w, inW) 897 hStart = maxInt(hStart, 0) 898 wStart = maxInt(wStart, 0) 899 900 poolIndex := ph*outW + pw 901 902 for hi := hStart; hi < hEnd; hi++ { 903 for wi := wStart; wi < wEnd; wi++ { 904 i := hi*inW + wi 905 if inData[i] > outData[poolIndex] { 906 outData[poolIndex] = inData[i] 907 maskData[poolIndex] = i 908 } 909 } 910 } 911 } 912 } 913 // skip by strides 914 inData = inData[inStride:] 915 outData = outData[outStride:] 916 maskData = maskData[maskStride:] 917 } 918 } 919 } 920 921 type maxPoolDiffOp struct { 922 maxPoolOp 923 } 924 925 func (op *maxPoolDiffOp) Arity() int { return 3 } 926 func (op *maxPoolDiffOp) Type() hm.Type { 927 a := hm.TypeVariable('a') 928 t := newTensorType(4, a) 929 return hm.NewFnType(t, t, t, t) 930 } 931 932 func (op *maxPoolDiffOp) InferShape(inputs ...DimSizer) (tensor.Shape, error) { 933 s := inputs[0].(tensor.Shape).Clone() 934 return s, nil 935 } 936 937 func (op *maxPoolDiffOp) Do(inputs ...Value) (Value, error) { 938 var in, out, pooled, pooledGrad tensor.Tensor 939 var err error 940 if in, pooled, pooledGrad, err = op.checkInput(inputs...); err != nil { 941 return nil, err 942 } 943 944 // out is the gradient of in 945 out = tensor.New(tensor.Of(in.Dtype()), tensor.WithShape(in.Shape().Clone()...), tensor.WithEngine(in.Engine())) 946 op.do(out, in, pooled, pooledGrad) 947 return out, nil 948 } 949 func (op *maxPoolDiffOp) ReturnsPtr() bool { return true } 950 func (op *maxPoolDiffOp) CallsExtern() bool { return false } 951 func (op *maxPoolDiffOp) OverwritesInput() int { return -1 } 952 func (op *maxPoolDiffOp) WriteHash(h hash.Hash) { 953 fmt.Fprintf(h, "MaxPoolDiff{%d, %d, %d, %d}(kernel: (%d, %d), pad: (%d, %d), stride: (%d, %d))", 954 op.unpaddedB, op.unpaddedC, op.unpaddedH, op.unpaddedW, 955 op.h, op.w, op.padNorth, op.padWest, op.strideH, op.strideW) 956 } 957 958 func (op *maxPoolDiffOp) Hashcode() uint32 { return simpleHash(op) } 959 960 func (op *maxPoolDiffOp) String() string { 961 return fmt.Sprintf("MaxPoolDiff{%d, %d, %d, %d}(kernel: (%d, %d), pad: (%d, %d), stride: (%d, %d))", 962 op.unpaddedB, op.unpaddedC, op.unpaddedH, op.unpaddedW, 963 op.h, op.w, op.padNorth, op.padWest, op.strideH, op.strideW) 964 } 965 966 func (op *maxPoolDiffOp) UsePreallocDo(prealloc Value, inputs ...Value) (Value, error) { 967 var in, pooled, pooledGrad tensor.Tensor 968 var err error 969 if in, pooled, pooledGrad, err = op.checkInput(inputs...); err != nil { 970 return nil, err 971 } 972 if p, ok := prealloc.(tensor.Tensor); ok { 973 op.do(p, in, pooled, pooledGrad) 974 return prealloc, nil 975 } 976 return nil, errors.Errorf("Cannot do with PreallocDo - expected PreAlloc to be tensor") 977 } 978 979 func (op *maxPoolDiffOp) checkInput(inputs ...Value) (in, pooled, pooledGrad tensor.Tensor, err error) { 980 if err = checkArity(op, len(inputs)); err != nil { 981 return 982 } 983 984 var ok bool 985 if in, ok = inputs[0].(tensor.Tensor); !ok { 986 err = errors.Errorf("Expected input to be a tensor") 987 return 988 } 989 if in.Shape().Dims() != 4 { 990 err = errors.Errorf("Expected input to have 4 dimensions") 991 return 992 } 993 994 if pooled, ok = inputs[1].(tensor.Tensor); !ok { 995 err = errors.Errorf("Expected pooled to be a tensor") 996 return 997 } 998 if pooledGrad, ok = inputs[2].(tensor.Tensor); !ok { 999 err = errors.Errorf("Expected pooledGrad to be a tensor") 1000 return 1001 } 1002 return 1003 } 1004 1005 func (op *maxPoolDiffOp) do(inGrad, in, pooled, pooledGrad tensor.Tensor) { 1006 pooledShape := pooled.Shape() 1007 pooledStride := pooled.Strides()[1] 1008 inStride := in.Strides()[1] 1009 maskStride := op.mask.Strides()[1] 1010 maskData := op.mask.Data().([]int) 1011 1012 b, c, h, w := pooledShape[0], pooledShape[1], pooledShape[2], pooledShape[3] 1013 switch in.Dtype() { 1014 case tensor.Float32: 1015 inGradData := inGrad.Data().([]float32) 1016 pooledGradData := pooledGrad.Data().([]float32) 1017 op.f32s(b, c, h, w, 1018 inStride, pooledStride, maskStride, 1019 inGradData, pooledGradData, maskData) 1020 case tensor.Float64: 1021 inGradData := inGrad.Data().([]float64) 1022 pooledGradData := pooledGrad.Data().([]float64) 1023 op.f64s(b, c, h, w, 1024 inStride, pooledStride, maskStride, 1025 inGradData, pooledGradData, maskData) 1026 } 1027 } 1028 1029 // in is the "bottom", while out is the "top" (bottom being the unpooled, and top being the pooled) 1030 func (op *maxPoolDiffOp) f32s(batches, channels, pooledH, pooledW int, 1031 inStride, outStride, maskStride int, 1032 inDiffData, outDiffData []float32, 1033 maskData []int) { 1034 1035 // zero out. let's hope go's optimizer is smart enought 1036 for i := range inDiffData { 1037 inDiffData[i] = 0 1038 } 1039 1040 // this loop can be goroutine'd 1041 for b := 0; b < batches; b++ { 1042 for c := 0; c < channels; c++ { 1043 for ph := 0; ph < pooledH; ph++ { 1044 for pw := 0; pw < pooledW; pw++ { 1045 index := ph*pooledW + pw 1046 inIndex := maskData[index] 1047 inDiffData[inIndex] += outDiffData[index] 1048 } 1049 } 1050 outDiffData = outDiffData[outStride:] 1051 inDiffData = inDiffData[inStride:] 1052 maskData = maskData[maskStride:] 1053 } 1054 } 1055 } 1056 1057 // in is the "bottom", while out is the "top" (bottom being the unpooled, and top being the pooled) 1058 func (op *maxPoolDiffOp) f64s(batches, channels, pooledH, pooledW int, 1059 inStride, outStride, maskStride int, 1060 inDiffData, outDiffData []float64, 1061 maskData []int) { 1062 1063 // zero out. let's hope go's optimizer is smart enought 1064 for i := range inDiffData { 1065 inDiffData[i] = 0 1066 } 1067 1068 // this loop can be goroutine'd 1069 for b := 0; b < batches; b++ { 1070 for c := 0; c < channels; c++ { 1071 for ph := 0; ph < pooledH; ph++ { 1072 for pw := 0; pw < pooledW; pw++ { 1073 index := ph*pooledW + pw 1074 inIndex := maskData[index] 1075 inDiffData[inIndex] += outDiffData[index] 1076 } 1077 } 1078 outDiffData = outDiffData[outStride:] 1079 inDiffData = inDiffData[inStride:] 1080 maskData = maskData[maskStride:] 1081 } 1082 } 1083 } 1084 1085 // clampOp is a constant clamping operation 1086 type clampOp struct { 1087 min, max Scalar 1088 } 1089 1090 func (op *clampOp) Arity() int { return 1 } 1091 1092 func (op *clampOp) Type() hm.Type { 1093 return hm.NewFnType(hm.TypeVariable('a'), hm.TypeVariable('a')) 1094 } 1095 1096 func (op *clampOp) InferShape(shps ...DimSizer) (tensor.Shape, error) { 1097 return shps[0].(tensor.Shape), nil 1098 } 1099 1100 func (op *clampOp) Do(vals ...Value) (Value, error) { 1101 return nil, nil 1102 } 1103 1104 func (op *clampOp) ReturnsPtr() bool { return true } 1105 1106 func (op *clampOp) CallsExtern() bool { return false } 1107 1108 func (op *clampOp) OverwritesInput() int { return 0 } 1109 1110 func (op *clampOp) WriteHash(h hash.Hash) { fmt.Fprintf(h, "ConstClamp{%f, %f}()", op.min, op.max) } 1111 1112 func (op *clampOp) Hashcode() uint32 { return simpleHash(op) } 1113 func (op *clampOp) String() string { return fmt.Sprintf("ConstClamp{%f, %f}()", op.min, op.max) } 1114 1115 // BatchNormOp is a batch normalization process as described by Ioffe and Szegedy (2015) - 1116 // http://arxiv.org/abs/1502.03167 1117 // 1118 // Normalization is done as: 1119 // γ(x - μ) / σ + β 1120 // γ is the scaling factor and β is the offset factor. These are created by BatchNorm() 1121 type BatchNormOp struct { 1122 momentum float64 // momentum for the moving average 1123 epsilon float64 // small variance to be added to avoid dividing by 0 1124 dims int // 2 or 4. defaults to 4 1125 1126 // learnables 1127 mean, variance, ma *tensor.Dense 1128 1129 // scratch space 1130 meanTmp, varianceTmp, tmpSpace, xNorm *tensor.Dense 1131 batchSumMultiplier, numByChans, spatialSumMultiplier *tensor.Dense 1132 1133 // training? if training then update movingMean and movingVar 1134 training bool 1135 } 1136 1137 // Arity returns 1 1138 func (op *BatchNormOp) Arity() int { return 1 } 1139 1140 // Type ... 1141 func (op *BatchNormOp) Type() hm.Type { 1142 dims := op.dims 1143 if dims == 0 { 1144 dims = 4 // default to 4 if not set 1145 } 1146 1147 t := TensorType{Dims: dims, Of: hm.TypeVariable('a')} 1148 return hm.NewFnType(t, t) 1149 } 1150 1151 // InferShape from the input values 1152 func (op *BatchNormOp) InferShape(ns ...DimSizer) (tensor.Shape, error) { 1153 if err := checkArity(op, len(ns)); err != nil { 1154 return nil, errors.Wrapf(err, "batchNorm") 1155 } 1156 1157 return ns[0].(tensor.Shape).Clone(), nil 1158 } 1159 1160 // Do performs the batchnorm computation on the values 1161 func (op *BatchNormOp) Do(values ...Value) (retVal Value, err error) { 1162 if err := checkArity(op, len(values)); err != nil { 1163 return nil, errors.Wrapf(err, "batchNorm Do") 1164 } 1165 var v, out Value 1166 v = values[0] 1167 if out, err = CloneValue(v); err != nil { 1168 return nil, err 1169 } 1170 return op.UsePreallocDo(out, v) 1171 } 1172 1173 // ReturnsPtr is true 1174 func (op *BatchNormOp) ReturnsPtr() bool { return true } 1175 1176 // CallsExtern is false 1177 func (op *BatchNormOp) CallsExtern() bool { return false } 1178 1179 // OverwritesInput is -1 (operator doesn't overwrite any input value) 1180 func (op *BatchNormOp) OverwritesInput() int { return -1 } 1181 1182 // WriteHash ... 1183 func (op *BatchNormOp) WriteHash(h hash.Hash) { 1184 fmt.Fprintf(h, "batchnorm-%1.1f-%1.1f", op.momentum, op.epsilon) 1185 } 1186 1187 // Hashcode ... 1188 func (op *BatchNormOp) Hashcode() uint32 { return simpleHash(op) } 1189 1190 func (op *BatchNormOp) String() string { 1191 return fmt.Sprintf("batchnorm-%1.1f-%1.1f", op.momentum, op.epsilon) 1192 } 1193 1194 // DoDiff does the gradient computation 1195 func (op *BatchNormOp) DoDiff(ctx ExecutionContext, inputs Nodes, output *Node) error { 1196 diff := &batchnormDiffOp{op} 1197 xdv, ydv := getDV(inputs[0], output) 1198 _, err := diff.UsePreallocDo(xdv.d, xdv.Value, ydv.d) 1199 return err 1200 } 1201 1202 // DiffWRT ... 1203 func (op *BatchNormOp) DiffWRT(inputs int) []bool { return []bool{true} } 1204 1205 // SymDiff ... 1206 func (op *BatchNormOp) SymDiff(inputs Nodes, output *Node, grad *Node) (retVal Nodes, err error) { 1207 if err = checkArity(op, len(inputs)); err != nil { 1208 return 1209 } 1210 input := inputs[0] 1211 diff := &batchnormDiffOp{op} 1212 1213 var ret *Node 1214 if ret, err = ApplyOp(diff, input, grad); err != nil { 1215 return nil, err 1216 } 1217 return Nodes{ret}, nil 1218 } 1219 1220 // UsePreallocDo ... 1221 func (op *BatchNormOp) UsePreallocDo(prealloc Value, inputs ...Value) (retVal Value, err error) { 1222 v := inputs[0] 1223 switch v.Dtype() { 1224 case Float64: 1225 err = op.f64s(v.(*tensor.Dense), prealloc.(*tensor.Dense)) 1226 case Float32: 1227 err = op.f32s(v.(*tensor.Dense), prealloc.(*tensor.Dense)) 1228 default: 1229 return nil, nyi("BatchNorm Do", v.Dtype()) 1230 } 1231 return prealloc, err 1232 } 1233 1234 // SetTraining configure the op for training mode. 1235 // A call to this function implicitly calls the Reset() method 1236 func (op *BatchNormOp) SetTraining() { op.Reset(); op.training = true } 1237 1238 // SetTesting configure the op for testing mode 1239 func (op *BatchNormOp) SetTesting() { op.training = false } 1240 1241 // Reset the operator by zeroing the internals scratch spaces 1242 func (op *BatchNormOp) Reset() error { 1243 dt := op.ma.Dtype() 1244 var uno interface{} 1245 switch dt { 1246 case Float64: 1247 uno = float64(1) 1248 case Float32: 1249 uno = float32(1) 1250 } 1251 1252 if err := op.spatialSumMultiplier.Memset(uno); err != nil { 1253 return err 1254 } 1255 1256 if err := op.batchSumMultiplier.Memset(uno); err != nil { 1257 return err 1258 } 1259 1260 op.mean.Zero() 1261 op.variance.Zero() 1262 op.ma.Zero() 1263 op.meanTmp.Zero() 1264 op.varianceTmp.Zero() 1265 op.tmpSpace.Zero() 1266 op.numByChans.Zero() 1267 return nil 1268 } 1269 1270 func (op *BatchNormOp) f64s(input, output *tensor.Dense) (err error) { 1271 n := input.Shape()[0] 1272 channels := input.Shape()[1] 1273 nc := channels * n 1274 spatialDim := input.Shape().TotalSize() / (nc) 1275 1276 inputF64s := input.Float64s() 1277 outputF64s := output.Float64s() 1278 copy(outputF64s, inputF64s) 1279 1280 meanTmp := op.meanTmp.Float64s() 1281 mean := op.mean.Float64s() 1282 varianceTmp := op.varianceTmp.Float64s() 1283 variance := op.variance.Float64s() 1284 tmp := op.tmpSpace.Float64s() 1285 ssm := op.spatialSumMultiplier.Float64s() 1286 nbc := op.numByChans.Float64s() 1287 bsm := op.batchSumMultiplier.Float64s() 1288 1289 momentum := op.momentum 1290 eps := op.epsilon 1291 1292 if !op.training { 1293 // use stored mean/variance estimates 1294 scaleFactor := float64(1) 1295 if fst := op.ma.Float64s()[0]; fst != 1 { 1296 scaleFactor = fst 1297 } 1298 copy(meanTmp, mean) 1299 whichblas.Dscal(len(meanTmp), scaleFactor, meanTmp, 1) 1300 copy(varianceTmp, variance) 1301 whichblas.Dscal(len(varianceTmp), scaleFactor, varianceTmp, 1) 1302 } else { 1303 // compute mean 1304 alpha := 1.0 / float64(n*spatialDim) 1305 whichblas.Dgemv(blas.NoTrans, nc, spatialDim, alpha, inputF64s, spatialDim, ssm, 1, 0, nbc, 1) 1306 whichblas.Dgemv(blas.Trans, n, channels, 1, nbc, channels, bsm, 1, 0, meanTmp, 1) 1307 } 1308 1309 // subtract mean 1310 whichblas.Dgemm(blas.NoTrans, blas.NoTrans, n, channels, 1, 1, bsm, 1, meanTmp, channels, 0, nbc, channels) 1311 whichblas.Dgemm(blas.NoTrans, blas.NoTrans, nc, spatialDim, 1, -1, nbc, 1, ssm, spatialDim, 1, outputF64s, spatialDim) 1312 1313 if op.training { 1314 // compute variance using var(X) = E(X-EX)²) 1315 copy(tmp, outputF64s) 1316 vecf64.Mul(tmp, tmp) // (X-EX) ^ 2 1317 1318 whichblas.Dgemv(blas.NoTrans, nc, spatialDim, 1.0/(float64(n*spatialDim)), tmp, spatialDim, ssm, 1, 0, nbc, 1) 1319 whichblas.Dgemv(blas.Trans, n, channels, 1.0, nbc, channels, bsm, 1, 0, varianceTmp, 1) // E((X_EX)²) 1320 1321 // compute and save moving average 1322 op.ma.Float64s()[0] *= momentum 1323 op.ma.Float64s()[0]++ 1324 1325 // TODO: write axpby for gonum 1326 whichblas.Dscal(len(mean), momentum, mean, 1) 1327 whichblas.Daxpy(len(meanTmp), 1.0, meanTmp, 1, mean, 1) 1328 1329 m := len(inputF64s) / channels 1330 correctionFactor := float64(1) 1331 if m > 1 { 1332 correctionFactor = float64(m) / (float64(m - 1)) 1333 } 1334 whichblas.Dscal(len(variance), momentum, variance, 1) 1335 whichblas.Daxpy(len(varianceTmp), correctionFactor, varianceTmp, 1, variance, 1) 1336 } 1337 1338 // normalize variance 1339 vecf64.Trans(varianceTmp, eps) 1340 vecf64.Sqrt(varianceTmp) 1341 1342 // replicate variance to inputsize 1343 whichblas.Dgemm(blas.NoTrans, blas.NoTrans, n, channels, 1, 1, bsm, 1, varianceTmp, channels, 0, nbc, channels) 1344 whichblas.Dgemm(blas.NoTrans, blas.NoTrans, nc, spatialDim, 1, 1, nbc, 1, ssm, spatialDim, 0, tmp, spatialDim) 1345 vecf64.Div(outputF64s, tmp) 1346 copy(op.xNorm.Float64s(), outputF64s) // caching 1347 1348 return nil 1349 } 1350 1351 func (op *BatchNormOp) f32s(input, output *tensor.Dense) (err error) { 1352 n := input.Shape()[0] 1353 channels := input.Shape()[1] 1354 nc := channels * n 1355 spatialDim := input.Shape().TotalSize() / (nc) 1356 1357 inputF32s := input.Float32s() 1358 outputF32s := output.Float32s() 1359 copy(outputF32s, inputF32s) 1360 1361 meanTmp := op.meanTmp.Float32s() 1362 mean := op.mean.Float32s() 1363 varianceTmp := op.varianceTmp.Float32s() 1364 variance := op.variance.Float32s() 1365 tmp := op.tmpSpace.Float32s() 1366 ssm := op.spatialSumMultiplier.Float32s() 1367 nbc := op.numByChans.Float32s() 1368 bsm := op.batchSumMultiplier.Float32s() 1369 1370 momentum := float32(op.momentum) 1371 eps := float32(op.epsilon) 1372 1373 if !op.training { 1374 // use stored mean/variance estimates 1375 scaleFactor := float32(1) 1376 if fst := op.ma.Float32s()[0]; fst != 1 { 1377 scaleFactor = fst 1378 } 1379 copy(meanTmp, mean) 1380 whichblas.Sscal(len(meanTmp), scaleFactor, meanTmp, 1) 1381 copy(varianceTmp, variance) 1382 whichblas.Sscal(len(varianceTmp), scaleFactor, varianceTmp, 1) 1383 } else { 1384 // compute mean 1385 alpha := 1.0 / float32(n*spatialDim) 1386 whichblas.Sgemv(blas.NoTrans, nc, spatialDim, alpha, inputF32s, spatialDim, ssm, 1, 0, nbc, 1) 1387 whichblas.Sgemv(blas.Trans, n, channels, 1, nbc, channels, bsm, 1, 0, meanTmp, 1) 1388 } 1389 1390 // subtract mean 1391 whichblas.Sgemm(blas.NoTrans, blas.NoTrans, n, channels, 1, 1, bsm, 1, meanTmp, channels, 0, nbc, channels) 1392 whichblas.Sgemm(blas.NoTrans, blas.NoTrans, nc, spatialDim, 1, -1, nbc, 1, ssm, spatialDim, 1, outputF32s, spatialDim) 1393 1394 if op.training { 1395 // compute variance using var(X) = E(X-EX)²) 1396 copy(tmp, outputF32s) 1397 vecf32.Mul(tmp, tmp) // (X-EX) ^ 2 1398 1399 whichblas.Sgemv(blas.NoTrans, nc, spatialDim, 1.0/(float32(n*spatialDim)), tmp, spatialDim, ssm, 1, 0, nbc, 1) 1400 whichblas.Sgemv(blas.Trans, n, channels, 1.0, nbc, channels, bsm, 1, 0, varianceTmp, 1) // E((X_EX)²) 1401 1402 // compute and save moving average 1403 op.ma.Float32s()[0] *= momentum 1404 op.ma.Float32s()[0]++ 1405 1406 // TODO: write axpby for gonum 1407 whichblas.Sscal(len(mean), momentum, mean, 1) 1408 whichblas.Saxpy(len(meanTmp), 1.0, meanTmp, 1, mean, 1) 1409 1410 m := len(inputF32s) / channels 1411 correctionFactor := float32(1) 1412 if m > 1 { 1413 correctionFactor = float32(m) / (float32(m - 1)) 1414 } 1415 whichblas.Sscal(len(variance), momentum, variance, 1) 1416 whichblas.Saxpy(len(varianceTmp), correctionFactor, varianceTmp, 1, variance, 1) 1417 } 1418 1419 // normalize variance 1420 vecf32.Trans(varianceTmp, eps) 1421 vecf32.Sqrt(varianceTmp) 1422 1423 // replicate variance to inputsize 1424 whichblas.Sgemm(blas.NoTrans, blas.NoTrans, n, channels, 1, 1, bsm, 1, varianceTmp, channels, 0, nbc, channels) 1425 whichblas.Sgemm(blas.NoTrans, blas.NoTrans, nc, spatialDim, 1, 1, nbc, 1, ssm, spatialDim, 0, tmp, spatialDim) 1426 vecf32.Div(outputF32s, tmp) 1427 copy(op.xNorm.Float32s(), outputF32s) // caching 1428 1429 return nil 1430 } 1431 1432 type batchnormDiffOp struct{ *BatchNormOp } 1433 1434 func (op *batchnormDiffOp) Arity() int { return 2 } 1435 1436 func (op *batchnormDiffOp) Type() hm.Type { 1437 dims := op.dims 1438 if dims == 0 { 1439 dims = 4 1440 } 1441 1442 t := TensorType{Dims: dims, Of: hm.TypeVariable('a')} 1443 return hm.NewFnType(t, t, t) 1444 } 1445 1446 func (op *batchnormDiffOp) InferShape(ns ...DimSizer) (tensor.Shape, error) { 1447 if err := checkArity(op, len(ns)); err != nil { 1448 return nil, errors.Wrapf(err, "batchNorm") 1449 } 1450 1451 return ns[0].(tensor.Shape).Clone(), nil 1452 } 1453 1454 func (op *batchnormDiffOp) Do(values ...Value) (Value, error) { 1455 input := values[0].(*tensor.Dense) 1456 grad := values[1].(*tensor.Dense) 1457 inputGrad := input.Clone().(*tensor.Dense) 1458 return op.UsePreallocDo(inputGrad, input, grad) 1459 } 1460 1461 // ReturnsPtr is the same exact characteristics of batchnorm 1462 // CallsExtern is the same exact characteristics of batchnorm 1463 // OverwritesInput is the same exact characteristics of batchnorm 1464 1465 func (op *batchnormDiffOp) WriteHash(h hash.Hash) { 1466 fmt.Fprintf(h, "batchnormdiff-%1.1f-%1.1f", op.momentum, op.epsilon) 1467 } 1468 1469 func (op *batchnormDiffOp) Hashcode() uint32 { return simpleHash(op) } 1470 1471 func (op *batchnormDiffOp) String() string { 1472 return fmt.Sprintf("batchnormdiff-%1.1f-%1.1f", op.momentum, op.epsilon) 1473 } 1474 1475 func (op *batchnormDiffOp) DiffWRT(inputs int) []bool { 1476 // god help those who want to do 2nd order differentiation on batchnorm 1477 return []bool{false, false} 1478 } 1479 1480 func (op *batchnormDiffOp) SymDiff(inputs Nodes, output *Node, grad *Node) (retVal Nodes, err error) { 1481 // god help those who want to do 2nd order differentiation on batchnorm 1482 return nil, nyi("SymDiff", "batchNormDiffOp") 1483 } 1484 1485 func (op *batchnormDiffOp) DoDiff(ctx ExecutionContext, inputs Nodes, output *Node) error { 1486 // god help those who want to do 2nd order differentiation on batchnorm 1487 return nyi("DoDiff", "batchnormDiffOp") 1488 } 1489 1490 func (op *batchnormDiffOp) UsePreallocDo(prealloc Value, inputs ...Value) (retVal Value, err error) { 1491 input := inputs[0].(*tensor.Dense) 1492 inGrad := prealloc.(*tensor.Dense) 1493 outGrad := inputs[1].(*tensor.Dense) 1494 1495 switch input.Dtype() { 1496 case Float64: 1497 err = op.f64s(input, inGrad, outGrad) 1498 case Float32: 1499 err = op.f32s(input, inGrad, outGrad) 1500 default: 1501 return nil, nyi("batchnormDiffOp", "Do") 1502 } 1503 return prealloc, err 1504 } 1505 1506 func (op *batchnormDiffOp) f64s(input, inGrad, outGrad *tensor.Dense) (err error) { 1507 in := input.Float64s() 1508 ig := inGrad.Float64s() 1509 og := outGrad.Float64s() 1510 tmp := op.tmpSpace.Float64s() 1511 out := op.xNorm.Float64s() 1512 ssm := op.spatialSumMultiplier.Float64s() 1513 nbc := op.numByChans.Float64s() 1514 bsm := op.batchSumMultiplier.Float64s() 1515 meanTmp := op.meanTmp.Float64s() 1516 1517 if !op.training { 1518 copy(ig, og) 1519 vecf64.Div(og, tmp) 1520 return nil 1521 } 1522 1523 n := input.Shape()[0] 1524 channels := input.Shape()[1] 1525 nc := n * channels 1526 spatialDim := len(in) / nc 1527 1528 // if Y = (X-mean(X))/(sqrt(var(X)+eps)), then 1529 // 1530 // dE(Y)/dX = 1531 // (dE/dY - mean(dE/dY) - mean(dE/dY ⋅ Y) ⋅ Y) 1532 // ./ sqrt(var(X) + eps) 1533 // 1534 // where ⋅ and ./ are hadamard product and elementwise division, 1535 // respectively, dE/dY is the top diff, and mean/var/sum are all computed 1536 // along all dimensions except the channels dimension. In the above 1537 // equation, the operations allow for expansion (i.e. broadcast) along all 1538 // dimensions except the channels dimension where required. 1539 1540 // sum(dE/dY ⋅ Y) 1541 copy(ig, out) 1542 vecf64.Mul(ig, og) 1543 whichblas.Dgemv(blas.NoTrans, nc, spatialDim, 1, ig, spatialDim, ssm, 1, 0, nbc, 1) 1544 whichblas.Dgemv(blas.Trans, n, channels, 1, nbc, channels, bsm, 1, 0, meanTmp, 1) 1545 1546 // reshape (broadcast) the above 1547 whichblas.Dgemm(blas.NoTrans, blas.NoTrans, n, channels, 1, 1, bsm, 1, meanTmp, channels, 0, nbc, channels) 1548 whichblas.Dgemm(blas.NoTrans, blas.NoTrans, nc, spatialDim, 1, 1, nbc, 1, ssm, spatialDim, 0, ig, spatialDim) 1549 1550 // sum(dE/dY ⋅ Y) ⋅ Y 1551 vecf64.Mul(ig, out) 1552 1553 // sum(dE/dY)-sum(dE/dY ⋅ Y) ⋅ Y 1554 whichblas.Dgemv(blas.NoTrans, nc, spatialDim, 1, og, spatialDim, ssm, 1, 0, nbc, 1) 1555 whichblas.Dgemv(blas.Trans, n, channels, 1, nbc, channels, bsm, 1, 0, meanTmp, 1) 1556 1557 // reshape (broadcast) the above to make 1558 // sum(dE/dY)-sum(dE/dY ⋅ Y) ⋅ Y 1559 whichblas.Dgemm(blas.NoTrans, blas.NoTrans, n, channels, 1, 1, bsm, 1, meanTmp, channels, 0, nbc, channels) 1560 whichblas.Dgemm(blas.NoTrans, blas.NoTrans, nc, spatialDim, 1, 1, nbc, 1, ssm, spatialDim, 1, ig, spatialDim) 1561 1562 // dE/dY - mean(dE/dY)-mean(dE/dY ⋅ Y) ⋅ Y 1563 beta := (-1.0 / float64(nc)) 1564 1565 vecf64.Scale(ig, beta) 1566 vecf64.Add(ig, og) 1567 1568 // note: temp_ still contains sqrt(var(X)+eps), computed during the forward 1569 // pass. 1570 vecf64.Div(ig, tmp) 1571 return nil 1572 1573 } 1574 1575 func (op *batchnormDiffOp) f32s(input, inGrad, outGrad *tensor.Dense) (err error) { 1576 in := input.Float32s() 1577 ig := inGrad.Float32s() 1578 og := outGrad.Float32s() 1579 tmp := op.tmpSpace.Float32s() 1580 out := op.xNorm.Float32s() 1581 ssm := op.spatialSumMultiplier.Float32s() 1582 nbc := op.numByChans.Float32s() 1583 bsm := op.batchSumMultiplier.Float32s() 1584 meanTmp := op.meanTmp.Float32s() 1585 1586 if !op.training { 1587 copy(ig, og) 1588 vecf32.Div(og, tmp) 1589 return nil 1590 } 1591 1592 n := input.Shape()[0] 1593 channels := input.Shape()[1] 1594 nc := n * channels 1595 spatialDim := len(in) / nc 1596 1597 // if Y = (X-mean(X))/(sqrt(var(X)+eps)), then 1598 // 1599 // dE(Y)/dX = 1600 // (dE/dY - mean(dE/dY) - mean(dE/dY ⋅ Y) ⋅ Y) 1601 // ./ sqrt(var(X) + eps) 1602 // 1603 // where ⋅ and ./ are hadamard product and elementwise division, 1604 // respectively, dE/dY is the top diff, and mean/var/sum are all computed 1605 // along all dimensions except the channels dimension. In the above 1606 // equation, the operations allow for expansion (i.e. broadcast) along all 1607 // dimensions except the channels dimension where required. 1608 1609 // sum(dE/dY ⋅ Y) 1610 copy(ig, out) 1611 vecf32.Mul(ig, og) 1612 whichblas.Sgemv(blas.NoTrans, nc, spatialDim, 1, ig, spatialDim, ssm, 1, 0, nbc, 1) 1613 whichblas.Sgemv(blas.Trans, n, channels, 1, nbc, channels, bsm, 1, 0, meanTmp, 1) 1614 1615 // reshape (broadcast) the above 1616 whichblas.Sgemm(blas.NoTrans, blas.NoTrans, n, channels, 1, 1, bsm, 1, meanTmp, channels, 0, nbc, channels) 1617 whichblas.Sgemm(blas.NoTrans, blas.NoTrans, nc, spatialDim, 1, 1, nbc, 1, ssm, spatialDim, 0, ig, spatialDim) 1618 1619 // sum(dE/dY ⋅ Y) ⋅ Y 1620 vecf32.Mul(ig, out) 1621 1622 // sum(dE/dY)-sum(dE/dY ⋅ Y) ⋅ Y 1623 whichblas.Sgemv(blas.NoTrans, nc, spatialDim, 1, og, spatialDim, ssm, 1, 0, nbc, 1) 1624 whichblas.Sgemv(blas.Trans, n, channels, 1, nbc, channels, bsm, 1, 0, meanTmp, 1) 1625 1626 // reshape (broadcast) the above to make 1627 // sum(dE/dY)-sum(dE/dY ⋅ Y) ⋅ Y 1628 whichblas.Sgemm(blas.NoTrans, blas.NoTrans, n, channels, 1, 1, bsm, 1, meanTmp, channels, 0, nbc, channels) 1629 whichblas.Sgemm(blas.NoTrans, blas.NoTrans, nc, spatialDim, 1, 1, nbc, 1, ssm, spatialDim, 1, ig, spatialDim) 1630 1631 // dE/dY - mean(dE/dY)-mean(dE/dY ⋅ Y) ⋅ Y 1632 beta := (-1.0 / float32(n*spatialDim)) 1633 vecf32.Scale(ig, beta) 1634 vecf32.Add(ig, og) 1635 1636 // note: temp_ still contains sqrt(var(X)+eps), computed during the forward 1637 // pass. 1638 vecf32.Div(ig, tmp) 1639 return nil 1640 1641 } 1642 1643 type globalAveragePoolOp struct{} 1644 1645 func (g *globalAveragePoolOp) Arity() int { 1646 return 1 1647 } 1648 1649 func (g *globalAveragePoolOp) Type() hm.Type { 1650 a := hm.TypeVariable('a') 1651 t := newTensorType(4, a) 1652 return hm.NewFnType(t, t) 1653 } 1654 1655 func (g *globalAveragePoolOp) InferShape(inputs ...DimSizer) (tensor.Shape, error) { 1656 b, err := inputs[0].DimSize(0) 1657 if err != nil { 1658 return nil, err 1659 } 1660 c, err := inputs[0].DimSize(1) 1661 if err != nil { 1662 return nil, err 1663 } 1664 // check if the shape is correct without doing type inference 1665 if _, err := inputs[0].DimSize(2); err != nil { 1666 return nil, err 1667 } 1668 if _, err := inputs[0].DimSize(3); err != nil { 1669 return nil, err 1670 } 1671 return tensor.Shape{b, c, 1, 1}, nil 1672 } 1673 1674 func (g *globalAveragePoolOp) Do(inputs ...Value) (Value, error) { 1675 im := inputs[0] 1676 switch im.(type) { 1677 case tensor.Tensor: 1678 v := im.(tensor.Tensor) 1679 B, C, H, W := v.Shape()[0], v.Shape()[1], v.Shape()[2], v.Shape()[3] 1680 s, err := g.InferShape(v.Shape()) 1681 if err != nil { 1682 return nil, err 1683 } 1684 output := tensor.New(tensor.Of(v.Dtype()), tensor.WithShape(s...)) 1685 switch v.Dtype() { 1686 case tensor.Float64: 1687 for b := 0; b < B; b++ { 1688 for c := 0; c < C; c++ { 1689 var sum float64 1690 for h := 0; h < H; h++ { 1691 for w := 0; w < W; w++ { 1692 val, err := v.At(b, c, h, w) 1693 if err != nil { 1694 return nil, err 1695 } 1696 sum += val.(float64) 1697 } 1698 } 1699 err := output.SetAt(sum/float64(H*W), b, c, 0, 0) 1700 if err != nil { 1701 return nil, err 1702 } 1703 } 1704 } 1705 case tensor.Float32: 1706 for b := 0; b < B; b++ { 1707 for c := 0; c < C; c++ { 1708 var sum float32 1709 for h := 0; h < H; h++ { 1710 for w := 0; w < W; w++ { 1711 val, err := v.At(b, c, h, w) 1712 if err != nil { 1713 return nil, err 1714 } 1715 sum += val.(float32) 1716 } 1717 } 1718 err := output.SetAt(sum/float32(H*W), b, c, 0, 0) 1719 if err != nil { 1720 return nil, err 1721 } 1722 } 1723 } 1724 default: 1725 return nil, nyi("Global Average Pool", v.Dtype()) 1726 } 1727 1728 return output, nil 1729 1730 default: 1731 return nil, nyi("globalAveragePoolOp", inputs) 1732 } 1733 } 1734 1735 func (g *globalAveragePoolOp) ReturnsPtr() bool { 1736 return false 1737 } 1738 1739 func (g *globalAveragePoolOp) CallsExtern() bool { 1740 return false 1741 } 1742 1743 func (g *globalAveragePoolOp) OverwritesInput() int { 1744 return -1 1745 } 1746 1747 func (g *globalAveragePoolOp) WriteHash(h hash.Hash) { 1748 fmt.Fprintf(h, "GlobalAveragePool") 1749 } 1750 1751 func (g *globalAveragePoolOp) Hashcode() uint32 { 1752 return simpleHash(g) 1753 } 1754 1755 func (g *globalAveragePoolOp) String() string { 1756 return "GlobalAveragePool" 1757 }