gorgonia.org/gorgonia@v0.9.17/ops/nn/dropout_cuda.go (about)

     1  // +build cuda
     2  
     3  package nnops
     4  
     5  import (
     6  	"fmt"
     7  	"hash"
     8  	"time"
     9  	"unsafe"
    10  
    11  	"github.com/chewxy/hm"
    12  	"github.com/pkg/errors"
    13  	"gorgonia.org/cu"
    14  	cudnn "gorgonia.org/cu/dnn"
    15  	t2cudnn "gorgonia.org/cu/dnn/interop"
    16  	"gorgonia.org/gorgonia"
    17  	"gorgonia.org/tensor"
    18  )
    19  
    20  type dropout struct {
    21  	*cudnn.Dropout
    22  	seed  uint64
    23  	xDesc *cudnn.TensorDescriptor
    24  }
    25  
    26  func newDropout(x *gorgonia.Node, prob float64) (*dropout, error) {
    27  	xDesc, err := t2cudnn.Describe(x)
    28  	if err != nil {
    29  		return nil, err
    30  	}
    31  
    32  	internal, err := cudnn.NewDropout(prob)
    33  	if err != nil {
    34  		return nil, err
    35  	}
    36  	return &dropout{
    37  		Dropout: internal,
    38  		xDesc:   xDesc,
    39  		seed:    uint64(time.Now().UnixNano()),
    40  	}, nil
    41  }
    42  
    43  func (op *dropout) Arity() int { return 1 }
    44  
    45  func (op *dropout) Type() hm.Type {
    46  	return hm.NewFnType(hm.TypeVariable('a'), hm.TypeVariable('a'))
    47  }
    48  
    49  func (op *dropout) InferShape(inputs ...gorgonia.DimSizer) (tensor.Shape, error) {
    50  	if err := checkArity(op, len(inputs)); err != nil {
    51  		return nil, err
    52  	}
    53  	return inputs[0].(tensor.Shape).Clone(), nil
    54  }
    55  
    56  func (op *dropout) Do(...gorgonia.Value) (gorgonia.Value, error) { panic("not implemented") }
    57  func (op *dropout) ReturnsPtr() bool                             { return true }
    58  func (op *dropout) CallsExtern() bool                            { return true }
    59  func (op *dropout) OverwritesInput() int                         { return 0 }
    60  func (op *dropout) WriteHash(h hash.Hash)                        { fmt.Fprintf(h, "Dropout %v", op.Dropout.Dropout()) }
    61  func (op *dropout) Hashcode() uint32                             { return simpleHash(op) }
    62  func (op *dropout) String() string                               { return fmt.Sprintf("Dropout %v", op.Dropout.Dropout()) }
    63  func (op *dropout) DiffWRT(inputs int) []bool                    { return []bool{true, false} } // it technically should be []bool{true, false}
    64  
    65  func (op *dropout) SymDiff(inputs gorgonia.Nodes, output *gorgonia.Node, grad *gorgonia.Node) (retVal gorgonia.Nodes, err error) {
    66  	diffOp := &dropoutDiff{op}
    67  	retVal = make(gorgonia.Nodes, 1) // retVal[1] will be nil
    68  	retVal[0], err = gorgonia.ApplyOp(diffOp, grad)
    69  	return
    70  }
    71  
    72  func (op *dropout) DoDiff(ctx gorgonia.ExecutionContext, inputs gorgonia.Nodes, output *gorgonia.Node) error {
    73  	panic("not implemented")
    74  }
    75  
    76  func (op *dropout) CUDADo(extern gorgonia.External, dev gorgonia.Device, prealloc gorgonia.Value, inputs ...gorgonia.Value) (retVal gorgonia.Value, err error) {
    77  	if err = checkArity(op, len(inputs)); err != nil {
    78  		return
    79  	}
    80  
    81  	x := inputs[0]
    82  	machine := extern.(gorgonia.CUDAMachine)
    83  	machine.Engines()[int(dev)].DoWork()
    84  	ctx := machine.CUDNNContexts()[int(dev)]
    85  
    86  	var s cudnn.Memory
    87  	var memsize uintptr
    88  	if memsize, err = op.RequiredStateSize(ctx); err != nil {
    89  		return nil, errors.Wrap(err, "Unable to get required state size for Dropout")
    90  	}
    91  	if !op.IsReady() {
    92  		// var x cu.DevicePtr
    93  		// machine.Engines()[int(dev)].DoWork()
    94  		// if x, err = machine.Contexts()[int(dev)].MemAlloc(int64(memsize)); err != nil {
    95  		// 	return nil, errors.Wrapf(err, "Unable to allocate %v bytes of memory of scratch space for Dropout", memsize)
    96  		// }
    97  
    98  		x, err := machine.Engines()[int(dev)].Get(int64(memsize))
    99  		if err != nil {
   100  			return nil, errors.Wrapf(err, "Unable to allocate %v bytes of memory of scratch space for Dropout", memsize)
   101  		}
   102  
   103  		s = tmpWrapper(x.(cu.DevicePtr))
   104  		// s = x.(cudnn.Memory)
   105  		if err = op.Use(ctx, s, memsize, op.seed); err != nil {
   106  			return nil, errors.Wrapf(err, "Unable to set dropout to use context %v", ctx)
   107  		}
   108  	} else {
   109  		s = op.States()
   110  	}
   111  
   112  	err = ctx.DropoutForward(op.Dropout, op.xDesc, x.(cudnn.Memory), op.xDesc, prealloc.(cudnn.Memory), s, memsize)
   113  	return prealloc, err
   114  }
   115  
   116  type dropoutDiff struct {
   117  	*dropout
   118  }
   119  
   120  func (op *dropoutDiff) Arity() int { return 1 }
   121  
   122  func (op *dropoutDiff) Type() hm.Type {
   123  	return hm.NewFnType(hm.TypeVariable('a'), hm.TypeVariable('a'))
   124  }
   125  
   126  func (op *dropoutDiff) InferShape(inputs ...gorgonia.DimSizer) (tensor.Shape, error) {
   127  	if err := checkArity(op, len(inputs)); err != nil {
   128  		return nil, err
   129  	}
   130  	return inputs[0].(tensor.Shape).Clone(), nil
   131  }
   132  
   133  func (op *dropoutDiff) Do(...gorgonia.Value) (gorgonia.Value, error) {
   134  	panic("not implemented")
   135  }
   136  
   137  func (op *dropoutDiff) ReturnsPtr() bool { return true }
   138  
   139  func (op *dropoutDiff) CallsExtern() bool { return true }
   140  
   141  func (op *dropoutDiff) OverwritesInput() int { return -1 }
   142  
   143  func (op *dropoutDiff) WriteHash(h hash.Hash) { fmt.Fprintf(h, "DropoutDiff %v", op.Dropout.Dropout()) }
   144  
   145  func (op *dropoutDiff) Hashcode() uint32 { return simpleHash(op) }
   146  
   147  func (op *dropoutDiff) String() string { return fmt.Sprintf("DropoutDiff %v", op.Dropout.Dropout()) }
   148  
   149  func (op *dropoutDiff) CUDADo(extern gorgonia.External, dev gorgonia.Device, prealloc gorgonia.Value, inputs ...gorgonia.Value) (retVal gorgonia.Value, err error) {
   150  	if err = checkArity(op, len(inputs)); err != nil {
   151  		return
   152  	}
   153  
   154  	dy := inputs[0]
   155  	machine := extern.(gorgonia.CUDAMachine)
   156  	machine.Engines()[int(dev)].DoWork()
   157  	ctx := machine.CUDNNContexts()[int(dev)]
   158  
   159  	if !op.IsReady() {
   160  		return nil, errors.New("OP is not ready")
   161  	}
   162  
   163  	scratch := op.States()
   164  	memsize, _ := op.RequiredStateSize(ctx)
   165  
   166  	err = ctx.DropoutBackward(op.Dropout, op.xDesc, dy.(cudnn.Memory),
   167  		op.xDesc, prealloc.(cudnn.Memory), scratch, memsize)
   168  	return prealloc, err
   169  }
   170  
   171  type tmpWrapper cu.DevicePtr
   172  
   173  func (p tmpWrapper) Uintptr() uintptr { return cu.DevicePtr(p).Uintptr() }
   174  
   175  func (p tmpWrapper) Pointer() unsafe.Pointer { return unsafe.Pointer(cu.DevicePtr(p).Uintptr()) }
   176  
   177  func (p tmpWrapper) IsNativelyAccessible() bool { return false }