gorgonia.org/gorgonia@v0.9.17/ops/nn/activation_cuda.go (about)

     1  // +build cuda
     2  
     3  package nnops
     4  
     5  import (
     6  	"fmt"
     7  	"hash"
     8  
     9  	"github.com/chewxy/hm"
    10  	"gorgonia.org/cu/dnn"
    11  	t2cudnn "gorgonia.org/cu/dnn/interop"
    12  	"gorgonia.org/gorgonia"
    13  	"gorgonia.org/tensor"
    14  )
    15  
    16  type activation struct {
    17  	*cudnn.Activation
    18  	xDesc, yDesc *cudnn.TensorDescriptor
    19  }
    20  
    21  func newRelu() (*activation, error) {
    22  	act, err := cudnn.NewActivation(cudnn.ReLU, cudnn.PropagateNan, 1.0)
    23  	if err != nil {
    24  		return nil, err
    25  	}
    26  	return &activation{Activation: act}, nil
    27  }
    28  
    29  func (op *activation) Arity() int { return 1 }
    30  
    31  func (op *activation) Type() hm.Type {
    32  	return hm.NewFnType(hm.TypeVariable('a'), hm.TypeVariable('a'))
    33  }
    34  
    35  func (op *activation) InferShape(inputs ...gorgonia.DimSizer) (tensor.Shape, error) {
    36  	if err := checkArity(op, len(inputs)); err != nil {
    37  		return nil, err
    38  	}
    39  	return inputs[0].(tensor.Shape).Clone(), nil
    40  }
    41  
    42  func (op *activation) Do(...gorgonia.Value) (gorgonia.Value, error) {
    43  	panic("not implemented")
    44  }
    45  
    46  func (op *activation) ReturnsPtr() bool { return true }
    47  
    48  func (op *activation) CallsExtern() bool { return true }
    49  
    50  func (op *activation) OverwritesInput() int { return -1 }
    51  
    52  func (op *activation) WriteHash(h hash.Hash) { fmt.Fprintf(h, "%v", op.Activation.Mode()) }
    53  
    54  func (op *activation) Hashcode() uint32 { return simpleHash(op) }
    55  
    56  func (op *activation) String() string { return fmt.Sprintf("%v", op.Activation.Mode()) }
    57  
    58  func (op *activation) DiffWRT(inputs int) []bool { return []bool{true} }
    59  
    60  func (op *activation) SymDiff(inputs gorgonia.Nodes, output *gorgonia.Node, grad *gorgonia.Node) (retVal gorgonia.Nodes, err error) {
    61  	if err = checkArity(op, len(inputs)); err != nil {
    62  		return
    63  	}
    64  
    65  	diffOp := &activationDiff{activation: op}
    66  
    67  	retVal = make(gorgonia.Nodes, 1)
    68  	retVal[0], err = gorgonia.ApplyOp(diffOp, inputs[0], output, grad)
    69  	return
    70  }
    71  
    72  func (op *activation) CUDADo(extern gorgonia.External, dev gorgonia.Device, prealloc gorgonia.Value, inputs ...gorgonia.Value) (retVal gorgonia.Value, err error) {
    73  	if err = checkArity(op, len(inputs)); err != nil {
    74  		return
    75  	}
    76  
    77  	x := inputs[0]
    78  
    79  	if op.xDesc == nil {
    80  		if op.xDesc, err = t2cudnn.Describe(x.(tensor.Tensor)); err != nil {
    81  			return
    82  		}
    83  	}
    84  	if op.yDesc == nil {
    85  		if op.yDesc, err = t2cudnn.Describe(prealloc.(tensor.Tensor)); err != nil {
    86  			return
    87  		}
    88  	}
    89  
    90  	machine := extern.(gorgonia.CUDAMachine)
    91  	ctx := machine.CUDNNContexts()[int(dev)]
    92  	err = ctx.ActivationForward(op.Activation, 1, op.xDesc, x.(cudnn.Memory), 0, op.yDesc, prealloc.(cudnn.Memory))
    93  	return prealloc, err
    94  }
    95  
    96  type activationDiff struct {
    97  	*activation
    98  	dyDesc, dxDesc *cudnn.TensorDescriptor
    99  }
   100  
   101  func (op *activationDiff) Arity() int {
   102  	return 3 // x, y, dy, dx
   103  }
   104  
   105  func (op *activationDiff) Type() hm.Type {
   106  	return hm.NewFnType(hm.TypeVariable('a'), hm.TypeVariable('a'), hm.TypeVariable('a'), hm.TypeVariable('a'))
   107  }
   108  
   109  func (op *activationDiff) InferShape(inputs ...gorgonia.DimSizer) (tensor.Shape, error) {
   110  	if err := checkArity(op, len(inputs)); err != nil {
   111  		return nil, err
   112  	}
   113  	return inputs[0].(tensor.Shape).Clone(), nil
   114  }
   115  
   116  func (op *activationDiff) Do(...gorgonia.Value) (gorgonia.Value, error) { panic("not implemented") }
   117  
   118  func (op *activationDiff) ReturnsPtr() bool { return true }
   119  
   120  func (op *activationDiff) CallsExtern() bool { return true }
   121  
   122  func (op *activationDiff) OverwritesInput() int { return -1 }
   123  
   124  func (op *activationDiff) WriteHash(h hash.Hash) { fmt.Fprintf(h, "DIFF%v", op.Activation.Mode()) }
   125  
   126  func (op *activationDiff) Hashcode() uint32 { return simpleHash(op) }
   127  
   128  func (op *activationDiff) String() string { return fmt.Sprintf("DIFF %v", op.Activation.Mode()) }
   129  
   130  func (op *activationDiff) CUDADo(extern gorgonia.External, dev gorgonia.Device, prealloc gorgonia.Value, inputs ...gorgonia.Value) (retVal gorgonia.Value, err error) {
   131  	x, y, dy := inputs[0], inputs[1], inputs[2]
   132  	if op.dxDesc == nil {
   133  		if op.dxDesc, err = t2cudnn.Describe(prealloc.(tensor.Tensor)); err != nil {
   134  			return
   135  		}
   136  	}
   137  	if op.dyDesc == nil {
   138  		if op.dyDesc, err = t2cudnn.Describe(dy.(tensor.Tensor)); err != nil {
   139  			return
   140  		}
   141  	}
   142  	machine := extern.(gorgonia.CUDAMachine)
   143  	machine.Engines()[int(dev)].DoWork()
   144  	ctx := machine.CUDNNContexts()[int(dev)]
   145  
   146  	err = ctx.ActivationBackward(op.Activation, 1,
   147  		op.yDesc, y.(cudnn.Memory),
   148  		op.dyDesc, dy.(cudnn.Memory),
   149  		op.xDesc, x.(cudnn.Memory),
   150  		0,
   151  		op.dxDesc, prealloc.(cudnn.Memory))
   152  	return prealloc, err
   153  }