gorgonia.org/gorgonia@v0.9.17/ops/nn/maxpool_cuda.go (about) 1 // +build cuda 2 3 package nnops 4 5 import ( 6 "fmt" 7 "hash" 8 9 "github.com/chewxy/hm" 10 cudnn "gorgonia.org/cu/dnn" 11 t2cudnn "gorgonia.org/cu/dnn/interop" 12 G "gorgonia.org/gorgonia" 13 "gorgonia.org/tensor" 14 ) 15 16 var ( 17 _ G.Op = &maxpool{} 18 _ G.CUDADoer = &maxpool{} 19 _ G.Op = &maxpoolDiff{} 20 _ G.CUDADoer = &maxpoolDiff{} 21 ) 22 23 type maxpool struct { 24 *cudnn.Pooling 25 26 xDesc *cudnn.TensorDescriptor 27 yDesc *cudnn.TensorDescriptor 28 } 29 30 func newMaxPoolOp(x *G.Node, kernel, pad, stride []int) (*maxpool, error) { 31 var xDesc *cudnn.TensorDescriptor 32 var err error 33 if xDesc, err = t2cudnn.Describe(x); err != nil { 34 return nil, err 35 } 36 37 var p *cudnn.Pooling 38 if p, err = cudnn.NewPooling(cudnn.MaxPooling, cudnn.NotPropagateNan, kernel, stride, pad); err != nil { 39 return nil, err 40 } 41 return &maxpool{ 42 Pooling: p, 43 xDesc: xDesc, 44 }, nil 45 } 46 47 func (p *maxpool) Arity() int { return 1 } 48 49 func (p *maxpool) Type() hm.Type { return hm.NewFnType(hm.TypeVariable('a'), hm.TypeVariable('a')) } 50 51 func (p *maxpool) InferShape(inputs ...G.DimSizer) (tensor.Shape, error) { 52 if err := checkArity(p, len(inputs)); err != nil { 53 return nil, err 54 } 55 return p.OutputShape(p.xDesc, 2) // only maxpool2d for now 56 } 57 58 func (p *maxpool) Do(...G.Value) (G.Value, error) { 59 panic("not implemented") 60 } 61 62 func (p *maxpool) ReturnsPtr() bool { return true } 63 64 func (p *maxpool) CallsExtern() bool { return true } 65 66 func (p *maxpool) OverwritesInput() int { return -1 } 67 68 func (p *maxpool) WriteHash(h hash.Hash) { 69 xShape := p.xDesc.Shape() 70 kernel := p.Shape() 71 padding := p.Padding() 72 strides := p.Strides() 73 fmt.Fprintf(h, "MaxPool{%d, %d, %d, %d}(kernel: (%d, %d), pad: (%d, %d), stride: (%d, %d))", 74 xShape[0], xShape[1], xShape[2], xShape[3], 75 kernel[0], kernel[1], 76 padding[0], padding[1], 77 strides[0], strides[1]) 78 } 79 80 func (p *maxpool) Hashcode() uint32 { return simpleHash(p) } 81 82 func (p *maxpool) String() string { 83 xShape := p.xDesc.Shape() 84 kernel := p.Shape() 85 padding := p.Padding() 86 strides := p.Strides() 87 return fmt.Sprintf("MaxPool{%d, %d, %d, %d}(kernel: (%d, %d), pad: (%d, %d), stride: (%d, %d))", 88 xShape[0], xShape[1], xShape[2], xShape[3], 89 kernel[0], kernel[1], 90 padding[0], padding[1], 91 strides[0], strides[1]) 92 } 93 94 func (p *maxpool) CUDADo(extern G.External, dev G.Device, prealloc G.Value, inputs ...G.Value) (retVal G.Value, err error) { 95 if err = checkArity(p, len(inputs)); err != nil { 96 return 97 } 98 in := inputs[0] 99 100 if p.yDesc == nil { 101 if p.yDesc, err = t2cudnn.Describe(prealloc.(tensor.Tensor)); err != nil { 102 return 103 } 104 } 105 106 machine := extern.(G.CUDAMachine) 107 machine.Engines()[int(dev)].DoWork() 108 ctx := machine.CUDNNContexts()[int(dev)] 109 err = ctx.PoolingForward(p.Pooling, 1.0, p.xDesc, in.(cudnn.Memory), 0, p.yDesc, prealloc.(cudnn.Memory)) 110 return prealloc, err 111 } 112 113 func (p *maxpool) DiffWRT(inputs int) []bool { return []bool{true} } 114 115 func (p *maxpool) SymDiff(inputs G.Nodes, output *G.Node, grad *G.Node) (retVal G.Nodes, err error) { 116 if err = checkArity(p, len(inputs)); err != nil { 117 return 118 } 119 diff := (*maxpoolDiff)(p) 120 x := inputs[0] 121 122 retVal = make(G.Nodes, 1) 123 retVal[0], err = G.ApplyOp(diff, x, output, grad) 124 return 125 } 126 127 func (p *maxpool) DoDiff(ctx G.ExecutionContext, inputs G.Nodes, output *G.Node) error { 128 panic("not implemented") 129 } 130 131 type maxpoolDiff maxpool 132 133 func (op *maxpoolDiff) Arity() int { return 3 } 134 135 func (op *maxpoolDiff) Type() hm.Type { 136 return hm.NewFnType(hm.TypeVariable('a'), hm.TypeVariable('a'), hm.TypeVariable('a'), hm.TypeVariable('a')) 137 } 138 139 func (op *maxpoolDiff) InferShape(inputs ...G.DimSizer) (tensor.Shape, error) { 140 return inputs[0].(tensor.Shape).Clone(), nil 141 } 142 143 func (op *maxpoolDiff) Do(...G.Value) (G.Value, error) { panic("not implemented") } 144 145 func (op *maxpoolDiff) ReturnsPtr() bool { return true } 146 147 func (op *maxpoolDiff) CallsExtern() bool { return true } 148 149 func (op *maxpoolDiff) OverwritesInput() int { return -1 } 150 151 func (op *maxpoolDiff) WriteHash(h hash.Hash) { 152 xShape := op.xDesc.Shape() 153 kernel := op.Shape() 154 padding := op.Padding() 155 strides := op.Strides() 156 fmt.Fprintf(h, "MaxPoolDiff{%d, %d, %d, %d}(kernel: (%d, %d), pad: (%d, %d), stride: (%d, %d))", 157 xShape[0], xShape[1], xShape[2], xShape[3], 158 kernel[0], kernel[1], 159 padding[0], padding[1], 160 strides[0], strides[1]) 161 } 162 163 func (op *maxpoolDiff) Hashcode() uint32 { return simpleHash(op) } 164 165 func (op *maxpoolDiff) String() string { 166 xShape := op.xDesc.Shape() 167 kernel := op.Shape() 168 padding := op.Padding() 169 strides := op.Strides() 170 return fmt.Sprintf("MaxPoolDiff{%d, %d, %d, %d}(kernel: (%d, %d), pad: (%d, %d), stride: (%d, %d))", 171 xShape[0], xShape[1], xShape[2], xShape[3], 172 kernel[0], kernel[1], 173 padding[0], padding[1], 174 strides[0], strides[1]) 175 } 176 177 func (op *maxpoolDiff) CUDADo(extern G.External, dev G.Device, prealloc G.Value, inputs ...G.Value) (retVal G.Value, err error) { 178 if err = checkArity(op, len(inputs)); err != nil { 179 return 180 } 181 x, y, dy := inputs[0], inputs[1], inputs[2] 182 183 machine := extern.(G.CUDAMachine) 184 machine.Engines()[int(dev)].DoWork() 185 ctx := machine.CUDNNContexts()[int(dev)] 186 err = ctx.PoolingBackward(op.Pooling, 1.0, op.yDesc, y.(cudnn.Memory), op.yDesc, dy.(cudnn.Memory), op.xDesc, x.(cudnn.Memory), 0, op.xDesc, prealloc.(cudnn.Memory)) 187 return prealloc, err 188 }