gorgonia.org/gorgonia@v0.9.17/ops/nn/maxpool_cuda.go (about)

     1  // +build cuda
     2  
     3  package nnops
     4  
     5  import (
     6  	"fmt"
     7  	"hash"
     8  
     9  	"github.com/chewxy/hm"
    10  	cudnn "gorgonia.org/cu/dnn"
    11  	t2cudnn "gorgonia.org/cu/dnn/interop"
    12  	G "gorgonia.org/gorgonia"
    13  	"gorgonia.org/tensor"
    14  )
    15  
    16  var (
    17  	_ G.Op       = &maxpool{}
    18  	_ G.CUDADoer = &maxpool{}
    19  	_ G.Op       = &maxpoolDiff{}
    20  	_ G.CUDADoer = &maxpoolDiff{}
    21  )
    22  
    23  type maxpool struct {
    24  	*cudnn.Pooling
    25  
    26  	xDesc *cudnn.TensorDescriptor
    27  	yDesc *cudnn.TensorDescriptor
    28  }
    29  
    30  func newMaxPoolOp(x *G.Node, kernel, pad, stride []int) (*maxpool, error) {
    31  	var xDesc *cudnn.TensorDescriptor
    32  	var err error
    33  	if xDesc, err = t2cudnn.Describe(x); err != nil {
    34  		return nil, err
    35  	}
    36  
    37  	var p *cudnn.Pooling
    38  	if p, err = cudnn.NewPooling(cudnn.MaxPooling, cudnn.NotPropagateNan, kernel, stride, pad); err != nil {
    39  		return nil, err
    40  	}
    41  	return &maxpool{
    42  		Pooling: p,
    43  		xDesc:   xDesc,
    44  	}, nil
    45  }
    46  
    47  func (p *maxpool) Arity() int { return 1 }
    48  
    49  func (p *maxpool) Type() hm.Type { return hm.NewFnType(hm.TypeVariable('a'), hm.TypeVariable('a')) }
    50  
    51  func (p *maxpool) InferShape(inputs ...G.DimSizer) (tensor.Shape, error) {
    52  	if err := checkArity(p, len(inputs)); err != nil {
    53  		return nil, err
    54  	}
    55  	return p.OutputShape(p.xDesc, 2) // only maxpool2d for now
    56  }
    57  
    58  func (p *maxpool) Do(...G.Value) (G.Value, error) {
    59  	panic("not implemented")
    60  }
    61  
    62  func (p *maxpool) ReturnsPtr() bool { return true }
    63  
    64  func (p *maxpool) CallsExtern() bool { return true }
    65  
    66  func (p *maxpool) OverwritesInput() int { return -1 }
    67  
    68  func (p *maxpool) WriteHash(h hash.Hash) {
    69  	xShape := p.xDesc.Shape()
    70  	kernel := p.Shape()
    71  	padding := p.Padding()
    72  	strides := p.Strides()
    73  	fmt.Fprintf(h, "MaxPool{%d, %d, %d, %d}(kernel: (%d, %d), pad: (%d, %d), stride: (%d, %d))",
    74  		xShape[0], xShape[1], xShape[2], xShape[3],
    75  		kernel[0], kernel[1],
    76  		padding[0], padding[1],
    77  		strides[0], strides[1])
    78  }
    79  
    80  func (p *maxpool) Hashcode() uint32 { return simpleHash(p) }
    81  
    82  func (p *maxpool) String() string {
    83  	xShape := p.xDesc.Shape()
    84  	kernel := p.Shape()
    85  	padding := p.Padding()
    86  	strides := p.Strides()
    87  	return fmt.Sprintf("MaxPool{%d, %d, %d, %d}(kernel: (%d, %d), pad: (%d, %d), stride: (%d, %d))",
    88  		xShape[0], xShape[1], xShape[2], xShape[3],
    89  		kernel[0], kernel[1],
    90  		padding[0], padding[1],
    91  		strides[0], strides[1])
    92  }
    93  
    94  func (p *maxpool) CUDADo(extern G.External, dev G.Device, prealloc G.Value, inputs ...G.Value) (retVal G.Value, err error) {
    95  	if err = checkArity(p, len(inputs)); err != nil {
    96  		return
    97  	}
    98  	in := inputs[0]
    99  
   100  	if p.yDesc == nil {
   101  		if p.yDesc, err = t2cudnn.Describe(prealloc.(tensor.Tensor)); err != nil {
   102  			return
   103  		}
   104  	}
   105  
   106  	machine := extern.(G.CUDAMachine)
   107  	machine.Engines()[int(dev)].DoWork()
   108  	ctx := machine.CUDNNContexts()[int(dev)]
   109  	err = ctx.PoolingForward(p.Pooling, 1.0, p.xDesc, in.(cudnn.Memory), 0, p.yDesc, prealloc.(cudnn.Memory))
   110  	return prealloc, err
   111  }
   112  
   113  func (p *maxpool) DiffWRT(inputs int) []bool { return []bool{true} }
   114  
   115  func (p *maxpool) SymDiff(inputs G.Nodes, output *G.Node, grad *G.Node) (retVal G.Nodes, err error) {
   116  	if err = checkArity(p, len(inputs)); err != nil {
   117  		return
   118  	}
   119  	diff := (*maxpoolDiff)(p)
   120  	x := inputs[0]
   121  
   122  	retVal = make(G.Nodes, 1)
   123  	retVal[0], err = G.ApplyOp(diff, x, output, grad)
   124  	return
   125  }
   126  
   127  func (p *maxpool) DoDiff(ctx G.ExecutionContext, inputs G.Nodes, output *G.Node) error {
   128  	panic("not implemented")
   129  }
   130  
   131  type maxpoolDiff maxpool
   132  
   133  func (op *maxpoolDiff) Arity() int { return 3 }
   134  
   135  func (op *maxpoolDiff) Type() hm.Type {
   136  	return hm.NewFnType(hm.TypeVariable('a'), hm.TypeVariable('a'), hm.TypeVariable('a'), hm.TypeVariable('a'))
   137  }
   138  
   139  func (op *maxpoolDiff) InferShape(inputs ...G.DimSizer) (tensor.Shape, error) {
   140  	return inputs[0].(tensor.Shape).Clone(), nil
   141  }
   142  
   143  func (op *maxpoolDiff) Do(...G.Value) (G.Value, error) { panic("not implemented") }
   144  
   145  func (op *maxpoolDiff) ReturnsPtr() bool { return true }
   146  
   147  func (op *maxpoolDiff) CallsExtern() bool { return true }
   148  
   149  func (op *maxpoolDiff) OverwritesInput() int { return -1 }
   150  
   151  func (op *maxpoolDiff) WriteHash(h hash.Hash) {
   152  	xShape := op.xDesc.Shape()
   153  	kernel := op.Shape()
   154  	padding := op.Padding()
   155  	strides := op.Strides()
   156  	fmt.Fprintf(h, "MaxPoolDiff{%d, %d, %d, %d}(kernel: (%d, %d), pad: (%d, %d), stride: (%d, %d))",
   157  		xShape[0], xShape[1], xShape[2], xShape[3],
   158  		kernel[0], kernel[1],
   159  		padding[0], padding[1],
   160  		strides[0], strides[1])
   161  }
   162  
   163  func (op *maxpoolDiff) Hashcode() uint32 { return simpleHash(op) }
   164  
   165  func (op *maxpoolDiff) String() string {
   166  	xShape := op.xDesc.Shape()
   167  	kernel := op.Shape()
   168  	padding := op.Padding()
   169  	strides := op.Strides()
   170  	return fmt.Sprintf("MaxPoolDiff{%d, %d, %d, %d}(kernel: (%d, %d), pad: (%d, %d), stride: (%d, %d))",
   171  		xShape[0], xShape[1], xShape[2], xShape[3],
   172  		kernel[0], kernel[1],
   173  		padding[0], padding[1],
   174  		strides[0], strides[1])
   175  }
   176  
   177  func (op *maxpoolDiff) CUDADo(extern G.External, dev G.Device, prealloc G.Value, inputs ...G.Value) (retVal G.Value, err error) {
   178  	if err = checkArity(op, len(inputs)); err != nil {
   179  		return
   180  	}
   181  	x, y, dy := inputs[0], inputs[1], inputs[2]
   182  
   183  	machine := extern.(G.CUDAMachine)
   184  	machine.Engines()[int(dev)].DoWork()
   185  	ctx := machine.CUDNNContexts()[int(dev)]
   186  	err = ctx.PoolingBackward(op.Pooling, 1.0, op.yDesc, y.(cudnn.Memory), op.yDesc, dy.(cudnn.Memory), op.xDesc, x.(cudnn.Memory), 0, op.xDesc, prealloc.(cudnn.Memory))
   187  	return prealloc, err
   188  }