gorgonia.org/gorgonia@v0.9.17/ops/nn/convolution_cuda.go (about) 1 // +build cuda 2 3 package nnops 4 5 import ( 6 "fmt" 7 "hash" 8 9 "github.com/chewxy/hm" 10 cudnn "gorgonia.org/cu/dnn" 11 t2cudnn "gorgonia.org/cu/dnn/interop" 12 G "gorgonia.org/gorgonia" 13 "gorgonia.org/tensor" 14 ) 15 16 var ( 17 _ G.Op = &convolution{} 18 _ G.CUDADoer = &convolution{} 19 ) 20 21 type convolution struct { 22 *cudnn.Convolution 23 24 // created with these attributes 25 padding, stride, dilation []int 26 inShape, filterShape tensor.Shape 27 28 // cached descriptors 29 xDesc, yDesc *cudnn.TensorDescriptor 30 wDesc *cudnn.Filter 31 } 32 33 func makeConvolutionOp(im, filter *G.Node, kernelShape tensor.Shape, pad, stride, dilation []int) (retVal *convolution, err error) { 34 var xDesc *cudnn.TensorDescriptor 35 var wDesc *cudnn.Filter 36 if xDesc, err = t2cudnn.Describe(im); err != nil { 37 return nil, err 38 } 39 if wDesc, err = t2cudnn.DescribeAsFilter(filter, cudnn.NCHW); err != nil { 40 return nil, err 41 } 42 datatype := t2cudnn.Dtype2DataType(im.Dtype()) 43 conv, err := cudnn.NewConvolution(cudnn.DefaultMath, 1, pad, stride, dilation, cudnn.StandardConvolution, datatype) 44 if err != nil { 45 return nil, err 46 } 47 48 return &convolution{ 49 Convolution: conv, 50 padding: pad, 51 stride: stride, 52 dilation: dilation, 53 54 inShape: im.Shape().Clone(), 55 filterShape: filter.Shape().Clone(), 56 57 xDesc: xDesc, 58 wDesc: wDesc, 59 }, nil 60 } 61 62 func (c *convolution) Arity() int { return 2 } 63 64 func (c *convolution) Type() hm.Type { 65 return hm.NewFnType(hm.TypeVariable('a'), hm.TypeVariable('a'), hm.TypeVariable('a')) 66 } 67 68 func (c *convolution) InferShape(inputs ...G.DimSizer) (retVal tensor.Shape, err error) { 69 if err = checkArity(c, len(inputs)); err != nil { 70 return 71 } 72 return c.ForwardOutputShape(c.xDesc, c.wDesc, 2) //only conv2d is supported now 73 } 74 75 func (c *convolution) Do(inputs ...G.Value) (retVal G.Value, err error) { 76 panic("not implemented") 77 } 78 79 func (c *convolution) ReturnsPtr() bool { return true } 80 81 func (c *convolution) CallsExtern() bool { return true } 82 83 func (c *convolution) OverwritesInput() int { return -1 } 84 85 func (c *convolution) WriteHash(h hash.Hash) { 86 fmt.Fprintf(h, "Convolution:%v-%v-%v", c.Padding(), c.FilterStride(), c.Dilation()) 87 } 88 89 func (c *convolution) Hashcode() uint32 { return simpleHash(c) } 90 91 func (c *convolution) String() string { 92 return fmt.Sprintf("Convolution:%v-%v-%v", c.Padding(), c.FilterStride(), c.Dilation()) 93 } 94 95 func (c *convolution) CUDADo(extern G.External, dev G.Device, prealloc G.Value, inputs ...G.Value) (retVal G.Value, err error) { 96 if err = checkArity(c, len(inputs)); err != nil { 97 return 98 } 99 im, filter := inputs[0], inputs[1] 100 101 if c.yDesc == nil { 102 if c.yDesc, err = t2cudnn.Describe(prealloc.(tensor.Tensor)); err != nil { 103 return 104 } 105 } 106 107 machine := extern.(G.CUDAMachine) 108 machine.Engines()[int(dev)].DoWork() 109 ctx := machine.CUDNNContexts()[int(dev)] 110 111 if err = ctx.ConvolutionForward(1.0, 112 c.xDesc, im.(cudnn.Memory), 113 c.wDesc, filter.(cudnn.Memory), 114 c.Convolution, 115 cudnn.ConvolutionFwdAlgoImplicitGemm, nomem{}, 116 0, 1.0, 117 c.yDesc, prealloc.(cudnn.Memory)); err != nil { 118 return 119 } 120 return prealloc, nil 121 } 122 123 func (c *convolution) DoDiff(ctx G.ExecutionContext, inputs G.Nodes, output *G.Node) error { 124 panic("not implemented") 125 } 126 127 func (c *convolution) DiffWRT(inputs int) []bool { 128 return []bool{true, true} 129 } 130 131 func (c *convolution) SymDiff(inputs G.Nodes, output *G.Node, grad *G.Node) (retVal G.Nodes, err error) { 132 var outDesc *cudnn.TensorDescriptor 133 if outDesc, err = t2cudnn.Describe(output); err != nil { 134 return nil, err 135 } 136 diffIm := &convDiffIm{ 137 convolution: c, 138 outputDesc: outDesc, 139 } 140 diffFilter := &convDiffFilter{ 141 convolution: c, 142 outputDesc: outDesc, 143 } 144 145 retVal = make(G.Nodes, 2) 146 if retVal[0], err = G.ApplyOp(diffIm, inputs[0], grad); err != nil { 147 return nil, err 148 } 149 if retVal[1], err = G.ApplyOp(diffFilter, inputs[1], grad); err != nil { 150 return nil, err 151 } 152 153 return 154 } 155 156 // convDiffIm is the d(z)/d(im) operation. See also convDiffFilter 157 type convDiffIm struct { 158 *convolution 159 outputDesc *cudnn.TensorDescriptor 160 } 161 162 func (c *convDiffIm) Arity() int { return 2 } 163 164 func (c *convDiffIm) Type() hm.Type { 165 return hm.NewFnType(hm.TypeVariable('a'), hm.TypeVariable('a'), hm.TypeVariable('a')) 166 } 167 168 func (c *convDiffIm) InferShape(shps ...G.DimSizer) (tensor.Shape, error) { 169 return c.inShape.Clone(), nil 170 } 171 172 func (c *convDiffIm) Do(...G.Value) (G.Value, error) { 173 panic("not implemented") 174 } 175 176 func (c *convDiffIm) ReturnsPtr() bool { return true } 177 178 func (c *convDiffIm) CallsExtern() bool { return true } 179 180 func (c *convDiffIm) OverwritesInput() int { return -1 } 181 182 func (c *convDiffIm) WriteHash(h hash.Hash) { 183 fmt.Fprintf(h, "ConvolutionImDiff:%v-%v-%v", c.Padding(), c.FilterStride(), c.Dilation()) 184 } 185 186 func (c *convDiffIm) Hashcode() uint32 { return simpleHash(c) } 187 188 func (c *convDiffIm) String() string { 189 return fmt.Sprintf("ConvolutionImDiff:%v-%v-%v", c.Padding(), c.FilterStride(), c.Dilation()) 190 } 191 192 func (c *convDiffIm) CUDADo(extern G.External, dev G.Device, prealloc G.Value, inputs ...G.Value) (retVal G.Value, err error) { 193 if err = checkArity(c, len(inputs)); err != nil { 194 return 195 } 196 filter, grad := inputs[0], inputs[1] 197 198 machine := extern.(G.CUDAMachine) 199 ctx := machine.CUDNNContexts()[int(dev)] 200 201 if err = ctx.ConvolutionBackwardData(1.0, 202 c.wDesc, filter.(cudnn.Memory), 203 c.outputDesc, grad.(cudnn.Memory), 204 c.Convolution, 205 cudnn.ConvolutionBwdDataAlgo0, nomem{}, 206 0, 1.0, 207 c.xDesc, prealloc.(cudnn.Memory)); err != nil { 208 return 209 } 210 return prealloc, nil 211 } 212 213 type convDiffFilter struct { 214 *convolution // shared struct as convDiffIm 215 outputDesc *cudnn.TensorDescriptor // shared output descriptor with convDiffIm 216 } 217 218 func (c *convDiffFilter) Arity() int { return 2 } 219 220 func (c *convDiffFilter) Type() hm.Type { 221 return hm.NewFnType(hm.TypeVariable('a'), hm.TypeVariable('a'), hm.TypeVariable('a')) 222 } 223 224 func (c *convDiffFilter) InferShape(...G.DimSizer) (tensor.Shape, error) { 225 return c.filterShape.Clone(), nil 226 } 227 228 func (c *convDiffFilter) Do(...G.Value) (G.Value, error) { 229 panic("not implemented") 230 } 231 232 func (c *convDiffFilter) ReturnsPtr() bool { return true } 233 234 func (c *convDiffFilter) CallsExtern() bool { return true } 235 236 func (c *convDiffFilter) OverwritesInput() int { return -1 } 237 238 func (c *convDiffFilter) WriteHash(h hash.Hash) { 239 fmt.Fprintf(h, "ConvolutionFilterDiff:%v-%v-%v", c.Padding(), c.FilterStride(), c.Dilation()) 240 } 241 242 func (c *convDiffFilter) Hashcode() uint32 { return simpleHash(c) } 243 244 func (c *convDiffFilter) String() string { 245 return fmt.Sprintf("ConvolutionFilterDiff:%v-%v-%v", c.Padding(), c.FilterStride(), c.Dilation()) 246 } 247 248 func (c *convDiffFilter) CUDADo(extern G.External, dev G.Device, prealloc G.Value, inputs ...G.Value) (retVal G.Value, err error) { 249 if err = checkArity(c, len(inputs)); err != nil { 250 return 251 } 252 im, grad := inputs[0], inputs[1] 253 254 machine := extern.(G.CUDAMachine) 255 ctx := machine.CUDNNContexts()[int(dev)] 256 257 if err = ctx.ConvolutionBackwardFilter(1.0, 258 c.xDesc, im.(cudnn.Memory), 259 c.outputDesc, grad.(cudnn.Memory), 260 c.Convolution, 261 cudnn.ConvolutionBwdFilterAlgo0, nomem{}, 262 0, 1.0, 263 c.wDesc, prealloc.(cudnn.Memory)); err != nil { 264 return 265 } 266 return prealloc, nil 267 }