gorgonia.org/gorgonia@v0.9.17/op_nn.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"fmt"
     5  	"hash"
     6  	"time"
     7  
     8  	"github.com/chewxy/hm"
     9  	rng "github.com/leesper/go_rng"
    10  	"github.com/pkg/errors"
    11  	"gonum.org/v1/gonum/blas"
    12  	"gorgonia.org/tensor"
    13  	"gorgonia.org/vecf32"
    14  	"gorgonia.org/vecf64"
    15  )
    16  
    17  // Sanity checks
    18  var (
    19  	_ SDOp = im2colOp{}
    20  	_ Op   = col2imOp{}
    21  	_ Op   = &maxPoolOp{}
    22  	_ Op   = &maxPoolDiffOp{}
    23  	_ Op   = &BatchNormOp{}
    24  	_ Op   = &batchnormDiffOp{}
    25  	_ Op   = &globalAveragePoolOp{}
    26  )
    27  
    28  /*
    29  	This file contains all the Ops related to building a neural network.
    30  
    31  	Bear in mind that not all things that are related to a neural network are here, as not everything
    32  	are encoded as Ops the way theano does it.
    33  
    34  	See also: nn.go for functions that relate to neural networks
    35  */
    36  
    37  type randomness byte
    38  
    39  const (
    40  	uniform randomness = iota
    41  	gaussian
    42  	binomial
    43  )
    44  
    45  type randomOp struct {
    46  	which randomness
    47  	shape tensor.Shape
    48  	dt    tensor.Dtype
    49  
    50  	a, b float64 // when uniform, a,b = low, high; when gaussian, a,b = mean, stdev
    51  }
    52  
    53  func makeRandomOp(which randomness, dt tensor.Dtype, a, b float64, shape ...int) randomOp {
    54  	return randomOp{
    55  		which: which,
    56  		shape: tensor.Shape(shape),
    57  		dt:    dt,
    58  		a:     a,
    59  		b:     b,
    60  	}
    61  }
    62  
    63  func (op randomOp) Arity() int { return 0 }
    64  
    65  // randomOp :: a
    66  // randomOp :: Tensor a
    67  func (op randomOp) Type() hm.Type {
    68  	if op.shape.IsScalar() {
    69  		return op.dt
    70  	}
    71  	tt := newTensorType(op.shape.Dims(), op.dt)
    72  	return tt
    73  }
    74  
    75  func (op randomOp) InferShape(...DimSizer) (tensor.Shape, error) { return op.shape, nil }
    76  
    77  func (op randomOp) Do(...Value) (retVal Value, err error) {
    78  	if op.shape.IsScalar() {
    79  		var v interface{}
    80  		switch op.dt {
    81  		case Float64:
    82  			switch op.which {
    83  			case uniform:
    84  				rand := rng.NewUniformGenerator(time.Now().UnixNano())
    85  				v = rand.Float64Range(op.a, op.b)
    86  			case gaussian:
    87  				rand := rng.NewGaussianGenerator(time.Now().UnixNano())
    88  				v = rand.Gaussian(op.a, op.b)
    89  			case binomial:
    90  				rand := rng.NewBinomialGenerator(time.Now().UnixNano())
    91  				v = float64(rand.Binomial(int64(op.a), op.b))
    92  			}
    93  		case Float32:
    94  			switch op.which {
    95  			case uniform:
    96  				rand := rng.NewUniformGenerator(time.Now().UnixNano())
    97  				v = rand.Float32Range(float32(op.a), float32(op.b))
    98  			case gaussian:
    99  				rand := rng.NewGaussianGenerator(time.Now().UnixNano())
   100  				v = float32(rand.Gaussian(op.a, op.b))
   101  			case binomial:
   102  				rand := rng.NewBinomialGenerator(time.Now().UnixNano())
   103  				v = float32(rand.Binomial(int64(op.a), op.b))
   104  			}
   105  		default:
   106  			return nil, errors.Errorf(nyiFail, "randomOp.do()", op.dt)
   107  		}
   108  
   109  		retVal, _ = anyToScalar(v)
   110  		return
   111  	}
   112  
   113  	switch op.dt {
   114  	case Float64:
   115  		switch op.which {
   116  		case uniform:
   117  			backing := Uniform64(op.a, op.b, op.shape...)
   118  			retVal = tensor.New(tensor.WithBacking(backing), tensor.WithShape(op.shape...))
   119  		case gaussian:
   120  			backing := Gaussian64(op.a, op.b, op.shape...)
   121  			retVal = tensor.New(tensor.WithBacking(backing), tensor.WithShape(op.shape...))
   122  		case binomial:
   123  			backing := Binomial64(op.a, op.b, op.shape...)
   124  			retVal = tensor.New(tensor.WithBacking(backing), tensor.WithShape(op.shape...))
   125  		}
   126  		return
   127  	case Float32:
   128  		switch op.which {
   129  		case uniform:
   130  			backing := Uniform32(op.a, op.b, op.shape...)
   131  			retVal = tensor.New(tensor.WithBacking(backing), tensor.WithShape(op.shape...))
   132  		case gaussian:
   133  			backing := Gaussian32(op.a, op.b, op.shape...)
   134  			retVal = tensor.New(tensor.WithBacking(backing), tensor.WithShape(op.shape...))
   135  		case binomial:
   136  			backing := Binomial32(op.a, op.b, op.shape...)
   137  			retVal = tensor.New(tensor.WithBacking(backing), tensor.WithShape(op.shape...))
   138  		}
   139  		return
   140  	default:
   141  		return nil, errors.Errorf(nyiFail, "randomOp.do() for non-scalar", op.dt)
   142  	}
   143  }
   144  
   145  func (op randomOp) ReturnsPtr() bool     { return false }
   146  func (op randomOp) CallsExtern() bool    { return false }
   147  func (op randomOp) OverwritesInput() int { return -1 }
   148  func (op randomOp) WriteHash(h hash.Hash) {
   149  	fmt.Fprintf(h, "%d%v%f%f", op.which, op.shape, op.a, op.b)
   150  }
   151  
   152  func (op randomOp) Hashcode() uint32 { return simpleHash(op) }
   153  
   154  func (op randomOp) String() string {
   155  	return fmt.Sprintf("%v(%v, %v) - %v", op.which, op.a, op.b, op.shape)
   156  }
   157  
   158  type im2colOp struct {
   159  	h, w                 int // kernel height and width
   160  	padH, padW           int
   161  	strideH, strideW     int
   162  	dilationH, dilationW int
   163  }
   164  
   165  func makeIm2ColOp(kernelHeight, kernelWidth, padHeight, padWidth, strideHeight, strideWidth, dilationHeight, dilationWidth int) im2colOp {
   166  	return im2colOp{
   167  		h:         kernelHeight,
   168  		w:         kernelWidth,
   169  		padH:      padHeight,
   170  		padW:      padWidth,
   171  		strideH:   strideHeight,
   172  		strideW:   strideWidth,
   173  		dilationH: dilationHeight,
   174  		dilationW: dilationWidth,
   175  	}
   176  }
   177  
   178  func (op im2colOp) Arity() int { return 1 }
   179  
   180  // im2col :: (Floats a) ⇒ Tensor a →  Tensor a
   181  func (op im2colOp) Type() hm.Type {
   182  	t := makeTensorType(4, hm.TypeVariable('a'))
   183  	return hm.NewFnType(t, t)
   184  }
   185  
   186  func (op im2colOp) InferShape(shapes ...DimSizer) (retVal tensor.Shape, err error) {
   187  	if err = checkArity(op, len(shapes)); err != nil {
   188  		return
   189  	}
   190  
   191  	if s, ok := shapes[0].(tensor.Shape); ok {
   192  		return op.calcShape(s), nil
   193  	}
   194  	return nil, errors.Errorf("expected tensor.Shape. got %T instead", shapes[0])
   195  }
   196  
   197  func (op im2colOp) Do(inputs ...Value) (retVal Value, err error) {
   198  	if err = checkArity(op, len(inputs)); err != nil {
   199  		return
   200  	}
   201  
   202  	im := inputs[0]
   203  
   204  	// todo type check values
   205  	// todo shape check values
   206  
   207  	retShape := op.calcShape(im.Shape())
   208  	prealloc := tensor.New(tensor.Of(im.Dtype()), tensor.WithShape(retShape...))
   209  
   210  	return op.do(prealloc, im)
   211  }
   212  
   213  func (op im2colOp) ReturnsPtr() bool     { return false }
   214  func (op im2colOp) CallsExtern() bool    { return false }
   215  func (op im2colOp) OverwritesInput() int { return -1 }
   216  
   217  func (op im2colOp) WriteHash(h hash.Hash) {
   218  	fmt.Fprintf(h, "im2col:%d-%d-%d-%d-%d-%d", op.h, op.w, op.padH, op.padW, op.strideH, op.strideW)
   219  }
   220  
   221  func (op im2colOp) Hashcode() uint32 { return simpleHash(op) }
   222  
   223  func (op im2colOp) String() string {
   224  	return fmt.Sprintf("im2col<(%d,%d), (%d, %d), (%d,%d) (%d, %d)>", op.h, op.w, op.padH, op.padW, op.strideH, op.strideW, op.dilationH, op.dilationW)
   225  }
   226  
   227  func (op im2colOp) DiffWRT(i int) []bool { return []bool{true} }
   228  
   229  func (op im2colOp) SymDiff(inputs Nodes, output, grad *Node) (retVal Nodes, err error) {
   230  	if err = checkArity(op, len(inputs)); err != nil {
   231  		return
   232  	}
   233  	im := inputs[0]
   234  	s := im.Shape()
   235  	if s.Dims() != 4 {
   236  		return nil, errors.Errorf("Expected input to have a shape with 4 dims")
   237  	}
   238  	var unpaddedB, unpaddedC, unpaddedH, unpaddedW int
   239  	unpaddedB, unpaddedC, unpaddedH, unpaddedW = s[0], s[1], s[2], s[3]
   240  	diffOp := col2imOp{
   241  		unpaddedB: unpaddedB,
   242  		unpaddedC: unpaddedC,
   243  		unpaddedH: unpaddedH,
   244  		unpaddedW: unpaddedW,
   245  
   246  		im2colOp: op,
   247  	}
   248  
   249  	var ret *Node
   250  	if ret, err = ApplyOp(diffOp, grad); err != nil {
   251  		return
   252  	}
   253  	retVal = Nodes{ret}
   254  	return
   255  }
   256  
   257  func (op im2colOp) DoDiff(ctx ExecutionContext, inputs Nodes, output *Node) (err error) {
   258  	if err = checkArity(op, len(inputs)); err != nil {
   259  		return
   260  	}
   261  
   262  	im := inputs[0]
   263  	s := im.Shape()
   264  	imv, colv := getDV(im, output)
   265  
   266  	var unpaddedB, unpaddedC, unpaddedH, unpaddedW int
   267  	unpaddedB, unpaddedC, unpaddedH, unpaddedW = s[0], s[1], s[2], s[3]
   268  	diffOp := col2imOp{
   269  		unpaddedB: unpaddedB,
   270  		unpaddedC: unpaddedC,
   271  		unpaddedH: unpaddedH,
   272  		unpaddedW: unpaddedW,
   273  
   274  		im2colOp: op,
   275  	}
   276  
   277  	if _, err = diffOp.UsePreallocDo(imv.d, colv.d); err != nil {
   278  		return errors.Wrapf(err, doFail, diffOp)
   279  	}
   280  	return
   281  }
   282  
   283  func (op im2colOp) calcShape(s tensor.Shape) (retVal tensor.Shape) {
   284  	b := s[0]
   285  	c := s[1]
   286  	h := s[2]
   287  	w := s[3]
   288  
   289  	retHeight, retWidth := op.retHW(h, w)
   290  	retVal = tensor.Shape(tensor.BorrowInts(4))
   291  
   292  	// todo: double check this with tests
   293  	retVal[0] = b
   294  	retVal[1] = retHeight
   295  	retVal[2] = retWidth
   296  	retVal[3] = c * op.w * op.h
   297  
   298  	return
   299  }
   300  
   301  func (op im2colOp) retHW(h, w int) (retHeight, retWidth int) {
   302  	retHeight = (h+2*op.padH-(op.dilationH*(op.h-1)+1))/op.strideH + 1
   303  	retWidth = (w+2*op.padW-(op.dilationW*(op.w-1)+1))/op.strideW + 1
   304  	return
   305  }
   306  
   307  func (op im2colOp) do(prealloc, input Value) (retVal Value, err error) {
   308  	inputT := input.(*tensor.Dense)
   309  	outputT := prealloc.(*tensor.Dense)
   310  
   311  	// extract bchw - this bit can be expanded in the future, but for now we only support bchw
   312  	s := inputT.Shape()
   313  	b := s[0]
   314  	c := s[1]
   315  	h := s[2]
   316  	w := s[3]
   317  
   318  	inputStrides := inputT.Strides()
   319  	retHeight, retWidth := op.retHW(h, w)
   320  	batchStrideIm := inputStrides[0]
   321  	batchStrideCol := outputT.Strides()[0]
   322  	chanStride := h * w
   323  	inRowStride := inputStrides[2]
   324  
   325  	switch input.Dtype() {
   326  	case tensor.Float64:
   327  		imData := input.Data().([]float64)
   328  		colData := prealloc.Data().([]float64)
   329  		for i := 0; i < b; i++ {
   330  			imStart := i * batchStrideIm
   331  			colStart := i * batchStrideCol
   332  			imEnd := imStart + batchStrideIm
   333  			colEnd := colStart + batchStrideCol
   334  
   335  			if imEnd >= len(imData) {
   336  				imEnd = len(imData)
   337  			}
   338  			if colEnd >= len(colData) {
   339  				colEnd = len(colData)
   340  			}
   341  
   342  			op.f64s(c, h, w, chanStride, inRowStride, retHeight, retWidth, imData[imStart:imEnd], colData[colStart:colEnd])
   343  		}
   344  	case tensor.Float32:
   345  		imData := input.Data().([]float32)
   346  		colData := prealloc.Data().([]float32)
   347  		for i := 0; i < b; i++ {
   348  			imStart := i * batchStrideIm
   349  			colStart := i * batchStrideCol
   350  			imEnd := imStart + batchStrideIm
   351  			colEnd := colStart + batchStrideCol
   352  
   353  			if imEnd >= len(imData) {
   354  				imEnd = len(imData)
   355  			}
   356  			if colEnd >= len(colData) {
   357  				colEnd = len(colData)
   358  			}
   359  
   360  			op.f32s(c, h, w, chanStride, inRowStride, retHeight, retWidth, imData[imStart:imEnd], colData[colStart:colEnd])
   361  		}
   362  	default:
   363  		return nil, errors.Errorf(nyiFail, "im2col", input.Dtype())
   364  	}
   365  	return prealloc, nil
   366  }
   367  
   368  func (op im2colOp) f64s(chans, height, width, chanStride, inRowStride, retHeight, retWidth int, im, col []float64) {
   369  	colIdx := 0
   370  	var inputRow int
   371  	var inputCol int
   372  	for outputRow := 0; outputRow < retHeight; outputRow++ {
   373  		for outputCol := 0; outputCol < retWidth; outputCol++ {
   374  			for ch := 0; ch < chans; ch++ {
   375  				for kernelRow := 0; kernelRow < op.h; kernelRow++ {
   376  					inputRow = -op.padH + kernelRow*op.dilationH + outputRow*op.strideH
   377  					for kernelCol := 0; kernelCol < op.w; kernelCol++ {
   378  						if inputRow < 0 || inputRow >= height {
   379  							col[colIdx] = 0
   380  							colIdx++
   381  							continue
   382  						}
   383  						inputCol = -op.padW + kernelCol*op.dilationW + outputCol*op.strideW
   384  						if inputCol < 0 || inputCol >= width {
   385  							col[colIdx] = 0
   386  							colIdx++
   387  						} else {
   388  							imIdx := chanStride*ch + inputRow*width + inputCol
   389  							col[colIdx] = im[imIdx]
   390  							colIdx++
   391  						}
   392  					}
   393  				}
   394  			}
   395  		}
   396  	}
   397  }
   398  
   399  func (op im2colOp) f32s(chans, height, width, chanStride, inRowStride, retHeight, retWidth int, im, col []float32) {
   400  	colIdx := 0
   401  	var inputRow int
   402  	var inputCol int
   403  	for outputRow := 0; outputRow < retHeight; outputRow++ {
   404  		for outputCol := 0; outputCol < retWidth; outputCol++ {
   405  			for ch := 0; ch < chans; ch++ {
   406  				for kernelRow := 0; kernelRow < op.h; kernelRow++ {
   407  					inputRow = -op.padH + kernelRow*op.dilationH + outputRow*op.strideH
   408  					for kernelCol := 0; kernelCol < op.w; kernelCol++ {
   409  						if inputRow < 0 || inputRow >= height {
   410  							col[colIdx] = 0
   411  							colIdx++
   412  							continue
   413  						}
   414  						inputCol = -op.padW + kernelCol*op.dilationW + outputCol*op.strideW
   415  						if inputCol < 0 || inputCol >= width {
   416  							col[colIdx] = 0
   417  							colIdx++
   418  						} else {
   419  							imIdx := chanStride*ch + inputRow*width + inputCol
   420  							col[colIdx] = im[imIdx]
   421  							colIdx++
   422  						}
   423  					}
   424  				}
   425  			}
   426  		}
   427  	}
   428  }
   429  
   430  type col2imOp struct {
   431  	// input shapes of im2col
   432  	unpaddedB int
   433  	unpaddedC int
   434  	unpaddedH int
   435  	unpaddedW int
   436  
   437  	im2colOp
   438  }
   439  
   440  func (op col2imOp) Arity() int { return 1 }
   441  
   442  // im2col :: (Floats a) ⇒ a →  a
   443  func (op col2imOp) Type() hm.Type {
   444  	return hm.NewFnType(hm.TypeVariable('a'), hm.TypeVariable('a'))
   445  }
   446  
   447  func (op col2imOp) InferShape(shapes ...DimSizer) (retVal tensor.Shape, err error) {
   448  	return tensor.Shape{op.unpaddedB, op.unpaddedC, op.unpaddedH, op.unpaddedW}, nil
   449  }
   450  
   451  func (op col2imOp) Do(inputs ...Value) (retVal Value, err error) {
   452  	if err = checkArity(op, len(inputs)); err != nil {
   453  		return
   454  	}
   455  
   456  	im := inputs[0]
   457  
   458  	// todo type check values
   459  	// todo shape check values
   460  
   461  	retShape := tensor.Shape{op.unpaddedB, op.unpaddedC, op.unpaddedH, op.unpaddedW}
   462  	prealloc := tensor.New(tensor.Of(im.Dtype()), tensor.WithShape(retShape...))
   463  
   464  	return op.do(prealloc, im)
   465  }
   466  
   467  func (op col2imOp) ReturnsPtr() bool     { return false }
   468  func (op col2imOp) CallsExtern() bool    { return false }
   469  func (op col2imOp) OverwritesInput() int { return -1 }
   470  
   471  func (op col2imOp) WriteHash(h hash.Hash) {
   472  	fmt.Fprintf(h, "col2im:%d-%d-%d-%d-%d-%d", op.h, op.w, op.padH, op.padW, op.strideH, op.strideW)
   473  }
   474  
   475  func (op col2imOp) Hashcode() uint32 { return simpleHash(op) }
   476  
   477  func (op col2imOp) String() string {
   478  	return fmt.Sprintf("col2im<(%d,%d), (%d, %d), (%d,%d)>", op.h, op.w, op.padH, op.padW, op.strideH, op.strideW)
   479  }
   480  
   481  func (op col2imOp) UsePreallocDo(prealloc Value, inputs ...Value) (Value, error) {
   482  	if err := checkArity(op, len(inputs)); err != nil {
   483  		return nil, err
   484  	}
   485  	return op.do(prealloc, inputs[0])
   486  }
   487  
   488  func (op col2imOp) do(prealloc, input Value) (retVal Value, err error) {
   489  	b := op.unpaddedB
   490  	c := op.unpaddedC
   491  	retHeight := op.unpaddedH
   492  	retWidth := op.unpaddedW
   493  	batchStrideIm := c * retHeight * retWidth
   494  
   495  	s := input.Shape()
   496  	h := s[1]
   497  	w := s[2]
   498  	chanStride := retHeight * retWidth
   499  	batchStrideCol := h * w * s[3]
   500  
   501  	var imStart, imEnd, colStart, colEnd int
   502  	imEnd = imStart + batchStrideIm
   503  	colEnd = colStart + batchStrideCol
   504  
   505  	switch input.Dtype() {
   506  	case tensor.Float64:
   507  		colData := input.Data().([]float64)
   508  		imData := prealloc.Data().([]float64)
   509  		for i := 0; i < b; i++ {
   510  			op.f64s(c, retHeight, retWidth, chanStride, h, w, colData[colStart:colEnd], imData[imStart:imEnd])
   511  
   512  			colStart += batchStrideCol
   513  			colEnd += batchStrideCol
   514  
   515  			imStart += batchStrideIm
   516  			imEnd += batchStrideIm
   517  
   518  			if imEnd > len(imData) {
   519  				imEnd = len(imData)
   520  			}
   521  			if colEnd > len(colData) {
   522  				colEnd = len(colData)
   523  			}
   524  		}
   525  	case tensor.Float32:
   526  		colData := input.Data().([]float32)
   527  		imData := prealloc.Data().([]float32)
   528  		for i := 0; i < b; i++ {
   529  			op.f32s(c, retHeight, retWidth, chanStride, h, w, colData[colStart:colEnd], imData[imStart:imEnd])
   530  
   531  			colStart += batchStrideCol
   532  			colEnd += batchStrideCol
   533  
   534  			imStart += batchStrideIm
   535  			imEnd += batchStrideIm
   536  
   537  			if imEnd > len(imData) {
   538  				imEnd = len(imData)
   539  			}
   540  			if colEnd > len(colData) {
   541  				colEnd = len(colData)
   542  			}
   543  		}
   544  	default:
   545  		return nil, errors.Errorf(nyiFail, "col2im", input.Dtype())
   546  	}
   547  
   548  	return prealloc, nil
   549  }
   550  
   551  func (op col2imOp) f64s(chans, height, width, chanStride, retHeight, retWidth int, col, im []float64) {
   552  	// memset im to 0
   553  	for i := 0; i < len(im); i++ {
   554  		im[i] = 0
   555  	}
   556  	colIdx := 0
   557  	var inputRow int
   558  	var inputCol int
   559  	for outputRow := 0; outputRow < retHeight; outputRow++ {
   560  		for outputCol := 0; outputCol < retWidth; outputCol++ {
   561  			for ch := 0; ch < chans; ch++ {
   562  				for kernelRow := 0; kernelRow < op.h; kernelRow++ {
   563  					inputRow = -op.padH + kernelRow*op.dilationH + outputRow*op.strideH
   564  					for kernelCol := 0; kernelCol < op.w; kernelCol++ {
   565  						if inputRow < 0 || inputRow >= height {
   566  							colIdx++
   567  							continue
   568  						}
   569  						inputCol = -op.padW + kernelCol*op.dilationW + outputCol*op.strideW
   570  						if inputCol >= 0 && inputCol < width {
   571  							imIdx := chanStride*ch + inputRow*width + inputCol
   572  							im[imIdx] += col[colIdx]
   573  						}
   574  						colIdx++
   575  					}
   576  				}
   577  			}
   578  		}
   579  	}
   580  }
   581  
   582  func (op col2imOp) f32s(chans, height, width, chanStride, retHeight, retWidth int, col, im []float32) {
   583  	// memset im to 0
   584  	for i := 0; i < len(im); i++ {
   585  		im[i] = 0
   586  	}
   587  	colIdx := 0
   588  	var inputRow int
   589  	var inputCol int
   590  	for outputRow := 0; outputRow < retHeight; outputRow++ {
   591  		for outputCol := 0; outputCol < retWidth; outputCol++ {
   592  			for ch := 0; ch < chans; ch++ {
   593  				for kernelRow := 0; kernelRow < op.h; kernelRow++ {
   594  					inputRow = -op.padH + kernelRow*op.dilationH + outputRow*op.strideH
   595  					for kernelCol := 0; kernelCol < op.w; kernelCol++ {
   596  						if inputRow < 0 || inputRow >= height {
   597  							colIdx++
   598  							continue
   599  						}
   600  						inputCol = -op.padW + kernelCol*op.dilationW + outputCol*op.strideW
   601  						if inputCol >= 0 && inputCol < width {
   602  							imIdx := chanStride*ch + inputRow*width + inputCol
   603  							im[imIdx] += col[colIdx]
   604  						}
   605  						colIdx++
   606  					}
   607  				}
   608  			}
   609  		}
   610  	}
   611  }
   612  
   613  // It's important to note that this op actually produces TWO values - one argmax, which will be used
   614  // as a mask, and the actual pooled value.
   615  //
   616  // The argmax is stored as an internal state and is not exposed to anything outside the op.
   617  // There are alternative ways of designing this op, but they all don't particularly seem nice.
   618  // Caffe's technique seemed the nicest.
   619  type maxPoolOp struct {
   620  	// Shape of Input
   621  	unpaddedB int
   622  	unpaddedC int
   623  	unpaddedH int
   624  	unpaddedW int
   625  
   626  	h, w              int // patch height and width
   627  	padNorth, padWest int
   628  	padSouth, padEast int
   629  	explicitPadding   bool
   630  	strideH, strideW  int
   631  
   632  	// execution state
   633  	// the mask is only filled at execution time
   634  	mask tensor.Tensor
   635  }
   636  
   637  func newMaxPoolOp(inputShape, kernel tensor.Shape, pad, stride []int) *maxPoolOp {
   638  	padNorth := pad[0]
   639  	padWest := pad[1]
   640  	padSouth := pad[0]
   641  	padEast := pad[1]
   642  	explicitPadding := false
   643  	if len(pad) == 4 {
   644  		explicitPadding = true
   645  		padNorth = pad[0]
   646  		padSouth = pad[1]
   647  		padWest = pad[2]
   648  		padEast = pad[3]
   649  	}
   650  	maxpoolOp := &maxPoolOp{
   651  		// Shape of Input
   652  		unpaddedB: inputShape[0],
   653  		unpaddedC: inputShape[1],
   654  		unpaddedH: inputShape[2],
   655  		unpaddedW: inputShape[3],
   656  
   657  		h:               kernel[0],
   658  		w:               kernel[1],
   659  		padNorth:        padNorth,
   660  		padWest:         padWest,
   661  		padSouth:        padSouth,
   662  		padEast:         padEast,
   663  		explicitPadding: explicitPadding,
   664  		strideH:         stride[0],
   665  		strideW:         stride[1],
   666  	}
   667  	maxpoolOp.mask = tensor.New(tensor.Of(tensor.Int), tensor.WithShape(maxpoolOp.calcShape(inputShape)...))
   668  	return maxpoolOp
   669  }
   670  
   671  func (op *maxPoolOp) Arity() int { return 1 }
   672  
   673  // maxPoolOp has this type:
   674  // 		op :: (...) → (...)
   675  func (op *maxPoolOp) Type() hm.Type {
   676  	a := hm.TypeVariable('a')
   677  	t := newTensorType(4, a)
   678  	return hm.NewFnType(t, t)
   679  }
   680  func (op *maxPoolOp) InferShape(inputs ...DimSizer) (tensor.Shape, error) {
   681  	if s, ok := inputs[0].(tensor.Shape); ok {
   682  		return op.calcShape(s), nil
   683  	}
   684  	return nil, errors.Errorf("Expected a shape")
   685  }
   686  
   687  func (op *maxPoolOp) Do(inputs ...Value) (retVal Value, err error) {
   688  	var in, out tensor.Tensor
   689  	if in, err = op.checkInput(inputs...); err != nil {
   690  		return nil, err
   691  	}
   692  	inShp := in.Shape()
   693  	out = tensor.New(tensor.Of(in.Dtype()), tensor.WithShape(op.calcShape(inShp)...), tensor.WithEngine(in.Engine()))
   694  	op.do(out, in)
   695  	return out, nil
   696  }
   697  
   698  func (op *maxPoolOp) ReturnsPtr() bool     { return false }
   699  func (op *maxPoolOp) CallsExtern() bool    { return false }
   700  func (op *maxPoolOp) OverwritesInput() int { return -1 }
   701  func (op *maxPoolOp) WriteHash(h hash.Hash) {
   702  	fmt.Fprintf(h, "MaxPool{%d, %d, %d, %d}(kernel: (%d, %d), pad: (%d, %d), stride: (%d, %d))",
   703  		op.unpaddedB, op.unpaddedC, op.unpaddedH, op.unpaddedW,
   704  		op.h, op.w, op.padNorth, op.padWest, op.strideH, op.strideW)
   705  }
   706  
   707  func (op *maxPoolOp) Hashcode() uint32 { return simpleHash(op) }
   708  
   709  func (op *maxPoolOp) String() string {
   710  	return fmt.Sprintf("MaxPool{%d, %d, %d, %d}(kernel: (%d, %d), pad: (%d, %d), stride: (%d, %d))",
   711  		op.unpaddedB, op.unpaddedC, op.unpaddedH, op.unpaddedW,
   712  		op.h, op.w, op.padNorth, op.padWest, op.strideH, op.strideW)
   713  }
   714  
   715  func (op *maxPoolOp) UsePreallocDo(prealloc Value, inputs ...Value) (Value, error) {
   716  	var in tensor.Tensor
   717  	var err error
   718  	if in, err = op.checkInput(inputs...); err != nil {
   719  		return nil, err
   720  	}
   721  
   722  	if p, ok := prealloc.(tensor.Tensor); ok {
   723  		op.do(p, in)
   724  		return p, nil
   725  	}
   726  	return nil, errors.Errorf("Expected prealloc to be a tensor")
   727  }
   728  
   729  func (op *maxPoolOp) DiffWRT(inputs int) []bool { return []bool{true} }
   730  
   731  func (op *maxPoolOp) SymDiff(inputs Nodes, output, grad *Node) (retVal Nodes, err error) {
   732  	if err = checkArity(op, len(inputs)); err != nil {
   733  		return
   734  	}
   735  	input := inputs[0]
   736  
   737  	var op2 maxPoolOp
   738  	op2 = *op
   739  	diff := &maxPoolDiffOp{op2}
   740  
   741  	var ret *Node
   742  	if ret, err = ApplyOp(diff, input, output, grad); err != nil {
   743  		return nil, err
   744  	}
   745  	return Nodes{ret}, nil
   746  }
   747  
   748  func (op *maxPoolOp) DoDiff(ctx ExecutionContext, inputs Nodes, output *Node) (err error) {
   749  	if err = checkArity(op, len(inputs)); err != nil {
   750  		return
   751  	}
   752  	input := inputs[0]
   753  	inputDV, outDV := getDV(input, output)
   754  
   755  	var op2 maxPoolOp
   756  	op2 = *op
   757  	diff := &maxPoolDiffOp{op2}
   758  
   759  	if _, err = diff.UsePreallocDo(inputDV.d, inputDV.Value, outDV.Value, outDV.d); err != nil {
   760  		return errors.Wrapf(err, doFail, diff)
   761  	}
   762  	return
   763  }
   764  
   765  func (op *maxPoolOp) checkInput(inputs ...Value) (tensor.Tensor, error) {
   766  	if err := checkArity(op, len(inputs)); err != nil {
   767  		return nil, err
   768  	}
   769  
   770  	var in tensor.Tensor
   771  	var ok bool
   772  	if in, ok = inputs[0].(tensor.Tensor); !ok {
   773  		return nil, errors.Errorf("Expected input to be a tensor")
   774  	}
   775  
   776  	if in.Shape().Dims() != 4 {
   777  		return nil, errors.Errorf("Expected input to have 4 dimensions")
   778  	}
   779  	return in, nil
   780  }
   781  
   782  // calcShape calculates the output shape given an input shape
   783  func (op *maxPoolOp) calcShape(s tensor.Shape) tensor.Shape {
   784  	b, c, h, w := s[0], s[1], s[2], s[3]
   785  
   786  	pooledH := (h+op.padSouth+op.padNorth-(op.h-1)-1)/op.strideH + 1
   787  	pooledW := (w+op.padEast+op.padWest-(op.w-1)-1)/op.strideW + 1
   788  	return tensor.Shape{b, c, pooledH, pooledW}
   789  }
   790  
   791  // do prepares the data, and then dispatches it to the correct (computation) kernel.
   792  // out is the preallocated tensor
   793  func (op *maxPoolOp) do(out, in tensor.Tensor) {
   794  	outShape := out.Shape()
   795  	outStride := out.Strides()[1]
   796  	inShape := in.Shape()
   797  	inStride := in.Strides()[1]
   798  	maskStride := op.mask.Strides()[1]
   799  
   800  	b, c, h, w := outShape[0], outShape[1], outShape[2], outShape[3]
   801  	inH, inW := inShape[2], inShape[3]
   802  
   803  	if op.mask == nil {
   804  		op.mask = tensor.New(tensor.Of(tensor.Int), tensor.WithShape(op.calcShape(inShape)...))
   805  	}
   806  
   807  	maskData := op.mask.Data().([]int)
   808  
   809  	switch in.Dtype() {
   810  	case tensor.Float64:
   811  		op.f64s(b, c, h, w, inH, inW,
   812  			outStride, inStride, maskStride,
   813  			out.Data().([]float64), in.Data().([]float64),
   814  			maskData)
   815  	case tensor.Float32:
   816  		op.f32s(b, c, h, w, inH, inW,
   817  			outStride, inStride, maskStride,
   818  			out.Data().([]float32), in.Data().([]float32),
   819  			maskData)
   820  	}
   821  }
   822  
   823  func (op *maxPoolOp) f32s(batches, channels, outH, outW, inH, inW,
   824  	outStride, inStride, maskStride int,
   825  	outData, inData []float32,
   826  	maskData []int) {
   827  
   828  	// set values
   829  	for i := range outData {
   830  		outData[i] = -maxFloat32
   831  		maskData[i] = -1
   832  	}
   833  	padH := op.padNorth
   834  	padW := op.padWest
   835  	if op.explicitPadding {
   836  		padH = op.padSouth
   837  		padW = op.padEast
   838  	}
   839  
   840  	for b := 0; b < batches; b++ {
   841  		for c := 0; c < channels; c++ {
   842  			for ph := 0; ph < outH; ph++ {
   843  				for pw := 0; pw < outW; pw++ {
   844  
   845  					hStart := ph*op.strideH - padH
   846  					wStart := pw*op.strideW - padW
   847  					hEnd := minInt(hStart+op.h, inH)
   848  					wEnd := minInt(wStart+op.w, inW)
   849  					hStart = maxInt(hStart, 0)
   850  					wStart = maxInt(wStart, 0)
   851  
   852  					poolIndex := ph*outW + pw
   853  					for hi := hStart; hi < hEnd; hi++ {
   854  						for wi := wStart; wi < wEnd; wi++ {
   855  							i := hi*inW + wi
   856  							if inData[i] > outData[poolIndex] {
   857  								outData[poolIndex] = inData[i]
   858  								maskData[poolIndex] = i
   859  							}
   860  						}
   861  					}
   862  				}
   863  			}
   864  			// skip by strides
   865  			inData = inData[inStride:]
   866  			outData = outData[outStride:]
   867  			maskData = maskData[maskStride:]
   868  		}
   869  	}
   870  }
   871  
   872  func (op *maxPoolOp) f64s(batches, channels, outH, outW, inH, inW,
   873  	outStride, inStride, maskStride int,
   874  	outData, inData []float64,
   875  	maskData []int) {
   876  
   877  	// set values
   878  	for i := range outData {
   879  		outData[i] = -maxFloat64
   880  		maskData[i] = -1
   881  	}
   882  	padH := op.padNorth
   883  	padW := op.padWest
   884  	if op.explicitPadding {
   885  		padH = op.padSouth
   886  		padW = op.padEast
   887  	}
   888  
   889  	for b := 0; b < batches; b++ {
   890  		for c := 0; c < channels; c++ {
   891  			for ph := 0; ph < outH; ph++ {
   892  				for pw := 0; pw < outW; pw++ {
   893  					hStart := ph*op.strideH - padH
   894  					wStart := pw*op.strideW - padW
   895  					hEnd := minInt(hStart+op.h, inH)
   896  					wEnd := minInt(wStart+op.w, inW)
   897  					hStart = maxInt(hStart, 0)
   898  					wStart = maxInt(wStart, 0)
   899  
   900  					poolIndex := ph*outW + pw
   901  
   902  					for hi := hStart; hi < hEnd; hi++ {
   903  						for wi := wStart; wi < wEnd; wi++ {
   904  							i := hi*inW + wi
   905  							if inData[i] > outData[poolIndex] {
   906  								outData[poolIndex] = inData[i]
   907  								maskData[poolIndex] = i
   908  							}
   909  						}
   910  					}
   911  				}
   912  			}
   913  			// skip by strides
   914  			inData = inData[inStride:]
   915  			outData = outData[outStride:]
   916  			maskData = maskData[maskStride:]
   917  		}
   918  	}
   919  }
   920  
   921  type maxPoolDiffOp struct {
   922  	maxPoolOp
   923  }
   924  
   925  func (op *maxPoolDiffOp) Arity() int { return 3 }
   926  func (op *maxPoolDiffOp) Type() hm.Type {
   927  	a := hm.TypeVariable('a')
   928  	t := newTensorType(4, a)
   929  	return hm.NewFnType(t, t, t, t)
   930  }
   931  
   932  func (op *maxPoolDiffOp) InferShape(inputs ...DimSizer) (tensor.Shape, error) {
   933  	s := inputs[0].(tensor.Shape).Clone()
   934  	return s, nil
   935  }
   936  
   937  func (op *maxPoolDiffOp) Do(inputs ...Value) (Value, error) {
   938  	var in, out, pooled, pooledGrad tensor.Tensor
   939  	var err error
   940  	if in, pooled, pooledGrad, err = op.checkInput(inputs...); err != nil {
   941  		return nil, err
   942  	}
   943  
   944  	// out is the gradient of in
   945  	out = tensor.New(tensor.Of(in.Dtype()), tensor.WithShape(in.Shape().Clone()...), tensor.WithEngine(in.Engine()))
   946  	op.do(out, in, pooled, pooledGrad)
   947  	return out, nil
   948  }
   949  func (op *maxPoolDiffOp) ReturnsPtr() bool     { return true }
   950  func (op *maxPoolDiffOp) CallsExtern() bool    { return false }
   951  func (op *maxPoolDiffOp) OverwritesInput() int { return -1 }
   952  func (op *maxPoolDiffOp) WriteHash(h hash.Hash) {
   953  	fmt.Fprintf(h, "MaxPoolDiff{%d, %d, %d, %d}(kernel: (%d, %d), pad: (%d, %d), stride: (%d, %d))",
   954  		op.unpaddedB, op.unpaddedC, op.unpaddedH, op.unpaddedW,
   955  		op.h, op.w, op.padNorth, op.padWest, op.strideH, op.strideW)
   956  }
   957  
   958  func (op *maxPoolDiffOp) Hashcode() uint32 { return simpleHash(op) }
   959  
   960  func (op *maxPoolDiffOp) String() string {
   961  	return fmt.Sprintf("MaxPoolDiff{%d, %d, %d, %d}(kernel: (%d, %d), pad: (%d, %d), stride: (%d, %d))",
   962  		op.unpaddedB, op.unpaddedC, op.unpaddedH, op.unpaddedW,
   963  		op.h, op.w, op.padNorth, op.padWest, op.strideH, op.strideW)
   964  }
   965  
   966  func (op *maxPoolDiffOp) UsePreallocDo(prealloc Value, inputs ...Value) (Value, error) {
   967  	var in, pooled, pooledGrad tensor.Tensor
   968  	var err error
   969  	if in, pooled, pooledGrad, err = op.checkInput(inputs...); err != nil {
   970  		return nil, err
   971  	}
   972  	if p, ok := prealloc.(tensor.Tensor); ok {
   973  		op.do(p, in, pooled, pooledGrad)
   974  		return prealloc, nil
   975  	}
   976  	return nil, errors.Errorf("Cannot do with PreallocDo - expected PreAlloc to be tensor")
   977  }
   978  
   979  func (op *maxPoolDiffOp) checkInput(inputs ...Value) (in, pooled, pooledGrad tensor.Tensor, err error) {
   980  	if err = checkArity(op, len(inputs)); err != nil {
   981  		return
   982  	}
   983  
   984  	var ok bool
   985  	if in, ok = inputs[0].(tensor.Tensor); !ok {
   986  		err = errors.Errorf("Expected input to be a tensor")
   987  		return
   988  	}
   989  	if in.Shape().Dims() != 4 {
   990  		err = errors.Errorf("Expected input to have 4 dimensions")
   991  		return
   992  	}
   993  
   994  	if pooled, ok = inputs[1].(tensor.Tensor); !ok {
   995  		err = errors.Errorf("Expected pooled to be a tensor")
   996  		return
   997  	}
   998  	if pooledGrad, ok = inputs[2].(tensor.Tensor); !ok {
   999  		err = errors.Errorf("Expected pooledGrad to be a tensor")
  1000  		return
  1001  	}
  1002  	return
  1003  }
  1004  
  1005  func (op *maxPoolDiffOp) do(inGrad, in, pooled, pooledGrad tensor.Tensor) {
  1006  	pooledShape := pooled.Shape()
  1007  	pooledStride := pooled.Strides()[1]
  1008  	inStride := in.Strides()[1]
  1009  	maskStride := op.mask.Strides()[1]
  1010  	maskData := op.mask.Data().([]int)
  1011  
  1012  	b, c, h, w := pooledShape[0], pooledShape[1], pooledShape[2], pooledShape[3]
  1013  	switch in.Dtype() {
  1014  	case tensor.Float32:
  1015  		inGradData := inGrad.Data().([]float32)
  1016  		pooledGradData := pooledGrad.Data().([]float32)
  1017  		op.f32s(b, c, h, w,
  1018  			inStride, pooledStride, maskStride,
  1019  			inGradData, pooledGradData, maskData)
  1020  	case tensor.Float64:
  1021  		inGradData := inGrad.Data().([]float64)
  1022  		pooledGradData := pooledGrad.Data().([]float64)
  1023  		op.f64s(b, c, h, w,
  1024  			inStride, pooledStride, maskStride,
  1025  			inGradData, pooledGradData, maskData)
  1026  	}
  1027  }
  1028  
  1029  // in is the "bottom", while out is the "top" (bottom being the unpooled, and top being the pooled)
  1030  func (op *maxPoolDiffOp) f32s(batches, channels, pooledH, pooledW int,
  1031  	inStride, outStride, maskStride int,
  1032  	inDiffData, outDiffData []float32,
  1033  	maskData []int) {
  1034  
  1035  	// zero out. let's hope go's optimizer is smart enought
  1036  	for i := range inDiffData {
  1037  		inDiffData[i] = 0
  1038  	}
  1039  
  1040  	// this loop can be goroutine'd
  1041  	for b := 0; b < batches; b++ {
  1042  		for c := 0; c < channels; c++ {
  1043  			for ph := 0; ph < pooledH; ph++ {
  1044  				for pw := 0; pw < pooledW; pw++ {
  1045  					index := ph*pooledW + pw
  1046  					inIndex := maskData[index]
  1047  					inDiffData[inIndex] += outDiffData[index]
  1048  				}
  1049  			}
  1050  			outDiffData = outDiffData[outStride:]
  1051  			inDiffData = inDiffData[inStride:]
  1052  			maskData = maskData[maskStride:]
  1053  		}
  1054  	}
  1055  }
  1056  
  1057  // in is the "bottom", while out is the "top" (bottom being the unpooled, and top being the pooled)
  1058  func (op *maxPoolDiffOp) f64s(batches, channels, pooledH, pooledW int,
  1059  	inStride, outStride, maskStride int,
  1060  	inDiffData, outDiffData []float64,
  1061  	maskData []int) {
  1062  
  1063  	// zero out. let's hope go's optimizer is smart enought
  1064  	for i := range inDiffData {
  1065  		inDiffData[i] = 0
  1066  	}
  1067  
  1068  	// this loop can be goroutine'd
  1069  	for b := 0; b < batches; b++ {
  1070  		for c := 0; c < channels; c++ {
  1071  			for ph := 0; ph < pooledH; ph++ {
  1072  				for pw := 0; pw < pooledW; pw++ {
  1073  					index := ph*pooledW + pw
  1074  					inIndex := maskData[index]
  1075  					inDiffData[inIndex] += outDiffData[index]
  1076  				}
  1077  			}
  1078  			outDiffData = outDiffData[outStride:]
  1079  			inDiffData = inDiffData[inStride:]
  1080  			maskData = maskData[maskStride:]
  1081  		}
  1082  	}
  1083  }
  1084  
  1085  // clampOp is a constant clamping operation
  1086  type clampOp struct {
  1087  	min, max Scalar
  1088  }
  1089  
  1090  func (op *clampOp) Arity() int { return 1 }
  1091  
  1092  func (op *clampOp) Type() hm.Type {
  1093  	return hm.NewFnType(hm.TypeVariable('a'), hm.TypeVariable('a'))
  1094  }
  1095  
  1096  func (op *clampOp) InferShape(shps ...DimSizer) (tensor.Shape, error) {
  1097  	return shps[0].(tensor.Shape), nil
  1098  }
  1099  
  1100  func (op *clampOp) Do(vals ...Value) (Value, error) {
  1101  	return nil, nil
  1102  }
  1103  
  1104  func (op *clampOp) ReturnsPtr() bool { return true }
  1105  
  1106  func (op *clampOp) CallsExtern() bool { return false }
  1107  
  1108  func (op *clampOp) OverwritesInput() int { return 0 }
  1109  
  1110  func (op *clampOp) WriteHash(h hash.Hash) { fmt.Fprintf(h, "ConstClamp{%f, %f}()", op.min, op.max) }
  1111  
  1112  func (op *clampOp) Hashcode() uint32 { return simpleHash(op) }
  1113  func (op *clampOp) String() string   { return fmt.Sprintf("ConstClamp{%f, %f}()", op.min, op.max) }
  1114  
  1115  // BatchNormOp is a batch normalization process as described by Ioffe and Szegedy (2015) -
  1116  // http://arxiv.org/abs/1502.03167
  1117  //
  1118  // Normalization is done as:
  1119  // 	γ(x - μ) / σ + β
  1120  // γ is the scaling factor and β is the offset factor. These are created by BatchNorm()
  1121  type BatchNormOp struct {
  1122  	momentum float64 // momentum for the moving average
  1123  	epsilon  float64 // small variance to be added to avoid dividing by 0
  1124  	dims     int     // 2 or 4. defaults to 4
  1125  
  1126  	// learnables
  1127  	mean, variance, ma *tensor.Dense
  1128  
  1129  	// scratch space
  1130  	meanTmp, varianceTmp, tmpSpace, xNorm                *tensor.Dense
  1131  	batchSumMultiplier, numByChans, spatialSumMultiplier *tensor.Dense
  1132  
  1133  	// training? if training then update movingMean and movingVar
  1134  	training bool
  1135  }
  1136  
  1137  // Arity returns 1
  1138  func (op *BatchNormOp) Arity() int { return 1 }
  1139  
  1140  // Type ...
  1141  func (op *BatchNormOp) Type() hm.Type {
  1142  	dims := op.dims
  1143  	if dims == 0 {
  1144  		dims = 4 // default to 4 if not set
  1145  	}
  1146  
  1147  	t := TensorType{Dims: dims, Of: hm.TypeVariable('a')}
  1148  	return hm.NewFnType(t, t)
  1149  }
  1150  
  1151  // InferShape from the input values
  1152  func (op *BatchNormOp) InferShape(ns ...DimSizer) (tensor.Shape, error) {
  1153  	if err := checkArity(op, len(ns)); err != nil {
  1154  		return nil, errors.Wrapf(err, "batchNorm")
  1155  	}
  1156  
  1157  	return ns[0].(tensor.Shape).Clone(), nil
  1158  }
  1159  
  1160  // Do performs the batchnorm computation on the values
  1161  func (op *BatchNormOp) Do(values ...Value) (retVal Value, err error) {
  1162  	if err := checkArity(op, len(values)); err != nil {
  1163  		return nil, errors.Wrapf(err, "batchNorm Do")
  1164  	}
  1165  	var v, out Value
  1166  	v = values[0]
  1167  	if out, err = CloneValue(v); err != nil {
  1168  		return nil, err
  1169  	}
  1170  	return op.UsePreallocDo(out, v)
  1171  }
  1172  
  1173  // ReturnsPtr is true
  1174  func (op *BatchNormOp) ReturnsPtr() bool { return true }
  1175  
  1176  // CallsExtern is false
  1177  func (op *BatchNormOp) CallsExtern() bool { return false }
  1178  
  1179  // OverwritesInput is -1 (operator doesn't overwrite any input value)
  1180  func (op *BatchNormOp) OverwritesInput() int { return -1 }
  1181  
  1182  // WriteHash ...
  1183  func (op *BatchNormOp) WriteHash(h hash.Hash) {
  1184  	fmt.Fprintf(h, "batchnorm-%1.1f-%1.1f", op.momentum, op.epsilon)
  1185  }
  1186  
  1187  // Hashcode ...
  1188  func (op *BatchNormOp) Hashcode() uint32 { return simpleHash(op) }
  1189  
  1190  func (op *BatchNormOp) String() string {
  1191  	return fmt.Sprintf("batchnorm-%1.1f-%1.1f", op.momentum, op.epsilon)
  1192  }
  1193  
  1194  // DoDiff does the gradient computation
  1195  func (op *BatchNormOp) DoDiff(ctx ExecutionContext, inputs Nodes, output *Node) error {
  1196  	diff := &batchnormDiffOp{op}
  1197  	xdv, ydv := getDV(inputs[0], output)
  1198  	_, err := diff.UsePreallocDo(xdv.d, xdv.Value, ydv.d)
  1199  	return err
  1200  }
  1201  
  1202  // DiffWRT ...
  1203  func (op *BatchNormOp) DiffWRT(inputs int) []bool { return []bool{true} }
  1204  
  1205  // SymDiff ...
  1206  func (op *BatchNormOp) SymDiff(inputs Nodes, output *Node, grad *Node) (retVal Nodes, err error) {
  1207  	if err = checkArity(op, len(inputs)); err != nil {
  1208  		return
  1209  	}
  1210  	input := inputs[0]
  1211  	diff := &batchnormDiffOp{op}
  1212  
  1213  	var ret *Node
  1214  	if ret, err = ApplyOp(diff, input, grad); err != nil {
  1215  		return nil, err
  1216  	}
  1217  	return Nodes{ret}, nil
  1218  }
  1219  
  1220  // UsePreallocDo ...
  1221  func (op *BatchNormOp) UsePreallocDo(prealloc Value, inputs ...Value) (retVal Value, err error) {
  1222  	v := inputs[0]
  1223  	switch v.Dtype() {
  1224  	case Float64:
  1225  		err = op.f64s(v.(*tensor.Dense), prealloc.(*tensor.Dense))
  1226  	case Float32:
  1227  		err = op.f32s(v.(*tensor.Dense), prealloc.(*tensor.Dense))
  1228  	default:
  1229  		return nil, nyi("BatchNorm Do", v.Dtype())
  1230  	}
  1231  	return prealloc, err
  1232  }
  1233  
  1234  // SetTraining configure the op for training mode.
  1235  // A call to this function implicitly calls the Reset() method
  1236  func (op *BatchNormOp) SetTraining() { op.Reset(); op.training = true }
  1237  
  1238  // SetTesting configure the op for testing mode
  1239  func (op *BatchNormOp) SetTesting() { op.training = false }
  1240  
  1241  // Reset the operator by zeroing the internals scratch spaces
  1242  func (op *BatchNormOp) Reset() error {
  1243  	dt := op.ma.Dtype()
  1244  	var uno interface{}
  1245  	switch dt {
  1246  	case Float64:
  1247  		uno = float64(1)
  1248  	case Float32:
  1249  		uno = float32(1)
  1250  	}
  1251  
  1252  	if err := op.spatialSumMultiplier.Memset(uno); err != nil {
  1253  		return err
  1254  	}
  1255  
  1256  	if err := op.batchSumMultiplier.Memset(uno); err != nil {
  1257  		return err
  1258  	}
  1259  
  1260  	op.mean.Zero()
  1261  	op.variance.Zero()
  1262  	op.ma.Zero()
  1263  	op.meanTmp.Zero()
  1264  	op.varianceTmp.Zero()
  1265  	op.tmpSpace.Zero()
  1266  	op.numByChans.Zero()
  1267  	return nil
  1268  }
  1269  
  1270  func (op *BatchNormOp) f64s(input, output *tensor.Dense) (err error) {
  1271  	n := input.Shape()[0]
  1272  	channels := input.Shape()[1]
  1273  	nc := channels * n
  1274  	spatialDim := input.Shape().TotalSize() / (nc)
  1275  
  1276  	inputF64s := input.Float64s()
  1277  	outputF64s := output.Float64s()
  1278  	copy(outputF64s, inputF64s)
  1279  
  1280  	meanTmp := op.meanTmp.Float64s()
  1281  	mean := op.mean.Float64s()
  1282  	varianceTmp := op.varianceTmp.Float64s()
  1283  	variance := op.variance.Float64s()
  1284  	tmp := op.tmpSpace.Float64s()
  1285  	ssm := op.spatialSumMultiplier.Float64s()
  1286  	nbc := op.numByChans.Float64s()
  1287  	bsm := op.batchSumMultiplier.Float64s()
  1288  
  1289  	momentum := op.momentum
  1290  	eps := op.epsilon
  1291  
  1292  	if !op.training {
  1293  		// use stored mean/variance estimates
  1294  		scaleFactor := float64(1)
  1295  		if fst := op.ma.Float64s()[0]; fst != 1 {
  1296  			scaleFactor = fst
  1297  		}
  1298  		copy(meanTmp, mean)
  1299  		whichblas.Dscal(len(meanTmp), scaleFactor, meanTmp, 1)
  1300  		copy(varianceTmp, variance)
  1301  		whichblas.Dscal(len(varianceTmp), scaleFactor, varianceTmp, 1)
  1302  	} else {
  1303  		// compute mean
  1304  		alpha := 1.0 / float64(n*spatialDim)
  1305  		whichblas.Dgemv(blas.NoTrans, nc, spatialDim, alpha, inputF64s, spatialDim, ssm, 1, 0, nbc, 1)
  1306  		whichblas.Dgemv(blas.Trans, n, channels, 1, nbc, channels, bsm, 1, 0, meanTmp, 1)
  1307  	}
  1308  
  1309  	// subtract mean
  1310  	whichblas.Dgemm(blas.NoTrans, blas.NoTrans, n, channels, 1, 1, bsm, 1, meanTmp, channels, 0, nbc, channels)
  1311  	whichblas.Dgemm(blas.NoTrans, blas.NoTrans, nc, spatialDim, 1, -1, nbc, 1, ssm, spatialDim, 1, outputF64s, spatialDim)
  1312  
  1313  	if op.training {
  1314  		// compute variance using var(X) = E(X-EX)²)
  1315  		copy(tmp, outputF64s)
  1316  		vecf64.Mul(tmp, tmp) // (X-EX) ^ 2
  1317  
  1318  		whichblas.Dgemv(blas.NoTrans, nc, spatialDim, 1.0/(float64(n*spatialDim)), tmp, spatialDim, ssm, 1, 0, nbc, 1)
  1319  		whichblas.Dgemv(blas.Trans, n, channels, 1.0, nbc, channels, bsm, 1, 0, varianceTmp, 1) // E((X_EX)²)
  1320  
  1321  		// compute and save moving average
  1322  		op.ma.Float64s()[0] *= momentum
  1323  		op.ma.Float64s()[0]++
  1324  
  1325  		// TODO: write axpby for gonum
  1326  		whichblas.Dscal(len(mean), momentum, mean, 1)
  1327  		whichblas.Daxpy(len(meanTmp), 1.0, meanTmp, 1, mean, 1)
  1328  
  1329  		m := len(inputF64s) / channels
  1330  		correctionFactor := float64(1)
  1331  		if m > 1 {
  1332  			correctionFactor = float64(m) / (float64(m - 1))
  1333  		}
  1334  		whichblas.Dscal(len(variance), momentum, variance, 1)
  1335  		whichblas.Daxpy(len(varianceTmp), correctionFactor, varianceTmp, 1, variance, 1)
  1336  	}
  1337  
  1338  	// normalize variance
  1339  	vecf64.Trans(varianceTmp, eps)
  1340  	vecf64.Sqrt(varianceTmp)
  1341  
  1342  	// replicate variance to inputsize
  1343  	whichblas.Dgemm(blas.NoTrans, blas.NoTrans, n, channels, 1, 1, bsm, 1, varianceTmp, channels, 0, nbc, channels)
  1344  	whichblas.Dgemm(blas.NoTrans, blas.NoTrans, nc, spatialDim, 1, 1, nbc, 1, ssm, spatialDim, 0, tmp, spatialDim)
  1345  	vecf64.Div(outputF64s, tmp)
  1346  	copy(op.xNorm.Float64s(), outputF64s) // caching
  1347  
  1348  	return nil
  1349  }
  1350  
  1351  func (op *BatchNormOp) f32s(input, output *tensor.Dense) (err error) {
  1352  	n := input.Shape()[0]
  1353  	channels := input.Shape()[1]
  1354  	nc := channels * n
  1355  	spatialDim := input.Shape().TotalSize() / (nc)
  1356  
  1357  	inputF32s := input.Float32s()
  1358  	outputF32s := output.Float32s()
  1359  	copy(outputF32s, inputF32s)
  1360  
  1361  	meanTmp := op.meanTmp.Float32s()
  1362  	mean := op.mean.Float32s()
  1363  	varianceTmp := op.varianceTmp.Float32s()
  1364  	variance := op.variance.Float32s()
  1365  	tmp := op.tmpSpace.Float32s()
  1366  	ssm := op.spatialSumMultiplier.Float32s()
  1367  	nbc := op.numByChans.Float32s()
  1368  	bsm := op.batchSumMultiplier.Float32s()
  1369  
  1370  	momentum := float32(op.momentum)
  1371  	eps := float32(op.epsilon)
  1372  
  1373  	if !op.training {
  1374  		// use stored mean/variance estimates
  1375  		scaleFactor := float32(1)
  1376  		if fst := op.ma.Float32s()[0]; fst != 1 {
  1377  			scaleFactor = fst
  1378  		}
  1379  		copy(meanTmp, mean)
  1380  		whichblas.Sscal(len(meanTmp), scaleFactor, meanTmp, 1)
  1381  		copy(varianceTmp, variance)
  1382  		whichblas.Sscal(len(varianceTmp), scaleFactor, varianceTmp, 1)
  1383  	} else {
  1384  		// compute mean
  1385  		alpha := 1.0 / float32(n*spatialDim)
  1386  		whichblas.Sgemv(blas.NoTrans, nc, spatialDim, alpha, inputF32s, spatialDim, ssm, 1, 0, nbc, 1)
  1387  		whichblas.Sgemv(blas.Trans, n, channels, 1, nbc, channels, bsm, 1, 0, meanTmp, 1)
  1388  	}
  1389  
  1390  	// subtract mean
  1391  	whichblas.Sgemm(blas.NoTrans, blas.NoTrans, n, channels, 1, 1, bsm, 1, meanTmp, channels, 0, nbc, channels)
  1392  	whichblas.Sgemm(blas.NoTrans, blas.NoTrans, nc, spatialDim, 1, -1, nbc, 1, ssm, spatialDim, 1, outputF32s, spatialDim)
  1393  
  1394  	if op.training {
  1395  		// compute variance using var(X) = E(X-EX)²)
  1396  		copy(tmp, outputF32s)
  1397  		vecf32.Mul(tmp, tmp) // (X-EX) ^ 2
  1398  
  1399  		whichblas.Sgemv(blas.NoTrans, nc, spatialDim, 1.0/(float32(n*spatialDim)), tmp, spatialDim, ssm, 1, 0, nbc, 1)
  1400  		whichblas.Sgemv(blas.Trans, n, channels, 1.0, nbc, channels, bsm, 1, 0, varianceTmp, 1) // E((X_EX)²)
  1401  
  1402  		// compute and save moving average
  1403  		op.ma.Float32s()[0] *= momentum
  1404  		op.ma.Float32s()[0]++
  1405  
  1406  		// TODO: write axpby for gonum
  1407  		whichblas.Sscal(len(mean), momentum, mean, 1)
  1408  		whichblas.Saxpy(len(meanTmp), 1.0, meanTmp, 1, mean, 1)
  1409  
  1410  		m := len(inputF32s) / channels
  1411  		correctionFactor := float32(1)
  1412  		if m > 1 {
  1413  			correctionFactor = float32(m) / (float32(m - 1))
  1414  		}
  1415  		whichblas.Sscal(len(variance), momentum, variance, 1)
  1416  		whichblas.Saxpy(len(varianceTmp), correctionFactor, varianceTmp, 1, variance, 1)
  1417  	}
  1418  
  1419  	// normalize variance
  1420  	vecf32.Trans(varianceTmp, eps)
  1421  	vecf32.Sqrt(varianceTmp)
  1422  
  1423  	// replicate variance to inputsize
  1424  	whichblas.Sgemm(blas.NoTrans, blas.NoTrans, n, channels, 1, 1, bsm, 1, varianceTmp, channels, 0, nbc, channels)
  1425  	whichblas.Sgemm(blas.NoTrans, blas.NoTrans, nc, spatialDim, 1, 1, nbc, 1, ssm, spatialDim, 0, tmp, spatialDim)
  1426  	vecf32.Div(outputF32s, tmp)
  1427  	copy(op.xNorm.Float32s(), outputF32s) // caching
  1428  
  1429  	return nil
  1430  }
  1431  
  1432  type batchnormDiffOp struct{ *BatchNormOp }
  1433  
  1434  func (op *batchnormDiffOp) Arity() int { return 2 }
  1435  
  1436  func (op *batchnormDiffOp) Type() hm.Type {
  1437  	dims := op.dims
  1438  	if dims == 0 {
  1439  		dims = 4
  1440  	}
  1441  
  1442  	t := TensorType{Dims: dims, Of: hm.TypeVariable('a')}
  1443  	return hm.NewFnType(t, t, t)
  1444  }
  1445  
  1446  func (op *batchnormDiffOp) InferShape(ns ...DimSizer) (tensor.Shape, error) {
  1447  	if err := checkArity(op, len(ns)); err != nil {
  1448  		return nil, errors.Wrapf(err, "batchNorm")
  1449  	}
  1450  
  1451  	return ns[0].(tensor.Shape).Clone(), nil
  1452  }
  1453  
  1454  func (op *batchnormDiffOp) Do(values ...Value) (Value, error) {
  1455  	input := values[0].(*tensor.Dense)
  1456  	grad := values[1].(*tensor.Dense)
  1457  	inputGrad := input.Clone().(*tensor.Dense)
  1458  	return op.UsePreallocDo(inputGrad, input, grad)
  1459  }
  1460  
  1461  // ReturnsPtr is the same exact characteristics of batchnorm
  1462  // CallsExtern is the same exact characteristics of batchnorm
  1463  // OverwritesInput is the same exact characteristics of batchnorm
  1464  
  1465  func (op *batchnormDiffOp) WriteHash(h hash.Hash) {
  1466  	fmt.Fprintf(h, "batchnormdiff-%1.1f-%1.1f", op.momentum, op.epsilon)
  1467  }
  1468  
  1469  func (op *batchnormDiffOp) Hashcode() uint32 { return simpleHash(op) }
  1470  
  1471  func (op *batchnormDiffOp) String() string {
  1472  	return fmt.Sprintf("batchnormdiff-%1.1f-%1.1f", op.momentum, op.epsilon)
  1473  }
  1474  
  1475  func (op *batchnormDiffOp) DiffWRT(inputs int) []bool {
  1476  	// god help those who want to  do 2nd order differentiation on batchnorm
  1477  	return []bool{false, false}
  1478  }
  1479  
  1480  func (op *batchnormDiffOp) SymDiff(inputs Nodes, output *Node, grad *Node) (retVal Nodes, err error) {
  1481  	// god help those who want to  do 2nd order differentiation on batchnorm
  1482  	return nil, nyi("SymDiff", "batchNormDiffOp")
  1483  }
  1484  
  1485  func (op *batchnormDiffOp) DoDiff(ctx ExecutionContext, inputs Nodes, output *Node) error {
  1486  	// god help those who want to  do 2nd order differentiation on batchnorm
  1487  	return nyi("DoDiff", "batchnormDiffOp")
  1488  }
  1489  
  1490  func (op *batchnormDiffOp) UsePreallocDo(prealloc Value, inputs ...Value) (retVal Value, err error) {
  1491  	input := inputs[0].(*tensor.Dense)
  1492  	inGrad := prealloc.(*tensor.Dense)
  1493  	outGrad := inputs[1].(*tensor.Dense)
  1494  
  1495  	switch input.Dtype() {
  1496  	case Float64:
  1497  		err = op.f64s(input, inGrad, outGrad)
  1498  	case Float32:
  1499  		err = op.f32s(input, inGrad, outGrad)
  1500  	default:
  1501  		return nil, nyi("batchnormDiffOp", "Do")
  1502  	}
  1503  	return prealloc, err
  1504  }
  1505  
  1506  func (op *batchnormDiffOp) f64s(input, inGrad, outGrad *tensor.Dense) (err error) {
  1507  	in := input.Float64s()
  1508  	ig := inGrad.Float64s()
  1509  	og := outGrad.Float64s()
  1510  	tmp := op.tmpSpace.Float64s()
  1511  	out := op.xNorm.Float64s()
  1512  	ssm := op.spatialSumMultiplier.Float64s()
  1513  	nbc := op.numByChans.Float64s()
  1514  	bsm := op.batchSumMultiplier.Float64s()
  1515  	meanTmp := op.meanTmp.Float64s()
  1516  
  1517  	if !op.training {
  1518  		copy(ig, og)
  1519  		vecf64.Div(og, tmp)
  1520  		return nil
  1521  	}
  1522  
  1523  	n := input.Shape()[0]
  1524  	channels := input.Shape()[1]
  1525  	nc := n * channels
  1526  	spatialDim := len(in) / nc
  1527  
  1528  	// if Y = (X-mean(X))/(sqrt(var(X)+eps)), then
  1529  	//
  1530  	// dE(Y)/dX =
  1531  	//   (dE/dY - mean(dE/dY) - mean(dE/dY ⋅ Y) ⋅ Y)
  1532  	//     ./ sqrt(var(X) + eps)
  1533  	//
  1534  	// where ⋅ and ./ are hadamard product and elementwise division,
  1535  	// respectively, dE/dY is the top diff, and mean/var/sum are all computed
  1536  	// along all dimensions except the channels dimension.  In the above
  1537  	// equation, the operations allow for expansion (i.e. broadcast) along all
  1538  	// dimensions except the channels dimension where required.
  1539  
  1540  	// sum(dE/dY ⋅ Y)
  1541  	copy(ig, out)
  1542  	vecf64.Mul(ig, og)
  1543  	whichblas.Dgemv(blas.NoTrans, nc, spatialDim, 1, ig, spatialDim, ssm, 1, 0, nbc, 1)
  1544  	whichblas.Dgemv(blas.Trans, n, channels, 1, nbc, channels, bsm, 1, 0, meanTmp, 1)
  1545  
  1546  	// reshape (broadcast) the above
  1547  	whichblas.Dgemm(blas.NoTrans, blas.NoTrans, n, channels, 1, 1, bsm, 1, meanTmp, channels, 0, nbc, channels)
  1548  	whichblas.Dgemm(blas.NoTrans, blas.NoTrans, nc, spatialDim, 1, 1, nbc, 1, ssm, spatialDim, 0, ig, spatialDim)
  1549  
  1550  	// sum(dE/dY ⋅ Y) ⋅ Y
  1551  	vecf64.Mul(ig, out)
  1552  
  1553  	// sum(dE/dY)-sum(dE/dY ⋅ Y) ⋅ Y
  1554  	whichblas.Dgemv(blas.NoTrans, nc, spatialDim, 1, og, spatialDim, ssm, 1, 0, nbc, 1)
  1555  	whichblas.Dgemv(blas.Trans, n, channels, 1, nbc, channels, bsm, 1, 0, meanTmp, 1)
  1556  
  1557  	// reshape (broadcast) the above to make
  1558  	// sum(dE/dY)-sum(dE/dY ⋅ Y) ⋅ Y
  1559  	whichblas.Dgemm(blas.NoTrans, blas.NoTrans, n, channels, 1, 1, bsm, 1, meanTmp, channels, 0, nbc, channels)
  1560  	whichblas.Dgemm(blas.NoTrans, blas.NoTrans, nc, spatialDim, 1, 1, nbc, 1, ssm, spatialDim, 1, ig, spatialDim)
  1561  
  1562  	// dE/dY - mean(dE/dY)-mean(dE/dY ⋅ Y) ⋅ Y
  1563  	beta := (-1.0 / float64(nc))
  1564  
  1565  	vecf64.Scale(ig, beta)
  1566  	vecf64.Add(ig, og)
  1567  
  1568  	// note: temp_ still contains sqrt(var(X)+eps), computed during the forward
  1569  	// pass.
  1570  	vecf64.Div(ig, tmp)
  1571  	return nil
  1572  
  1573  }
  1574  
  1575  func (op *batchnormDiffOp) f32s(input, inGrad, outGrad *tensor.Dense) (err error) {
  1576  	in := input.Float32s()
  1577  	ig := inGrad.Float32s()
  1578  	og := outGrad.Float32s()
  1579  	tmp := op.tmpSpace.Float32s()
  1580  	out := op.xNorm.Float32s()
  1581  	ssm := op.spatialSumMultiplier.Float32s()
  1582  	nbc := op.numByChans.Float32s()
  1583  	bsm := op.batchSumMultiplier.Float32s()
  1584  	meanTmp := op.meanTmp.Float32s()
  1585  
  1586  	if !op.training {
  1587  		copy(ig, og)
  1588  		vecf32.Div(og, tmp)
  1589  		return nil
  1590  	}
  1591  
  1592  	n := input.Shape()[0]
  1593  	channels := input.Shape()[1]
  1594  	nc := n * channels
  1595  	spatialDim := len(in) / nc
  1596  
  1597  	// if Y = (X-mean(X))/(sqrt(var(X)+eps)), then
  1598  	//
  1599  	// dE(Y)/dX =
  1600  	//   (dE/dY - mean(dE/dY) - mean(dE/dY ⋅ Y) ⋅ Y)
  1601  	//     ./ sqrt(var(X) + eps)
  1602  	//
  1603  	// where ⋅ and ./ are hadamard product and elementwise division,
  1604  	// respectively, dE/dY is the top diff, and mean/var/sum are all computed
  1605  	// along all dimensions except the channels dimension.  In the above
  1606  	// equation, the operations allow for expansion (i.e. broadcast) along all
  1607  	// dimensions except the channels dimension where required.
  1608  
  1609  	// sum(dE/dY ⋅ Y)
  1610  	copy(ig, out)
  1611  	vecf32.Mul(ig, og)
  1612  	whichblas.Sgemv(blas.NoTrans, nc, spatialDim, 1, ig, spatialDim, ssm, 1, 0, nbc, 1)
  1613  	whichblas.Sgemv(blas.Trans, n, channels, 1, nbc, channels, bsm, 1, 0, meanTmp, 1)
  1614  
  1615  	// reshape (broadcast) the above
  1616  	whichblas.Sgemm(blas.NoTrans, blas.NoTrans, n, channels, 1, 1, bsm, 1, meanTmp, channels, 0, nbc, channels)
  1617  	whichblas.Sgemm(blas.NoTrans, blas.NoTrans, nc, spatialDim, 1, 1, nbc, 1, ssm, spatialDim, 0, ig, spatialDim)
  1618  
  1619  	// sum(dE/dY ⋅ Y) ⋅ Y
  1620  	vecf32.Mul(ig, out)
  1621  
  1622  	// sum(dE/dY)-sum(dE/dY ⋅ Y) ⋅ Y
  1623  	whichblas.Sgemv(blas.NoTrans, nc, spatialDim, 1, og, spatialDim, ssm, 1, 0, nbc, 1)
  1624  	whichblas.Sgemv(blas.Trans, n, channels, 1, nbc, channels, bsm, 1, 0, meanTmp, 1)
  1625  
  1626  	// reshape (broadcast) the above to make
  1627  	// sum(dE/dY)-sum(dE/dY ⋅ Y) ⋅ Y
  1628  	whichblas.Sgemm(blas.NoTrans, blas.NoTrans, n, channels, 1, 1, bsm, 1, meanTmp, channels, 0, nbc, channels)
  1629  	whichblas.Sgemm(blas.NoTrans, blas.NoTrans, nc, spatialDim, 1, 1, nbc, 1, ssm, spatialDim, 1, ig, spatialDim)
  1630  
  1631  	// dE/dY - mean(dE/dY)-mean(dE/dY ⋅ Y) ⋅ Y
  1632  	beta := (-1.0 / float32(n*spatialDim))
  1633  	vecf32.Scale(ig, beta)
  1634  	vecf32.Add(ig, og)
  1635  
  1636  	// note: temp_ still contains sqrt(var(X)+eps), computed during the forward
  1637  	// pass.
  1638  	vecf32.Div(ig, tmp)
  1639  	return nil
  1640  
  1641  }
  1642  
  1643  type globalAveragePoolOp struct{}
  1644  
  1645  func (g *globalAveragePoolOp) Arity() int {
  1646  	return 1
  1647  }
  1648  
  1649  func (g *globalAveragePoolOp) Type() hm.Type {
  1650  	a := hm.TypeVariable('a')
  1651  	t := newTensorType(4, a)
  1652  	return hm.NewFnType(t, t)
  1653  }
  1654  
  1655  func (g *globalAveragePoolOp) InferShape(inputs ...DimSizer) (tensor.Shape, error) {
  1656  	b, err := inputs[0].DimSize(0)
  1657  	if err != nil {
  1658  		return nil, err
  1659  	}
  1660  	c, err := inputs[0].DimSize(1)
  1661  	if err != nil {
  1662  		return nil, err
  1663  	}
  1664  	// check if the shape is correct without doing type inference
  1665  	if _, err := inputs[0].DimSize(2); err != nil {
  1666  		return nil, err
  1667  	}
  1668  	if _, err := inputs[0].DimSize(3); err != nil {
  1669  		return nil, err
  1670  	}
  1671  	return tensor.Shape{b, c, 1, 1}, nil
  1672  }
  1673  
  1674  func (g *globalAveragePoolOp) Do(inputs ...Value) (Value, error) {
  1675  	im := inputs[0]
  1676  	switch im.(type) {
  1677  	case tensor.Tensor:
  1678  		v := im.(tensor.Tensor)
  1679  		B, C, H, W := v.Shape()[0], v.Shape()[1], v.Shape()[2], v.Shape()[3]
  1680  		s, err := g.InferShape(v.Shape())
  1681  		if err != nil {
  1682  			return nil, err
  1683  		}
  1684  		output := tensor.New(tensor.Of(v.Dtype()), tensor.WithShape(s...))
  1685  		switch v.Dtype() {
  1686  		case tensor.Float64:
  1687  			for b := 0; b < B; b++ {
  1688  				for c := 0; c < C; c++ {
  1689  					var sum float64
  1690  					for h := 0; h < H; h++ {
  1691  						for w := 0; w < W; w++ {
  1692  							val, err := v.At(b, c, h, w)
  1693  							if err != nil {
  1694  								return nil, err
  1695  							}
  1696  							sum += val.(float64)
  1697  						}
  1698  					}
  1699  					err := output.SetAt(sum/float64(H*W), b, c, 0, 0)
  1700  					if err != nil {
  1701  						return nil, err
  1702  					}
  1703  				}
  1704  			}
  1705  		case tensor.Float32:
  1706  			for b := 0; b < B; b++ {
  1707  				for c := 0; c < C; c++ {
  1708  					var sum float32
  1709  					for h := 0; h < H; h++ {
  1710  						for w := 0; w < W; w++ {
  1711  							val, err := v.At(b, c, h, w)
  1712  							if err != nil {
  1713  								return nil, err
  1714  							}
  1715  							sum += val.(float32)
  1716  						}
  1717  					}
  1718  					err := output.SetAt(sum/float32(H*W), b, c, 0, 0)
  1719  					if err != nil {
  1720  						return nil, err
  1721  					}
  1722  				}
  1723  			}
  1724  		default:
  1725  			return nil, nyi("Global Average Pool", v.Dtype())
  1726  		}
  1727  
  1728  		return output, nil
  1729  
  1730  	default:
  1731  		return nil, nyi("globalAveragePoolOp", inputs)
  1732  	}
  1733  }
  1734  
  1735  func (g *globalAveragePoolOp) ReturnsPtr() bool {
  1736  	return false
  1737  }
  1738  
  1739  func (g *globalAveragePoolOp) CallsExtern() bool {
  1740  	return false
  1741  }
  1742  
  1743  func (g *globalAveragePoolOp) OverwritesInput() int {
  1744  	return -1
  1745  }
  1746  
  1747  func (g *globalAveragePoolOp) WriteHash(h hash.Hash) {
  1748  	fmt.Fprintf(h, "GlobalAveragePool")
  1749  }
  1750  
  1751  func (g *globalAveragePoolOp) Hashcode() uint32 {
  1752  	return simpleHash(g)
  1753  }
  1754  
  1755  func (g *globalAveragePoolOp) String() string {
  1756  	return "GlobalAveragePool"
  1757  }