gorgonia.org/gorgonia@v0.9.17/ops/nn/batchnorm_cuda.go (about)

     1  // +build cuda
     2  
     3  package nnops
     4  
     5  import (
     6  	"fmt"
     7  	"hash"
     8  
     9  	"github.com/chewxy/hm"
    10  	"gorgonia.org/cu/dnn"
    11  	t2cudnn "gorgonia.org/cu/dnn/interop"
    12  	"gorgonia.org/gorgonia"
    13  	"gorgonia.org/tensor"
    14  )
    15  
    16  type BatchNormOp struct {
    17  	mode              cudnn.BatchNormMode
    18  	momentum, epsilon float64
    19  
    20  	xDesc     *cudnn.TensorDescriptor
    21  	bnScratch *cudnn.TensorDescriptor
    22  
    23  	training bool
    24  }
    25  
    26  func newBatchNormOp(momentum, epsilon float64) *BatchNormOp {
    27  	return &BatchNormOp{
    28  		mode:     cudnn.PerActivation,
    29  		momentum: momentum,
    30  		epsilon:  epsilon,
    31  		training: true,
    32  	}
    33  }
    34  
    35  func (op *BatchNormOp) Arity() int { return 7 }
    36  
    37  func (op *BatchNormOp) Type() hm.Type {
    38  	t := gorgonia.TensorType{Dims: 4, Of: hm.TypeVariable('a')}
    39  	return hm.NewFnType(t, // x
    40  		t, // scale
    41  		t, // bias
    42  		t, // running mean / expected mean
    43  		t, // running var / expected var
    44  		t, // cached mean
    45  		t, // cachedVar
    46  		t) // retVal
    47  }
    48  
    49  func (op *BatchNormOp) InferShape(inputs ...gorgonia.DimSizer) (tensor.Shape, error) {
    50  	if err := checkArity(op, len(inputs)); err != nil {
    51  		return nil, err
    52  	}
    53  	return inputs[0].(tensor.Shape).Clone(), nil
    54  }
    55  
    56  func (op *BatchNormOp) Do(...gorgonia.Value) (gorgonia.Value, error) { panic("not implemented") }
    57  func (op *BatchNormOp) ReturnsPtr() bool                             { return true }
    58  func (op *BatchNormOp) CallsExtern() bool                            { return true }
    59  func (op *BatchNormOp) OverwritesInput() int                         { return -1 }
    60  func (op *BatchNormOp) WriteHash(h hash.Hash) {
    61  	fmt.Fprintf(h, "BatchNorm %v %v", op.momentum, op.epsilon)
    62  }
    63  func (op *BatchNormOp) Hashcode() uint32 { return simpleHash(op) }
    64  func (op *BatchNormOp) String() string   { return fmt.Sprintf("BatchNorm %v %v", op.momentum, op.epsilon) }
    65  
    66  func (op *BatchNormOp) CUDADo(extern gorgonia.External, dev gorgonia.Device, prealloc gorgonia.Value, inputs ...gorgonia.Value) (retVal gorgonia.Value, err error) {
    67  	// panic("not implemented")
    68  
    69  	machine := extern.(gorgonia.CUDAMachine)
    70  	ctx := machine.CUDNNContexts()[int(dev)]
    71  
    72  	x, bnScale, bnBias, mean, variance, cachedMean, cachedVar := inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], inputs[5], inputs[6]
    73  	if op.xDesc == nil {
    74  		if op.xDesc, err = t2cudnn.Describe(x.(tensor.Tensor)); err != nil {
    75  			return
    76  		}
    77  	}
    78  
    79  	if op.bnScratch == nil {
    80  		if op.bnScratch, err = t2cudnn.Describe(mean.(tensor.Tensor)); err != nil {
    81  			return
    82  		}
    83  	}
    84  
    85  	alpha := 0.0
    86  	beta := 1.0
    87  	if op.training {
    88  		err = ctx.BatchNormalizationForwardTraining(op.mode, alpha, beta,
    89  			op.xDesc, x.(cudnn.Memory),
    90  			op.xDesc, prealloc.(cudnn.Memory), // yDesc, y
    91  			op.bnScratch,
    92  			bnScale.(cudnn.Memory),
    93  			bnBias.(cudnn.Memory),
    94  			op.momentum,
    95  			mean.(cudnn.Memory),     // runniing mean
    96  			variance.(cudnn.Memory), // running variance
    97  			op.epsilon,
    98  			cachedMean.(cudnn.Memory),
    99  			cachedVar.(cudnn.Memory),
   100  		)
   101  	} else {
   102  		err = ctx.BatchNormalizationForwardInference(op.mode, alpha, beta,
   103  			op.xDesc, x.(cudnn.Memory),
   104  			op.xDesc, prealloc.(cudnn.Memory),
   105  			op.bnScratch,
   106  			bnScale.(cudnn.Memory),
   107  			bnBias.(cudnn.Memory),
   108  			mean.(cudnn.Memory),     // expected mean
   109  			variance.(cudnn.Memory), // expected variance
   110  			op.epsilon)
   111  	}
   112  	return prealloc, err
   113  }
   114  
   115  func (op *BatchNormOp) DiffWRT(inputs int) []bool {
   116  	return []bool{true, true, true, false, false, false, false}
   117  }
   118  
   119  func (op *BatchNormOp) SymDiff(inputs gorgonia.Nodes, output *gorgonia.Node, grad *gorgonia.Node) (retVal gorgonia.Nodes, err error) {
   120  	x, scale, bias := inputs[0], inputs[1], inputs[2]
   121  	cachedMean, cachedVar := inputs[5], inputs[6]
   122  	dy := grad // rename for simplicity of reading
   123  
   124  	// create new nodes for the diffs
   125  	g := x.Graph()
   126  	dt := x.Dtype()
   127  	scaleScratch := &scratchOp{scale.Shape().Clone(), dt, scale.Name() + "Diff"}
   128  	biasScratch := &scratchOp{bias.Shape().Clone(), dt, bias.Name() + "Diff"}
   129  	dscale := gorgonia.NewTensor(g, dt, scale.Shape().Dims(), gorgonia.WithOp(scaleScratch))
   130  	dbias := gorgonia.NewTensor(g, dt, bias.Shape().Dims(), gorgonia.WithOp(biasScratch))
   131  
   132  	retVal = make(gorgonia.Nodes, 7)
   133  
   134  	diffOp := &batchNormDiffOp{op}
   135  	retVal[0], err = gorgonia.ApplyOp(diffOp, x, scale, dscale, dbias, dy, cachedMean, cachedVar)
   136  	retVal[1] = dscale
   137  	retVal[2] = dbias
   138  	gorgonia.SetDerivOf(dscale, scale)
   139  	gorgonia.SetDerivOf(dbias, bias)
   140  
   141  	return retVal, err
   142  }
   143  
   144  func (op *BatchNormOp) DoDiff(ctx gorgonia.ExecutionContext, inputs gorgonia.Nodes, output *gorgonia.Node) error {
   145  	panic("not implemented")
   146  }
   147  
   148  func (op *BatchNormOp) SetTraining() { op.training = true }
   149  func (op *BatchNormOp) SetTesting()  { op.training = false }
   150  func (op *BatchNormOp) Reset() error { return nil }
   151  
   152  type batchNormDiffOp struct {
   153  	*BatchNormOp
   154  }
   155  
   156  // Arity is the same exact function as BatchNormOp (7)
   157  
   158  // Type is exactly the same as BatchNormOp, but the semantics are different:
   159  // 	return hm.NewFnType(
   160  //		t, // x
   161  // 		t, // scale
   162  // 		t, // dscale
   163  // 		t, // dbias
   164  // 		t, // dy
   165  // 		t, // cachedMean
   166  // 		t, // cachedVar
   167  // 		t  // retVal
   168  //	)
   169  
   170  // InferShape is the same exact function as BatchNormOp
   171  
   172  func (op *batchNormDiffOp) Do(...gorgonia.Value) (gorgonia.Value, error) {
   173  	panic("not implemented")
   174  }
   175  
   176  func (op *batchNormDiffOp) ReturnsPtr() bool { return true }
   177  
   178  func (op *batchNormDiffOp) CallsExtern() bool { return true }
   179  
   180  func (op *batchNormDiffOp) OverwritesInput() int { return -1 }
   181  
   182  func (op *batchNormDiffOp) WriteHash(h hash.Hash) {
   183  	fmt.Fprintf(h, "BatchNormDiff %v %v", op.momentum, op.epsilon)
   184  }
   185  
   186  // HashCode is exactly the same as BatchNormOp
   187  
   188  func (op *batchNormDiffOp) String() string {
   189  	return fmt.Sprintf("BatchNormDiff %v %v", op.momentum, op.epsilon)
   190  }
   191  
   192  func (op *batchNormDiffOp) CUDADo(extern gorgonia.External, dev gorgonia.Device, prealloc gorgonia.Value, inputs ...gorgonia.Value) (retVal gorgonia.Value, err error) {
   193  	machine := extern.(gorgonia.CUDAMachine)
   194  	e := &machine.Engines()[int(dev)]
   195  	ctx := machine.CUDNNContexts()[int(dev)]
   196  
   197  	x, scale, dscale, dbias, dy, cachedMean, cachedVariance := inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], inputs[5], inputs[6]
   198  	dscale = gorgonia.ScalarAsTensor(dscale, 4, e)
   199  	dbias = gorgonia.ScalarAsTensor(dbias, 4, e)
   200  
   201  	alpha := 0.0
   202  	beta := 1.0
   203  	err = ctx.BatchNormalizationBackward(op.mode,
   204  		alpha, beta, // for data
   205  		alpha, beta, // for param
   206  		op.xDesc,
   207  		x.(cudnn.Memory),
   208  		op.xDesc, // dyDesc
   209  		dy.(cudnn.Memory),
   210  		op.xDesc, // dxDesc
   211  		prealloc.(cudnn.Memory),
   212  		op.bnScratch, // scratch space descriptor
   213  		scale.(cudnn.Memory),
   214  		dscale.(cudnn.Memory), // deriv of scale
   215  		dbias.(cudnn.Memory),  // deriv of bias
   216  		op.epsilon,
   217  		cachedMean.(cudnn.Memory), cachedVariance.(cudnn.Memory))
   218  	return prealloc, err
   219  }