gorgonia.org/gorgonia@v0.9.17/cuda/engine.go (about) 1 package cuda 2 3 import "C" 4 5 import ( 6 "fmt" 7 "sync" 8 "unsafe" 9 10 "github.com/pkg/errors" 11 "gorgonia.org/cu" 12 "gorgonia.org/cu/blas" 13 "gorgonia.org/cu/dnn" 14 "gorgonia.org/tensor" 15 ) 16 17 var ( 18 _ tensor.Adder = &Engine{} 19 _ tensor.Suber = &Engine{} 20 _ tensor.Muler = &Engine{} 21 _ tensor.Diver = &Engine{} 22 _ tensor.Power = &Engine{} 23 _ tensor.Moder = &Engine{} 24 // _ tensor.FMAer = &Engine{} 25 _ tensor.MatMuler = &Engine{} 26 _ tensor.MatVecMuler = &Engine{} 27 _ tensor.OuterProder = &Engine{} 28 // _ tensor.Dotter = &Engine{} 29 // _ tensor.SVDer = &Engine{} 30 _ tensor.Lter = &Engine{} 31 _ tensor.Lteer = &Engine{} 32 _ tensor.Gter = &Engine{} 33 _ tensor.Gteer = &Engine{} 34 _ tensor.ElEqer = &Engine{} 35 ) 36 37 // Engine is a CUDA engine 38 type Engine struct { 39 tensor.Engine 40 sync.Mutex 41 42 a bfc 43 b cublas.Standard 44 c cu.BatchedContext 45 d cu.Device 46 f map[string]cu.Function 47 m map[string]cu.Module 48 n cudnn.Context 49 50 warp int 51 mtpb int 52 mgdx int 53 mgdy int 54 mgdz int 55 mbdx int 56 mbdy int 57 mbdz int 58 59 freeMem int64 60 totalMem int64 61 62 syncChan chan struct{} 63 finishChan chan struct{} 64 finishChan2 chan struct{} 65 workAvailable chan bool 66 err error 67 initialized bool 68 running bool 69 } 70 71 // AllocAccessible returns true because the engine return Go-accessible memory pointers 72 func (e *Engine) AllocAccessible() bool { return true } 73 74 // Alloc allocates a chunk of certain size from engine memory 75 func (e *Engine) Alloc(size int64) (tensor.Memory, error) { 76 // return e.c.MemAllocManaged(size, cu.AttachGlobal) 77 return e.Get(size) 78 } 79 80 // AllocFlags returns allocation flags 81 func (e *Engine) AllocFlags() (tensor.MemoryFlag, tensor.DataOrder) { 82 return tensor.MakeMemoryFlag(tensor.ManuallyManaged), tensor.ColMajor 83 } 84 85 // Free rees memory 86 func (e *Engine) Free(mem tensor.Memory, size int64) error { 87 // e.c.MemFree(mem.(cu.DevicePtr)) 88 // return e.c.Error() 89 e.Put(mem, size) 90 return nil 91 } 92 93 func (e *Engine) Memset(mem tensor.Memory, val interface{}) error { 94 panic("not implemented") 95 } 96 97 func (e *Engine) Memclr(mem tensor.Memory) { 98 panic("not implemented") 99 } 100 101 // Memcpy is part of the implementation of tensor.Engine. It is eager, and will signal the context to actually do work. 102 // The memory that will be copied is up to the smallest of sizes between dst and src. 103 // i.e. if dst is 8 bytes and src is 16 bytes, only the first 8 bytes of src will be copied. 104 // Likewise, if dst is 20 bytes and src is 3 bytes, only 3 bytes will be copied. 105 func (e *Engine) Memcpy(dst tensor.Memory, src tensor.Memory) error { 106 sSize := src.MemSize() 107 dSize := dst.MemSize() 108 109 var size int64 110 switch { 111 case dSize < sSize: 112 size = int64(dSize) 113 case sSize < dSize: 114 size = int64(sSize) 115 default: 116 size = int64(dSize) 117 } 118 d := cu.DevicePtr(dst.Uintptr()) 119 s := cu.DevicePtr(src.Uintptr()) 120 e.c.Memcpy(d, s, size) 121 e.Signal() 122 <-e.syncChan 123 return e.c.Error() 124 } 125 126 func (e *Engine) memcpy(dst cu.DevicePtr, src cu.DevicePtr, size int64) { 127 e.c.Memcpy(dst, src, size) 128 } 129 130 func (e *Engine) Accessible(mem tensor.Memory) (tensor.Memory, error) { 131 panic("not implemented") 132 } 133 134 // WorksWith returns true because the data order can be directly worked with 135 func (e *Engine) WorksWith(order tensor.DataOrder) bool { return true } 136 137 // NonStdAlloc nothing instead of running the default built in allocator 138 func (e *Engine) NonStdAlloc() {} 139 140 // Errors returns an error message 141 func (e *Engine) Errors() error { return e.c.Errors() } 142 143 // NaNChecker checks that the tensor contains a NaN 144 func (e *Engine) HasNaN(a tensor.Tensor) (bool, error) { 145 dt := a.Dtype() 146 name := fmt.Sprintf("misc.hasNaN_f%v", int(dt.Size()*8)) 147 148 if !e.HasFunc(name) { 149 return false, errors.Errorf("Unable to perform HasNaN(). The tensor engine does not have the function %q", name) 150 } 151 152 mem := cu.DevicePtr(a.Uintptr()) 153 size := int64(logicalSize(a.Shape())) 154 fn := e.f[name] 155 156 var retVal C.int 157 gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ := e.ElemGridSize(int(size)) 158 args := []unsafe.Pointer{ 159 unsafe.Pointer(&mem), 160 unsafe.Pointer(&size), 161 unsafe.Pointer(&retVal), 162 } 163 e.c.LaunchAndSync(fn, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, 0, cu.NoStream, args) 164 e.DoWork() 165 return int(retVal) > 0, e.c.Error() 166 } 167 168 // InfChecker checks that the tensor contains a Inf 169 func (e *Engine) HasInf(a tensor.Tensor) (bool, error) { 170 dt := a.Dtype() 171 name := fmt.Sprintf("misc.hasInf_f%v", int(dt.Size()*8)) 172 173 if !e.HasFunc(name) { 174 return false, errors.Errorf("Unable to perform HasInf(). The tensor engine does not have the function %q", name) 175 } 176 177 mem := cu.DevicePtr(a.Uintptr()) 178 size := int64(logicalSize(a.Shape())) 179 fn := e.f[name] 180 181 var retVal C.int 182 gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ := e.ElemGridSize(int(size)) 183 args := []unsafe.Pointer{ 184 unsafe.Pointer(&mem), 185 unsafe.Pointer(&size), 186 unsafe.Pointer(&retVal), 187 } 188 e.c.LaunchAndSync(fn, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, 0, cu.NoStream, args) 189 e.DoWork() 190 return int(retVal) > 0, e.c.Error() 191 }