gorgonia.org/gorgonia@v0.9.17/typeSystem_test.go (about) 1 package gorgonia 2 3 import ( 4 "testing" 5 6 "github.com/chewxy/hm" 7 "github.com/stretchr/testify/assert" 8 "gorgonia.org/tensor" 9 ) 10 11 // TODO: gather edge cases 12 func TestInferNodeType(t *testing.T) { 13 assert := assert.New(t) 14 g := NewGraph() 15 var inferNodeTests = []struct { 16 name string 17 op Op 18 children Nodes 19 20 correct hm.Type 21 err bool 22 }{ 23 // simple case Float+Float 24 {"+(1, 2)", 25 newEBOByType(addOpType, Float64, Float64), 26 Nodes{ 27 newNode(In(g), WithType(Float64), WithName("a")), 28 newNode(In(g), WithType(Float64), WithName("b"))}, 29 Float64, 30 false}, 31 32 // complicated case: will error out due to mis match 33 {"+(1, 2)", 34 newEBOByType(addOpType, Float64, Float32), 35 Nodes{ 36 newNode(In(g), WithType(Float64), WithName("a")), 37 newNode(In(g), WithType(Float32), WithName("b"))}, 38 Float64, 39 true}, 40 } 41 42 for _, ints := range inferNodeTests { 43 t0, err := inferNodeType(ints.op, ints.children...) 44 switch { 45 case ints.err && err == nil: 46 t.Errorf("Expected an error in test %q", ints.name) 47 case !ints.err && err != nil: 48 t.Errorf("Error in test %q: %v", ints.name, err) 49 } 50 51 if ints.err { 52 continue 53 } 54 55 assert.True(ints.correct.Eq(t0)) 56 } 57 } 58 59 var inferTypeTests = []struct { 60 expr interface{} 61 62 correct hm.Type 63 err bool 64 }{ 65 {newEBOByType(addOpType, Float64, Float64), hm.NewFnType(hm.TypeVariable('a'), hm.TypeVariable('a'), hm.TypeVariable('a')), false}, 66 {float32(0), Float32, false}, 67 {float64(0), Float64, false}, 68 {0, Int, false}, 69 {int64(0), Int64, false}, 70 {int32(0), Int32, false}, 71 {true, Bool, false}, 72 {newNode(In(NewGraph()), WithType(Float64), WithOp(newEBOByType(addOpType, Float64, Float64))), Float64, false}, 73 74 {[]int{0}, nil, true}, 75 } 76 77 func TestInferType(t *testing.T) { 78 for i, itts := range inferTypeTests { 79 t0, err := inferType(itts.expr) 80 switch { 81 case itts.err && err == nil: 82 t.Errorf("Expected an error in infering type of %T", itts.expr) 83 case !itts.err && err != nil: 84 t.Errorf("Error while inferring type of %T: %v", itts.expr, err) 85 } 86 87 if itts.err { 88 continue 89 } 90 assert.True(t, itts.correct.Eq(t0), "Test %d: %v != %v", i, t0, itts.correct) 91 } 92 93 // way out there stuff 94 g := NewGraph() 95 n := newNode(In(g), WithOp(newEBOByType(addOpType, Float64, Float64)), WithChildren(Nodes{newNode(In(g), WithName("a"), WithType(Float64)), newNode(In(g), WithName("b"), WithType(Float64))})) 96 t0, err := inferType(n) 97 if err != nil { 98 t.Errorf("Special Case #1: %v", err) 99 } 100 t.Logf("t0: %v", t0) 101 } 102 103 var scalarTypeTests []struct { 104 name string 105 a hm.Type 106 107 isScalar bool 108 panics bool 109 } 110 111 func TestIsScalarType(t *testing.T) { 112 for _, stts := range scalarTypeTests { 113 if stts.panics { 114 f := func() { 115 isScalarType(stts.a) 116 } 117 assert.Panics(t, f) 118 continue 119 } 120 121 if isScalarType(stts.a) != stts.isScalar { 122 t.Errorf("Expected isScalarType(%v) to be scalar: %v", stts.a, stts.isScalar) 123 } 124 } 125 } 126 127 var dtypeOfTests []struct { 128 a hm.Type 129 130 correct tensor.Dtype 131 err bool 132 } 133 134 func TestDtypeOf(t *testing.T) { 135 for _, dots := range dtypeOfTests { 136 dt, err := dtypeOf(dots.a) 137 138 switch { 139 case err != nil && !dots.err: 140 t.Errorf("Error when performing dtypeOf(%v): %v", dots.a, err) 141 case err == nil && dots.err: 142 t.Errorf("Expected an error when performing dtypeOf(%v)", dots.a) 143 } 144 145 if dots.err { 146 continue 147 } 148 149 if !dots.correct.Eq(dt) { 150 t.Errorf("Incorrect dtypeOf when performing dtypeOf(%v). Expected %v. Got %v", dots.a, dots.correct, dt) 151 } 152 } 153 } 154 155 func init() { 156 scalarTypeTests = []struct { 157 name string 158 a hm.Type 159 160 isScalar bool 161 panics bool 162 }{ 163 {"Float64", Float64, true, false}, 164 {"Tensor Float64", makeTensorType(1, Float64), false, false}, 165 {"Tensor Float64 (special)", makeTensorType(0, Float64), true, false}, 166 167 // this is bad 168 {"a", hm.TypeVariable('a'), false, true}, 169 {"malformed", malformed{}, false, true}, 170 } 171 172 dtypeOfTests = []struct { 173 a hm.Type 174 175 correct tensor.Dtype 176 err bool 177 }{ 178 {Float64, Float64, false}, 179 {makeTensorType(1, Float64), Float64, false}, 180 181 // this is bad 182 // {hm.TypeVariable('a'), MAXDTYPE, true}, 183 // {hm.TypeVariable('a'), MAXDTYPE, true}, 184 // {makeTensorType(1, hm.TypeVariable('a')), MAXDTYPE, true}, 185 // {malformed{}, MAXDTYPE, true}, 186 } 187 }