gorgonia.org/gorgonia@v0.9.17/vm_tape_cuda.go (about) 1 // +build cuda 2 3 package gorgonia 4 5 import ( 6 "github.com/pkg/errors" 7 "gorgonia.org/cu" 8 "gorgonia.org/tensor" 9 ) 10 11 func finalizeTapeMachine(m *tapeMachine) { 12 cudaLogf("Finalizing tape machine %p", m) 13 m.cleanup() 14 m.initFail() // not really a failure. Just call to detroy all the contexts and shit 15 } 16 17 func (m *tapeMachine) init() { 18 var initCUDA bool 19 cudaLogf("instructions %v", len(m.p.instructions)) 20 for _, instr := range m.p.instructions { 21 if eo, ok := instr.(*execOp); ok { 22 if _, ok := eo.op.(CUDADoer); ok { 23 initCUDA = true 24 break 25 } 26 } 27 } 28 29 // don't bother initializing contexts if no instructions were CUDA based 30 if !initCUDA { 31 cudaLogf("No CUDA ops") 32 return 33 } 34 35 if err := m.ExternMetadata.init(m.p.gpumem); err != nil { 36 m.ExternMetadata.initFail() 37 panic(err) 38 } 39 m.loadStdLib() 40 41 } 42 43 // loads the standardlib 44 func (m *tapeMachine) loadStdLib() { 45 if cudaStdLib == nil { 46 return 47 } 48 49 for _, lib := range cudaStdLib { 50 for i := range m.engines { 51 e := &m.engines[i] 52 if err := e.LoadCUDAFunc(lib.name, lib.data, lib.funcs); err != nil { 53 panic(err) 54 } 55 } 56 } 57 } 58 59 func (m *tapeMachine) getEngine(dev Device) tensor.Engine { 60 if dev == CPU { 61 return m.Engine 62 } 63 return &m.Engines()[int(dev)] 64 } 65 66 func (instr *execOp) exec(m *tapeMachine) (err error) { 67 m.logf("Executing %v. Node is: %x", instr, instr.id) 68 m.enterLogScope() 69 defer m.leaveLogScope() 70 71 enterLogScope() 72 defer leaveLogScope() 73 74 m.watchedLogf("Inputs:") 75 m.enterLogScope() 76 var inputs []Value 77 for _, reg := range instr.readFrom { 78 v := m.getValue(reg) 79 inputs = append(inputs, v) 80 m.watchedLogf(m.valueFmt, v.Uintptr()) 81 } 82 m.leaveLogScope() 83 84 toDev := instr.writeTo.device 85 var v Value 86 switch op := instr.op.(type) { 87 case CUDADoer: 88 prealloc := m.getValue(instr.writeTo) 89 if v, err = op.CUDADo(m, toDev, prealloc, inputs...); err != nil { 90 return errors.Wrapf(err, "Happened while attempting to use CUDA to execute %v. Node is %x. Register was %v", instr, instr.id, instr.writeTo.id) 91 } 92 e := &m.Engines()[int(toDev)] 93 setEngine(v, e) 94 case CLDoer: 95 default: 96 switch { 97 case instr.preAllocated: 98 if pd, ok := instr.op.(UsePreallocDoer); ok { 99 p := m.cpumem[instr.writeTo.id] 100 if v, err = pd.UsePreallocDo(p, inputs...); err != nil { 101 return errors.Wrapf(err, "Happened while attempting to execute %v. Node is %x. Register was: %v ", instr, instr.id, instr.writeTo.id) 102 } 103 } else { 104 // TODO: maybe warn? 105 if v, err = instr.op.Do(inputs...); err != nil { 106 return errors.Wrap(err, opDoFail) 107 } 108 } 109 case instr.useUnsafe: 110 if ud, ok := instr.op.(UnsafeDoer); ok { 111 if v, err = ud.UnsafeDo(inputs...); err != nil { 112 return errors.Wrap(err, "Failed to carry UnsafeDo()") 113 } 114 } else { 115 // TODO: warn? 116 if v, err = instr.op.Do(inputs...); err != nil { 117 return errors.Wrap(err, opDoFail) 118 } 119 } 120 default: 121 if v, err = instr.op.Do(inputs...); err != nil { 122 return errors.Wrap(err, opDoFail) 123 } 124 } 125 setEngine(v, m.Engine) 126 127 } 128 m.watchedLogf("Result E:") 129 m.enterLogScope() 130 if vt, ok := v.(tensor.Tensor); ok { 131 m.watchedLogf("%x | %T", v.Uintptr(), vt.Engine()) 132 } else { 133 m.watchedLogf("%x", v.Uintptr()) 134 } 135 m.leaveLogScope() 136 // TODO: type and shape checks 137 138 // Write 139 m.writeValue(instr.writeTo, v) 140 node := m.p.g.Node(instr.id).(*Node) 141 142 if m.trace() && (len(m.watchNodes) == 0 || m.watchNodes.Contains(node)) { 143 m.Signal() 144 if err = node.bindCopy(v); err != nil { 145 return errors.Wrapf(err, "TraceExec failed to bind copy") 146 } 147 } else { 148 node.bind(v) 149 } 150 151 // this is a gradient node then, we should also bind the value to the node's dualValue 152 if m.bindDV() && node.derivOf != nil { 153 for _, src := range node.derivOf { 154 if len(m.bindNodesDV) > 0 && !m.bindNodesDV.Contains(src) { 155 continue 156 } 157 158 if src.boundTo != nil { 159 dv := dvUnit(src.boundTo) 160 cudaLogf("dv.d 0x%x v 0x%x | writeTo: %v", dv.d.Uintptr(), v.Uintptr(), instr.writeTo) 161 dev := instr.writeTo.device 162 add := newEBOByType(addOpType, TypeOf(dv.d), TypeOf(v)) 163 switch dev { 164 case CPU: 165 if d, err := add.UnsafeDo(dv.d, v); err == nil { 166 dv.SetDeriv(d) 167 src.bind(dv) 168 } else { 169 return err 170 } 171 default: 172 // temporarily allocate a valu 173 ctx := m.Contexts()[int(dev)] 174 175 dt := dv.d.Dtype() 176 shp := dv.d.Shape() 177 memsize := calcMemSize(dt, shp) 178 179 var mem tensor.Memory 180 if mem, err = m.Get(dev, memsize); err != nil { 181 return errors.Wrapf(err, "Unable to allocate %v bytes from %v", memsize, dev) 182 } 183 184 var d Value 185 if d, err = makeValueFromMem(dt, shp, mem); err != nil { 186 return 187 } 188 189 // copy dv.d to d 190 ctx.MemcpyHtoD(mem.(cu.DevicePtr), dv.d.Pointer(), memsize) 191 192 // perform the op 193 if _, err = add.CUDADo(m, dev, d, d, v); err != nil { 194 return 195 } 196 // copy the value back into dv.d 197 ctx.MemcpyDtoH(dv.d.Pointer(), mem.(cu.DevicePtr), memsize) 198 m.Put(dev, mem, memsize) // then free it 199 200 src.bind(dv) 201 // the CPU method is correct. This method is correct for MOST cases, but will not be correct under some other circumstances 202 // ctx.MemcpyDtoH(dv.d.Pointer(), cu.DevicePtr(v.Uintptr()), instr.size) 203 } 204 } 205 } 206 207 } 208 209 m.watchedLogf("Written To: %v", instr.writeTo) 210 m.enterLogScope() 211 m.watchedLogf(m.valueFmt, v.Uintptr()) 212 m.leaveLogScope() 213 214 return nil 215 } 216 217 func (instr deviceTransport) exec(m *tapeMachine) (err error) { 218 m.logf("Executing %v", instr) 219 from := m.getValue(instr.from) 220 to := m.getValue(instr.to) 221 222 var ctx *cu.BatchedContext 223 switch { 224 case instr.from.device == CPU && instr.to.device != CPU: 225 memsize := int64(from.MemSize()) 226 ctx = m.Contexts()[int(instr.to.device)] 227 ctx.MemcpyHtoD(cu.DevicePtr(to.Uintptr()), from.Pointer(), memsize) 228 case instr.from.device != CPU && instr.to.device == CPU: 229 dt := from.Dtype() 230 memsize := calcMemSize(dt, from.Shape()) 231 ctx = m.Contexts()[int(instr.from.device)] 232 ctx.MemcpyDtoH(to.Pointer(), cu.DevicePtr(from.Uintptr()), memsize) 233 234 // when copying from device to host, it's assumed that the host will want to immediately use 235 // so signal the DoWork 236 m.Signal() 237 } 238 239 return nil 240 }