gorgonia.org/gorgonia@v0.9.17/ops/nn/api_cuda.go (about) 1 // +build cuda 2 3 package nnops 4 5 import ( 6 G "gorgonia.org/gorgonia" 7 "gorgonia.org/tensor" 8 ) 9 10 func Conv2d(im, filter *G.Node, kernelShape tensor.Shape, pad, stride, dilation []int) (retVal *G.Node, err error) { 11 var op *convolution 12 if op, err = makeConvolutionOp(im, filter, kernelShape, pad, stride, dilation); err != nil { 13 return nil, err 14 } 15 return G.ApplyOp(op, im, filter) 16 } 17 18 func Conv1d(in, filter *G.Node, kernel, pad, stride, dilation int) (*G.Node, error) { 19 return Conv2d(in, filter, tensor.Shape{1, kernel}, []int{0, pad}, []int{1, stride}, []int{1, dilation}) 20 } 21 22 func MaxPool2D(x *G.Node, kernel tensor.Shape, pad, stride []int) (retVal *G.Node, err error) { 23 var op *maxpool 24 if op, err = newMaxPoolOp(x, kernel, pad, stride); err != nil { 25 return nil, err 26 } 27 return G.ApplyOp(op, x) 28 } 29 30 func Dropout(x *G.Node, prob float64) (retVal *G.Node, err error) { 31 var op *dropout 32 if op, err = newDropout(x, prob); err != nil { 33 return nil, err 34 } 35 36 // states := &scratchOp{x.Shape().Clone(), x.Dtype(), ""} 37 // m := G.NewUniqueNode(G.WithType(x.Type()), G.WithOp(states), G.In(x.Graph()), G.WithShape(states.shape...)) 38 39 retVal, err = G.ApplyOp(op, x) 40 return 41 } 42 43 func Rectify(x *G.Node) (retVal *G.Node, err error) { 44 var op *activation 45 if op, err = newRelu(); err != nil { 46 return nil, err 47 } 48 retVal, err = G.ApplyOp(op, x) 49 return 50 } 51 52 func BatchNorm(x, scale, bias *G.Node, momentum, epsilon float64) (retVal, γ, β *G.Node, op *BatchNormOp, err error) { 53 dt, err := dtypeOf(x.Type()) 54 if err != nil { 55 return nil, nil, nil, nil, err 56 } 57 58 // batches := x.Shape()[0] 59 channels := x.Shape()[1] 60 H, W := x.Shape()[2], x.Shape()[3] 61 // spatialDim := x.Shape().TotalSize() / (channels * batches) 62 scratchShape := tensor.Shape{1, channels, H, W} 63 64 // scaleScratch := &scratchOp{x.Shape().Clone(), dt, "scale"} 65 // biasScratch := &scratchOp{x.Shape().Clone(), dt, "bias"} 66 meanScratch := &gpuScratchOp{scratchOp{x.Shape().Clone(), dt, "mean"}} 67 varianceScratch := &gpuScratchOp{scratchOp{x.Shape().Clone(), dt, "variance"}} 68 cacheMeanScratch := &gpuScratchOp{scratchOp{scratchShape, dt, "cacheMean"}} 69 cacheVarianceScratch := &gpuScratchOp{scratchOp{scratchShape, dt, "cacheVariance"}} 70 71 g := x.Graph() 72 dims := len(x.Shape()) 73 mean := G.NewTensor(g, dt, dims, G.WithShape(scratchShape.Clone()...), G.WithName(x.Name()+"_mean"), G.WithOp(meanScratch)) 74 variance := G.NewTensor(g, dt, dims, G.WithShape(scratchShape.Clone()...), G.WithName(x.Name()+"_variance"), G.WithOp(varianceScratch)) 75 cacheMean := G.NewTensor(g, dt, dims, G.WithShape(scratchShape.Clone()...), G.WithOp(cacheMeanScratch)) 76 cacheVariance := G.NewTensor(g, dt, dims, G.WithShape(scratchShape.Clone()...), G.WithOp(cacheVarianceScratch)) 77 78 if scale == nil { 79 scale = G.NewTensor(g, dt, dims, G.WithShape(scratchShape.Clone()...), G.WithName(x.Name()+"_γ"), G.WithInit(G.GlorotN(1.0))) 80 } 81 82 if bias == nil { 83 bias = G.NewTensor(g, dt, dims, G.WithShape(scratchShape.Clone()...), G.WithName(x.Name()+"_β"), G.WithInit(G.GlorotN(1.0))) 84 } 85 86 op = newBatchNormOp(momentum, epsilon) 87 retVal, err = G.ApplyOp(op, x, scale, bias, mean, variance, cacheMean, cacheVariance) 88 return retVal, scale, bias, op, err 89 }