gorgonia.org/gorgonia@v0.9.17/ops/nn/api_cuda.go (about)

     1  // +build cuda
     2  
     3  package nnops
     4  
     5  import (
     6  	G "gorgonia.org/gorgonia"
     7  	"gorgonia.org/tensor"
     8  )
     9  
    10  func Conv2d(im, filter *G.Node, kernelShape tensor.Shape, pad, stride, dilation []int) (retVal *G.Node, err error) {
    11  	var op *convolution
    12  	if op, err = makeConvolutionOp(im, filter, kernelShape, pad, stride, dilation); err != nil {
    13  		return nil, err
    14  	}
    15  	return G.ApplyOp(op, im, filter)
    16  }
    17  
    18  func Conv1d(in, filter *G.Node, kernel, pad, stride, dilation int) (*G.Node, error) {
    19  	return Conv2d(in, filter, tensor.Shape{1, kernel}, []int{0, pad}, []int{1, stride}, []int{1, dilation})
    20  }
    21  
    22  func MaxPool2D(x *G.Node, kernel tensor.Shape, pad, stride []int) (retVal *G.Node, err error) {
    23  	var op *maxpool
    24  	if op, err = newMaxPoolOp(x, kernel, pad, stride); err != nil {
    25  		return nil, err
    26  	}
    27  	return G.ApplyOp(op, x)
    28  }
    29  
    30  func Dropout(x *G.Node, prob float64) (retVal *G.Node, err error) {
    31  	var op *dropout
    32  	if op, err = newDropout(x, prob); err != nil {
    33  		return nil, err
    34  	}
    35  
    36  	// states := &scratchOp{x.Shape().Clone(), x.Dtype(), ""}
    37  	// m := G.NewUniqueNode(G.WithType(x.Type()), G.WithOp(states), G.In(x.Graph()), G.WithShape(states.shape...))
    38  
    39  	retVal, err = G.ApplyOp(op, x)
    40  	return
    41  }
    42  
    43  func Rectify(x *G.Node) (retVal *G.Node, err error) {
    44  	var op *activation
    45  	if op, err = newRelu(); err != nil {
    46  		return nil, err
    47  	}
    48  	retVal, err = G.ApplyOp(op, x)
    49  	return
    50  }
    51  
    52  func BatchNorm(x, scale, bias *G.Node, momentum, epsilon float64) (retVal, γ, β *G.Node, op *BatchNormOp, err error) {
    53  	dt, err := dtypeOf(x.Type())
    54  	if err != nil {
    55  		return nil, nil, nil, nil, err
    56  	}
    57  
    58  	// batches := x.Shape()[0]
    59  	channels := x.Shape()[1]
    60  	H, W := x.Shape()[2], x.Shape()[3]
    61  	// spatialDim := x.Shape().TotalSize() / (channels * batches)
    62  	scratchShape := tensor.Shape{1, channels, H, W}
    63  
    64  	// scaleScratch := &scratchOp{x.Shape().Clone(), dt, "scale"}
    65  	// biasScratch := &scratchOp{x.Shape().Clone(), dt, "bias"}
    66  	meanScratch := &gpuScratchOp{scratchOp{x.Shape().Clone(), dt, "mean"}}
    67  	varianceScratch := &gpuScratchOp{scratchOp{x.Shape().Clone(), dt, "variance"}}
    68  	cacheMeanScratch := &gpuScratchOp{scratchOp{scratchShape, dt, "cacheMean"}}
    69  	cacheVarianceScratch := &gpuScratchOp{scratchOp{scratchShape, dt, "cacheVariance"}}
    70  
    71  	g := x.Graph()
    72  	dims := len(x.Shape())
    73  	mean := G.NewTensor(g, dt, dims, G.WithShape(scratchShape.Clone()...), G.WithName(x.Name()+"_mean"), G.WithOp(meanScratch))
    74  	variance := G.NewTensor(g, dt, dims, G.WithShape(scratchShape.Clone()...), G.WithName(x.Name()+"_variance"), G.WithOp(varianceScratch))
    75  	cacheMean := G.NewTensor(g, dt, dims, G.WithShape(scratchShape.Clone()...), G.WithOp(cacheMeanScratch))
    76  	cacheVariance := G.NewTensor(g, dt, dims, G.WithShape(scratchShape.Clone()...), G.WithOp(cacheVarianceScratch))
    77  
    78  	if scale == nil {
    79  		scale = G.NewTensor(g, dt, dims, G.WithShape(scratchShape.Clone()...), G.WithName(x.Name()+"_γ"), G.WithInit(G.GlorotN(1.0)))
    80  	}
    81  
    82  	if bias == nil {
    83  		bias = G.NewTensor(g, dt, dims, G.WithShape(scratchShape.Clone()...), G.WithName(x.Name()+"_β"), G.WithInit(G.GlorotN(1.0)))
    84  	}
    85  
    86  	op = newBatchNormOp(momentum, epsilon)
    87  	retVal, err = G.ApplyOp(op, x, scale, bias, mean, variance, cacheMean, cacheVariance)
    88  	return retVal, scale, bias, op, err
    89  }