github.com/wzzhu/tensor@v0.9.24/types.go (about)

     1  package tensor
     2  
     3  import (
     4  	"fmt"
     5  	"math"
     6  	"reflect"
     7  	"unsafe"
     8  
     9  	"github.com/chewxy/hm"
    10  	"github.com/pkg/errors"
    11  )
    12  
    13  // Dtype represents a data type of a Tensor. Concretely it's implemented as an embedded reflect.Type
    14  // which allows for easy reflection operations. It also implements hm.Type, for type inference in Gorgonia
    15  type Dtype struct {
    16  	reflect.Type
    17  }
    18  
    19  // note: the Name() and String() methods are already defined in reflect.Type. Might as well use the composed methods
    20  
    21  func (dt Dtype) Apply(hm.Subs) hm.Substitutable                { return dt }
    22  func (dt Dtype) FreeTypeVar() hm.TypeVarSet                    { return nil }
    23  func (dt Dtype) Normalize(k, v hm.TypeVarSet) (hm.Type, error) { return dt, nil }
    24  func (dt Dtype) Types() hm.Types                               { return nil }
    25  func (dt Dtype) Format(s fmt.State, c rune)                    { fmt.Fprintf(s, "%s", dt.Name()) }
    26  func (dt Dtype) Eq(other hm.Type) bool                         { return other == dt }
    27  
    28  var numpyDtypes map[Dtype]string
    29  var reverseNumpyDtypes map[string]Dtype
    30  
    31  func init() {
    32  	numpyDtypes = map[Dtype]string{
    33  		Bool:       "b1",
    34  		Int:        fmt.Sprintf("i%d", Int.Size()),
    35  		Int8:       "i1",
    36  		Int16:      "i2",
    37  		Int32:      "i4",
    38  		Int64:      "i8",
    39  		Uint:       fmt.Sprintf("u%d", Uint.Size()),
    40  		Uint8:      "u1",
    41  		Uint16:     "u2",
    42  		Uint32:     "u4",
    43  		Uint64:     "u8",
    44  		Float32:    "f4",
    45  		Float64:    "f8",
    46  		Complex64:  "c8",
    47  		Complex128: "c16",
    48  	}
    49  
    50  	reverseNumpyDtypes = map[string]Dtype{
    51  		"b1":  Bool,
    52  		"i1":  Int8,
    53  		"i2":  Int16,
    54  		"i4":  Int32,
    55  		"i8":  Int64,
    56  		"u1":  Uint8,
    57  		"u2":  Uint16,
    58  		"u4":  Uint32,
    59  		"u8":  Uint64,
    60  		"f4":  Float32,
    61  		"f8":  Float64,
    62  		"c8":  Complex64,
    63  		"c16": Complex128,
    64  	}
    65  }
    66  
    67  // NumpyDtype returns the Numpy's Dtype equivalent. This is predominantly used in converting a Tensor to a Numpy ndarray,
    68  // however, not all Dtypes are supported
    69  func (dt Dtype) numpyDtype() (string, error) {
    70  	retVal, ok := numpyDtypes[dt]
    71  	if !ok {
    72  		return "v", errors.Errorf("Unsupported Dtype conversion to Numpy Dtype: %v", dt)
    73  	}
    74  	return retVal, nil
    75  }
    76  
    77  func fromNumpyDtype(t string) (Dtype, error) {
    78  	retVal, ok := reverseNumpyDtypes[t]
    79  	if !ok {
    80  		return Dtype{}, errors.Errorf("Unsupported Dtype conversion from %q to Dtype", t)
    81  	}
    82  	if t == "i4" && Int.Size() == 4 {
    83  		return Int, nil
    84  	}
    85  	if t == "i8" && Int.Size() == 8 {
    86  		return Int, nil
    87  	}
    88  	if t == "u4" && Uint.Size() == 4 {
    89  		return Uint, nil
    90  	}
    91  	if t == "u8" && Uint.Size() == 8 {
    92  		return Uint, nil
    93  	}
    94  	return retVal, nil
    95  }
    96  
    97  type typeclass struct {
    98  	name string
    99  	set  []Dtype
   100  }
   101  
   102  var parameterizedKinds = [...]reflect.Kind{
   103  	reflect.Array,
   104  	reflect.Chan,
   105  	reflect.Func,
   106  	reflect.Interface,
   107  	reflect.Map,
   108  	reflect.Ptr,
   109  	reflect.Slice,
   110  	reflect.Struct,
   111  }
   112  
   113  func isParameterizedKind(k reflect.Kind) bool {
   114  	for _, v := range parameterizedKinds {
   115  		if v == k {
   116  			return true
   117  		}
   118  	}
   119  	return false
   120  }
   121  
   122  // oh how nice it'd be if I could make them immutable
   123  var (
   124  	Bool       = Dtype{reflect.TypeOf(true)}
   125  	Int        = Dtype{reflect.TypeOf(int(1))}
   126  	Int8       = Dtype{reflect.TypeOf(int8(1))}
   127  	Int16      = Dtype{reflect.TypeOf(int16(1))}
   128  	Int32      = Dtype{reflect.TypeOf(int32(1))}
   129  	Int64      = Dtype{reflect.TypeOf(int64(1))}
   130  	Uint       = Dtype{reflect.TypeOf(uint(1))}
   131  	Uint8      = Dtype{reflect.TypeOf(uint8(1))}
   132  	Uint16     = Dtype{reflect.TypeOf(uint16(1))}
   133  	Uint32     = Dtype{reflect.TypeOf(uint32(1))}
   134  	Uint64     = Dtype{reflect.TypeOf(uint64(1))}
   135  	Float32    = Dtype{reflect.TypeOf(float32(1))}
   136  	Float64    = Dtype{reflect.TypeOf(float64(1))}
   137  	Complex64  = Dtype{reflect.TypeOf(complex64(1))}
   138  	Complex128 = Dtype{reflect.TypeOf(complex128(1))}
   139  	String     = Dtype{reflect.TypeOf("")}
   140  
   141  	// aliases
   142  	Byte = Uint8
   143  
   144  	// extras
   145  	Uintptr       = Dtype{reflect.TypeOf(uintptr(0))}
   146  	UnsafePointer = Dtype{reflect.TypeOf(unsafe.Pointer(&Uintptr))}
   147  )
   148  
   149  // allTypes for indexing
   150  var allTypes = &typeclass{
   151  	name: "τ",
   152  	set: []Dtype{
   153  		Bool, Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, Complex64, Complex128, String, Uintptr, UnsafePointer,
   154  	},
   155  }
   156  
   157  // specialized types indicate that there are specialized code generated for these types
   158  var specializedTypes = &typeclass{
   159  	name: "Specialized",
   160  	set: []Dtype{
   161  		Bool, Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, Complex64, Complex128, String,
   162  	},
   163  }
   164  
   165  var addableTypes = &typeclass{
   166  	name: "Addable",
   167  	set: []Dtype{
   168  		Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, Complex64, Complex128, String,
   169  	},
   170  }
   171  
   172  var numberTypes = &typeclass{
   173  	name: "Number",
   174  	set: []Dtype{
   175  		Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, Complex64, Complex128,
   176  	},
   177  }
   178  
   179  var ordTypes = &typeclass{
   180  	name: "Ord",
   181  	set: []Dtype{
   182  		Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, String,
   183  	},
   184  }
   185  
   186  var eqTypes = &typeclass{
   187  	name: "Eq",
   188  	set: []Dtype{
   189  		Bool, Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, Complex64, Complex128, String, Uintptr, UnsafePointer,
   190  	},
   191  }
   192  
   193  var unsignedTypes = &typeclass{
   194  	name: "Unsigned",
   195  	set:  []Dtype{Uint, Uint8, Uint16, Uint32, Uint64},
   196  }
   197  
   198  var signedTypes = &typeclass{
   199  	name: "Signed",
   200  	set: []Dtype{
   201  		Int, Int8, Int16, Int32, Int64, Float32, Float64, Complex64, Complex128,
   202  	},
   203  }
   204  
   205  // this typeclass is ever only used by Sub tests
   206  var signedNonComplexTypes = &typeclass{
   207  	name: "Signed NonComplex",
   208  	set: []Dtype{
   209  		Int, Int8, Int16, Int32, Int64, Float32, Float64,
   210  	},
   211  }
   212  
   213  var floatTypes = &typeclass{
   214  	name: "Float",
   215  	set: []Dtype{
   216  		Float32, Float64,
   217  	},
   218  }
   219  
   220  var complexTypes = &typeclass{
   221  	name: "Complex Numbers",
   222  	set:  []Dtype{Complex64, Complex128},
   223  }
   224  
   225  var floatcmplxTypes = &typeclass{
   226  	name: "Real",
   227  	set: []Dtype{
   228  		Float32, Float64, Complex64, Complex128,
   229  	},
   230  }
   231  
   232  var nonComplexNumberTypes = &typeclass{
   233  	name: "Non complex numbers",
   234  	set: []Dtype{
   235  		Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64,
   236  	},
   237  }
   238  
   239  // this typeclass is ever only used by Pow tests
   240  var generatableTypes = &typeclass{
   241  	name: "Generatable types",
   242  	set: []Dtype{
   243  		Bool, Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, String,
   244  	},
   245  }
   246  
   247  func isFloat(dt Dtype) bool {
   248  	return dt == Float64 || dt == Float32
   249  }
   250  
   251  func typeclassCheck(a Dtype, tc *typeclass) error {
   252  	if tc == nil {
   253  		return nil
   254  	}
   255  	for _, s := range tc.set {
   256  		if s == a {
   257  			return nil
   258  		}
   259  	}
   260  	return errors.Errorf("Type %v is not a member of %v", a, tc.name)
   261  }
   262  
   263  // RegisterNumber is a function required to register a new numerical Dtype.
   264  // This package provides the following Dtype:
   265  //		Int
   266  //		Int8
   267  //		Int16
   268  //		Int32
   269  //		Int64
   270  //		Uint
   271  //		Uint8
   272  //		Uint16
   273  //		Uint32
   274  //		Uint64
   275  //		Float32
   276  //		Float64
   277  //		Complex64
   278  //		Complex128
   279  //
   280  // If a Dtype that is registered already exists on the list, it will not be added to the list.
   281  func RegisterNumber(a Dtype) {
   282  	for _, dt := range numberTypes.set {
   283  		if dt == a {
   284  			return
   285  		}
   286  	}
   287  	numberTypes.set = append(numberTypes.set, a)
   288  	RegisterEq(a)
   289  }
   290  
   291  func RegisterFloat(a Dtype) {
   292  	for _, dt := range floatTypes.set {
   293  		if dt == a {
   294  			return
   295  		}
   296  	}
   297  	floatTypes.set = append(floatTypes.set, a)
   298  	RegisterNumber(a)
   299  	RegisterOrd(a)
   300  }
   301  
   302  // RegisterOrd registers a dtype as a type that can be typed
   303  func RegisterOrd(a Dtype) {
   304  	for _, dt := range ordTypes.set {
   305  		if dt == a {
   306  			return
   307  		}
   308  	}
   309  	ordTypes.set = append(ordTypes.set, a)
   310  	RegisterEq(a)
   311  }
   312  
   313  // RegisterEq registers a dtype as a type that can be compared for equality
   314  func RegisterEq(a Dtype) {
   315  	for _, dt := range eqTypes.set {
   316  		if dt == a {
   317  			return
   318  		}
   319  	}
   320  	eqTypes.set = append(eqTypes.set, a)
   321  	Register(a)
   322  }
   323  
   324  // Register registers a new Dtype
   325  func Register(a Dtype) {
   326  	for _, dt := range allTypes.set {
   327  		if a == dt {
   328  			return
   329  		}
   330  	}
   331  	allTypes.set = append(allTypes.set, a)
   332  }
   333  
   334  func dtypeID(a Dtype) int {
   335  	for i, v := range allTypes.set {
   336  		if a == v {
   337  			return i
   338  		}
   339  	}
   340  	return -1
   341  }
   342  
   343  // NormOrder represents the order of the norm. Ideally, we'd only represent norms with a uint/byte.
   344  // But there are norm types that are outside numerical types, such as nuclear norm and fobenius norm.
   345  // So it is internally represented by a float. If Go could use NaN and Inf as consts, it would have been best,
   346  // Instead, we use constructors. Both Nuclear and Frobenius norm types are represented as NaNs
   347  //
   348  // The using of NaN and Inf as "special" Norm types lead to the need for IsInf() and IsFrobenius() and IsNuclear() method
   349  type NormOrder float64
   350  
   351  func Norm(ord int) NormOrder   { return NormOrder(float64(ord)) }
   352  func InfNorm() NormOrder       { return NormOrder(math.Inf(1)) }
   353  func NegInfNorm() NormOrder    { return NormOrder(math.Inf(-1)) }
   354  func UnorderedNorm() NormOrder { return NormOrder(math.Float64frombits(0x7ff8000000000001)) }
   355  func FrobeniusNorm() NormOrder { return NormOrder(math.Float64frombits(0x7ff8000000000002)) }
   356  func NuclearNorm() NormOrder   { return NormOrder(math.Float64frombits(0x7ff8000000000003)) }
   357  
   358  // Valid() is a helper method that deterines if the norm order is valid. A valid norm order is
   359  // one where the fraction component is 0
   360  func (n NormOrder) Valid() bool {
   361  	switch {
   362  	case math.IsNaN(float64(n)):
   363  		nb := math.Float64bits(float64(n))
   364  		if math.Float64bits(float64(UnorderedNorm())) == nb || math.Float64bits(float64(FrobeniusNorm())) == nb || math.Float64bits(float64(NuclearNorm())) == nb {
   365  			return true
   366  		}
   367  	case math.IsInf(float64(n), 0):
   368  		return true
   369  	default:
   370  		if _, frac := math.Modf(float64(n)); frac == 0.0 {
   371  			return true
   372  		}
   373  	}
   374  	return false
   375  }
   376  
   377  // IsUnordered returns true if the NormOrder is not an ordered norm
   378  func (n NormOrder) IsUnordered() bool {
   379  	return math.Float64bits(float64(n)) == math.Float64bits(float64(UnorderedNorm()))
   380  }
   381  
   382  // IsFrobenius returns true if the NormOrder is a Frobenius norm
   383  func (n NormOrder) IsFrobenius() bool {
   384  	return math.Float64bits(float64(n)) == math.Float64bits(float64(FrobeniusNorm()))
   385  }
   386  
   387  // IsNuclear returns true if the NormOrder is a nuclear norm
   388  func (n NormOrder) IsNuclear() bool {
   389  	return math.Float64bits(float64(n)) == math.Float64bits(float64(NuclearNorm()))
   390  }
   391  
   392  func (n NormOrder) IsInf(sign int) bool {
   393  	return math.IsInf(float64(n), sign)
   394  }
   395  
   396  func (n NormOrder) String() string {
   397  	switch {
   398  	case n.IsUnordered():
   399  		return "Unordered"
   400  	case n.IsFrobenius():
   401  		return "Frobenius"
   402  	case n.IsNuclear():
   403  		return "Nuclear"
   404  	case n.IsInf(1):
   405  		return "+Inf"
   406  	case n.IsInf(-1):
   407  		return "-Inf"
   408  	default:
   409  		return fmt.Sprintf("Norm %v", float64(n))
   410  	}
   411  	panic("unreachable")
   412  }
   413  
   414  // FuncOpt are optionals for calling Tensor function.
   415  type FuncOpt func(*OpOpt)
   416  
   417  // WithIncr passes in a Tensor to be incremented.
   418  func WithIncr(incr Tensor) FuncOpt {
   419  	f := func(opt *OpOpt) {
   420  		opt.incr = incr
   421  	}
   422  	return f
   423  }
   424  
   425  // WithReuse passes in a Tensor to be reused.
   426  func WithReuse(reuse Tensor) FuncOpt {
   427  	f := func(opt *OpOpt) {
   428  		opt.reuse = reuse
   429  	}
   430  	return f
   431  }
   432  
   433  // UseSafe ensures that the operation is a safe operation (copies data, does not clobber). This is the default option for most methods and functions
   434  func UseSafe() FuncOpt {
   435  	f := func(opt *OpOpt) {
   436  		opt.unsafe = false
   437  	}
   438  	return f
   439  }
   440  
   441  // UseUnsafe ensures that the operation is an unsafe operation - data will be clobbered, and operations performed inplace
   442  func UseUnsafe() FuncOpt {
   443  	f := func(opt *OpOpt) {
   444  		opt.unsafe = true
   445  	}
   446  	return f
   447  }
   448  
   449  // AsSameType makes sure that the return Tensor is the same type as input Tensors.
   450  func AsSameType() FuncOpt {
   451  	f := func(opt *OpOpt) {
   452  		opt.same = true
   453  	}
   454  	return f
   455  }
   456  
   457  // As makes sure that the the return Tensor is of the type specified. Currently only works for FromMat64
   458  func As(t Dtype) FuncOpt {
   459  	f := func(opt *OpOpt) {
   460  		opt.t = t
   461  	}
   462  	return f
   463  }