gorgonia.org/gorgonia@v0.9.17/typeSystem.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"github.com/chewxy/hm"
     5  	"github.com/pkg/errors"
     6  	"gorgonia.org/tensor"
     7  )
     8  
     9  // inferType infers the type of the expression
    10  func inferType(expr interface{}) (retVal hm.Type, err error) {
    11  	switch e := expr.(type) {
    12  	case *Node:
    13  		if e.isInput() || e.isConstant() {
    14  			// Var (and Let const)
    15  			return e.t, nil
    16  		}
    17  
    18  		// stop the recursive inference early - if the node already has a type, return it
    19  		if e.t != nil {
    20  			return e.t, nil
    21  		}
    22  
    23  		return inferNodeType(e.op, e.children...)
    24  	case Op:
    25  		return e.Type(), nil
    26  	case float32:
    27  		return Float32, nil
    28  	case float64:
    29  		return Float64, nil
    30  	case int:
    31  		return Int, nil
    32  	case int64:
    33  		return Int64, nil
    34  	case int32:
    35  		return Int32, nil
    36  	case bool:
    37  		return Bool, nil
    38  	default:
    39  		err = errors.Errorf(nyiTypeFail, "inferType", expr)
    40  		return
    41  	}
    42  }
    43  
    44  // Instead of using hm's Infer function, since all the nodes are pretty much hm.Apply, we write our own.
    45  func inferNodeType(op Op, children ...*Node) (retVal hm.Type, err error) {
    46  	fnType := op.Type()
    47  	if fnt, ok := fnType.(*hm.FunctionType); ok {
    48  		defer hm.ReturnFnType(fnt)
    49  	}
    50  
    51  	argTypes := hm.BorrowTypes(len(children) + 1)
    52  	defer hm.ReturnTypes(argTypes)
    53  	for i, child := range children {
    54  		if argTypes[i], err = inferType(child); err != nil {
    55  			return nil, errors.Wrapf(err, "Failed to infer type of %v", child)
    56  		}
    57  	}
    58  
    59  	b := hm.TypeVariable('b')
    60  	argTypes[len(argTypes)-1] = b
    61  
    62  	fn := hm.NewFnType(argTypes...)
    63  	defer hm.ReturnFnType(fn)
    64  
    65  	// var t0 hm.Type
    66  	var sub hm.Subs
    67  	if sub, err = hm.Unify(fn, fnType); err != nil {
    68  		return nil, errors.Wrapf(err, "Unable to unify while inferring type of %v", op)
    69  	}
    70  
    71  	var ok bool
    72  	if retVal, ok = sub.Get(b); !ok {
    73  		return nil, errors.Errorf("Expected a replacement for %v", b)
    74  	}
    75  
    76  	// return pruneReturn(t0.(*hm.FunctionType).ReturnType()), nil
    77  	return retVal, nil
    78  }
    79  
    80  func isScalarType(t hm.Type) bool {
    81  	switch tt := t.(type) {
    82  	case tensor.Dtype:
    83  		return true
    84  	case TensorType:
    85  		if tt.Dims == 0 {
    86  			return true
    87  		}
    88  		return false
    89  	case hm.TypeVariable:
    90  		panic("Type Variable is a type that is not yet known.")
    91  	default:
    92  		panic("Unhandled type")
    93  	}
    94  }
    95  
    96  func dtypeOf(t hm.Type) (retVal tensor.Dtype, err error) {
    97  	switch p := t.(type) {
    98  	case tensor.Dtype:
    99  		retVal = p
   100  	case TensorType:
   101  		return dtypeOf(p.Of)
   102  	case hm.TypeVariable:
   103  		err = errors.Errorf("instance %v does not have a dtype", p)
   104  	default:
   105  		err = errors.Errorf(nyiFail, "dtypeOf", p)
   106  		return
   107  	}
   108  
   109  	return
   110  }
   111  
   112  // DEPRECATED
   113  
   114  /*
   115  func runtimeTypeCheck(expected, got hm.Types) (of Dtype, err error) {
   116  	if len(expected) != len(got) {
   117  		err = NewError(RuntimeError, "Input length mismatch")
   118  		return
   119  	}
   120  
   121  	if of, err = dtypeOf(expected[0]); err != nil {
   122  		return
   123  	}
   124  
   125  	for i, e := range expected {
   126  		g := got[i]
   127  		if !e.Eq(g) {
   128  			err = NewError(RuntimeError, "Expected input[%d] to be %v. Got %v instead", i, e, got[i])
   129  			return
   130  		}
   131  
   132  		if i > 0 {
   133  			var gdt Dtype
   134  			if gdt, err = dtypeOf(g); err == nil {
   135  				if gdt != of {
   136  					err = NewError(RuntimeError, "Different dtypes encountered... Expected %v. Got %v instead", of, gdt)
   137  					return
   138  				}
   139  			} else {
   140  				return
   141  			}
   142  		}
   143  	}
   144  	return
   145  }
   146  */