gorgonia.org/gorgonia@v0.9.17/ops/nn/batchnorm_cuda.go (about) 1 // +build cuda 2 3 package nnops 4 5 import ( 6 "fmt" 7 "hash" 8 9 "github.com/chewxy/hm" 10 "gorgonia.org/cu/dnn" 11 t2cudnn "gorgonia.org/cu/dnn/interop" 12 "gorgonia.org/gorgonia" 13 "gorgonia.org/tensor" 14 ) 15 16 type BatchNormOp struct { 17 mode cudnn.BatchNormMode 18 momentum, epsilon float64 19 20 xDesc *cudnn.TensorDescriptor 21 bnScratch *cudnn.TensorDescriptor 22 23 training bool 24 } 25 26 func newBatchNormOp(momentum, epsilon float64) *BatchNormOp { 27 return &BatchNormOp{ 28 mode: cudnn.PerActivation, 29 momentum: momentum, 30 epsilon: epsilon, 31 training: true, 32 } 33 } 34 35 func (op *BatchNormOp) Arity() int { return 7 } 36 37 func (op *BatchNormOp) Type() hm.Type { 38 t := gorgonia.TensorType{Dims: 4, Of: hm.TypeVariable('a')} 39 return hm.NewFnType(t, // x 40 t, // scale 41 t, // bias 42 t, // running mean / expected mean 43 t, // running var / expected var 44 t, // cached mean 45 t, // cachedVar 46 t) // retVal 47 } 48 49 func (op *BatchNormOp) InferShape(inputs ...gorgonia.DimSizer) (tensor.Shape, error) { 50 if err := checkArity(op, len(inputs)); err != nil { 51 return nil, err 52 } 53 return inputs[0].(tensor.Shape).Clone(), nil 54 } 55 56 func (op *BatchNormOp) Do(...gorgonia.Value) (gorgonia.Value, error) { panic("not implemented") } 57 func (op *BatchNormOp) ReturnsPtr() bool { return true } 58 func (op *BatchNormOp) CallsExtern() bool { return true } 59 func (op *BatchNormOp) OverwritesInput() int { return -1 } 60 func (op *BatchNormOp) WriteHash(h hash.Hash) { 61 fmt.Fprintf(h, "BatchNorm %v %v", op.momentum, op.epsilon) 62 } 63 func (op *BatchNormOp) Hashcode() uint32 { return simpleHash(op) } 64 func (op *BatchNormOp) String() string { return fmt.Sprintf("BatchNorm %v %v", op.momentum, op.epsilon) } 65 66 func (op *BatchNormOp) CUDADo(extern gorgonia.External, dev gorgonia.Device, prealloc gorgonia.Value, inputs ...gorgonia.Value) (retVal gorgonia.Value, err error) { 67 // panic("not implemented") 68 69 machine := extern.(gorgonia.CUDAMachine) 70 ctx := machine.CUDNNContexts()[int(dev)] 71 72 x, bnScale, bnBias, mean, variance, cachedMean, cachedVar := inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], inputs[5], inputs[6] 73 if op.xDesc == nil { 74 if op.xDesc, err = t2cudnn.Describe(x.(tensor.Tensor)); err != nil { 75 return 76 } 77 } 78 79 if op.bnScratch == nil { 80 if op.bnScratch, err = t2cudnn.Describe(mean.(tensor.Tensor)); err != nil { 81 return 82 } 83 } 84 85 alpha := 0.0 86 beta := 1.0 87 if op.training { 88 err = ctx.BatchNormalizationForwardTraining(op.mode, alpha, beta, 89 op.xDesc, x.(cudnn.Memory), 90 op.xDesc, prealloc.(cudnn.Memory), // yDesc, y 91 op.bnScratch, 92 bnScale.(cudnn.Memory), 93 bnBias.(cudnn.Memory), 94 op.momentum, 95 mean.(cudnn.Memory), // runniing mean 96 variance.(cudnn.Memory), // running variance 97 op.epsilon, 98 cachedMean.(cudnn.Memory), 99 cachedVar.(cudnn.Memory), 100 ) 101 } else { 102 err = ctx.BatchNormalizationForwardInference(op.mode, alpha, beta, 103 op.xDesc, x.(cudnn.Memory), 104 op.xDesc, prealloc.(cudnn.Memory), 105 op.bnScratch, 106 bnScale.(cudnn.Memory), 107 bnBias.(cudnn.Memory), 108 mean.(cudnn.Memory), // expected mean 109 variance.(cudnn.Memory), // expected variance 110 op.epsilon) 111 } 112 return prealloc, err 113 } 114 115 func (op *BatchNormOp) DiffWRT(inputs int) []bool { 116 return []bool{true, true, true, false, false, false, false} 117 } 118 119 func (op *BatchNormOp) SymDiff(inputs gorgonia.Nodes, output *gorgonia.Node, grad *gorgonia.Node) (retVal gorgonia.Nodes, err error) { 120 x, scale, bias := inputs[0], inputs[1], inputs[2] 121 cachedMean, cachedVar := inputs[5], inputs[6] 122 dy := grad // rename for simplicity of reading 123 124 // create new nodes for the diffs 125 g := x.Graph() 126 dt := x.Dtype() 127 scaleScratch := &scratchOp{scale.Shape().Clone(), dt, scale.Name() + "Diff"} 128 biasScratch := &scratchOp{bias.Shape().Clone(), dt, bias.Name() + "Diff"} 129 dscale := gorgonia.NewTensor(g, dt, scale.Shape().Dims(), gorgonia.WithOp(scaleScratch)) 130 dbias := gorgonia.NewTensor(g, dt, bias.Shape().Dims(), gorgonia.WithOp(biasScratch)) 131 132 retVal = make(gorgonia.Nodes, 7) 133 134 diffOp := &batchNormDiffOp{op} 135 retVal[0], err = gorgonia.ApplyOp(diffOp, x, scale, dscale, dbias, dy, cachedMean, cachedVar) 136 retVal[1] = dscale 137 retVal[2] = dbias 138 gorgonia.SetDerivOf(dscale, scale) 139 gorgonia.SetDerivOf(dbias, bias) 140 141 return retVal, err 142 } 143 144 func (op *BatchNormOp) DoDiff(ctx gorgonia.ExecutionContext, inputs gorgonia.Nodes, output *gorgonia.Node) error { 145 panic("not implemented") 146 } 147 148 func (op *BatchNormOp) SetTraining() { op.training = true } 149 func (op *BatchNormOp) SetTesting() { op.training = false } 150 func (op *BatchNormOp) Reset() error { return nil } 151 152 type batchNormDiffOp struct { 153 *BatchNormOp 154 } 155 156 // Arity is the same exact function as BatchNormOp (7) 157 158 // Type is exactly the same as BatchNormOp, but the semantics are different: 159 // return hm.NewFnType( 160 // t, // x 161 // t, // scale 162 // t, // dscale 163 // t, // dbias 164 // t, // dy 165 // t, // cachedMean 166 // t, // cachedVar 167 // t // retVal 168 // ) 169 170 // InferShape is the same exact function as BatchNormOp 171 172 func (op *batchNormDiffOp) Do(...gorgonia.Value) (gorgonia.Value, error) { 173 panic("not implemented") 174 } 175 176 func (op *batchNormDiffOp) ReturnsPtr() bool { return true } 177 178 func (op *batchNormDiffOp) CallsExtern() bool { return true } 179 180 func (op *batchNormDiffOp) OverwritesInput() int { return -1 } 181 182 func (op *batchNormDiffOp) WriteHash(h hash.Hash) { 183 fmt.Fprintf(h, "BatchNormDiff %v %v", op.momentum, op.epsilon) 184 } 185 186 // HashCode is exactly the same as BatchNormOp 187 188 func (op *batchNormDiffOp) String() string { 189 return fmt.Sprintf("BatchNormDiff %v %v", op.momentum, op.epsilon) 190 } 191 192 func (op *batchNormDiffOp) CUDADo(extern gorgonia.External, dev gorgonia.Device, prealloc gorgonia.Value, inputs ...gorgonia.Value) (retVal gorgonia.Value, err error) { 193 machine := extern.(gorgonia.CUDAMachine) 194 e := &machine.Engines()[int(dev)] 195 ctx := machine.CUDNNContexts()[int(dev)] 196 197 x, scale, dscale, dbias, dy, cachedMean, cachedVariance := inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], inputs[5], inputs[6] 198 dscale = gorgonia.ScalarAsTensor(dscale, 4, e) 199 dbias = gorgonia.ScalarAsTensor(dbias, 4, e) 200 201 alpha := 0.0 202 beta := 1.0 203 err = ctx.BatchNormalizationBackward(op.mode, 204 alpha, beta, // for data 205 alpha, beta, // for param 206 op.xDesc, 207 x.(cudnn.Memory), 208 op.xDesc, // dyDesc 209 dy.(cudnn.Memory), 210 op.xDesc, // dxDesc 211 prealloc.(cudnn.Memory), 212 op.bnScratch, // scratch space descriptor 213 scale.(cudnn.Memory), 214 dscale.(cudnn.Memory), // deriv of scale 215 dbias.(cudnn.Memory), // deriv of bias 216 op.epsilon, 217 cachedMean.(cudnn.Memory), cachedVariance.(cudnn.Memory)) 218 return prealloc, err 219 }