gorgonia.org/gorgonia@v0.9.17/ops/nn/activation_cuda.go (about) 1 // +build cuda 2 3 package nnops 4 5 import ( 6 "fmt" 7 "hash" 8 9 "github.com/chewxy/hm" 10 "gorgonia.org/cu/dnn" 11 t2cudnn "gorgonia.org/cu/dnn/interop" 12 "gorgonia.org/gorgonia" 13 "gorgonia.org/tensor" 14 ) 15 16 type activation struct { 17 *cudnn.Activation 18 xDesc, yDesc *cudnn.TensorDescriptor 19 } 20 21 func newRelu() (*activation, error) { 22 act, err := cudnn.NewActivation(cudnn.ReLU, cudnn.PropagateNan, 1.0) 23 if err != nil { 24 return nil, err 25 } 26 return &activation{Activation: act}, nil 27 } 28 29 func (op *activation) Arity() int { return 1 } 30 31 func (op *activation) Type() hm.Type { 32 return hm.NewFnType(hm.TypeVariable('a'), hm.TypeVariable('a')) 33 } 34 35 func (op *activation) InferShape(inputs ...gorgonia.DimSizer) (tensor.Shape, error) { 36 if err := checkArity(op, len(inputs)); err != nil { 37 return nil, err 38 } 39 return inputs[0].(tensor.Shape).Clone(), nil 40 } 41 42 func (op *activation) Do(...gorgonia.Value) (gorgonia.Value, error) { 43 panic("not implemented") 44 } 45 46 func (op *activation) ReturnsPtr() bool { return true } 47 48 func (op *activation) CallsExtern() bool { return true } 49 50 func (op *activation) OverwritesInput() int { return -1 } 51 52 func (op *activation) WriteHash(h hash.Hash) { fmt.Fprintf(h, "%v", op.Activation.Mode()) } 53 54 func (op *activation) Hashcode() uint32 { return simpleHash(op) } 55 56 func (op *activation) String() string { return fmt.Sprintf("%v", op.Activation.Mode()) } 57 58 func (op *activation) DiffWRT(inputs int) []bool { return []bool{true} } 59 60 func (op *activation) SymDiff(inputs gorgonia.Nodes, output *gorgonia.Node, grad *gorgonia.Node) (retVal gorgonia.Nodes, err error) { 61 if err = checkArity(op, len(inputs)); err != nil { 62 return 63 } 64 65 diffOp := &activationDiff{activation: op} 66 67 retVal = make(gorgonia.Nodes, 1) 68 retVal[0], err = gorgonia.ApplyOp(diffOp, inputs[0], output, grad) 69 return 70 } 71 72 func (op *activation) CUDADo(extern gorgonia.External, dev gorgonia.Device, prealloc gorgonia.Value, inputs ...gorgonia.Value) (retVal gorgonia.Value, err error) { 73 if err = checkArity(op, len(inputs)); err != nil { 74 return 75 } 76 77 x := inputs[0] 78 79 if op.xDesc == nil { 80 if op.xDesc, err = t2cudnn.Describe(x.(tensor.Tensor)); err != nil { 81 return 82 } 83 } 84 if op.yDesc == nil { 85 if op.yDesc, err = t2cudnn.Describe(prealloc.(tensor.Tensor)); err != nil { 86 return 87 } 88 } 89 90 machine := extern.(gorgonia.CUDAMachine) 91 ctx := machine.CUDNNContexts()[int(dev)] 92 err = ctx.ActivationForward(op.Activation, 1, op.xDesc, x.(cudnn.Memory), 0, op.yDesc, prealloc.(cudnn.Memory)) 93 return prealloc, err 94 } 95 96 type activationDiff struct { 97 *activation 98 dyDesc, dxDesc *cudnn.TensorDescriptor 99 } 100 101 func (op *activationDiff) Arity() int { 102 return 3 // x, y, dy, dx 103 } 104 105 func (op *activationDiff) Type() hm.Type { 106 return hm.NewFnType(hm.TypeVariable('a'), hm.TypeVariable('a'), hm.TypeVariable('a'), hm.TypeVariable('a')) 107 } 108 109 func (op *activationDiff) InferShape(inputs ...gorgonia.DimSizer) (tensor.Shape, error) { 110 if err := checkArity(op, len(inputs)); err != nil { 111 return nil, err 112 } 113 return inputs[0].(tensor.Shape).Clone(), nil 114 } 115 116 func (op *activationDiff) Do(...gorgonia.Value) (gorgonia.Value, error) { panic("not implemented") } 117 118 func (op *activationDiff) ReturnsPtr() bool { return true } 119 120 func (op *activationDiff) CallsExtern() bool { return true } 121 122 func (op *activationDiff) OverwritesInput() int { return -1 } 123 124 func (op *activationDiff) WriteHash(h hash.Hash) { fmt.Fprintf(h, "DIFF%v", op.Activation.Mode()) } 125 126 func (op *activationDiff) Hashcode() uint32 { return simpleHash(op) } 127 128 func (op *activationDiff) String() string { return fmt.Sprintf("DIFF %v", op.Activation.Mode()) } 129 130 func (op *activationDiff) CUDADo(extern gorgonia.External, dev gorgonia.Device, prealloc gorgonia.Value, inputs ...gorgonia.Value) (retVal gorgonia.Value, err error) { 131 x, y, dy := inputs[0], inputs[1], inputs[2] 132 if op.dxDesc == nil { 133 if op.dxDesc, err = t2cudnn.Describe(prealloc.(tensor.Tensor)); err != nil { 134 return 135 } 136 } 137 if op.dyDesc == nil { 138 if op.dyDesc, err = t2cudnn.Describe(dy.(tensor.Tensor)); err != nil { 139 return 140 } 141 } 142 machine := extern.(gorgonia.CUDAMachine) 143 machine.Engines()[int(dev)].DoWork() 144 ctx := machine.CUDNNContexts()[int(dev)] 145 146 err = ctx.ActivationBackward(op.Activation, 1, 147 op.yDesc, y.(cudnn.Memory), 148 op.dyDesc, dy.(cudnn.Memory), 149 op.xDesc, x.(cudnn.Memory), 150 0, 151 op.dxDesc, prealloc.(cudnn.Memory)) 152 return prealloc, err 153 }