gorgonia.org/gorgonia@v0.9.17/ops/nn/dropout_cuda.go (about) 1 // +build cuda 2 3 package nnops 4 5 import ( 6 "fmt" 7 "hash" 8 "time" 9 "unsafe" 10 11 "github.com/chewxy/hm" 12 "github.com/pkg/errors" 13 "gorgonia.org/cu" 14 cudnn "gorgonia.org/cu/dnn" 15 t2cudnn "gorgonia.org/cu/dnn/interop" 16 "gorgonia.org/gorgonia" 17 "gorgonia.org/tensor" 18 ) 19 20 type dropout struct { 21 *cudnn.Dropout 22 seed uint64 23 xDesc *cudnn.TensorDescriptor 24 } 25 26 func newDropout(x *gorgonia.Node, prob float64) (*dropout, error) { 27 xDesc, err := t2cudnn.Describe(x) 28 if err != nil { 29 return nil, err 30 } 31 32 internal, err := cudnn.NewDropout(prob) 33 if err != nil { 34 return nil, err 35 } 36 return &dropout{ 37 Dropout: internal, 38 xDesc: xDesc, 39 seed: uint64(time.Now().UnixNano()), 40 }, nil 41 } 42 43 func (op *dropout) Arity() int { return 1 } 44 45 func (op *dropout) Type() hm.Type { 46 return hm.NewFnType(hm.TypeVariable('a'), hm.TypeVariable('a')) 47 } 48 49 func (op *dropout) InferShape(inputs ...gorgonia.DimSizer) (tensor.Shape, error) { 50 if err := checkArity(op, len(inputs)); err != nil { 51 return nil, err 52 } 53 return inputs[0].(tensor.Shape).Clone(), nil 54 } 55 56 func (op *dropout) Do(...gorgonia.Value) (gorgonia.Value, error) { panic("not implemented") } 57 func (op *dropout) ReturnsPtr() bool { return true } 58 func (op *dropout) CallsExtern() bool { return true } 59 func (op *dropout) OverwritesInput() int { return 0 } 60 func (op *dropout) WriteHash(h hash.Hash) { fmt.Fprintf(h, "Dropout %v", op.Dropout.Dropout()) } 61 func (op *dropout) Hashcode() uint32 { return simpleHash(op) } 62 func (op *dropout) String() string { return fmt.Sprintf("Dropout %v", op.Dropout.Dropout()) } 63 func (op *dropout) DiffWRT(inputs int) []bool { return []bool{true, false} } // it technically should be []bool{true, false} 64 65 func (op *dropout) SymDiff(inputs gorgonia.Nodes, output *gorgonia.Node, grad *gorgonia.Node) (retVal gorgonia.Nodes, err error) { 66 diffOp := &dropoutDiff{op} 67 retVal = make(gorgonia.Nodes, 1) // retVal[1] will be nil 68 retVal[0], err = gorgonia.ApplyOp(diffOp, grad) 69 return 70 } 71 72 func (op *dropout) DoDiff(ctx gorgonia.ExecutionContext, inputs gorgonia.Nodes, output *gorgonia.Node) error { 73 panic("not implemented") 74 } 75 76 func (op *dropout) CUDADo(extern gorgonia.External, dev gorgonia.Device, prealloc gorgonia.Value, inputs ...gorgonia.Value) (retVal gorgonia.Value, err error) { 77 if err = checkArity(op, len(inputs)); err != nil { 78 return 79 } 80 81 x := inputs[0] 82 machine := extern.(gorgonia.CUDAMachine) 83 machine.Engines()[int(dev)].DoWork() 84 ctx := machine.CUDNNContexts()[int(dev)] 85 86 var s cudnn.Memory 87 var memsize uintptr 88 if memsize, err = op.RequiredStateSize(ctx); err != nil { 89 return nil, errors.Wrap(err, "Unable to get required state size for Dropout") 90 } 91 if !op.IsReady() { 92 // var x cu.DevicePtr 93 // machine.Engines()[int(dev)].DoWork() 94 // if x, err = machine.Contexts()[int(dev)].MemAlloc(int64(memsize)); err != nil { 95 // return nil, errors.Wrapf(err, "Unable to allocate %v bytes of memory of scratch space for Dropout", memsize) 96 // } 97 98 x, err := machine.Engines()[int(dev)].Get(int64(memsize)) 99 if err != nil { 100 return nil, errors.Wrapf(err, "Unable to allocate %v bytes of memory of scratch space for Dropout", memsize) 101 } 102 103 s = tmpWrapper(x.(cu.DevicePtr)) 104 // s = x.(cudnn.Memory) 105 if err = op.Use(ctx, s, memsize, op.seed); err != nil { 106 return nil, errors.Wrapf(err, "Unable to set dropout to use context %v", ctx) 107 } 108 } else { 109 s = op.States() 110 } 111 112 err = ctx.DropoutForward(op.Dropout, op.xDesc, x.(cudnn.Memory), op.xDesc, prealloc.(cudnn.Memory), s, memsize) 113 return prealloc, err 114 } 115 116 type dropoutDiff struct { 117 *dropout 118 } 119 120 func (op *dropoutDiff) Arity() int { return 1 } 121 122 func (op *dropoutDiff) Type() hm.Type { 123 return hm.NewFnType(hm.TypeVariable('a'), hm.TypeVariable('a')) 124 } 125 126 func (op *dropoutDiff) InferShape(inputs ...gorgonia.DimSizer) (tensor.Shape, error) { 127 if err := checkArity(op, len(inputs)); err != nil { 128 return nil, err 129 } 130 return inputs[0].(tensor.Shape).Clone(), nil 131 } 132 133 func (op *dropoutDiff) Do(...gorgonia.Value) (gorgonia.Value, error) { 134 panic("not implemented") 135 } 136 137 func (op *dropoutDiff) ReturnsPtr() bool { return true } 138 139 func (op *dropoutDiff) CallsExtern() bool { return true } 140 141 func (op *dropoutDiff) OverwritesInput() int { return -1 } 142 143 func (op *dropoutDiff) WriteHash(h hash.Hash) { fmt.Fprintf(h, "DropoutDiff %v", op.Dropout.Dropout()) } 144 145 func (op *dropoutDiff) Hashcode() uint32 { return simpleHash(op) } 146 147 func (op *dropoutDiff) String() string { return fmt.Sprintf("DropoutDiff %v", op.Dropout.Dropout()) } 148 149 func (op *dropoutDiff) CUDADo(extern gorgonia.External, dev gorgonia.Device, prealloc gorgonia.Value, inputs ...gorgonia.Value) (retVal gorgonia.Value, err error) { 150 if err = checkArity(op, len(inputs)); err != nil { 151 return 152 } 153 154 dy := inputs[0] 155 machine := extern.(gorgonia.CUDAMachine) 156 machine.Engines()[int(dev)].DoWork() 157 ctx := machine.CUDNNContexts()[int(dev)] 158 159 if !op.IsReady() { 160 return nil, errors.New("OP is not ready") 161 } 162 163 scratch := op.States() 164 memsize, _ := op.RequiredStateSize(ctx) 165 166 err = ctx.DropoutBackward(op.Dropout, op.xDesc, dy.(cudnn.Memory), 167 op.xDesc, prealloc.(cudnn.Memory), scratch, memsize) 168 return prealloc, err 169 } 170 171 type tmpWrapper cu.DevicePtr 172 173 func (p tmpWrapper) Uintptr() uintptr { return cu.DevicePtr(p).Uintptr() } 174 175 func (p tmpWrapper) Pointer() unsafe.Pointer { return unsafe.Pointer(cu.DevicePtr(p).Uintptr()) } 176 177 func (p tmpWrapper) IsNativelyAccessible() bool { return false }