gorgonia.org/gorgonia@v0.9.17/op_math_cuda.go (about) 1 // +build cuda 2 3 package gorgonia 4 5 import ( 6 "fmt" 7 "unsafe" 8 9 "github.com/pkg/errors" 10 "gorgonia.org/cu" 11 "gorgonia.org/gorgonia/cuda" 12 "gorgonia.org/tensor" 13 ) 14 15 // module names 16 const ( 17 elemBinOpMod = "elembinop" 18 elemUnaryOpMod = "elemunaryop" 19 ) 20 21 func (op elemUnaryOp) CallsExtern() bool { return true } 22 23 func (op elemUnaryOp) CUDADo(extern External, dev Device, prealloc Value, inputs ...Value) (retVal Value, err error) { 24 if err = checkArity(op, len(inputs)); err != nil { 25 return 26 } 27 28 cudaLogf("CUDADoing %v | prealloc %x | %x", op, prealloc.Uintptr(), inputs[0].Uintptr()) 29 enterLogScope() 30 defer leaveLogScope() 31 32 // check 33 cudaLogf("checking if input is scalar") 34 a := inputs[0] 35 dt := a.Dtype() 36 37 // build name 38 name := fmt.Sprintf("%v.%v_f%d", elemUnaryOpMod, op.unaryOpType(), int(dt.Size())*8) 39 40 machine := extern.(CUDAMachine) 41 eng := machine.Engines()[int(dev)] 42 if !eng.HasFunc(name) { 43 cudaLogf("extern does not have func %q", name) 44 extern.Signal() 45 46 if retVal, err = op.do(a); err != nil { 47 return 48 } 49 if prealloc == nil { 50 return 51 } 52 return Copy(prealloc, retVal) 53 } 54 fn := eng.Functions()[name] 55 ctx := machine.Contexts()[int(dev)] 56 57 retVal = prealloc 58 if prealloc == nil { 59 prealloc = a 60 retVal = a 61 } 62 63 var mem cu.DevicePtr 64 if prealloc.Uintptr() == a.Uintptr() && a.Shape().Eq(prealloc.Shape()) { 65 mem = cu.DevicePtr(a.Uintptr()) 66 } else { 67 mem = cu.DevicePtr(prealloc.Uintptr()) 68 memSize := int64(a.MemSize()) 69 memA := cu.DevicePtr(a.Uintptr()) 70 ctx.Memcpy(mem, memA, memSize) 71 } 72 size := logicalSize(a.Shape()) 73 74 // blocks, threads := machine.(*tapeMachine).blockThread(int(size), int(dev)) 75 gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ := machine.ElemGridSize(int(size), int(dev)) 76 args := []unsafe.Pointer{ 77 unsafe.Pointer(&mem), 78 unsafe.Pointer(&size), 79 } 80 cudaLogf("gx %d, gy %d, gz %d | bx %d by %d, bz %d", gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ) 81 cudaLogf("CUDADO %q, Mem: %v size %v, args %v", name, mem, size, args) 82 cudaLogf("LaunchKernel Params. mem: %v. Size %v", mem, size) 83 ctx.LaunchAndSync(fn, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, 0, cu.NoStream, args) 84 return 85 } 86 87 func (op elemBinOp) CallsExtern() bool { return true } 88 89 func (op elemBinOp) CUDADo(extern External, dev Device, prealloc Value, inputs ...Value) (retVal Value, err error) { 90 if err = checkArity(op, len(inputs)); err != nil { 91 return 92 } 93 cudaLogf("CUDADoing %v", op) 94 enterLogScope() 95 defer leaveLogScope() 96 97 a := inputs[0] 98 b := inputs[1] 99 as := a.Shape() 100 bs := b.Shape() 101 102 m := extern.(CUDAMachine) 103 e := &m.Engines()[int(dev)] 104 105 if as.IsScalar() && bs.IsScalar() { 106 return op.ssop(a, b, prealloc, e) 107 } 108 109 if aT, ok := a.(tensor.Tensor); ok { 110 tensor.WithEngine(e)(aT) 111 } 112 if bT, ok := b.(tensor.Tensor); ok { 113 tensor.WithEngine(e)(bT) 114 } 115 116 pT, toReuse := prealloc.(tensor.Tensor) 117 if toReuse { 118 tensor.WithEngine(e)(pT) 119 } 120 121 boType := op.binOpType() 122 if fn := binOps[boType]; fn != nil { 123 if toReuse { 124 return (*fn)(a, b, tensor.WithReuse(pT)) 125 } 126 return (*fn)(a, b, tensor.UseUnsafe()) 127 } 128 129 if fn := cmpOps[boType]; fn != nil { 130 if toReuse { 131 return (*fn)(a, b, tensor.WithReuse(pT)) 132 } 133 return (*fn)(a, b, tensor.UseUnsafe()) 134 } 135 136 return nil, errors.Errorf("op %v cannot be done by CUDA", op) 137 } 138 139 func (op elemBinOp) ssop(a, b, prealloc Value, e *cuda.Engine) (retVal Value, err error) { 140 dt := a.Dtype() 141 ctx := e.Context() 142 opName := ʘBinOpNames[op.binOpType()] 143 name := fmt.Sprintf("%v.%v_ss_f%d", elemBinOpMod, opName, int(dt.Size())*8) 144 var mem, memB cu.DevicePtr 145 var size int64 146 if prealloc == nil { 147 mem = cu.DevicePtr(a.Uintptr()) 148 retVal = a 149 size = int64(logicalSize(a.Shape())) 150 } else { 151 mem = cu.DevicePtr(prealloc.Uintptr()) 152 memA := cu.DevicePtr(a.Uintptr()) 153 memSize := int64(a.MemSize()) 154 ctx.Memcpy(mem, memA, memSize) 155 156 size = int64(logicalSize(prealloc.Shape())) 157 retVal = prealloc 158 } 159 memB = cu.DevicePtr(b.Uintptr()) 160 fn := e.Functions()[name] 161 162 var args []unsafe.Pointer 163 cudaLogf("%v mem %v, memB %v", op, mem, memB) 164 gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ := e.ElemGridSize(int(size)) 165 args = []unsafe.Pointer{ 166 unsafe.Pointer(&mem), 167 unsafe.Pointer(&memB), 168 unsafe.Pointer(&size), 169 } 170 171 cudaLogf("CUDADO %q, size %v", name, size) 172 cudaLogf("LaunchKernel params. mem: %v memB: %v size: %v", mem, memB, size) 173 cudaLogf("%d, %d, %d, %d, %d, %d", gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ) 174 ctx.LaunchAndSync(fn, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, 0, cu.NoStream, args) 175 return 176 } 177 178 /* LINEAR ALGEBRA STUFF */ 179 180 func (op linAlgBinOp) CallsExtern() bool { return true } 181 182 func (op linAlgBinOp) CUDADo(extern External, dev Device, prealloc Value, inputs ...Value) (retVal Value, err error) { 183 if err = checkArity(op, len(inputs)); err != nil { 184 return 185 } 186 187 m := extern.(CUDAMachine) 188 e := &m.Engines()[int(dev)] 189 190 a := inputs[0] 191 b := inputs[1] 192 193 aT, ok := a.(tensor.Tensor) 194 if !ok { 195 return nil, errors.Errorf("Expected a a to be a Tensor. Got %T instead", a) 196 } 197 bT, ok := b.(tensor.Tensor) 198 if !ok { 199 return nil, errors.Errorf("Expected a b to be a Tensor. Got %T instead", b) 200 } 201 202 pT, ok := prealloc.(tensor.Tensor) 203 if !ok { 204 return nil, errors.Errorf("Expected a prealloc to be a Tensor. Got %T instead", prealloc) 205 } 206 tensor.WithEngine(e)(bT) 207 tensor.WithEngine(e)(aT) 208 tensor.WithEngine(e)(pT) 209 210 if op.transA && op.āBinaryOperator != batchedMatMulOperator { 211 if err = aT.T(); err != nil { 212 return nil, errors.Wrap(err, tFail) 213 } 214 // untranspose 215 defer aT.T() 216 } 217 218 if op.transB && op.āBinaryOperator != batchedMatMulOperator { 219 if err = bT.T(); err != nil { 220 return nil, errors.Wrap(err, tFail) 221 } 222 // untranspose 223 defer bT.T() 224 } 225 226 switch op.āBinaryOperator { 227 case matMulOperator: 228 return tensor.MatMul(aT, bT, tensor.WithReuse(pT)) 229 case matVecMulOperator: 230 return tensor.MatVecMul(aT, bT, tensor.WithReuse(pT)) 231 case vecDotOperator: 232 return nil, errors.New("NYI") 233 case outerProdOperator: 234 return tensor.Outer(aT, bT, tensor.WithReuse(pT)) 235 case batchedMatMulOperator: 236 // checks were done when the op was created 237 return batchedMatMul(aT, bT, nil, op.transA, op.transB, false) 238 } 239 panic("Unreachable") 240 } 241 242 /* API stuff */ 243 244 // NewAddOp creates a new *ExternalOp that wraps a add op 245 func NewAddOp(a, b *Node, ctx ExecutionContext) *ExternalOp { 246 add := newElemBinOp(addOpType, a, b) 247 op := NewExternalOp(add, ctx, nil) 248 if a.Device() == CPU && b.Device() == CPU { 249 op.Device = CPU 250 return op 251 } 252 253 if a.Device() != CPU { 254 op.Device = a.Device() 255 return op 256 } 257 258 if b.Device() != CPU { 259 op.Device = b.Device() 260 return op 261 } 262 263 return op 264 } 265 266 // NewSubOp creates a new *ExternalOp that wraps a sub op 267 func NewSubOp(a, b *Node, ctx ExecutionContext) *ExternalOp { 268 sub := newEBOByType(subOpType, a.t, b.t) 269 op := NewExternalOp(sub, ctx, nil) 270 271 if a.Device() == CPU && b.Device() == CPU { 272 op.Device = CPU 273 return op 274 } 275 276 if a.Device() != CPU { 277 op.Device = a.Device() 278 return op 279 } 280 281 if b.Device() != CPU { 282 op.Device = b.Device() 283 return op 284 } 285 return op 286 } 287 288 // NewHadamardProdOp creates a new *ExternalOp that wraps a mul op 289 func NewHadamardProdOp(a, b *Node, ctx ExecutionContext) *ExternalOp { 290 mul := newEBOByType(mulOpType, a.t, b.t) 291 op := NewExternalOp(mul, ctx, nil) 292 293 if a.Device() == CPU && b.Device() == CPU { 294 op.Device = CPU 295 return op 296 } 297 298 if a.Device() != CPU { 299 op.Device = a.Device() 300 return op 301 } 302 303 if b.Device() != CPU { 304 op.Device = b.Device() 305 return op 306 } 307 return op 308 }