gorgonia.org/gorgonia@v0.9.17/type_test.go (about) 1 package gorgonia 2 3 import ( 4 "fmt" 5 "testing" 6 7 "github.com/chewxy/hm" 8 "github.com/stretchr/testify/assert" 9 "gorgonia.org/tensor" 10 ) 11 12 func TestDtypeBasics(t *testing.T) { 13 assert := assert.New(t) 14 15 var t0 tensor.Dtype 16 var a hm.TypeVariable 17 18 t0 = Float64 19 a = hm.TypeVariable('a') 20 21 assert.True(t0.Eq(Float64)) 22 assert.False(t0.Eq(Float32)) 23 assert.False(t0.Eq(a)) 24 assert.Nil(t0.Types()) 25 26 k := hm.TypeVarSet{'x', 'y'} 27 v := hm.TypeVarSet{'a', 'b'} 28 t1, err := t0.Normalize(k, v) 29 assert.Nil(err) 30 assert.Equal(t0, t1) 31 32 // for completeness sake 33 assert.Equal("float64", t0.Name()) 34 assert.Equal("float64", t0.String()) 35 assert.Equal("float64", fmt.Sprintf("%v", t0)) 36 37 } 38 39 func TestDtypeOps(t *testing.T) { 40 var sub hm.Subs 41 var a hm.TypeVariable 42 var err error 43 44 a = hm.TypeVariable('a') 45 46 if sub, err = hm.Unify(a, Float64); err != nil { 47 t.Fatal(err) 48 } 49 50 if repl, ok := sub.Get(a); !ok { 51 t.Errorf("Expected a substitution for %v", a) 52 } else if repl != Float64 { 53 t.Errorf("Expecetd substitution for %v to be %v. Got %v instead", a, Float64, repl) 54 } 55 56 if sub, err = hm.Unify(Float64, a); err != nil { 57 t.Fatal(err) 58 } 59 60 if repl, ok := sub.Get(a); !ok { 61 t.Errorf("Expected a substitution for %v", a) 62 } else if repl != Float64 { 63 t.Errorf("Expecetd substitution for %v to be %v. Got %v instead", a, Float64, repl) 64 } 65 } 66 67 var tensorTypeTests []struct { 68 a, b TensorType 69 70 eq bool 71 types hm.Types 72 format string 73 } 74 75 func TestTensorTypeBasics(t *testing.T) { 76 assert := assert.New(t) 77 78 for _, ttts := range tensorTypeTests { 79 // Equality 80 if ttts.eq { 81 assert.True(ttts.a.Eq(ttts.b), "TensorType Equality failed: %#v != %#v", ttts.a, ttts.b) 82 } else { 83 assert.False(ttts.a.Eq(ttts.b), "TensorType Equality: %v == %v should be false", ttts.a, ttts.b) 84 } 85 86 // Types 87 assert.Equal(ttts.types, ttts.a.Types()) 88 89 // string and format for completeness sake 90 assert.Equal("Tensor", ttts.a.Name()) 91 assert.Equal(ttts.format, fmt.Sprintf("%v", ttts.a)) 92 assert.Equal(fmt.Sprintf("Tensor-%d %v", ttts.a.Dims, ttts.a.Of), fmt.Sprintf("%#v", ttts.a)) 93 } 94 95 tt := makeTensorType(1, hm.TypeVariable('x')) 96 k := hm.TypeVarSet{'x', 'y'} 97 v := hm.TypeVarSet{'a', 'b'} 98 tt2, err := tt.Normalize(k, v) 99 if err != nil { 100 t.Error(err) 101 } 102 assert.True(tt2.Eq(makeTensorType(1, hm.TypeVariable('a')))) 103 104 } 105 106 var tensorOpsTest []struct { 107 name string 108 109 a hm.Type 110 b hm.Type 111 112 aSub hm.Type 113 } 114 115 func TestTensorTypeOps(t *testing.T) { 116 for _, tots := range tensorOpsTest { 117 sub, err := hm.Unify(tots.a, tots.b) 118 if err != nil { 119 t.Error(err) 120 continue 121 } 122 123 if subst, ok := sub.Get(hm.TypeVariable('a')); !ok { 124 t.Errorf("Expected a substitution for a") 125 } else if !subst.Eq(tots.aSub) { 126 t.Errorf("Expected substitution to be %v. Got %v instead", tots.aSub, subst) 127 } 128 } 129 } 130 131 func init() { 132 tensorTypeTests = []struct { 133 a, b TensorType 134 135 eq bool 136 types hm.Types 137 format string 138 }{ 139 140 {makeTensorType(1, Float64), makeTensorType(1, Float64), true, hm.Types{Float64}, "Vector float64"}, 141 {makeTensorType(1, Float64), makeTensorType(1, Float32), false, hm.Types{Float64}, "Vector float64"}, 142 {makeTensorType(1, Float64), makeTensorType(2, Float64), false, hm.Types{Float64}, "Vector float64"}, 143 {makeTensorType(1, hm.TypeVariable('a')), makeTensorType(1, hm.TypeVariable('a')), true, hm.Types{hm.TypeVariable('a')}, "Vector a"}, 144 {makeTensorType(1, hm.TypeVariable('a')), makeTensorType(1, hm.TypeVariable('b')), false, hm.Types{hm.TypeVariable('a')}, "Vector a"}, 145 } 146 147 tensorOpsTest = []struct { 148 name string 149 150 a hm.Type 151 b hm.Type 152 153 aSub hm.Type 154 }{ 155 {"a ~ Tensor Float64", hm.TypeVariable('a'), makeTensorType(1, Float64), makeTensorType(1, Float64)}, 156 {"Tensor Float64 ~ a", makeTensorType(1, Float64), hm.TypeVariable('a'), makeTensorType(1, Float64)}, 157 {"Tensor a ~ Tensor Float64", makeTensorType(1, hm.TypeVariable('a')), makeTensorType(1, Float64), Float64}, 158 {"Tensor a ~ Tensor Float64", makeTensorType(1, Float64), makeTensorType(1, hm.TypeVariable('a')), Float64}, 159 } 160 }