gorgonia.org/gorgonia@v0.9.17/type.go (about) 1 package gorgonia 2 3 import ( 4 "fmt" 5 6 "github.com/chewxy/hm" 7 "gorgonia.org/tensor" 8 ) 9 10 var ( 11 // Represents the types that Nodes can take in Gorgonia 12 13 // Float64 ... 14 Float64 = tensor.Float64 15 // Float32 ... 16 Float32 = tensor.Float32 17 // Int ... 18 Int = tensor.Int 19 // Int64 ... 20 Int64 = tensor.Int64 21 // Int32 ... 22 Int32 = tensor.Int32 23 // Byte ... 24 Byte = tensor.Uint8 25 // Bool ... 26 Bool = tensor.Bool 27 28 // Ptr is equivalent to interface{}. Ugh Ugh Ugh 29 Ptr = tensor.UnsafePointer 30 31 vecF64 = &TensorType{Dims: 1, Of: tensor.Float64} 32 vecF32 = &TensorType{Dims: 1, Of: tensor.Float32} 33 matF64 = &TensorType{Dims: 2, Of: tensor.Float64} 34 matF32 = &TensorType{Dims: 2, Of: tensor.Float32} 35 ten3F64 = &TensorType{Dims: 3, Of: tensor.Float64} 36 ten3F32 = &TensorType{Dims: 3, Of: tensor.Float32} 37 38 // removes the need for type checking 39 f64T = tensor.Float64 // hm.Type 40 f32T = tensor.Float32 // hm.Type 41 ) 42 43 var acceptableDtypes = [...]tensor.Dtype{tensor.Float64, tensor.Float32, tensor.Int, tensor.Int64, tensor.Int32, tensor.Byte, tensor.Bool} 44 45 /*Tensor Type*/ 46 47 // TensorType is a type constructor for tensors. 48 // 49 // Think of it as something like this: 50 // data Tensor a = Tensor d a 51 // 52 // The shape of the Tensor is not part of TensorType. 53 // Shape checking is relegated to the dynamic part of the program run 54 type TensorType struct { 55 Dims int // dims 56 57 Of hm.Type 58 } 59 60 func makeFromTensorType(t TensorType, tv hm.TypeVariable) TensorType { 61 return makeTensorType(t.Dims, tv) 62 } 63 64 func makeTensorType(dims int, typ hm.Type) TensorType { 65 return TensorType{ 66 Dims: dims, 67 Of: typ, 68 } 69 } 70 71 func newTensorType(dims int, typ hm.Type) *TensorType { 72 switch { 73 case dims == 1 && typ == f64T: 74 return vecF64 75 case dims == 1 && typ == f32T: 76 return vecF32 77 case dims == 2 && typ == f64T: 78 return matF64 79 case dims == 2 && typ == f32T: 80 return matF32 81 case dims == 3 && typ == f64T: 82 return ten3F64 83 case dims == 3 && typ == f32T: 84 return ten3F32 85 } 86 t := borrowTensorType() 87 t.Dims = dims 88 t.Of = typ 89 return t 90 } 91 92 // Name returns the name of the type, which will always be "Tensor". Satisfies the hm.Type interface. 93 func (t TensorType) Name() string { return "Tensor" } 94 95 // Format implements fmt.Formatter. It is also required for the satisfication the hm.Type interface. 96 func (t TensorType) Format(state fmt.State, c rune) { 97 if state.Flag('#') { 98 fmt.Fprintf(state, "Tensor-%d %#v", t.Dims, t.Of) 99 } else { 100 switch t.Dims { 101 case 1: 102 fmt.Fprintf(state, "Vector %v", t.Of) 103 case 2: 104 fmt.Fprintf(state, "Matrix %v", t.Of) 105 default: 106 fmt.Fprintf(state, "Tensor-%d %v", t.Dims, t.Of) 107 } 108 } 109 } 110 111 // String implements fmt.Stringer and runtime.Stringer. Satisfies the hm.Type interface. 112 func (t TensorType) String() string { return fmt.Sprintf("%v", t) } 113 114 // Types returns a list of types that TensorType contains - in this case, the type of Tensor (float64, float32, etc). Satisfies the hm.Type interface. 115 func (t TensorType) Types() hm.Types { ts := hm.BorrowTypes(1); ts[0] = t.Of; return ts } 116 117 // Normalize normalizes the type variable names (if any) in the TensorType. Satisfies the hm.Type interface. 118 func (t TensorType) Normalize(k, v hm.TypeVarSet) (hm.Type, error) { 119 var err error 120 if t.Of, err = t.Of.Normalize(k, v); err != nil { 121 return nil, err 122 } 123 124 return t, nil 125 } 126 127 // Apply applies the substitutions on the types. Satisfies the hm.Type interface. 128 func (t TensorType) Apply(sub hm.Subs) hm.Substitutable { 129 t.Of = t.Of.Apply(sub).(hm.Type) 130 return t 131 } 132 133 // FreeTypeVar returns any free (unbound) type variables in this type. Satisfies the hm.Type interface. 134 func (t TensorType) FreeTypeVar() hm.TypeVarSet { 135 return t.Of.FreeTypeVar() 136 } 137 138 // Eq is the equality function of this type. The type of Tensor has to be the same, and for now, only the dimensions are compared. 139 // Shape may be compared in the future for tighter type inference. Satisfies the hm.Type interface. 140 func (t TensorType) Eq(other hm.Type) bool { 141 switch ot := other.(type) { 142 case TensorType: 143 return t.Of.Eq(ot.Of) && t.Dims == ot.Dims 144 case *TensorType: 145 return t.Of.Eq(ot.Of) && t.Dims == ot.Dims 146 } 147 return false 148 }