gorgonia.org/gorgonia@v0.9.17/values.go (about) 1 package gorgonia 2 3 import ( 4 "fmt" 5 "unsafe" 6 7 "github.com/chewxy/hm" 8 "github.com/pkg/errors" 9 "gorgonia.org/tensor" 10 ) 11 12 // Value represents a value that Gorgonia accepts. At this point it is implemented by: 13 // - all scalar value types (F64, F32... etc) 14 // - *tensor.Dense 15 // - *dualValue 16 // 17 // A Value is essentially any thing that knows its own type and shape. 18 // Most importantly though, a Value is a pointer - and can be converted into a tensor.Memory. 19 // This is done for the sake of interoperability with external devices like cgo or CUDA or OpenCL. 20 // This also means for the most part most Values will be allocated on the heap. 21 // There are some performance tradeoffs made in this decision, but ultimately this is better than having to manually manage blocks of memory 22 type Value interface { 23 Shape() tensor.Shape // Shape returns the shape of the Value. Scalar values return ScalarShape() 24 Size() int // Size represents the number of elements in the Value. Note that in cases such as a *tensor.Dense, the underlying slice MAY have more elements than the Size() reports. This is correct. 25 Data() interface{} // Data returns the original representation of the Value 26 Dtype() tensor.Dtype // Dtype returns the Dtype of the value 27 28 tensor.Memory 29 fmt.Formatter 30 } 31 32 // Valuer is any type that can return a Value 33 type Valuer interface { 34 Value() Value 35 } 36 37 // Zeroer is a Value that can zero itself 38 type Zeroer interface { 39 Value 40 Zero() 41 } 42 43 // ZeroValuer is a a Value that can provide the zero-value of its type 44 type ZeroValuer interface { 45 Value 46 ZeroValue() Value 47 } 48 49 // Dtyper represents any type (typically a Value) that knows its own Dtype 50 type Dtyper interface { 51 Dtype() tensor.Dtype 52 } 53 54 // Typer represents any type (typically a Op) that knows its own Type 55 type Typer interface { 56 Type() hm.Type 57 } 58 59 // ValueEqualer represents any type that can perform a equal value check 60 type ValueEqualer interface { 61 ValueEq(Value) bool 62 } 63 64 // ValueCloser represents any type that can perform a close-value check 65 type ValueCloser interface { 66 ValueClose(interface{}) bool 67 } 68 69 // Cloner represents any type that can clone itself. 70 type Cloner interface { 71 Clone() interface{} 72 } 73 74 // CloneErrorer represents any type that can clone itself and return an error if necessary 75 type CloneErrorer interface { 76 Clone() (interface{}, error) 77 } 78 79 // CopierTo represents any type that can copy data to the destination. 80 type CopierTo interface { 81 CopyTo(dest interface{}) error 82 } 83 84 // CopierFrom represents any type that can copy data from the source provided. 85 type CopierFrom interface { 86 CopyFrom(src interface{}) error 87 } 88 89 // Setter is a any value that can Memset itself to the provided value 90 // type Setter interface { 91 // SetAll(interface{}) error 92 // } 93 94 // makeValue creates a value given a type and shape. The default value is the zero value of the type. 95 func makeValue(t hm.Type, s tensor.Shape) (retVal Value, err error) { 96 var dt tensor.Dtype 97 if dt, err = dtypeOf(t); err != nil { 98 return 99 } 100 101 if s.IsScalar() { 102 switch dt { 103 case tensor.Float64: 104 return NewF64(0), nil 105 case tensor.Float32: 106 return NewF32(0), nil 107 case tensor.Int: 108 return NewI(0), nil 109 case tensor.Int64: 110 return NewI64(0), nil 111 case tensor.Int32: 112 return NewI32(0), nil 113 case tensor.Byte: 114 return NewU8(0), nil 115 case tensor.Bool: 116 return NewB(false), nil 117 } 118 } 119 120 switch tt := t.(type) { 121 case TensorType: 122 return tensor.New(tensor.Of(dt), tensor.WithShape(s...)), nil 123 default: 124 err = errors.Errorf(nyiTypeFail, "MakeValue", tt) 125 return 126 } 127 } 128 129 func makeValueFromMem(t hm.Type, s tensor.Shape, mem tensor.Memory) (retVal Value, err error) { 130 var dt tensor.Dtype 131 if dt, err = dtypeOf(t); err != nil { 132 return 133 } 134 if s.IsScalar() { 135 return makeScalarFromMem(dt, mem) 136 } 137 138 switch tt := t.(type) { 139 case TensorType: 140 memsize := calcMemSize(dt, s) 141 return tensor.New(tensor.Of(dt), tensor.WithShape(s...), tensor.FromMemory(mem.Uintptr(), uintptr(memsize))), nil 142 case tensor.Dtype: 143 return makeScalarFromMem(tt, mem) 144 default: 145 err = errors.Errorf(nyiTypeFail, "MakeValue", tt) 146 return 147 } 148 } 149 150 func makeScalarFromMem(dt tensor.Dtype, mem tensor.Memory) (retVal Value, err error) { 151 switch dt { 152 case tensor.Float64: 153 retVal = (*F64)(unsafe.Pointer(mem.Uintptr())) 154 case tensor.Float32: 155 retVal = (*F32)(unsafe.Pointer(mem.Uintptr())) 156 case tensor.Int: 157 retVal = (*I)(unsafe.Pointer(mem.Uintptr())) 158 case tensor.Int64: 159 retVal = (*I64)(unsafe.Pointer(mem.Uintptr())) 160 case tensor.Int32: 161 retVal = (*I32)(unsafe.Pointer(mem.Uintptr())) 162 case tensor.Byte: 163 retVal = (*U8)(unsafe.Pointer(mem.Uintptr())) 164 case tensor.Bool: 165 retVal = (*B)(unsafe.Pointer(mem.Uintptr())) 166 default: 167 err = errors.Errorf(nyiTypeFail, "makeScalarFromMem", dt) 168 } 169 return 170 } 171 172 func logicalSize(s tensor.Shape) int { 173 if s.IsScalar() { 174 return 1 175 } 176 return s.TotalSize() 177 } 178 179 func calcMemSize(dt tensor.Dtype, s tensor.Shape) int64 { 180 var elemSize int64 181 if s.IsScalar() { 182 elemSize = 1 183 } else { 184 elemSize = int64(s.TotalSize()) 185 } 186 dtSize := int64(dt.Size()) 187 return elemSize * dtSize 188 } 189 190 // ScalarAsTensor returns the tensor representation of a scalar. It is particularly useful as a "reshape" of tensors of sorts 191 // 192 // The Value passed in are either Scalar, tensor.Tensor, or *dualValue. Anything else will panic. 193 func ScalarAsTensor(v Value, dims int, e tensor.Engine) Value { 194 switch a := v.(type) { 195 case Scalar: 196 sh := make(tensor.Shape, dims) 197 for i := range sh { 198 sh[i] = 1 199 } 200 return tensor.New(tensor.WithShape(sh...), tensor.Of(a.Dtype()), tensor.FromMemory(a.Uintptr(), a.MemSize()), tensor.WithEngine(e)) 201 case tensor.Tensor: 202 return a 203 case *dualValue: 204 b := new(dualValue) 205 b.Value = ScalarAsTensor(a.Value, dims, e) 206 b.d = ScalarAsTensor(a.d, dims, e) 207 return b 208 case nil: 209 return nil 210 default: 211 panic(fmt.Sprintf("Unable to convert %v to Tensor", v)) 212 } 213 }