github.com/wzzhu/tensor@v0.9.24/genlib2/dense_compat_tests.go (about)

     1  package main
     2  
     3  import (
     4  	"io"
     5  	"text/template"
     6  )
     7  
     8  const compatTestsRaw = `var toMat64Tests = []struct{
     9  	data interface{}
    10  	sliced interface{}
    11  	shape Shape
    12  	dt Dtype
    13  }{
    14  	{{range .Kinds -}}
    15  	{{if isNumber . -}}
    16  	{ Range({{asType . | title | strip}}, 0, 6), []{{asType .}}{0,1,3,4}, Shape{2,3}, {{asType . | title | strip}} },
    17  	{{end -}}
    18  	{{end -}}
    19  }
    20  func TestToMat64(t *testing.T){
    21  	assert := assert.New(t)
    22  	for i, tmt := range toMat64Tests {
    23  		T := New(WithBacking(tmt.data), WithShape(tmt.shape...))
    24  		var m *mat.Dense
    25  		var err error
    26  		if m, err = ToMat64(T); err != nil {
    27  			t.Errorf("ToMat basic test %d failed : %v", i, err)
    28  			continue
    29  		}
    30  		conv := anyToFloat64s(tmt.data)
    31  		assert.Equal(conv, m.RawMatrix().Data, "i %d from %v", i, tmt.dt)
    32  
    33  		if T, err = sliceDense(T, nil, makeRS(0, 2)); err != nil{
    34  			t.Errorf("Slice failed %v", err)
    35  			continue
    36  		}
    37  		if m, err = ToMat64(T); err != nil {
    38  			t.Errorf("ToMat of slice test %d failed : %v", i, err)
    39  			continue
    40  		}
    41  		conv = anyToFloat64s(tmt.sliced)
    42  		assert.Equal(conv, m.RawMatrix().Data, "sliced test %d from %v", i, tmt.dt)
    43  		t.Logf("Done")
    44  
    45  		if tmt.dt == Float64 {
    46  			T = New(WithBacking(tmt.data), WithShape(tmt.shape...))
    47  			if m, err = ToMat64(T, UseUnsafe()); err != nil {
    48  				t.Errorf("ToMat64 unsafe test %d failed: %v", i, err)
    49  			}
    50  			conv = anyToFloat64s(tmt.data)
    51  			assert.Equal(conv, m.RawMatrix().Data, "float64 unsafe i %d from %v", i, tmt.dt)
    52  			conv[0] = 1000
    53  			assert.Equal(conv, m.RawMatrix().Data,"float64 unsafe i %d from %v", i, tmt.dt)
    54  			conv[0] = 0 // reset for future tests that use the same backing
    55  		}
    56  	}
    57  	// idiocy test
    58  	T := New(Of(Float64), WithShape(2,3,4))
    59  	_, err := ToMat64(T)
    60  	if err == nil {
    61  		t.Error("Expected an error when trying to convert a 3-T to *mat.Dense")
    62  	}
    63  }
    64  
    65  func TestFromMat64(t *testing.T){
    66  	assert := assert.New(t)
    67  	var m *mat.Dense
    68  	var T *Dense
    69  	var backing []float64
    70  
    71  
    72  	for i, tmt := range toMat64Tests {
    73  		backing = Range(Float64, 0, 6).([]float64)
    74  		m = mat.NewDense(2, 3, backing)
    75  		T = FromMat64(m)
    76  		conv := anyToFloat64s(tmt.data)
    77  		assert.Equal(conv, T.Float64s(), "test %d: []float64 from %v", i, tmt.dt)
    78  		assert.True(T.Shape().Eq(tmt.shape))
    79  
    80  		T = FromMat64(m, As(tmt.dt))
    81  		assert.Equal(tmt.data, T.Data())
    82  		assert.True(T.Shape().Eq(tmt.shape))
    83  
    84  		if tmt.dt == Float64{
    85  			backing = Range(Float64, 0, 6).([]float64)
    86  			m = mat.NewDense(2, 3, backing)
    87  			T = FromMat64(m, UseUnsafe())
    88  			assert.Equal(backing, T.Float64s())
    89  			assert.True(T.Shape().Eq(tmt.shape))
    90  			backing[0] = 1000 
    91  			assert.Equal(backing, T.Float64s(), "test %d - unsafe float64", i)
    92  		}
    93  	}
    94  }
    95  `
    96  
    97  const compatArrowArrayTestsRaw = `var toArrowArrayTests = []struct{
    98  	data interface{}
    99  	valid []bool
   100  	dt arrow.DataType
   101  	shape Shape
   102  }{
   103  	{{range .PrimitiveTypes -}}
   104  	{
   105  		data: Range({{.}}, 0, 6),
   106  		valid: []bool{true, true, true, false, true, true},
   107  		dt: arrow.PrimitiveTypes.{{ . }},
   108  		shape: Shape{6,1},
   109  	},
   110  	{{end -}}
   111  }
   112  func TestFromArrowArray(t *testing.T){
   113  	assert := assert.New(t)
   114  	var T *Dense
   115  	pool := memory.NewGoAllocator()
   116  
   117  	for i, taat := range toArrowArrayTests {
   118  		var m arrowArray.Interface
   119  
   120  		switch taat.dt {
   121  		{{range .BinaryTypes -}}
   122  		case arrow.BinaryTypes.{{ . }}:
   123  			b := arrowArray.New{{ . }}Builder(pool)
   124  			defer b.Release()
   125  			b.AppendValues(
   126  				{{if eq . "String" -}}
   127  				[]string{"0", "1", "2", "3", "4", "5"},
   128  				{{else -}}
   129  				Range({{ . }}, 0, 6).([]{{lower . }}),
   130  				{{end -}}
   131  				taat.valid,
   132  			)
   133  			m = b.NewArray()
   134  			defer m.Release()
   135  		{{end -}}
   136  		{{range .FixedWidthTypes -}}
   137  		case arrow.FixedWidthTypes.{{ . }}:
   138  			b := arrowArray.New{{ . }}Builder(pool)
   139  			defer b.Release()
   140  			b.AppendValues(
   141  				{{if eq . "Boolean" -}}
   142  				[]bool{true, false, true, false, true, false},
   143  				{{else -}}
   144  				Range({{ . }}, 0, 6).([]{{lower . }}),
   145  				{{end -}}
   146  				taat.valid,
   147  			)
   148  			m = b.NewArray()
   149  			defer m.Release()
   150  		{{end -}}
   151  		{{range .PrimitiveTypes -}}
   152  		case arrow.PrimitiveTypes.{{ . }}:
   153  			b := arrowArray.New{{ . }}Builder(pool)
   154  			defer b.Release()
   155  			b.AppendValues(
   156  				Range({{ . }}, 0, 6).([]{{lower . }}),
   157  				taat.valid,
   158  			)
   159  			m = b.NewArray()
   160  			defer m.Release()
   161  		{{end -}}
   162  		default:
   163  			t.Errorf("DataType not supported in tests: %v", taat.dt)
   164  		}
   165  
   166  		T = FromArrowArray(m)
   167  		switch taat.dt {
   168  		{{range .PrimitiveTypes -}}
   169  		case arrow.PrimitiveTypes.{{ . }}:
   170  			conv := taat.data.([]{{lower . }})
   171  			assert.Equal(conv, T.{{ . }}s(), "test %d: []{{lower . }} from %v", i, taat.dt)
   172  		{{end -}}
   173  		default:
   174  			t.Errorf("DataType not supported in tests: %v", taat.dt)
   175  		}
   176  		for i, invalid := range T.Mask() {
   177  			assert.Equal(taat.valid[i], !invalid)
   178  		}
   179  		assert.True(T.Shape().Eq(taat.shape))
   180  	}
   181  }
   182  `
   183  
   184  const compatArrowTensorTestsRaw = `var toArrowTensorTests = []struct{
   185  	rowMajorData interface{}
   186  	colMajorData interface{}
   187  	rowMajorValid []bool
   188  	colMajorValid []bool
   189  	dt arrow.DataType
   190  	shape Shape
   191  }{
   192  	{{range .PrimitiveTypes -}}
   193  	{
   194  		rowMajorData: []{{lower .}}{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
   195  		colMajorData: []{{lower .}}{1, 6, 2, 7, 3, 8, 4, 9, 5, 10},
   196  		rowMajorValid: []bool{true, false, true, false, true, false, true, false, true, false},
   197  		colMajorValid: []bool{true, false, false, true, true, false, false, true, true, false},
   198  		dt: arrow.PrimitiveTypes.{{ . }},
   199  		shape: Shape{2,5},
   200  	},
   201  	{{end -}}
   202  }
   203  func TestFromArrowTensor(t *testing.T){
   204  	assert := assert.New(t)
   205  	var rowMajorT *Dense
   206  	var colMajorT *Dense
   207  	pool := memory.NewGoAllocator()
   208  
   209  	for i, taat := range toArrowTensorTests {
   210  		var rowMajorArr arrowArray.Interface
   211  		var colMajorArr arrowArray.Interface
   212  		var rowMajor arrowTensor.Interface
   213  		var colMajor arrowTensor.Interface
   214  
   215  		switch taat.dt {
   216  		{{range .PrimitiveTypes -}}
   217  		case arrow.PrimitiveTypes.{{ . }}:
   218  			b := arrowArray.New{{ . }}Builder(pool)
   219  			defer b.Release()
   220  			b.AppendValues(
   221  				[]{{lower . }}{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
   222  				taat.rowMajorValid,
   223  			)
   224  			rowMajorArr = b.NewArray()
   225  			defer rowMajorArr.Release()
   226  
   227  			b.AppendValues(
   228  				[]{{lower .}}{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
   229  				taat.rowMajorValid,
   230  			)
   231  			colMajorArr = b.NewArray()
   232  			defer colMajorArr.Release()
   233  
   234  			rowMajor = arrowTensor.New{{.}}(rowMajorArr.Data(), []int64{2, 5}, nil, []string{"x", "y"})
   235  			defer rowMajor.Release()
   236  			colMajor = arrowTensor.New{{.}}(colMajorArr.Data(), []int64{2, 5}, []int64{int64(arrow.{{ . }}SizeBytes), int64(arrow.{{ . }}SizeBytes * 2)}, []string{"x", "y"})
   237  			defer colMajor.Release()
   238  		{{end -}}
   239  		default:
   240  			t.Errorf("DataType not supported in tests: %v", taat.dt)
   241  		}
   242  
   243  		rowMajorT = FromArrowTensor(rowMajor)
   244  		colMajorT = FromArrowTensor(colMajor)
   245  
   246  		assert.Equal(taat.rowMajorData, rowMajorT.Data(), "test %d: row major %v", i, taat.dt)
   247  		assert.Equal(len(taat.rowMajorValid), len(rowMajorT.Mask()), "test %d: row major %v mask length incorrect", i, taat.dt)
   248  		for i, invalid := range rowMajorT.Mask() {
   249  			assert.Equal(taat.rowMajorValid[i], !invalid, "test %d: row major %v mask value incorrect", i, taat.dt)
   250  		}
   251  		assert.True(colMajorT.Shape().Eq(taat.shape))
   252  
   253  		assert.Equal(taat.colMajorData, colMajorT.Data(), "test %d: column major %v", i, taat.dt)
   254  		assert.Equal(len(taat.colMajorValid), len(colMajorT.Mask()), "test %d: column major %v mask length incorrect", i, taat.dt)
   255  		for i, invalid := range colMajorT.Mask() {
   256  			assert.Equal(taat.colMajorValid[i], !invalid, "test %d: column major %v mask value incorrect", i, taat.dt)
   257  		}
   258  		assert.True(rowMajorT.Shape().Eq(taat.shape))
   259  	}
   260  }
   261  `
   262  
   263  var (
   264  	compatTests            *template.Template
   265  	compatArrowArrayTests  *template.Template
   266  	compatArrowTensorTests *template.Template
   267  )
   268  
   269  func init() {
   270  	compatTests = template.Must(template.New("testCompat").Funcs(funcs).Parse(compatTestsRaw))
   271  	compatArrowArrayTests = template.Must(template.New("testArrowArrayCompat").Funcs(funcs).Parse(compatArrowArrayTestsRaw))
   272  	compatArrowTensorTests = template.Must(template.New("testArrowTensorCompat").Funcs(funcs).Parse(compatArrowTensorTestsRaw))
   273  }
   274  
   275  func generateDenseCompatTests(f io.Writer, generic Kinds) {
   276  	// NOTE(poopoothegorilla): an alias is needed for the Arrow Array pkg to prevent naming
   277  	// collisions
   278  	importsArrow.Execute(f, generic)
   279  	compatTests.Execute(f, generic)
   280  	arrowData := ArrowData{
   281  		BinaryTypes:     arrowBinaryTypes,
   282  		FixedWidthTypes: arrowFixedWidthTypes,
   283  		PrimitiveTypes:  arrowPrimitiveTypes,
   284  	}
   285  	compatArrowArrayTests.Execute(f, arrowData)
   286  	compatArrowTensorTests.Execute(f, arrowData)
   287  }