github.com/wzzhu/tensor@v0.9.24/types.go (about) 1 package tensor 2 3 import ( 4 "fmt" 5 "math" 6 "reflect" 7 "unsafe" 8 9 "github.com/chewxy/hm" 10 "github.com/pkg/errors" 11 ) 12 13 // Dtype represents a data type of a Tensor. Concretely it's implemented as an embedded reflect.Type 14 // which allows for easy reflection operations. It also implements hm.Type, for type inference in Gorgonia 15 type Dtype struct { 16 reflect.Type 17 } 18 19 // note: the Name() and String() methods are already defined in reflect.Type. Might as well use the composed methods 20 21 func (dt Dtype) Apply(hm.Subs) hm.Substitutable { return dt } 22 func (dt Dtype) FreeTypeVar() hm.TypeVarSet { return nil } 23 func (dt Dtype) Normalize(k, v hm.TypeVarSet) (hm.Type, error) { return dt, nil } 24 func (dt Dtype) Types() hm.Types { return nil } 25 func (dt Dtype) Format(s fmt.State, c rune) { fmt.Fprintf(s, "%s", dt.Name()) } 26 func (dt Dtype) Eq(other hm.Type) bool { return other == dt } 27 28 var numpyDtypes map[Dtype]string 29 var reverseNumpyDtypes map[string]Dtype 30 31 func init() { 32 numpyDtypes = map[Dtype]string{ 33 Bool: "b1", 34 Int: fmt.Sprintf("i%d", Int.Size()), 35 Int8: "i1", 36 Int16: "i2", 37 Int32: "i4", 38 Int64: "i8", 39 Uint: fmt.Sprintf("u%d", Uint.Size()), 40 Uint8: "u1", 41 Uint16: "u2", 42 Uint32: "u4", 43 Uint64: "u8", 44 Float32: "f4", 45 Float64: "f8", 46 Complex64: "c8", 47 Complex128: "c16", 48 } 49 50 reverseNumpyDtypes = map[string]Dtype{ 51 "b1": Bool, 52 "i1": Int8, 53 "i2": Int16, 54 "i4": Int32, 55 "i8": Int64, 56 "u1": Uint8, 57 "u2": Uint16, 58 "u4": Uint32, 59 "u8": Uint64, 60 "f4": Float32, 61 "f8": Float64, 62 "c8": Complex64, 63 "c16": Complex128, 64 } 65 } 66 67 // NumpyDtype returns the Numpy's Dtype equivalent. This is predominantly used in converting a Tensor to a Numpy ndarray, 68 // however, not all Dtypes are supported 69 func (dt Dtype) numpyDtype() (string, error) { 70 retVal, ok := numpyDtypes[dt] 71 if !ok { 72 return "v", errors.Errorf("Unsupported Dtype conversion to Numpy Dtype: %v", dt) 73 } 74 return retVal, nil 75 } 76 77 func fromNumpyDtype(t string) (Dtype, error) { 78 retVal, ok := reverseNumpyDtypes[t] 79 if !ok { 80 return Dtype{}, errors.Errorf("Unsupported Dtype conversion from %q to Dtype", t) 81 } 82 if t == "i4" && Int.Size() == 4 { 83 return Int, nil 84 } 85 if t == "i8" && Int.Size() == 8 { 86 return Int, nil 87 } 88 if t == "u4" && Uint.Size() == 4 { 89 return Uint, nil 90 } 91 if t == "u8" && Uint.Size() == 8 { 92 return Uint, nil 93 } 94 return retVal, nil 95 } 96 97 type typeclass struct { 98 name string 99 set []Dtype 100 } 101 102 var parameterizedKinds = [...]reflect.Kind{ 103 reflect.Array, 104 reflect.Chan, 105 reflect.Func, 106 reflect.Interface, 107 reflect.Map, 108 reflect.Ptr, 109 reflect.Slice, 110 reflect.Struct, 111 } 112 113 func isParameterizedKind(k reflect.Kind) bool { 114 for _, v := range parameterizedKinds { 115 if v == k { 116 return true 117 } 118 } 119 return false 120 } 121 122 // oh how nice it'd be if I could make them immutable 123 var ( 124 Bool = Dtype{reflect.TypeOf(true)} 125 Int = Dtype{reflect.TypeOf(int(1))} 126 Int8 = Dtype{reflect.TypeOf(int8(1))} 127 Int16 = Dtype{reflect.TypeOf(int16(1))} 128 Int32 = Dtype{reflect.TypeOf(int32(1))} 129 Int64 = Dtype{reflect.TypeOf(int64(1))} 130 Uint = Dtype{reflect.TypeOf(uint(1))} 131 Uint8 = Dtype{reflect.TypeOf(uint8(1))} 132 Uint16 = Dtype{reflect.TypeOf(uint16(1))} 133 Uint32 = Dtype{reflect.TypeOf(uint32(1))} 134 Uint64 = Dtype{reflect.TypeOf(uint64(1))} 135 Float32 = Dtype{reflect.TypeOf(float32(1))} 136 Float64 = Dtype{reflect.TypeOf(float64(1))} 137 Complex64 = Dtype{reflect.TypeOf(complex64(1))} 138 Complex128 = Dtype{reflect.TypeOf(complex128(1))} 139 String = Dtype{reflect.TypeOf("")} 140 141 // aliases 142 Byte = Uint8 143 144 // extras 145 Uintptr = Dtype{reflect.TypeOf(uintptr(0))} 146 UnsafePointer = Dtype{reflect.TypeOf(unsafe.Pointer(&Uintptr))} 147 ) 148 149 // allTypes for indexing 150 var allTypes = &typeclass{ 151 name: "τ", 152 set: []Dtype{ 153 Bool, Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, Complex64, Complex128, String, Uintptr, UnsafePointer, 154 }, 155 } 156 157 // specialized types indicate that there are specialized code generated for these types 158 var specializedTypes = &typeclass{ 159 name: "Specialized", 160 set: []Dtype{ 161 Bool, Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, Complex64, Complex128, String, 162 }, 163 } 164 165 var addableTypes = &typeclass{ 166 name: "Addable", 167 set: []Dtype{ 168 Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, Complex64, Complex128, String, 169 }, 170 } 171 172 var numberTypes = &typeclass{ 173 name: "Number", 174 set: []Dtype{ 175 Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, Complex64, Complex128, 176 }, 177 } 178 179 var ordTypes = &typeclass{ 180 name: "Ord", 181 set: []Dtype{ 182 Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, String, 183 }, 184 } 185 186 var eqTypes = &typeclass{ 187 name: "Eq", 188 set: []Dtype{ 189 Bool, Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, Complex64, Complex128, String, Uintptr, UnsafePointer, 190 }, 191 } 192 193 var unsignedTypes = &typeclass{ 194 name: "Unsigned", 195 set: []Dtype{Uint, Uint8, Uint16, Uint32, Uint64}, 196 } 197 198 var signedTypes = &typeclass{ 199 name: "Signed", 200 set: []Dtype{ 201 Int, Int8, Int16, Int32, Int64, Float32, Float64, Complex64, Complex128, 202 }, 203 } 204 205 // this typeclass is ever only used by Sub tests 206 var signedNonComplexTypes = &typeclass{ 207 name: "Signed NonComplex", 208 set: []Dtype{ 209 Int, Int8, Int16, Int32, Int64, Float32, Float64, 210 }, 211 } 212 213 var floatTypes = &typeclass{ 214 name: "Float", 215 set: []Dtype{ 216 Float32, Float64, 217 }, 218 } 219 220 var complexTypes = &typeclass{ 221 name: "Complex Numbers", 222 set: []Dtype{Complex64, Complex128}, 223 } 224 225 var floatcmplxTypes = &typeclass{ 226 name: "Real", 227 set: []Dtype{ 228 Float32, Float64, Complex64, Complex128, 229 }, 230 } 231 232 var nonComplexNumberTypes = &typeclass{ 233 name: "Non complex numbers", 234 set: []Dtype{ 235 Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, 236 }, 237 } 238 239 // this typeclass is ever only used by Pow tests 240 var generatableTypes = &typeclass{ 241 name: "Generatable types", 242 set: []Dtype{ 243 Bool, Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, String, 244 }, 245 } 246 247 func isFloat(dt Dtype) bool { 248 return dt == Float64 || dt == Float32 249 } 250 251 func typeclassCheck(a Dtype, tc *typeclass) error { 252 if tc == nil { 253 return nil 254 } 255 for _, s := range tc.set { 256 if s == a { 257 return nil 258 } 259 } 260 return errors.Errorf("Type %v is not a member of %v", a, tc.name) 261 } 262 263 // RegisterNumber is a function required to register a new numerical Dtype. 264 // This package provides the following Dtype: 265 // Int 266 // Int8 267 // Int16 268 // Int32 269 // Int64 270 // Uint 271 // Uint8 272 // Uint16 273 // Uint32 274 // Uint64 275 // Float32 276 // Float64 277 // Complex64 278 // Complex128 279 // 280 // If a Dtype that is registered already exists on the list, it will not be added to the list. 281 func RegisterNumber(a Dtype) { 282 for _, dt := range numberTypes.set { 283 if dt == a { 284 return 285 } 286 } 287 numberTypes.set = append(numberTypes.set, a) 288 RegisterEq(a) 289 } 290 291 func RegisterFloat(a Dtype) { 292 for _, dt := range floatTypes.set { 293 if dt == a { 294 return 295 } 296 } 297 floatTypes.set = append(floatTypes.set, a) 298 RegisterNumber(a) 299 RegisterOrd(a) 300 } 301 302 // RegisterOrd registers a dtype as a type that can be typed 303 func RegisterOrd(a Dtype) { 304 for _, dt := range ordTypes.set { 305 if dt == a { 306 return 307 } 308 } 309 ordTypes.set = append(ordTypes.set, a) 310 RegisterEq(a) 311 } 312 313 // RegisterEq registers a dtype as a type that can be compared for equality 314 func RegisterEq(a Dtype) { 315 for _, dt := range eqTypes.set { 316 if dt == a { 317 return 318 } 319 } 320 eqTypes.set = append(eqTypes.set, a) 321 Register(a) 322 } 323 324 // Register registers a new Dtype 325 func Register(a Dtype) { 326 for _, dt := range allTypes.set { 327 if a == dt { 328 return 329 } 330 } 331 allTypes.set = append(allTypes.set, a) 332 } 333 334 func dtypeID(a Dtype) int { 335 for i, v := range allTypes.set { 336 if a == v { 337 return i 338 } 339 } 340 return -1 341 } 342 343 // NormOrder represents the order of the norm. Ideally, we'd only represent norms with a uint/byte. 344 // But there are norm types that are outside numerical types, such as nuclear norm and fobenius norm. 345 // So it is internally represented by a float. If Go could use NaN and Inf as consts, it would have been best, 346 // Instead, we use constructors. Both Nuclear and Frobenius norm types are represented as NaNs 347 // 348 // The using of NaN and Inf as "special" Norm types lead to the need for IsInf() and IsFrobenius() and IsNuclear() method 349 type NormOrder float64 350 351 func Norm(ord int) NormOrder { return NormOrder(float64(ord)) } 352 func InfNorm() NormOrder { return NormOrder(math.Inf(1)) } 353 func NegInfNorm() NormOrder { return NormOrder(math.Inf(-1)) } 354 func UnorderedNorm() NormOrder { return NormOrder(math.Float64frombits(0x7ff8000000000001)) } 355 func FrobeniusNorm() NormOrder { return NormOrder(math.Float64frombits(0x7ff8000000000002)) } 356 func NuclearNorm() NormOrder { return NormOrder(math.Float64frombits(0x7ff8000000000003)) } 357 358 // Valid() is a helper method that deterines if the norm order is valid. A valid norm order is 359 // one where the fraction component is 0 360 func (n NormOrder) Valid() bool { 361 switch { 362 case math.IsNaN(float64(n)): 363 nb := math.Float64bits(float64(n)) 364 if math.Float64bits(float64(UnorderedNorm())) == nb || math.Float64bits(float64(FrobeniusNorm())) == nb || math.Float64bits(float64(NuclearNorm())) == nb { 365 return true 366 } 367 case math.IsInf(float64(n), 0): 368 return true 369 default: 370 if _, frac := math.Modf(float64(n)); frac == 0.0 { 371 return true 372 } 373 } 374 return false 375 } 376 377 // IsUnordered returns true if the NormOrder is not an ordered norm 378 func (n NormOrder) IsUnordered() bool { 379 return math.Float64bits(float64(n)) == math.Float64bits(float64(UnorderedNorm())) 380 } 381 382 // IsFrobenius returns true if the NormOrder is a Frobenius norm 383 func (n NormOrder) IsFrobenius() bool { 384 return math.Float64bits(float64(n)) == math.Float64bits(float64(FrobeniusNorm())) 385 } 386 387 // IsNuclear returns true if the NormOrder is a nuclear norm 388 func (n NormOrder) IsNuclear() bool { 389 return math.Float64bits(float64(n)) == math.Float64bits(float64(NuclearNorm())) 390 } 391 392 func (n NormOrder) IsInf(sign int) bool { 393 return math.IsInf(float64(n), sign) 394 } 395 396 func (n NormOrder) String() string { 397 switch { 398 case n.IsUnordered(): 399 return "Unordered" 400 case n.IsFrobenius(): 401 return "Frobenius" 402 case n.IsNuclear(): 403 return "Nuclear" 404 case n.IsInf(1): 405 return "+Inf" 406 case n.IsInf(-1): 407 return "-Inf" 408 default: 409 return fmt.Sprintf("Norm %v", float64(n)) 410 } 411 panic("unreachable") 412 } 413 414 // FuncOpt are optionals for calling Tensor function. 415 type FuncOpt func(*OpOpt) 416 417 // WithIncr passes in a Tensor to be incremented. 418 func WithIncr(incr Tensor) FuncOpt { 419 f := func(opt *OpOpt) { 420 opt.incr = incr 421 } 422 return f 423 } 424 425 // WithReuse passes in a Tensor to be reused. 426 func WithReuse(reuse Tensor) FuncOpt { 427 f := func(opt *OpOpt) { 428 opt.reuse = reuse 429 } 430 return f 431 } 432 433 // UseSafe ensures that the operation is a safe operation (copies data, does not clobber). This is the default option for most methods and functions 434 func UseSafe() FuncOpt { 435 f := func(opt *OpOpt) { 436 opt.unsafe = false 437 } 438 return f 439 } 440 441 // UseUnsafe ensures that the operation is an unsafe operation - data will be clobbered, and operations performed inplace 442 func UseUnsafe() FuncOpt { 443 f := func(opt *OpOpt) { 444 opt.unsafe = true 445 } 446 return f 447 } 448 449 // AsSameType makes sure that the return Tensor is the same type as input Tensors. 450 func AsSameType() FuncOpt { 451 f := func(opt *OpOpt) { 452 opt.same = true 453 } 454 return f 455 } 456 457 // As makes sure that the the return Tensor is of the type specified. Currently only works for FromMat64 458 func As(t Dtype) FuncOpt { 459 f := func(opt *OpOpt) { 460 opt.t = t 461 } 462 return f 463 }