gorgonia.org/gorgonia@v0.9.17/vm_genera_cuda.go (about) 1 // +build cuda 2 3 package gorgonia 4 5 import ( 6 "log" 7 8 "github.com/pkg/errors" 9 "gorgonia.org/tensor" 10 ) 11 12 func (m *lispMachine) init() error { 13 if err := m.prepGraph(); err != nil { 14 return err 15 } 16 17 // VERY simple data analysis (even simpler than the one used in Compile) 18 // using replaceWithSelf reduces the need for hashing, hence less work is required 19 // However this also means that CSE won't be performed 20 df := newdataflow() 21 df.replaceWithSelf(m.sorted) 22 df.buildIntervals(m.sorted) 23 df.fixIntervalDevices(m.sorted) 24 m.df = df 25 26 if err := m.calcMemSize(); err != nil { 27 log.Printf("err1") 28 return err 29 } 30 31 if len(m.gpumem) == 0 { 32 m.ForceCPU() 33 return nil 34 } 35 36 if err := m.ExternMetadata.init(m.gpumem); err != nil { 37 m.ExternMetadata.initFail() 38 return err 39 } 40 m.loadStdLib() 41 42 if len(m.engines) == 0 { 43 m.ForceCPU() 44 } 45 return nil 46 } 47 48 func finalizeLispMachine(m *lispMachine) { 49 m.ExternMetadata.cleanup() 50 m.ExternMetadata.initFail() 51 } 52 53 func (m *lispMachine) WorkAvailable() <-chan bool { 54 if m.ExternMetadata.WorkAvailable() == nil { 55 return nil 56 } 57 return m.ExternMetadata.WorkAvailable() 58 } 59 60 func (m *lispMachine) calcMemSize() (err error) { 61 compileLogf("calcmemsize") 62 enterLogScope() 63 defer leaveLogScope() 64 var cpumem int64 65 var gpumem []int64 66 for _, n := range m.sorted { 67 interv := m.df.intervals[n] 68 dev := interv.result.device 69 compileLogf("n: %v | %v", n, interv) 70 71 var dt tensor.Dtype 72 if dt, err = dtypeOf(n.t); err != nil { 73 if n.isStmt { 74 continue 75 } 76 return errors.Wrapf(err, "Cannot calculate memsize of n(%v)", n) 77 } 78 switch { 79 case n.isArg(): 80 cpumem += calcMemSize(dt, n.Shape()) 81 case n.isStmt: 82 default: 83 // if !n.op.ReturnsPtr() { 84 if dev != CPU { 85 if len(gpumem) < int(dev)+1 { 86 diff := int(dev) + 1 - len(gpumem) 87 gpumem = append(gpumem, make([]int64, diff)...) 88 } 89 } 90 91 switch dev { 92 case CPU: 93 cpumem += calcMemSize(dt, n.Shape()) 94 default: 95 compileLogf("n: %v. AddedDEF", n) 96 gpumem[int(dev)] += 4 * calcMemSize(dt, n.Shape()) 97 } 98 // } 99 } 100 } 101 102 m.cpumem = cpumem 103 m.gpumem = gpumem 104 return nil 105 } 106 107 func (m *lispMachine) execDevTrans(op devTrans, n *Node, children Nodes) (err error) { 108 child := children[0] 109 m.logf("DevTrans: %v | %v | %v", op, n.boundTo, child.boundTo) 110 111 var dv *dualValue 112 var cv, cd, v, d Value 113 if child.boundTo != nil { 114 var ok bool 115 if dv, ok = child.boundTo.(*dualValue); ok { 116 cv = dv.Value 117 cd = dv.d 118 } else { 119 cv = child.boundTo 120 } 121 } else { 122 err = errors.Errorf("Cannot execute transfer when there is no value in child") 123 return 124 } 125 126 var synchronous bool 127 if op.to == CPU && op.from != CPU { 128 synchronous = true 129 } 130 131 if v, err = m.Transfer(op.to, op.from, cv, false); err != nil { 132 return 133 } 134 135 if cd != nil { 136 if d, err = m.Transfer(op.to, op.from, cd, false); err != nil { 137 return 138 } 139 } else { 140 var mem tensor.Memory 141 if mem, err = m.Get(op.to, calcMemSize(cv.Dtype(), child.shape)); err != nil { 142 return 143 } 144 if _, err = makeValueFromMem(child.t, child.shape, mem); err != nil { 145 return 146 } 147 } 148 149 if synchronous { 150 m.Signal() 151 } 152 153 dv = new(dualValue) 154 dv.Value = v 155 dv.d = d 156 n.boundTo = dv 157 158 return nil 159 } 160 161 // loads the standardlib 162 func (m *lispMachine) loadStdLib() { 163 if cudaStdLib == nil { 164 return 165 } 166 167 for _, lib := range cudaStdLib { 168 for i := range m.engines { 169 e := &m.engines[i] 170 if err := e.LoadCUDAFunc(lib.name, lib.data, lib.funcs); err != nil { 171 panic(err) 172 } 173 } 174 } 175 } 176 177 // ForceCPU forces the lispMachine to have the nodes run on the CPU 178 func (m *lispMachine) ForceCPU() { 179 m.cleanup() 180 m.initFail() 181 m.df = nil 182 183 for _, n := range m.sorted { 184 n.dataOn = CPU 185 } 186 187 // remove devTrans if any 188 for i := 0; i < len(m.sorted); i++ { 189 n := m.sorted[i] 190 if _, ok := n.op.(devTrans); ok { 191 copy(m.sorted[i:], m.sorted[i+1:]) 192 m.sorted = m.sorted[:len(m.sorted)-1] 193 i-- 194 } 195 } 196 }