gorgonia.org/gorgonia@v0.9.17/vm_tape_nocuda.go (about) 1 // +build !cuda 2 3 package gorgonia 4 5 import ( 6 "github.com/pkg/errors" 7 "gorgonia.org/tensor" 8 ) 9 10 func finalizeTapeMachine(m *tapeMachine) {} 11 12 // UseCudaFor is an option for *tapeMachine. This function is NO-OP unless the program is built with the `cuda` tag. 13 func UseCudaFor(ops ...string) VMOpt { 14 return func(m VM) {} 15 } 16 17 func (m *tapeMachine) getEngine(dev Device) tensor.Engine { return m.Engine } 18 19 func (instr *execOp) exec(m *tapeMachine) (err error) { 20 m.logf("Executing %v. Node is: %x", instr, instr.id) 21 m.enterLogScope() 22 defer m.leaveLogScope() 23 24 // Read 25 m.watchedLogf("Inputs:") 26 m.enterLogScope() 27 var inputs []Value 28 for _, reg := range instr.readFrom { 29 v := m.cpumem[reg.id] 30 inputs = append(inputs, v) 31 m.watchedLogf(m.valueFmt, v) 32 } 33 m.leaveLogScope() 34 35 // check if the destination has already been allocated 36 var usePrealloc bool 37 dest := instr.writeTo.id 38 if m.cpumem[dest] != nil { 39 usePrealloc = true 40 } 41 42 // Execute 43 var v Value 44 switch { 45 case instr.preAllocated: 46 if pd, ok := instr.op.(UsePreallocDoer); ok { 47 p := m.cpumem[instr.writeTo.id] 48 if v, err = pd.UsePreallocDo(p, inputs...); err != nil { 49 return errors.Wrapf(err, "Happened while attempting to execute %v. Node is %x. Register was: %v ", instr, instr.id, instr.writeTo.id) 50 } 51 } else { 52 // TODO: maybe warn? 53 if v, err = instr.op.Do(inputs...); err != nil { 54 return errors.Wrap(err, opDoFail) 55 } 56 } 57 case usePrealloc: 58 if pd, ok := instr.op.(UsePreallocDoer); ok { 59 p := m.cpumem[instr.writeTo.id] 60 if v, err = pd.UsePreallocDo(p, inputs...); err != nil { 61 if v, err = instr.op.Do(inputs...); err != nil { 62 return errors.Wrap(err, opDoFail) 63 } 64 } 65 } else { 66 if v, err = instr.op.Do(inputs...); err != nil { 67 return errors.Wrap(err, opDoFail) 68 } 69 } 70 case instr.useUnsafe: 71 if ud, ok := instr.op.(UnsafeDoer); ok { 72 if v, err = ud.UnsafeDo(inputs...); err != nil { 73 return errors.Wrap(err, "Failed to carry UnsafeDo()") 74 } 75 } else { 76 // TODO: warn? 77 if v, err = instr.op.Do(inputs...); err != nil { 78 return errors.Wrap(err, opDoFail) 79 } 80 } 81 default: 82 if v, err = instr.op.Do(inputs...); err != nil { 83 return errors.Wrap(err, opDoFail) 84 } 85 } 86 87 m.watchedLogf("Result:") 88 m.enterLogScope() 89 m.watchedLogf(m.valueFmt, v) 90 m.leaveLogScope() 91 // TODO: type and shape checks 92 93 // Write 94 setEngine(v, m.Engine) 95 96 m.cpumem[dest] = v 97 node := m.p.g.Node(instr.id).(*Node) 98 99 if m.trace() && (len(m.watchNodes) == 0 || m.watchNodes.Contains(node)) { 100 if err = node.bindCopy(v); err != nil { 101 return errors.Wrapf(err, "TraceExec failed to bind copy") 102 } 103 } else { 104 node.bind(v) 105 } 106 107 // this is a gradient node then, we should also bind the value to the node's dualValue 108 if m.bindDV() && node.derivOf != nil { 109 for _, src := range node.derivOf { 110 if len(m.bindNodesDV) > 0 && !m.bindNodesDV.Contains(src) { 111 continue 112 } 113 114 if src.boundTo != nil { 115 dv := dvUnit(src.boundTo) 116 117 add := newEBOByType(addOpType, TypeOf(dv.d), TypeOf(v)) 118 119 if d, err := add.UnsafeDo(dv.d, v); err == nil { 120 dv.SetDeriv(d) 121 src.bind(dv) 122 } else { 123 return err 124 } 125 } 126 } 127 128 } 129 130 m.watchedLogf("Written To: %v", instr.writeTo) 131 m.enterLogScope() 132 m.watchedLogf(m.valueFmt, v) 133 m.leaveLogScope() 134 return nil 135 } 136 137 func (instr deviceTransport) exec(m *tapeMachine) error { 138 return nil 139 }