gorgonia.org/gorgonia@v0.9.17/ops/nn/convolution_cuda.go (about)

     1  // +build cuda
     2  
     3  package nnops
     4  
     5  import (
     6  	"fmt"
     7  	"hash"
     8  
     9  	"github.com/chewxy/hm"
    10  	cudnn "gorgonia.org/cu/dnn"
    11  	t2cudnn "gorgonia.org/cu/dnn/interop"
    12  	G "gorgonia.org/gorgonia"
    13  	"gorgonia.org/tensor"
    14  )
    15  
    16  var (
    17  	_ G.Op       = &convolution{}
    18  	_ G.CUDADoer = &convolution{}
    19  )
    20  
    21  type convolution struct {
    22  	*cudnn.Convolution
    23  
    24  	// created with these attributes
    25  	padding, stride, dilation []int
    26  	inShape, filterShape      tensor.Shape
    27  
    28  	// cached descriptors
    29  	xDesc, yDesc *cudnn.TensorDescriptor
    30  	wDesc        *cudnn.Filter
    31  }
    32  
    33  func makeConvolutionOp(im, filter *G.Node, kernelShape tensor.Shape, pad, stride, dilation []int) (retVal *convolution, err error) {
    34  	var xDesc *cudnn.TensorDescriptor
    35  	var wDesc *cudnn.Filter
    36  	if xDesc, err = t2cudnn.Describe(im); err != nil {
    37  		return nil, err
    38  	}
    39  	if wDesc, err = t2cudnn.DescribeAsFilter(filter, cudnn.NCHW); err != nil {
    40  		return nil, err
    41  	}
    42  	datatype := t2cudnn.Dtype2DataType(im.Dtype())
    43  	conv, err := cudnn.NewConvolution(cudnn.DefaultMath, 1, pad, stride, dilation, cudnn.StandardConvolution, datatype)
    44  	if err != nil {
    45  		return nil, err
    46  	}
    47  
    48  	return &convolution{
    49  		Convolution: conv,
    50  		padding:     pad,
    51  		stride:      stride,
    52  		dilation:    dilation,
    53  
    54  		inShape:     im.Shape().Clone(),
    55  		filterShape: filter.Shape().Clone(),
    56  
    57  		xDesc: xDesc,
    58  		wDesc: wDesc,
    59  	}, nil
    60  }
    61  
    62  func (c *convolution) Arity() int { return 2 }
    63  
    64  func (c *convolution) Type() hm.Type {
    65  	return hm.NewFnType(hm.TypeVariable('a'), hm.TypeVariable('a'), hm.TypeVariable('a'))
    66  }
    67  
    68  func (c *convolution) InferShape(inputs ...G.DimSizer) (retVal tensor.Shape, err error) {
    69  	if err = checkArity(c, len(inputs)); err != nil {
    70  		return
    71  	}
    72  	return c.ForwardOutputShape(c.xDesc, c.wDesc, 2) //only conv2d is supported now
    73  }
    74  
    75  func (c *convolution) Do(inputs ...G.Value) (retVal G.Value, err error) {
    76  	panic("not implemented")
    77  }
    78  
    79  func (c *convolution) ReturnsPtr() bool { return true }
    80  
    81  func (c *convolution) CallsExtern() bool { return true }
    82  
    83  func (c *convolution) OverwritesInput() int { return -1 }
    84  
    85  func (c *convolution) WriteHash(h hash.Hash) {
    86  	fmt.Fprintf(h, "Convolution:%v-%v-%v", c.Padding(), c.FilterStride(), c.Dilation())
    87  }
    88  
    89  func (c *convolution) Hashcode() uint32 { return simpleHash(c) }
    90  
    91  func (c *convolution) String() string {
    92  	return fmt.Sprintf("Convolution:%v-%v-%v", c.Padding(), c.FilterStride(), c.Dilation())
    93  }
    94  
    95  func (c *convolution) CUDADo(extern G.External, dev G.Device, prealloc G.Value, inputs ...G.Value) (retVal G.Value, err error) {
    96  	if err = checkArity(c, len(inputs)); err != nil {
    97  		return
    98  	}
    99  	im, filter := inputs[0], inputs[1]
   100  
   101  	if c.yDesc == nil {
   102  		if c.yDesc, err = t2cudnn.Describe(prealloc.(tensor.Tensor)); err != nil {
   103  			return
   104  		}
   105  	}
   106  
   107  	machine := extern.(G.CUDAMachine)
   108  	machine.Engines()[int(dev)].DoWork()
   109  	ctx := machine.CUDNNContexts()[int(dev)]
   110  
   111  	if err = ctx.ConvolutionForward(1.0,
   112  		c.xDesc, im.(cudnn.Memory),
   113  		c.wDesc, filter.(cudnn.Memory),
   114  		c.Convolution,
   115  		cudnn.ConvolutionFwdAlgoImplicitGemm, nomem{},
   116  		0, 1.0,
   117  		c.yDesc, prealloc.(cudnn.Memory)); err != nil {
   118  		return
   119  	}
   120  	return prealloc, nil
   121  }
   122  
   123  func (c *convolution) DoDiff(ctx G.ExecutionContext, inputs G.Nodes, output *G.Node) error {
   124  	panic("not implemented")
   125  }
   126  
   127  func (c *convolution) DiffWRT(inputs int) []bool {
   128  	return []bool{true, true}
   129  }
   130  
   131  func (c *convolution) SymDiff(inputs G.Nodes, output *G.Node, grad *G.Node) (retVal G.Nodes, err error) {
   132  	var outDesc *cudnn.TensorDescriptor
   133  	if outDesc, err = t2cudnn.Describe(output); err != nil {
   134  		return nil, err
   135  	}
   136  	diffIm := &convDiffIm{
   137  		convolution: c,
   138  		outputDesc:  outDesc,
   139  	}
   140  	diffFilter := &convDiffFilter{
   141  		convolution: c,
   142  		outputDesc:  outDesc,
   143  	}
   144  
   145  	retVal = make(G.Nodes, 2)
   146  	if retVal[0], err = G.ApplyOp(diffIm, inputs[0], grad); err != nil {
   147  		return nil, err
   148  	}
   149  	if retVal[1], err = G.ApplyOp(diffFilter, inputs[1], grad); err != nil {
   150  		return nil, err
   151  	}
   152  
   153  	return
   154  }
   155  
   156  // convDiffIm is the d(z)/d(im) operation. See also convDiffFilter
   157  type convDiffIm struct {
   158  	*convolution
   159  	outputDesc *cudnn.TensorDescriptor
   160  }
   161  
   162  func (c *convDiffIm) Arity() int { return 2 }
   163  
   164  func (c *convDiffIm) Type() hm.Type {
   165  	return hm.NewFnType(hm.TypeVariable('a'), hm.TypeVariable('a'), hm.TypeVariable('a'))
   166  }
   167  
   168  func (c *convDiffIm) InferShape(shps ...G.DimSizer) (tensor.Shape, error) {
   169  	return c.inShape.Clone(), nil
   170  }
   171  
   172  func (c *convDiffIm) Do(...G.Value) (G.Value, error) {
   173  	panic("not implemented")
   174  }
   175  
   176  func (c *convDiffIm) ReturnsPtr() bool { return true }
   177  
   178  func (c *convDiffIm) CallsExtern() bool { return true }
   179  
   180  func (c *convDiffIm) OverwritesInput() int { return -1 }
   181  
   182  func (c *convDiffIm) WriteHash(h hash.Hash) {
   183  	fmt.Fprintf(h, "ConvolutionImDiff:%v-%v-%v", c.Padding(), c.FilterStride(), c.Dilation())
   184  }
   185  
   186  func (c *convDiffIm) Hashcode() uint32 { return simpleHash(c) }
   187  
   188  func (c *convDiffIm) String() string {
   189  	return fmt.Sprintf("ConvolutionImDiff:%v-%v-%v", c.Padding(), c.FilterStride(), c.Dilation())
   190  }
   191  
   192  func (c *convDiffIm) CUDADo(extern G.External, dev G.Device, prealloc G.Value, inputs ...G.Value) (retVal G.Value, err error) {
   193  	if err = checkArity(c, len(inputs)); err != nil {
   194  		return
   195  	}
   196  	filter, grad := inputs[0], inputs[1]
   197  
   198  	machine := extern.(G.CUDAMachine)
   199  	ctx := machine.CUDNNContexts()[int(dev)]
   200  
   201  	if err = ctx.ConvolutionBackwardData(1.0,
   202  		c.wDesc, filter.(cudnn.Memory),
   203  		c.outputDesc, grad.(cudnn.Memory),
   204  		c.Convolution,
   205  		cudnn.ConvolutionBwdDataAlgo0, nomem{},
   206  		0, 1.0,
   207  		c.xDesc, prealloc.(cudnn.Memory)); err != nil {
   208  		return
   209  	}
   210  	return prealloc, nil
   211  }
   212  
   213  type convDiffFilter struct {
   214  	*convolution                         // shared struct as convDiffIm
   215  	outputDesc   *cudnn.TensorDescriptor // shared output descriptor with convDiffIm
   216  }
   217  
   218  func (c *convDiffFilter) Arity() int { return 2 }
   219  
   220  func (c *convDiffFilter) Type() hm.Type {
   221  	return hm.NewFnType(hm.TypeVariable('a'), hm.TypeVariable('a'), hm.TypeVariable('a'))
   222  }
   223  
   224  func (c *convDiffFilter) InferShape(...G.DimSizer) (tensor.Shape, error) {
   225  	return c.filterShape.Clone(), nil
   226  }
   227  
   228  func (c *convDiffFilter) Do(...G.Value) (G.Value, error) {
   229  	panic("not implemented")
   230  }
   231  
   232  func (c *convDiffFilter) ReturnsPtr() bool { return true }
   233  
   234  func (c *convDiffFilter) CallsExtern() bool { return true }
   235  
   236  func (c *convDiffFilter) OverwritesInput() int { return -1 }
   237  
   238  func (c *convDiffFilter) WriteHash(h hash.Hash) {
   239  	fmt.Fprintf(h, "ConvolutionFilterDiff:%v-%v-%v", c.Padding(), c.FilterStride(), c.Dilation())
   240  }
   241  
   242  func (c *convDiffFilter) Hashcode() uint32 { return simpleHash(c) }
   243  
   244  func (c *convDiffFilter) String() string {
   245  	return fmt.Sprintf("ConvolutionFilterDiff:%v-%v-%v", c.Padding(), c.FilterStride(), c.Dilation())
   246  }
   247  
   248  func (c *convDiffFilter) CUDADo(extern G.External, dev G.Device, prealloc G.Value, inputs ...G.Value) (retVal G.Value, err error) {
   249  	if err = checkArity(c, len(inputs)); err != nil {
   250  		return
   251  	}
   252  	im, grad := inputs[0], inputs[1]
   253  
   254  	machine := extern.(G.CUDAMachine)
   255  	ctx := machine.CUDNNContexts()[int(dev)]
   256  
   257  	if err = ctx.ConvolutionBackwardFilter(1.0,
   258  		c.xDesc, im.(cudnn.Memory),
   259  		c.outputDesc, grad.(cudnn.Memory),
   260  		c.Convolution,
   261  		cudnn.ConvolutionBwdFilterAlgo0, nomem{},
   262  		0, 1.0,
   263  		c.wDesc, prealloc.(cudnn.Memory)); err != nil {
   264  		return
   265  	}
   266  	return prealloc, nil
   267  }