gorgonia.org/gorgonia@v0.9.17/typeSystem.go (about) 1 package gorgonia 2 3 import ( 4 "github.com/chewxy/hm" 5 "github.com/pkg/errors" 6 "gorgonia.org/tensor" 7 ) 8 9 // inferType infers the type of the expression 10 func inferType(expr interface{}) (retVal hm.Type, err error) { 11 switch e := expr.(type) { 12 case *Node: 13 if e.isInput() || e.isConstant() { 14 // Var (and Let const) 15 return e.t, nil 16 } 17 18 // stop the recursive inference early - if the node already has a type, return it 19 if e.t != nil { 20 return e.t, nil 21 } 22 23 return inferNodeType(e.op, e.children...) 24 case Op: 25 return e.Type(), nil 26 case float32: 27 return Float32, nil 28 case float64: 29 return Float64, nil 30 case int: 31 return Int, nil 32 case int64: 33 return Int64, nil 34 case int32: 35 return Int32, nil 36 case bool: 37 return Bool, nil 38 default: 39 err = errors.Errorf(nyiTypeFail, "inferType", expr) 40 return 41 } 42 } 43 44 // Instead of using hm's Infer function, since all the nodes are pretty much hm.Apply, we write our own. 45 func inferNodeType(op Op, children ...*Node) (retVal hm.Type, err error) { 46 fnType := op.Type() 47 if fnt, ok := fnType.(*hm.FunctionType); ok { 48 defer hm.ReturnFnType(fnt) 49 } 50 51 argTypes := hm.BorrowTypes(len(children) + 1) 52 defer hm.ReturnTypes(argTypes) 53 for i, child := range children { 54 if argTypes[i], err = inferType(child); err != nil { 55 return nil, errors.Wrapf(err, "Failed to infer type of %v", child) 56 } 57 } 58 59 b := hm.TypeVariable('b') 60 argTypes[len(argTypes)-1] = b 61 62 fn := hm.NewFnType(argTypes...) 63 defer hm.ReturnFnType(fn) 64 65 // var t0 hm.Type 66 var sub hm.Subs 67 if sub, err = hm.Unify(fn, fnType); err != nil { 68 return nil, errors.Wrapf(err, "Unable to unify while inferring type of %v", op) 69 } 70 71 var ok bool 72 if retVal, ok = sub.Get(b); !ok { 73 return nil, errors.Errorf("Expected a replacement for %v", b) 74 } 75 76 // return pruneReturn(t0.(*hm.FunctionType).ReturnType()), nil 77 return retVal, nil 78 } 79 80 func isScalarType(t hm.Type) bool { 81 switch tt := t.(type) { 82 case tensor.Dtype: 83 return true 84 case TensorType: 85 if tt.Dims == 0 { 86 return true 87 } 88 return false 89 case hm.TypeVariable: 90 panic("Type Variable is a type that is not yet known.") 91 default: 92 panic("Unhandled type") 93 } 94 } 95 96 func dtypeOf(t hm.Type) (retVal tensor.Dtype, err error) { 97 switch p := t.(type) { 98 case tensor.Dtype: 99 retVal = p 100 case TensorType: 101 return dtypeOf(p.Of) 102 case hm.TypeVariable: 103 err = errors.Errorf("instance %v does not have a dtype", p) 104 default: 105 err = errors.Errorf(nyiFail, "dtypeOf", p) 106 return 107 } 108 109 return 110 } 111 112 // DEPRECATED 113 114 /* 115 func runtimeTypeCheck(expected, got hm.Types) (of Dtype, err error) { 116 if len(expected) != len(got) { 117 err = NewError(RuntimeError, "Input length mismatch") 118 return 119 } 120 121 if of, err = dtypeOf(expected[0]); err != nil { 122 return 123 } 124 125 for i, e := range expected { 126 g := got[i] 127 if !e.Eq(g) { 128 err = NewError(RuntimeError, "Expected input[%d] to be %v. Got %v instead", i, e, got[i]) 129 return 130 } 131 132 if i > 0 { 133 var gdt Dtype 134 if gdt, err = dtypeOf(g); err == nil { 135 if gdt != of { 136 err = NewError(RuntimeError, "Different dtypes encountered... Expected %v. Got %v instead", of, gdt) 137 return 138 } 139 } else { 140 return 141 } 142 } 143 } 144 return 145 } 146 */