gorgonia.org/gorgonia@v0.9.17/execution.go (about) 1 package gorgonia 2 3 import ( 4 "github.com/pkg/errors" 5 "gorgonia.org/tensor" 6 ) 7 8 // Arena is a representation of a pool of tensor.Memory 9 type Arena interface { 10 Get(dev Device, size int64) (tensor.Memory, error) // Get returns a NoOpError when it cannot get a memory. Please allocate 11 GetFromValue(dev Device, v Value) (tensor.Memory, error) // Gets a memory and copies the values into the memory and returns it. 12 Put(dev Device, mem tensor.Memory, size int64) // puts the memory back into the arena 13 PutValue(dev Device, v Value) // puts the memory back into the arena 14 15 // Transfers memory from device to device 16 Transfer(toDev, fromDev Device, v Value, synchronous bool) (retVal Value, err error) 17 } 18 19 // External is a representation of an external device (cuda/cgo/openCL), conceptually modelled as a machine. 20 type External interface { 21 Arena 22 Signal() // signals the machine to do work 23 Sync() chan struct{} 24 } 25 26 // ExecutionContext informs how an op should be executed 27 type ExecutionContext struct { 28 External 29 Device 30 } 31 32 // ExternalOp is an op that contains an external context. This allows for ops to be run without needing a VM 33 type ExternalOp struct { 34 Op 35 ExecutionContext 36 37 Prealloc Value 38 Incr Value // is this a Incr? IncrDoers have higher precedence over PreallocDo 39 UseUnsafe bool // Is this an unsafe op? Lowest of all "special" Dos 40 } 41 42 // NewExternalOp creates a new *ExternalOp. 43 func NewExternalOp(op Op, ctx ExecutionContext, prealloc Value) *ExternalOp { 44 retVal := &ExternalOp{ 45 Op: op, 46 ExecutionContext: ctx, 47 Prealloc: prealloc, 48 UseUnsafe: false, 49 } 50 51 return retVal 52 } 53 54 // DetermineDevice ... 55 func (op *ExternalOp) DetermineDevice(inputs Nodes, output *Node) error { 56 dev := output.dataOn 57 var inDev Device = -2 58 var allSame bool 59 for _, in := range inputs { 60 if in.dataOn != dev { 61 allSame = false 62 } 63 64 if inDev == -2 { 65 inDev = in.dataOn 66 continue 67 } 68 if in.dataOn != inDev && in.dataOn != dev { 69 return errors.Errorf("Cannot automatically determine device.") 70 } 71 } 72 73 if !allSame { 74 return errors.Errorf("Not all the same devices") 75 } 76 op.Device = dev 77 return nil 78 } 79 80 // Do performs the op, 81 func (op *ExternalOp) Do(vals ...Value) (Value, error) { 82 if op.Device == CPU { 83 switch { 84 case op.Incr != nil: 85 if id, ok := op.Op.(IncrDoer); ok { 86 if err := id.IncrDo(op.Incr, vals...); err != nil { 87 if ver, ok := err.(Valuer); ok { 88 return ver.Value(), nil 89 } 90 return nil, err 91 } 92 return op.Incr, nil 93 } 94 case op.Prealloc != nil: 95 if pd, ok := op.Op.(UsePreallocDoer); ok { 96 pd.UsePreallocDo(op.Prealloc, vals...) 97 } 98 retVal, err := op.Op.Do(vals...) 99 if err != nil { 100 return retVal, err 101 } 102 return Copy(op.Prealloc, retVal) 103 case op.UseUnsafe: 104 if ud, ok := op.Op.(UnsafeDoer); ok { 105 return ud.UnsafeDo(vals...) 106 } 107 fallthrough 108 default: 109 return op.Op.Do(vals...) 110 } 111 } 112 113 switch o := op.Op.(type) { 114 case CUDADoer: 115 if op.Incr != nil { 116 v, err := o.CUDADo(op.External, op.Device, op.Prealloc, vals...) 117 if err != nil { 118 return nil, err 119 } 120 121 add := newEBOByType(addOpType, TypeOf(op.Incr), TypeOf(v)) 122 addOp := NewExternalOp(add, op.ExecutionContext, nil) 123 addOp.UseUnsafe = true 124 retVal, err := addOp.Do(op.Incr, v) 125 return retVal, err 126 } 127 return o.CUDADo(op.External, op.Device, op.Prealloc, vals...) 128 case CLDoer: 129 case IncrDoer: 130 if op.Incr != nil { 131 if err := o.IncrDo(op.Incr, vals...); err != nil { 132 if ver, ok := err.(Valuer); ok { 133 return ver.Value(), nil 134 } 135 return nil, err 136 } 137 return op.Incr, nil 138 } 139 return op.Op.Do(vals...) 140 case UsePreallocDoer: 141 if op.Prealloc != nil { 142 return o.UsePreallocDo(op.Prealloc, vals...) 143 } 144 return op.Op.Do(vals...) 145 case UnsafeDoer: 146 if op.UseUnsafe { 147 return o.UnsafeDo(vals...) 148 } 149 return op.Op.Do(vals...) 150 default: 151 return o.Do(vals...) 152 } 153 154 panic("Unreachable") 155 } 156 157 func (op *ExternalOp) String() string { 158 return op.Op.String() 159 }