gorgonia.org/gorgonia@v0.9.17/type.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"fmt"
     5  
     6  	"github.com/chewxy/hm"
     7  	"gorgonia.org/tensor"
     8  )
     9  
    10  var (
    11  	// Represents the types that Nodes can take in Gorgonia
    12  
    13  	// Float64 ...
    14  	Float64 = tensor.Float64
    15  	// Float32 ...
    16  	Float32 = tensor.Float32
    17  	// Int ...
    18  	Int = tensor.Int
    19  	// Int64 ...
    20  	Int64 = tensor.Int64
    21  	// Int32 ...
    22  	Int32 = tensor.Int32
    23  	// Byte ...
    24  	Byte = tensor.Uint8
    25  	// Bool ...
    26  	Bool = tensor.Bool
    27  
    28  	// Ptr is equivalent to interface{}. Ugh Ugh Ugh
    29  	Ptr = tensor.UnsafePointer
    30  
    31  	vecF64  = &TensorType{Dims: 1, Of: tensor.Float64}
    32  	vecF32  = &TensorType{Dims: 1, Of: tensor.Float32}
    33  	matF64  = &TensorType{Dims: 2, Of: tensor.Float64}
    34  	matF32  = &TensorType{Dims: 2, Of: tensor.Float32}
    35  	ten3F64 = &TensorType{Dims: 3, Of: tensor.Float64}
    36  	ten3F32 = &TensorType{Dims: 3, Of: tensor.Float32}
    37  
    38  	// removes the need for type checking
    39  	f64T = tensor.Float64 // hm.Type
    40  	f32T = tensor.Float32 // hm.Type
    41  )
    42  
    43  var acceptableDtypes = [...]tensor.Dtype{tensor.Float64, tensor.Float32, tensor.Int, tensor.Int64, tensor.Int32, tensor.Byte, tensor.Bool}
    44  
    45  /*Tensor Type*/
    46  
    47  // TensorType is a type constructor for tensors.
    48  //
    49  // Think of it as  something like this:
    50  //		data Tensor a = Tensor d a
    51  //
    52  // The shape of the Tensor is not part of TensorType.
    53  // Shape checking is relegated to the dynamic part of the program run
    54  type TensorType struct {
    55  	Dims int // dims
    56  
    57  	Of hm.Type
    58  }
    59  
    60  func makeFromTensorType(t TensorType, tv hm.TypeVariable) TensorType {
    61  	return makeTensorType(t.Dims, tv)
    62  }
    63  
    64  func makeTensorType(dims int, typ hm.Type) TensorType {
    65  	return TensorType{
    66  		Dims: dims,
    67  		Of:   typ,
    68  	}
    69  }
    70  
    71  func newTensorType(dims int, typ hm.Type) *TensorType {
    72  	switch {
    73  	case dims == 1 && typ == f64T:
    74  		return vecF64
    75  	case dims == 1 && typ == f32T:
    76  		return vecF32
    77  	case dims == 2 && typ == f64T:
    78  		return matF64
    79  	case dims == 2 && typ == f32T:
    80  		return matF32
    81  	case dims == 3 && typ == f64T:
    82  		return ten3F64
    83  	case dims == 3 && typ == f32T:
    84  		return ten3F32
    85  	}
    86  	t := borrowTensorType()
    87  	t.Dims = dims
    88  	t.Of = typ
    89  	return t
    90  }
    91  
    92  // Name returns the name of the type, which will always be "Tensor". Satisfies the hm.Type interface.
    93  func (t TensorType) Name() string { return "Tensor" }
    94  
    95  // Format implements fmt.Formatter. It is also required for the satisfication the hm.Type interface.
    96  func (t TensorType) Format(state fmt.State, c rune) {
    97  	if state.Flag('#') {
    98  		fmt.Fprintf(state, "Tensor-%d %#v", t.Dims, t.Of)
    99  	} else {
   100  		switch t.Dims {
   101  		case 1:
   102  			fmt.Fprintf(state, "Vector %v", t.Of)
   103  		case 2:
   104  			fmt.Fprintf(state, "Matrix %v", t.Of)
   105  		default:
   106  			fmt.Fprintf(state, "Tensor-%d %v", t.Dims, t.Of)
   107  		}
   108  	}
   109  }
   110  
   111  // String implements fmt.Stringer and runtime.Stringer. Satisfies the hm.Type interface.
   112  func (t TensorType) String() string { return fmt.Sprintf("%v", t) }
   113  
   114  // Types returns a list of types that TensorType contains - in this case, the type of Tensor (float64, float32, etc). Satisfies the hm.Type interface.
   115  func (t TensorType) Types() hm.Types { ts := hm.BorrowTypes(1); ts[0] = t.Of; return ts }
   116  
   117  // Normalize normalizes the type variable names (if any) in the TensorType. Satisfies the hm.Type interface.
   118  func (t TensorType) Normalize(k, v hm.TypeVarSet) (hm.Type, error) {
   119  	var err error
   120  	if t.Of, err = t.Of.Normalize(k, v); err != nil {
   121  		return nil, err
   122  	}
   123  
   124  	return t, nil
   125  }
   126  
   127  // Apply applies the substitutions on the types. Satisfies the hm.Type interface.
   128  func (t TensorType) Apply(sub hm.Subs) hm.Substitutable {
   129  	t.Of = t.Of.Apply(sub).(hm.Type)
   130  	return t
   131  }
   132  
   133  // FreeTypeVar returns any free (unbound) type variables in this type. Satisfies the hm.Type interface.
   134  func (t TensorType) FreeTypeVar() hm.TypeVarSet {
   135  	return t.Of.FreeTypeVar()
   136  }
   137  
   138  // Eq is the equality function of this type. The type of Tensor has to be the same, and for now, only the dimensions are compared.
   139  // Shape may be compared in the future for tighter type inference. Satisfies the hm.Type interface.
   140  func (t TensorType) Eq(other hm.Type) bool {
   141  	switch ot := other.(type) {
   142  	case TensorType:
   143  		return t.Of.Eq(ot.Of) && t.Dims == ot.Dims
   144  	case *TensorType:
   145  		return t.Of.Eq(ot.Of) && t.Dims == ot.Dims
   146  	}
   147  	return false
   148  }