gorgonia.org/gorgonia@v0.9.17/nn.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"fmt"
     5  
     6  	"github.com/pkg/errors"
     7  	"gorgonia.org/gorgonia/internal/encoding"
     8  	"gorgonia.org/tensor"
     9  )
    10  
    11  // BinaryXent is a convenience function for doing binary crossentropy stuff.
    12  // The formula is as below:
    13  // 		-(y * logprob) +  (1-y)(1-logprob)
    14  func BinaryXent(output, target *Node) (retVal *Node, err error) {
    15  	var one *Node
    16  	var logO, omt, omo, tLogO *Node
    17  
    18  	// which constant one to use?
    19  	var dt tensor.Dtype
    20  	if dt, err = dtypeOf(output.t); err != nil {
    21  		return nil, errors.Wrapf(err, dtypeExtractionFail, output.t)
    22  	}
    23  
    24  	switch dt {
    25  	case Float64:
    26  		one = onef64
    27  	case Float32:
    28  		one = onef32
    29  	default:
    30  		return nil, errors.Errorf(nyiFail, "BinaryXEnt", dt)
    31  	}
    32  
    33  	if logO, err = Log(output); err != nil {
    34  		return nil, errors.Wrap(err, operationError)
    35  	}
    36  
    37  	if omt, err = Sub(one, target); err != nil {
    38  		return nil, errors.Wrap(err, operationError)
    39  	}
    40  
    41  	if omo, err = Sub(one, output); err != nil {
    42  		return nil, errors.Wrap(err, operationError)
    43  	}
    44  
    45  	if tLogO, err = HadamardProd(target, logO); err != nil {
    46  		return nil, errors.Wrap(err, operationError)
    47  	}
    48  
    49  	if retVal, err = Log(omo); err != nil {
    50  		return nil, errors.Wrap(err, operationError)
    51  	}
    52  
    53  	if retVal, err = HadamardProd(omt, retVal); err != nil {
    54  		return nil, errors.Wrap(err, operationError)
    55  	}
    56  
    57  	if retVal, err = Add(tLogO, retVal); err != nil {
    58  		return nil, errors.Wrap(err, operationError)
    59  	}
    60  
    61  	return Neg(retVal)
    62  }
    63  
    64  // Dropout is a convenience function to implement dropout.
    65  // It uses randomly zeroes out a *Tensor with a probability drawn from
    66  // a uniform distribution
    67  func Dropout(x *Node, dropProb float64) (retVal *Node, err error) {
    68  	return dropout(x, dropProb, UniformRandomNode)
    69  }
    70  
    71  type dropoutRandFn func(g *ExprGraph, dt tensor.Dtype, low, high float64, shape ...int) *Node
    72  
    73  func dropout(x *Node, dropProb float64, randFn dropoutRandFn) (retVal *Node, err error) {
    74  	if dropProb == 0.0 {
    75  		return x, nil
    76  	}
    77  	keepProb := 1.0 - dropProb
    78  
    79  	var dt tensor.Dtype
    80  	if dt, err = dtypeOf(x.t); err != nil {
    81  		return nil, errors.Wrap(err, dtypeOfFail)
    82  	}
    83  
    84  	var pr Value
    85  	switch dt {
    86  	case Float64:
    87  		pr, _ = anyToScalar(keepProb)
    88  	case Float32:
    89  		pr, _ = anyToScalar(float32(keepProb))
    90  	default:
    91  		return nil, errors.Errorf(nyiTypeFail, "Dropout()", dt)
    92  	}
    93  
    94  	p := NewConstant(pr)
    95  
    96  	m := randFn(x.g, dt, 0, 1, x.shape...)
    97  	if retVal, err = Lt(m, p, true); err != nil {
    98  		return nil, errors.Wrap(err, "Greater Than failed")
    99  	}
   100  
   101  	if retVal, err = HadamardProd(x, retVal); err != nil {
   102  		return nil, errors.Wrap(err, mulFail)
   103  	}
   104  
   105  	return HadamardDiv(retVal, p)
   106  }
   107  
   108  // LeakyRelu returns a node whose underlying value is:
   109  //   f(x) = alpha * x if x < 0
   110  //   f(x) = x for x ⩾ 0
   111  // applied elementwise.
   112  func LeakyRelu(x *Node, alpha float64) (*Node, error) {
   113  	var zero *Node
   114  	var dt tensor.Dtype
   115  	var err error
   116  	var alphaN *Node
   117  
   118  	// which zero to use?
   119  	if dt, err = dtypeOf(x.t); err != nil {
   120  		return nil, errors.Wrap(err, dtypeOfFail)
   121  	}
   122  	switch dt {
   123  	case Float64:
   124  		zero = zerof64
   125  		alphaN = NewConstant(alpha)
   126  	case Float32:
   127  		zero = zerof32
   128  		alphaN = NewConstant(float32(alpha))
   129  	default:
   130  		return nil, errors.Errorf(nyiFail, "ReLu", dt)
   131  	}
   132  
   133  	gteZeroOp := newElemBinOp(gteOpType, x, zero)
   134  	gteZeroOp.retSame = true
   135  
   136  	xGteZeroCmp, err := ApplyOp(gteZeroOp, x, zero)
   137  	if err != nil {
   138  		return nil, errors.Wrap(err, applyOpFail)
   139  	}
   140  	ltZeroOp := newElemBinOp(ltOpType, x, zero)
   141  	ltZeroOp.retSame = true
   142  
   143  	xLtZeroCmp, err := ApplyOp(ltZeroOp, x, zero)
   144  	if err != nil {
   145  		return nil, errors.Wrap(err, applyOpFail)
   146  	}
   147  	xGteZero, err := HadamardProd(x, xGteZeroCmp)
   148  	if err != nil {
   149  		return nil, errors.Wrap(err, applyOpFail)
   150  	}
   151  	xLtZero, err := HadamardProd(x, xLtZeroCmp)
   152  	if err != nil {
   153  		return nil, errors.Wrap(err, applyOpFail)
   154  	}
   155  	xLtZeroAlpha, err := HadamardProd(xLtZero, alphaN)
   156  	if err != nil {
   157  		return nil, errors.Wrap(err, applyOpFail)
   158  	}
   159  	return Add(xGteZero, xLtZeroAlpha)
   160  }
   161  
   162  // Rectify is a convenience function for creating rectified linear units activation functions.
   163  // This function uses ⩾, which is the canonical version. If you want to use >, you can create
   164  // your own by just following this.
   165  func Rectify(x *Node) (retVal *Node, err error) {
   166  	var zero *Node
   167  	var dt tensor.Dtype
   168  	group := encoding.NewGroup("Rectify")
   169  
   170  	// which zero to use?
   171  	if dt, err = dtypeOf(x.t); err != nil {
   172  		return nil, errors.Wrap(err, dtypeOfFail)
   173  	}
   174  	switch dt {
   175  	case Float64:
   176  		zero = zerof64
   177  	case Float32:
   178  		zero = zerof32
   179  	default:
   180  		return nil, errors.Errorf(nyiFail, "ReLu", dt)
   181  	}
   182  
   183  	cmp := newElemBinOp(gteOpType, x, zero)
   184  	cmp.retSame = true
   185  
   186  	if retVal, err = ApplyOp(cmp, x, zero); err != nil {
   187  		return nil, errors.Wrap(err, applyOpFail)
   188  	}
   189  	retVal.groups = retVal.groups.Upsert(group)
   190  
   191  	return HadamardProd(x, retVal)
   192  }
   193  
   194  // Im2Col converts a BCHW image block to columns. The kernel, pad and stride parameter must be shape of size 2, no more no less
   195  // This poor naming scheme clearly comes from matlab
   196  func Im2Col(n *Node, kernel, pad, stride, dilation tensor.Shape) (retVal *Node, err error) {
   197  	if kernel.Dims() != 2 {
   198  		return nil, errors.Errorf("kernel shape is supposed to have a dim of 2")
   199  	}
   200  	if pad.Dims() != 2 {
   201  		return nil, errors.Errorf("pad is supposed to have a dim of 2")
   202  	}
   203  	if stride.Dims() != 2 {
   204  		return nil, errors.Errorf("strides is supposed to have a dim of 2")
   205  	}
   206  	if dilation.Dims() != 2 {
   207  		return nil, errors.Errorf("dilation is supposed to have a dim of 2")
   208  	}
   209  
   210  	if kernel[0] <= 0 || kernel[1] <= 0 {
   211  		return nil, errors.Errorf("cannot have negative or 0 in kernel shape")
   212  	}
   213  
   214  	if stride[0] <= 0 || stride[1] <= 0 {
   215  		return nil, errors.Errorf("cannot have negative or 0 in stride: %v", stride)
   216  	}
   217  
   218  	if pad[0] < 0 || pad[1] < 0 {
   219  		return nil, errors.Errorf("cannot have negative padding")
   220  	}
   221  
   222  	if dilation[0] <= 0 || dilation[1] <= 0 {
   223  		return nil, errors.Errorf("cannot have negative or 0 in dilation. %v", dilation)
   224  	}
   225  
   226  	op := makeIm2ColOp(kernel[0], kernel[1], pad[0], pad[1], stride[0], stride[1], dilation[0], dilation[1])
   227  	return ApplyOp(op, n)
   228  }
   229  
   230  // Conv2d is a simple 2D convolution, to be used for CPU computation only.
   231  // If CuDNN is used, use the CUDAConv2D function.
   232  // These are the properties the inputs must fulfil:
   233  //
   234  // - im: must have 4D shape. Expected format is BCHW (batch, channels, height, width)
   235  // - filter: must have 4D shape: (batch, kernel, height, width)
   236  // - kernelShape: shape of the filter kernel
   237  // - pad: len(pad) == 2, defaults to []int{0, 0} if nil is passed
   238  // - stride: len(stride) == 2, example: []int{1, 1}
   239  // - dilation: len(dilation) == 2, defaults to []int{1, 1} if nil is passed
   240  func Conv2d(im, filter *Node, kernelShape tensor.Shape, pad, stride, dilation []int) (retVal *Node, err error) {
   241  	group := encoding.NewGroup("Convolution")
   242  	// niceness for defaults
   243  	if pad == nil {
   244  		pad = []int{0, 0}
   245  	}
   246  	if dilation == nil {
   247  		dilation = []int{1, 1}
   248  	}
   249  
   250  	if im.Shape().Dims() != 4 {
   251  		return nil, fmt.Errorf("im should have 4 dims, got %v dims", im.Shape().Dims())
   252  	}
   253  
   254  	if filter.Shape().Dims() != 4 {
   255  		return nil, fmt.Errorf("filter should have 4 dims, got %v dims", filter.Shape().Dims())
   256  	}
   257  
   258  	// checks
   259  	for _, s := range stride {
   260  		if s <= 0 {
   261  			return nil, errors.Errorf("Cannot use strides of less than or equal 0: %v", stride)
   262  		}
   263  	}
   264  
   265  	for _, p := range pad {
   266  		if p < 0 {
   267  			return nil, errors.Errorf("Cannot use padding of less than 0: %v", pad)
   268  		}
   269  	}
   270  
   271  	for _, d := range dilation {
   272  		if d <= 0 {
   273  			return nil, errors.Errorf("Cannot use dilation less than or eq 0 %v", dilation)
   274  		}
   275  	}
   276  
   277  	var colIm *Node
   278  	if colIm, err = Im2Col(im, kernelShape, pad, stride, dilation); err != nil {
   279  		return nil, fmt.Errorf("Im2Col to failed: %w", err)
   280  	}
   281  	colIm.groups = colIm.groups.Upsert(group)
   282  
   283  	layer := filter.Shape()[0]
   284  	kernel := filter.Shape()[1]
   285  	row := filter.Shape()[2]
   286  	col := filter.Shape()[3]
   287  
   288  	if colIm.Shape()[3] != kernel*row*col {
   289  		return nil, fmt.Errorf("%d (kernel) * %d (width) * %d (height) must be %d, got %d", kernel, row, col, colIm.Shape()[3], kernel*row*col)
   290  	}
   291  
   292  	var flattened *Node
   293  	if flattened, err = Reshape(filter, tensor.Shape{layer, kernel * row * col}); err != nil {
   294  		return nil, fmt.Errorf("reshaping filter from %v to (%v, %v * %v * %v) failed: %w", filter.Shape(), layer, kernel, row, col, err)
   295  	}
   296  	flattened.groups = flattened.groups.Upsert(group)
   297  
   298  	// extract patch
   299  	batch := colIm.Shape()[0]
   300  	m := colIm.Shape()[1]
   301  	n := colIm.Shape()[2]
   302  	z := colIm.Shape()[3]
   303  
   304  	var patch, colImLayer *Node
   305  	if patch, err = Reshape(colIm, tensor.Shape{batch * m * n, z}); err != nil {
   306  		return nil, fmt.Errorf("reshaping colIm from %v to (%v * %v * %v * %v) failed: %w", colIm.Shape(), batch, m, n, z, err)
   307  	}
   308  	patch.groups = patch.groups.Upsert(group)
   309  
   310  	op := linAlgBinOp{
   311  		āBinaryOperator: matMulOperator,
   312  		transA:          false,
   313  		transB:          true,
   314  	}
   315  
   316  	if colImLayer, err = ApplyOp(op, patch, flattened); err != nil {
   317  		return nil, fmt.Errorf("failed to apply op: %w", err)
   318  	}
   319  	colImLayer.groups = colImLayer.groups.Upsert(group)
   320  
   321  	// now reshape and transpose the values back into the original order
   322  	var res *Node
   323  	if res, err = Reshape(colImLayer, tensor.Shape{batch, m, n, layer}); err != nil {
   324  		return nil, fmt.Errorf("failed to reshape %v to (%v, %v, %v, %v): %w", colImLayer.Shape(), batch, m, n, layer, err)
   325  	}
   326  	res.groups = res.groups.Upsert(group)
   327  	ret, err := Transpose(res, 0, 3, 1, 2)
   328  	if err != nil {
   329  		return nil, fmt.Errorf("transpose %v failed: %w", res.Shape(), err)
   330  	}
   331  
   332  	ret.groups = ret.groups.Upsert(group)
   333  	return ret, nil
   334  }
   335  
   336  // Conv1d is a 1D convlution. It relies on Conv2D
   337  func Conv1d(in, filter *Node, kernel, pad, stride, dilation int) (*Node, error) {
   338  	return Conv2d(in, filter, tensor.Shape{1, kernel}, []int{0, pad}, []int{1, stride}, []int{1, dilation})
   339  }
   340  
   341  // MaxPool2D applies the kernel filter to the input node.
   342  // The pad slice can have two different lengths.
   343  //
   344  // - if len(pad) == 2, padding is assume to be symetric, and a padding is adding up *and* down to each dimension
   345  //   paddedOutputH = pad[0] + inputH + pad[0]
   346  //   paddedOutputW = pad[1] + inputW + pad[1]
   347  //
   348  // - if len(pad) == 4, padding is explicit and can be asymmetric.
   349  //   paddedOutputH = pad[0] + inputH + pad[1]
   350  //   paddedOutputW = pad[2] + inputW + pad[3]
   351  func MaxPool2D(x *Node, kernel tensor.Shape, pad, stride []int) (*Node, error) {
   352  	group := encoding.NewGroup("Maxpool")
   353  	xShape := x.Shape()
   354  	h, w := xShape[2], xShape[3]
   355  	kh, kw := kernel[0], kernel[1]
   356  
   357  	// check shape
   358  	if xShape.Dims() != 4 {
   359  		return nil, errors.Errorf("Expected input to have a shape with dimension 4")
   360  	}
   361  	if kernel.Dims() != 2 {
   362  		return nil, errors.Errorf("Expected kernel to have a shape of dimension 2")
   363  	}
   364  
   365  	// checks
   366  	for _, s := range stride {
   367  		if s <= 0 {
   368  			return nil, errors.Errorf("Cannot use strides of less than or equal 0: %v", stride)
   369  		}
   370  	}
   371  
   372  	for _, p := range pad {
   373  		if p < 0 {
   374  			return nil, errors.Errorf("Cannot use padding of less than 0: %v", pad)
   375  		}
   376  	}
   377  
   378  	padNorth := pad[0]
   379  	padWest := pad[1]
   380  	padSouth := pad[0]
   381  	padEast := pad[1]
   382  	if len(pad) == 4 {
   383  		padNorth = pad[0]
   384  		padSouth = pad[1]
   385  		padWest = pad[2]
   386  		padEast = pad[3]
   387  	}
   388  
   389  	if h-kh+padNorth+padSouth < 0 {
   390  		// error
   391  		return nil, errors.New("Impossible height/kernel/pad combination")
   392  	}
   393  
   394  	if w-kw+padWest+padEast < 0 {
   395  		// error
   396  		return nil, errors.New("Impossible width/kernel/pad combination")
   397  	}
   398  
   399  	op := newMaxPoolOp(xShape, kernel, pad, stride)
   400  	retVal, err := ApplyOp(op, x)
   401  	retVal.groups = retVal.groups.Upsert(group)
   402  	return retVal, err
   403  }
   404  
   405  // MaxPool1D applies a maxpool on the node x.
   406  func MaxPool1D(x *Node, kernel, pad, stride int) (*Node, error) {
   407  	return MaxPool2D(x, tensor.Shape{1, kernel}, []int{0, pad}, []int{1, stride})
   408  }
   409  
   410  // BatchNorm applies a batchnormalization. This operator can be used in forward pass or for training.
   411  // In an evaluation only, the "op" output can be discared.
   412  // In training phase, γ, β can be discarded and the op should be used.
   413  // Input must be a matrix with shape (B, N) or a 4d tensor with shape (B, C, W, H)
   414  func BatchNorm(x, scale, bias *Node, momentum, epsilon float64) (retVal, γ, β *Node, op *BatchNormOp, err error) {
   415  	dt, err := dtypeOf(x.Type())
   416  	if err != nil {
   417  		return nil, nil, nil, nil, err
   418  	}
   419  	batches := x.Shape()[0]
   420  	channels := x.Shape()[1]
   421  	spatialDim := x.Shape().TotalSize() / (channels * batches)
   422  
   423  	mean := tensor.New(tensor.Of(dt), tensor.WithShape(channels))
   424  	variance := tensor.New(tensor.Of(dt), tensor.WithShape(channels))
   425  	ma := tensor.New(tensor.Of(dt), tensor.WithShape(1))
   426  
   427  	meanTmp := tensor.New(tensor.Of(dt), tensor.WithShape(channels))
   428  	varianceTmp := tensor.New(tensor.Of(dt), tensor.WithShape(channels))
   429  	tmp := tensor.New(tensor.Of(dt), tensor.WithShape(x.Shape().Clone()...))
   430  	xNorm := tensor.New(tensor.Of(dt), tensor.WithShape(x.Shape().Clone()...))
   431  	batchSumMultiplier := tensor.New(tensor.Of(dt), tensor.WithShape(batches))
   432  
   433  	var uno interface{}
   434  	switch dt {
   435  	case Float64:
   436  		uno = float64(1)
   437  	case Float32:
   438  		uno = float32(1)
   439  	}
   440  	spatialSumMultiplier := tensor.New(tensor.Of(dt), tensor.WithShape(spatialDim))
   441  	if err = spatialSumMultiplier.Memset(uno); err != nil {
   442  		return nil, nil, nil, nil, err
   443  	}
   444  
   445  	numByChans := tensor.New(tensor.Of(dt), tensor.WithShape(channels*batches))
   446  	if err = batchSumMultiplier.Memset(uno); err != nil {
   447  		return nil, nil, nil, nil, err
   448  	}
   449  
   450  	op = &BatchNormOp{
   451  		momentum: momentum,
   452  		epsilon:  epsilon,
   453  
   454  		mean:     mean,
   455  		variance: variance,
   456  		ma:       ma,
   457  
   458  		meanTmp:              meanTmp,
   459  		varianceTmp:          varianceTmp,
   460  		tmpSpace:             tmp,
   461  		xNorm:                xNorm,
   462  		batchSumMultiplier:   batchSumMultiplier,
   463  		numByChans:           numByChans,
   464  		spatialSumMultiplier: spatialSumMultiplier,
   465  
   466  		training: true,
   467  		dims:     x.Dims(),
   468  	}
   469  	g := x.Graph()
   470  	dims := x.Shape().Dims()
   471  
   472  	if scale == nil {
   473  		scale = NewTensor(g, dt, dims, WithShape(x.Shape().Clone()...), WithName(x.Name()+"_γ"), WithInit(GlorotN(1.0)))
   474  	}
   475  	if bias == nil {
   476  		bias = NewTensor(g, dt, dims, WithShape(x.Shape().Clone()...), WithName(x.Name()+"_β"), WithInit(GlorotN(1.0)))
   477  	}
   478  
   479  	if retVal, err = ApplyOp(op, x); err != nil {
   480  		return nil, nil, nil, nil, err
   481  	}
   482  	if retVal, err = Auto(BroadcastHadamardProd, scale, retVal); err != nil {
   483  		return nil, nil, nil, nil, err
   484  	}
   485  	retVal, err = Auto(BroadcastAdd, retVal, bias)
   486  
   487  	return retVal, scale, bias, op, err
   488  }
   489  
   490  // GlobalAveragePool2D consumes an input tensor X and applies average pooling across the values in the same channel.
   491  // The expected input shape is BCHW where B is the batch size, C is the number of channels, and H and W are the height and the width of the data.
   492  func GlobalAveragePool2D(x *Node) (*Node, error) {
   493  	return ApplyOp(&globalAveragePoolOp{}, x)
   494  }