gorgonia.org/gorgonia@v0.9.17/values_primitives.go (about) 1 package gorgonia 2 3 import ( 4 "bytes" 5 "fmt" 6 "reflect" 7 "unsafe" 8 9 "github.com/chewxy/hm" 10 "github.com/pkg/errors" 11 "gorgonia.org/tensor" 12 ) 13 14 // Scalar represents a scalar(non-array-based) value. Do note that it's the pointers of the scalar types (F64, F32, etc) that implement 15 // the Scalar interface. The main reason is primarily due to optimizations with regards to memory allocation and copying for device interoperability. 16 type Scalar interface { 17 Value 18 isScalar() bool 19 } 20 21 // F64 represents a float64 value. 22 type F64 float64 23 24 // F32 represents a float32 value. 25 type F32 float32 26 27 // I represents a int value. 28 type I int 29 30 // I64 represents a int64 value. 31 type I64 int64 32 33 // I32 represents a int32 value. 34 type I32 int32 35 36 // U8 represents a byte value. 37 type U8 byte 38 39 // B represents a bool value. 40 type B bool 41 42 func NewF64(v float64) *F64 { r := F64(v); return &r } 43 func NewF32(v float32) *F32 { r := F32(v); return &r } 44 func NewI(v int) *I { r := I(v); return &r } 45 func NewI64(v int64) *I64 { r := I64(v); return &r } 46 func NewI32(v int32) *I32 { r := I32(v); return &r } 47 func NewU8(v byte) *U8 { r := U8(v); return &r } 48 func NewB(v bool) *B { r := B(v); return &r } 49 50 /* Shape() */ 51 52 // Shape returns a scalar shape for all scalar values 53 func (v *F64) Shape() tensor.Shape { return scalarShape } 54 55 // Shape returns a scalar shape for all scalar values 56 func (v *F32) Shape() tensor.Shape { return scalarShape } 57 58 // Shape returns a scalar shape for all scalar values 59 func (v *I) Shape() tensor.Shape { return scalarShape } 60 61 // Shape returns a scalar shape for all scalar values 62 func (v *I64) Shape() tensor.Shape { return scalarShape } 63 64 // Shape returns a scalar shape for all scalar values 65 func (v *I32) Shape() tensor.Shape { return scalarShape } 66 67 // Shape returns a scalar shape for all scalar values 68 func (v *U8) Shape() tensor.Shape { return scalarShape } 69 70 // Shape returns a scalar shape for all scalar values 71 func (v *B) Shape() tensor.Shape { return scalarShape } 72 73 // Size returns 0 for all scalar Values 74 func (v *F64) Size() int { return 0 } 75 76 // Size returns 0 for all scalar Values 77 func (v *F32) Size() int { return 0 } 78 79 // Size returns 0 for all scalar Values 80 func (v *I) Size() int { return 0 } 81 82 // Size returns 0 for all scalar Values 83 func (v *I64) Size() int { return 0 } 84 85 // Size returns 0 for all scalar Values 86 func (v *I32) Size() int { return 0 } 87 88 // Size returns 0 for all scalar Values 89 func (v *U8) Size() int { return 0 } 90 91 // Size returns 0 for all scalar Values 92 func (v *B) Size() int { return 0 } 93 94 /* Data() */ 95 96 // Data returns the original representation of the Value 97 func (v *F64) Data() interface{} { return v.any() } 98 99 // Data returns the original representation of the Value 100 func (v *F32) Data() interface{} { return v.any() } 101 102 // Data returns the original representation of the Value 103 func (v *I) Data() interface{} { return v.any() } 104 105 // Data returns the original representation of the Value 106 func (v *I64) Data() interface{} { return v.any() } 107 108 // Data returns the original representation of the Value 109 func (v *I32) Data() interface{} { return v.any() } 110 111 // Data returns the original representation of the Value 112 func (v *U8) Data() interface{} { return v.any() } 113 114 // Data returns the original representation of the Value 115 func (v *B) Data() interface{} { return v.any() } 116 117 func (v *F64) any() float64 { return float64(*v) } 118 func (v *F32) any() float32 { return float32(*v) } 119 func (v *I) any() int { return int(*v) } 120 func (v *I64) any() int64 { return int64(*v) } 121 func (v *I32) any() int32 { return int32(*v) } 122 func (v *U8) any() byte { return byte(*v) } 123 func (v *B) any() bool { return bool(*v) } 124 125 /* implements fmt.Formatter */ 126 127 // Format implements fmt.Formatter 128 func (v *F64) Format(s fmt.State, c rune) { formatScalar(v, s, c) } 129 130 // Format implements fmt.Formatter 131 func (v *F32) Format(s fmt.State, c rune) { formatScalar(v, s, c) } 132 133 // Format implements fmt.Formatter 134 func (v *I) Format(s fmt.State, c rune) { formatScalar(v, s, c) } 135 136 // Format implements fmt.Formatter 137 func (v *I64) Format(s fmt.State, c rune) { formatScalar(v, s, c) } 138 139 // Format implements fmt.Formatter 140 func (v *I32) Format(s fmt.State, c rune) { formatScalar(v, s, c) } 141 142 // Format implements fmt.Formatter 143 func (v *U8) Format(s fmt.State, c rune) { formatScalar(v, s, c) } 144 145 // Format implements fmt.Formatter 146 func (v *B) Format(s fmt.State, c rune) { formatScalar(v, s, c) } 147 148 /* Dtype() */ 149 150 // Dtype returns the Dtype of the value 151 func (v *F64) Dtype() tensor.Dtype { return tensor.Float64 } 152 153 // Dtype returns the Dtype of the value 154 func (v *F32) Dtype() tensor.Dtype { return tensor.Float32 } 155 156 // Dtype returns the Dtype of the value 157 func (v *I) Dtype() tensor.Dtype { return tensor.Int } 158 159 // Dtype returns the Dtype of the value 160 func (v *I64) Dtype() tensor.Dtype { return tensor.Int64 } 161 162 // Dtype returns the Dtype of the value 163 func (v *I32) Dtype() tensor.Dtype { return tensor.Int32 } 164 165 // Dtype returns the Dtype of the value 166 func (v *U8) Dtype() tensor.Dtype { return tensor.Byte } 167 168 // Dtype returns the Dtype of the value 169 func (v *B) Dtype() tensor.Dtype { return tensor.Bool } 170 171 /* isScalar */ 172 173 func (v *F64) isScalar() bool { return true } 174 func (v *F32) isScalar() bool { return true } 175 func (v *I) isScalar() bool { return true } 176 func (v *I64) isScalar() bool { return true } 177 func (v *I32) isScalar() bool { return true } 178 func (v *U8) isScalar() bool { return true } 179 func (v *B) isScalar() bool { return true } 180 181 /* Uintptr */ 182 183 // Uintptr satisfies the tensor.Memory interface 184 func (v *F64) Uintptr() uintptr { return uintptr(unsafe.Pointer(v)) } 185 186 // Uintptr satisfies the tensor.Memory interface 187 func (v *F32) Uintptr() uintptr { return uintptr(unsafe.Pointer(v)) } 188 189 // Uintptr satisfies the tensor.Memory interface 190 func (v *I) Uintptr() uintptr { return uintptr(unsafe.Pointer(v)) } 191 192 // Uintptr satisfies the tensor.Memory interface 193 func (v *I64) Uintptr() uintptr { return uintptr(unsafe.Pointer(v)) } 194 195 // Uintptr satisfies the tensor.Memory interface 196 func (v *I32) Uintptr() uintptr { return uintptr(unsafe.Pointer(v)) } 197 198 // Uintptr satisfies the tensor.Memory interface 199 func (v *U8) Uintptr() uintptr { return uintptr(unsafe.Pointer(v)) } 200 201 // Uintptr satisfies the tensor.Memory interface 202 func (v *B) Uintptr() uintptr { return uintptr(unsafe.Pointer(v)) } 203 204 /* MemSize */ 205 206 // MemSize satisfies the tensor.Memory interface 207 func (v *F64) MemSize() uintptr { return 8 } 208 209 // MemSize satisfies the tensor.Memory interface 210 func (v *F32) MemSize() uintptr { return 4 } 211 212 // MemSize satisfies the tensor.Memory interface 213 func (v *I) MemSize() uintptr { return reflect.TypeOf(*v).Size() } 214 215 // MemSize satisfies the tensor.Memory interface 216 func (v *I64) MemSize() uintptr { return 8 } 217 218 // MemSize satisfies the tensor.Memory interface 219 func (v *I32) MemSize() uintptr { return 4 } 220 221 // MemSize satisfies the tensor.Memory interface 222 func (v *U8) MemSize() uintptr { return 1 } 223 224 // MemSize satisfies the tensor.Memory interface 225 func (v *B) MemSize() uintptr { return reflect.TypeOf(*v).Size() } 226 227 /* Pointer */ 228 229 // Pointer returns the pointer as an unsafe.Pointer. Satisfies the tensor.Memory interface 230 func (v *F64) Pointer() unsafe.Pointer { return unsafe.Pointer(v) } 231 232 // Pointer returns the pointer as an unsafe.Pointer. Satisfies the tensor.Memory interface 233 func (v *F32) Pointer() unsafe.Pointer { return unsafe.Pointer(v) } 234 235 // Pointer returns the pointer as an unsafe.Pointer. Satisfies the tensor.Memory interface 236 func (v *I) Pointer() unsafe.Pointer { return unsafe.Pointer(v) } 237 238 // Pointer returns the pointer as an unsafe.Pointer. Satisfies the tensor.Memory interface 239 func (v *I64) Pointer() unsafe.Pointer { return unsafe.Pointer(v) } 240 241 // Pointer returns the pointer as an unsafe.Pointer. Satisfies the tensor.Memory interface 242 func (v *I32) Pointer() unsafe.Pointer { return unsafe.Pointer(v) } 243 244 // Pointer returns the pointer as an unsafe.Pointer. Satisfies the tensor.Memory interface 245 func (v *U8) Pointer() unsafe.Pointer { return unsafe.Pointer(v) } 246 247 // Pointer returns the pointer as an unsafe.Pointer. Satisfies the tensor.Memory interface 248 func (v *B) Pointer() unsafe.Pointer { return unsafe.Pointer(v) } 249 250 func formatScalar(v Scalar, s fmt.State, c rune) { 251 var buf bytes.Buffer 252 var ok bool 253 254 buf.WriteRune('%') 255 256 var width int 257 if width, ok = s.Width(); ok { 258 fmt.Fprintf(&buf, "%d", width) 259 } 260 261 var prec int 262 if prec, ok = s.Precision(); ok { 263 fmt.Fprintf(&buf, ".%d", prec) 264 } 265 266 switch c { 267 case 's': 268 buf.WriteRune('v') 269 case 'd': 270 switch v.(type) { 271 case *F64, *F32, *U8, *B: 272 buf.WriteRune('v') 273 default: 274 buf.WriteRune(c) 275 } 276 case 'f', 'g': 277 switch v.(type) { 278 case *I, *I64, *I32, *U8, *B: 279 buf.WriteRune('v') 280 default: 281 buf.WriteRune(c) 282 } 283 default: 284 buf.WriteRune(c) 285 } 286 287 if s.Flag('+') { 288 s.Write([]byte(v.Dtype().String())) 289 s.Write([]byte{' '}) 290 } 291 292 fmt.Fprintf(s, buf.String(), v.Data()) 293 } 294 295 func anyToScalar(any interface{}) (Scalar, tensor.Dtype) { 296 switch at := any.(type) { 297 case Scalar: 298 return at, at.Dtype() 299 case float64: 300 return NewF64(at), Float64 301 case float32: 302 return NewF32(at), Float32 303 case int: 304 return NewI(at), Int 305 case int32: 306 return NewI32(at), Int32 307 case int64: 308 return NewI64(at), Int64 309 case byte: 310 return NewU8(at), Byte 311 case bool: 312 return NewB(at), Bool 313 default: 314 panic(fmt.Sprintf("%v(%T) not scalar/not handled", any, any)) 315 } 316 } 317 318 func anyToValue(any interface{}) (val Value, t hm.Type, dt tensor.Dtype, err error) { 319 switch a := any.(type) { 320 case Value: 321 val = a 322 t = TypeOf(a) 323 dt = a.Dtype() 324 return 325 case float64, float32, int, int64, int32, byte, bool: 326 val, dt = anyToScalar(any) 327 t = dt 328 return 329 case F64: 330 return NewF64(float64(a)), tensor.Float64, tensor.Float64, nil 331 case F32: 332 return NewF32(float32(a)), tensor.Float32, tensor.Float32, nil 333 case I: 334 return NewI(int(a)), tensor.Int, tensor.Int, nil 335 case I64: 336 return NewI64(int64(a)), tensor.Int64, tensor.Int64, nil 337 case I32: 338 return NewI32(int32(a)), tensor.Int32, tensor.Int32, nil 339 case U8: 340 return NewU8(byte(a)), tensor.Uint8, tensor.Uint8, nil 341 case B: 342 return NewB(bool(a)), tensor.Bool, tensor.Bool, nil 343 case tensor.Tensor: 344 val = a 345 t = TypeOf(a) 346 dt = a.Dtype() 347 return 348 default: 349 err = errors.Errorf("value %v of %T not yet handled", any, any) 350 return 351 } 352 } 353 354 func one(dt tensor.Dtype) Scalar { 355 switch dt { 356 case tensor.Float64: 357 return NewF64(float64(1)) 358 case tensor.Float32: 359 return NewF32(float32(1)) 360 case tensor.Int: 361 return NewI(1) 362 case tensor.Int32: 363 return NewI32(int32(1)) 364 case tensor.Int64: 365 return NewI64(int64(1)) 366 case tensor.Byte: 367 return NewU8(byte(1)) 368 case tensor.Bool: 369 return NewB(true) 370 default: 371 panic("Unhandled dtype") 372 } 373 } 374 375 func zero(dt tensor.Dtype) Scalar { 376 switch dt { 377 case tensor.Float64: 378 return NewF64(float64(0)) 379 case tensor.Float32: 380 return NewF32(float32(0)) 381 case tensor.Int: 382 return NewI(0) 383 case tensor.Int32: 384 return NewI32(int32(0)) 385 case tensor.Int64: 386 return NewI64(int64(0)) 387 case tensor.Byte: 388 return NewU8(byte(0)) 389 case tensor.Bool: 390 return NewB(false) 391 default: 392 panic("Unhandled dtype") 393 } 394 }