gorgonia.org/gorgonia@v0.9.17/type_test.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"fmt"
     5  	"testing"
     6  
     7  	"github.com/chewxy/hm"
     8  	"github.com/stretchr/testify/assert"
     9  	"gorgonia.org/tensor"
    10  )
    11  
    12  func TestDtypeBasics(t *testing.T) {
    13  	assert := assert.New(t)
    14  
    15  	var t0 tensor.Dtype
    16  	var a hm.TypeVariable
    17  
    18  	t0 = Float64
    19  	a = hm.TypeVariable('a')
    20  
    21  	assert.True(t0.Eq(Float64))
    22  	assert.False(t0.Eq(Float32))
    23  	assert.False(t0.Eq(a))
    24  	assert.Nil(t0.Types())
    25  
    26  	k := hm.TypeVarSet{'x', 'y'}
    27  	v := hm.TypeVarSet{'a', 'b'}
    28  	t1, err := t0.Normalize(k, v)
    29  	assert.Nil(err)
    30  	assert.Equal(t0, t1)
    31  
    32  	// for completeness sake
    33  	assert.Equal("float64", t0.Name())
    34  	assert.Equal("float64", t0.String())
    35  	assert.Equal("float64", fmt.Sprintf("%v", t0))
    36  
    37  }
    38  
    39  func TestDtypeOps(t *testing.T) {
    40  	var sub hm.Subs
    41  	var a hm.TypeVariable
    42  	var err error
    43  
    44  	a = hm.TypeVariable('a')
    45  
    46  	if sub, err = hm.Unify(a, Float64); err != nil {
    47  		t.Fatal(err)
    48  	}
    49  
    50  	if repl, ok := sub.Get(a); !ok {
    51  		t.Errorf("Expected a substitution for %v", a)
    52  	} else if repl != Float64 {
    53  		t.Errorf("Expecetd substitution for %v to be %v. Got %v instead", a, Float64, repl)
    54  	}
    55  
    56  	if sub, err = hm.Unify(Float64, a); err != nil {
    57  		t.Fatal(err)
    58  	}
    59  
    60  	if repl, ok := sub.Get(a); !ok {
    61  		t.Errorf("Expected a substitution for %v", a)
    62  	} else if repl != Float64 {
    63  		t.Errorf("Expecetd substitution for %v to be %v. Got %v instead", a, Float64, repl)
    64  	}
    65  }
    66  
    67  var tensorTypeTests []struct {
    68  	a, b TensorType
    69  
    70  	eq     bool
    71  	types  hm.Types
    72  	format string
    73  }
    74  
    75  func TestTensorTypeBasics(t *testing.T) {
    76  	assert := assert.New(t)
    77  
    78  	for _, ttts := range tensorTypeTests {
    79  		// Equality
    80  		if ttts.eq {
    81  			assert.True(ttts.a.Eq(ttts.b), "TensorType Equality failed: %#v != %#v", ttts.a, ttts.b)
    82  		} else {
    83  			assert.False(ttts.a.Eq(ttts.b), "TensorType Equality: %v == %v should be false", ttts.a, ttts.b)
    84  		}
    85  
    86  		// Types
    87  		assert.Equal(ttts.types, ttts.a.Types())
    88  
    89  		// string and format for completeness sake
    90  		assert.Equal("Tensor", ttts.a.Name())
    91  		assert.Equal(ttts.format, fmt.Sprintf("%v", ttts.a))
    92  		assert.Equal(fmt.Sprintf("Tensor-%d %v", ttts.a.Dims, ttts.a.Of), fmt.Sprintf("%#v", ttts.a))
    93  	}
    94  
    95  	tt := makeTensorType(1, hm.TypeVariable('x'))
    96  	k := hm.TypeVarSet{'x', 'y'}
    97  	v := hm.TypeVarSet{'a', 'b'}
    98  	tt2, err := tt.Normalize(k, v)
    99  	if err != nil {
   100  		t.Error(err)
   101  	}
   102  	assert.True(tt2.Eq(makeTensorType(1, hm.TypeVariable('a'))))
   103  
   104  }
   105  
   106  var tensorOpsTest []struct {
   107  	name string
   108  
   109  	a hm.Type
   110  	b hm.Type
   111  
   112  	aSub hm.Type
   113  }
   114  
   115  func TestTensorTypeOps(t *testing.T) {
   116  	for _, tots := range tensorOpsTest {
   117  		sub, err := hm.Unify(tots.a, tots.b)
   118  		if err != nil {
   119  			t.Error(err)
   120  			continue
   121  		}
   122  
   123  		if subst, ok := sub.Get(hm.TypeVariable('a')); !ok {
   124  			t.Errorf("Expected a substitution for a")
   125  		} else if !subst.Eq(tots.aSub) {
   126  			t.Errorf("Expected substitution to be %v. Got %v instead", tots.aSub, subst)
   127  		}
   128  	}
   129  }
   130  
   131  func init() {
   132  	tensorTypeTests = []struct {
   133  		a, b TensorType
   134  
   135  		eq     bool
   136  		types  hm.Types
   137  		format string
   138  	}{
   139  
   140  		{makeTensorType(1, Float64), makeTensorType(1, Float64), true, hm.Types{Float64}, "Vector float64"},
   141  		{makeTensorType(1, Float64), makeTensorType(1, Float32), false, hm.Types{Float64}, "Vector float64"},
   142  		{makeTensorType(1, Float64), makeTensorType(2, Float64), false, hm.Types{Float64}, "Vector float64"},
   143  		{makeTensorType(1, hm.TypeVariable('a')), makeTensorType(1, hm.TypeVariable('a')), true, hm.Types{hm.TypeVariable('a')}, "Vector a"},
   144  		{makeTensorType(1, hm.TypeVariable('a')), makeTensorType(1, hm.TypeVariable('b')), false, hm.Types{hm.TypeVariable('a')}, "Vector a"},
   145  	}
   146  
   147  	tensorOpsTest = []struct {
   148  		name string
   149  
   150  		a hm.Type
   151  		b hm.Type
   152  
   153  		aSub hm.Type
   154  	}{
   155  		{"a ~ Tensor Float64", hm.TypeVariable('a'), makeTensorType(1, Float64), makeTensorType(1, Float64)},
   156  		{"Tensor Float64 ~ a", makeTensorType(1, Float64), hm.TypeVariable('a'), makeTensorType(1, Float64)},
   157  		{"Tensor a ~ Tensor Float64", makeTensorType(1, hm.TypeVariable('a')), makeTensorType(1, Float64), Float64},
   158  		{"Tensor a ~ Tensor Float64", makeTensorType(1, Float64), makeTensorType(1, hm.TypeVariable('a')), Float64},
   159  	}
   160  }