gorgonia.org/gorgonia@v0.9.17/cmd/gencudaengine/tmpl.go (about)

     1  package main
     2  
     3  import "text/template"
     4  
     5  const binopRaw = `// {{.Method}} implements tensor.{{.Method}}er. It does not support safe or increment operation options and will return an error if those options are passed in
     6  func (e *Engine) {{.Method}}(a tensor.Tensor, b tensor.Tensor, opts ...tensor.FuncOpt) (retVal tensor.Tensor, err error) {
     7  	name := constructName2(a, b, "{{.ScalarMethod | lower}}")
     8  
     9  	if !e.HasFunc(name) {
    10  		return nil, errors.Errorf("Unable to perform {{.Method}}(). The tensor engine does not have the function %q", name)
    11  	}
    12  
    13  	if err = binaryCheck(a, b); err != nil {
    14  		return nil, errors.Wrap(err, "Basic checks failed for {{.Method}}")
    15  	}
    16  
    17  	var reuse tensor.DenseTensor
    18  	var safe, toReuse bool
    19  	if reuse, safe, toReuse, _, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil {
    20  		return nil, errors.Wrap(err, "Unable to handle funcOpts")
    21  	}
    22  
    23  	var mem, memB cu.DevicePtr
    24  	var size int64
    25  
    26  	switch {
    27  	case toReuse:
    28  		mem = cu.DevicePtr(reuse.Uintptr())
    29  		memA := cu.DevicePtr(a.Uintptr())
    30  		memSize := int64(a.MemSize())
    31  		e.memcpy(mem, memA, memSize)
    32  
    33  		size = int64(logicalSize(reuse.Shape()))
    34  		retVal = reuse
    35  	case !safe:
    36  		mem = cu.DevicePtr(a.Uintptr())
    37  		retVal = a
    38  		size = int64(logicalSize(a.Shape()))
    39  	default:
    40  		return nil, errors.New("Impossible state: A reuse tensor must be passed in, or the operation must be unsafe. Incr and safe operations are not supported")
    41  	}
    42  
    43  	memB = cu.DevicePtr(b.Uintptr())
    44  	fn := e.f[name]
    45  	gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ := e.ElemGridSize(int(size))
    46  	args := []unsafe.Pointer{
    47  		unsafe.Pointer(&mem),
    48  		unsafe.Pointer(&memB),
    49  		unsafe.Pointer(&size),
    50  	}
    51  	logf("gx %d, gy %d, gz %d | bx %d by %d, bz %d", gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ)
    52  	logf("CUDADO %q, Mem: %v MemB: %v size %v, args %v", name, mem, memB, size, args)
    53  	logf("LaunchKernel Params. mem: %v. Size %v", mem, size)
    54  	e.c.LaunchAndSync(fn, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, 0, cu.NoStream, args)
    55  	return
    56  }
    57  
    58  // {{.ScalarMethod}}Scalar implements tensor.{{.Method}}er. It does not support safe or increment operation options and will return an error if those options are passed in
    59  func (e *Engine) {{.ScalarMethod}}Scalar(a tensor.Tensor, b interface{}, leftTensor bool, opts ...tensor.FuncOpt) (retVal tensor.Tensor, err error) {
    60  	name := constructName1(a, leftTensor, "{{.ScalarMethod | lower}}")
    61  	if !e.HasFunc(name) {
    62  		return nil, errors.Errorf("Unable to perform {{.ScalarMethod}}Scalar(). The tensor engine does not have the function %q", name)
    63  	}
    64  
    65  	var bMem tensor.Memory
    66  	var ok bool
    67  	if bMem, ok = b.(tensor.Memory); !ok {
    68  		return nil, errors.Errorf("b has to be a tensor.Memory. Got %T instead", b)
    69  	}
    70  
    71  	if err = unaryCheck(a); err != nil {
    72  		return nil, errors.Wrap(err, "Basic checks failed for {{.ScalarMethod}}Scalar")
    73  	}
    74  
    75  	var reuse tensor.DenseTensor
    76  	var safe, toReuse bool
    77  	if reuse, safe, toReuse, _, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil {
    78  		return nil, errors.Wrap(err, "Unable to handle funcOpts")
    79  	}
    80  
    81  	var mem, memB cu.DevicePtr
    82  	var size int64
    83  
    84  	switch {
    85  	case toReuse:
    86  		mem = cu.DevicePtr(reuse.Uintptr())
    87  		memA := cu.DevicePtr(a.Uintptr())
    88  		memSize := int64(a.MemSize())
    89  		e.memcpy(mem, memA, memSize)
    90  
    91  		size = int64(logicalSize(reuse.Shape()))
    92  		retVal = reuse
    93  	case !safe:
    94  		mem = cu.DevicePtr(a.Uintptr())
    95  		retVal = a
    96  		size = int64(logicalSize(a.Shape()))
    97  	default:
    98  		return nil, errors.New("Impossible state: A reuse tensor must be passed in, or the operation must be unsafe. Incr and safe operations are not supported")
    99  	}
   100  
   101  	memB = cu.DevicePtr(bMem.Uintptr())
   102  	if !leftTensor {
   103  		mem, memB = memB, mem
   104  	}
   105  	
   106  	fn := e.f[name]
   107  	gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ := e.ElemGridSize(int(size))
   108  	args := []unsafe.Pointer{
   109  		unsafe.Pointer(&mem),
   110  		unsafe.Pointer(&memB),
   111  		unsafe.Pointer(&size),
   112  	}
   113  	logf("gx %d, gy %d, gz %d | bx %d by %d, bz %d", gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ)
   114  	logf("CUDADO %q, Mem: %v size %v, args %v", name, mem, size, args)
   115  	logf("LaunchKernel Params. mem: %v. Size %v", mem, size)
   116  	e.c.LaunchAndSync(fn, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, 0, cu.NoStream, args)
   117  	return
   118  }
   119  `
   120  
   121  var (
   122  	binopTmpl *template.Template
   123  )
   124  
   125  func init() {
   126  	binopTmpl = template.Must(template.New("binop").Funcs(funcmap).Parse(binopRaw))
   127  }