gorgonia.org/gorgonia@v0.9.17/cmd/gencudaengine/tmpl.go (about) 1 package main 2 3 import "text/template" 4 5 const binopRaw = `// {{.Method}} implements tensor.{{.Method}}er. It does not support safe or increment operation options and will return an error if those options are passed in 6 func (e *Engine) {{.Method}}(a tensor.Tensor, b tensor.Tensor, opts ...tensor.FuncOpt) (retVal tensor.Tensor, err error) { 7 name := constructName2(a, b, "{{.ScalarMethod | lower}}") 8 9 if !e.HasFunc(name) { 10 return nil, errors.Errorf("Unable to perform {{.Method}}(). The tensor engine does not have the function %q", name) 11 } 12 13 if err = binaryCheck(a, b); err != nil { 14 return nil, errors.Wrap(err, "Basic checks failed for {{.Method}}") 15 } 16 17 var reuse tensor.DenseTensor 18 var safe, toReuse bool 19 if reuse, safe, toReuse, _, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { 20 return nil, errors.Wrap(err, "Unable to handle funcOpts") 21 } 22 23 var mem, memB cu.DevicePtr 24 var size int64 25 26 switch { 27 case toReuse: 28 mem = cu.DevicePtr(reuse.Uintptr()) 29 memA := cu.DevicePtr(a.Uintptr()) 30 memSize := int64(a.MemSize()) 31 e.memcpy(mem, memA, memSize) 32 33 size = int64(logicalSize(reuse.Shape())) 34 retVal = reuse 35 case !safe: 36 mem = cu.DevicePtr(a.Uintptr()) 37 retVal = a 38 size = int64(logicalSize(a.Shape())) 39 default: 40 return nil, errors.New("Impossible state: A reuse tensor must be passed in, or the operation must be unsafe. Incr and safe operations are not supported") 41 } 42 43 memB = cu.DevicePtr(b.Uintptr()) 44 fn := e.f[name] 45 gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ := e.ElemGridSize(int(size)) 46 args := []unsafe.Pointer{ 47 unsafe.Pointer(&mem), 48 unsafe.Pointer(&memB), 49 unsafe.Pointer(&size), 50 } 51 logf("gx %d, gy %d, gz %d | bx %d by %d, bz %d", gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ) 52 logf("CUDADO %q, Mem: %v MemB: %v size %v, args %v", name, mem, memB, size, args) 53 logf("LaunchKernel Params. mem: %v. Size %v", mem, size) 54 e.c.LaunchAndSync(fn, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, 0, cu.NoStream, args) 55 return 56 } 57 58 // {{.ScalarMethod}}Scalar implements tensor.{{.Method}}er. It does not support safe or increment operation options and will return an error if those options are passed in 59 func (e *Engine) {{.ScalarMethod}}Scalar(a tensor.Tensor, b interface{}, leftTensor bool, opts ...tensor.FuncOpt) (retVal tensor.Tensor, err error) { 60 name := constructName1(a, leftTensor, "{{.ScalarMethod | lower}}") 61 if !e.HasFunc(name) { 62 return nil, errors.Errorf("Unable to perform {{.ScalarMethod}}Scalar(). The tensor engine does not have the function %q", name) 63 } 64 65 var bMem tensor.Memory 66 var ok bool 67 if bMem, ok = b.(tensor.Memory); !ok { 68 return nil, errors.Errorf("b has to be a tensor.Memory. Got %T instead", b) 69 } 70 71 if err = unaryCheck(a); err != nil { 72 return nil, errors.Wrap(err, "Basic checks failed for {{.ScalarMethod}}Scalar") 73 } 74 75 var reuse tensor.DenseTensor 76 var safe, toReuse bool 77 if reuse, safe, toReuse, _, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { 78 return nil, errors.Wrap(err, "Unable to handle funcOpts") 79 } 80 81 var mem, memB cu.DevicePtr 82 var size int64 83 84 switch { 85 case toReuse: 86 mem = cu.DevicePtr(reuse.Uintptr()) 87 memA := cu.DevicePtr(a.Uintptr()) 88 memSize := int64(a.MemSize()) 89 e.memcpy(mem, memA, memSize) 90 91 size = int64(logicalSize(reuse.Shape())) 92 retVal = reuse 93 case !safe: 94 mem = cu.DevicePtr(a.Uintptr()) 95 retVal = a 96 size = int64(logicalSize(a.Shape())) 97 default: 98 return nil, errors.New("Impossible state: A reuse tensor must be passed in, or the operation must be unsafe. Incr and safe operations are not supported") 99 } 100 101 memB = cu.DevicePtr(bMem.Uintptr()) 102 if !leftTensor { 103 mem, memB = memB, mem 104 } 105 106 fn := e.f[name] 107 gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ := e.ElemGridSize(int(size)) 108 args := []unsafe.Pointer{ 109 unsafe.Pointer(&mem), 110 unsafe.Pointer(&memB), 111 unsafe.Pointer(&size), 112 } 113 logf("gx %d, gy %d, gz %d | bx %d by %d, bz %d", gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ) 114 logf("CUDADO %q, Mem: %v size %v, args %v", name, mem, size, args) 115 logf("LaunchKernel Params. mem: %v. Size %v", mem, size) 116 e.c.LaunchAndSync(fn, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, 0, cu.NoStream, args) 117 return 118 } 119 ` 120 121 var ( 122 binopTmpl *template.Template 123 ) 124 125 func init() { 126 binopTmpl = template.Must(template.New("binop").Funcs(funcmap).Parse(binopRaw)) 127 }