github.com/wzzhu/tensor@v0.9.24/genlib2/dense_compat_tests.go (about) 1 package main 2 3 import ( 4 "io" 5 "text/template" 6 ) 7 8 const compatTestsRaw = `var toMat64Tests = []struct{ 9 data interface{} 10 sliced interface{} 11 shape Shape 12 dt Dtype 13 }{ 14 {{range .Kinds -}} 15 {{if isNumber . -}} 16 { Range({{asType . | title | strip}}, 0, 6), []{{asType .}}{0,1,3,4}, Shape{2,3}, {{asType . | title | strip}} }, 17 {{end -}} 18 {{end -}} 19 } 20 func TestToMat64(t *testing.T){ 21 assert := assert.New(t) 22 for i, tmt := range toMat64Tests { 23 T := New(WithBacking(tmt.data), WithShape(tmt.shape...)) 24 var m *mat.Dense 25 var err error 26 if m, err = ToMat64(T); err != nil { 27 t.Errorf("ToMat basic test %d failed : %v", i, err) 28 continue 29 } 30 conv := anyToFloat64s(tmt.data) 31 assert.Equal(conv, m.RawMatrix().Data, "i %d from %v", i, tmt.dt) 32 33 if T, err = sliceDense(T, nil, makeRS(0, 2)); err != nil{ 34 t.Errorf("Slice failed %v", err) 35 continue 36 } 37 if m, err = ToMat64(T); err != nil { 38 t.Errorf("ToMat of slice test %d failed : %v", i, err) 39 continue 40 } 41 conv = anyToFloat64s(tmt.sliced) 42 assert.Equal(conv, m.RawMatrix().Data, "sliced test %d from %v", i, tmt.dt) 43 t.Logf("Done") 44 45 if tmt.dt == Float64 { 46 T = New(WithBacking(tmt.data), WithShape(tmt.shape...)) 47 if m, err = ToMat64(T, UseUnsafe()); err != nil { 48 t.Errorf("ToMat64 unsafe test %d failed: %v", i, err) 49 } 50 conv = anyToFloat64s(tmt.data) 51 assert.Equal(conv, m.RawMatrix().Data, "float64 unsafe i %d from %v", i, tmt.dt) 52 conv[0] = 1000 53 assert.Equal(conv, m.RawMatrix().Data,"float64 unsafe i %d from %v", i, tmt.dt) 54 conv[0] = 0 // reset for future tests that use the same backing 55 } 56 } 57 // idiocy test 58 T := New(Of(Float64), WithShape(2,3,4)) 59 _, err := ToMat64(T) 60 if err == nil { 61 t.Error("Expected an error when trying to convert a 3-T to *mat.Dense") 62 } 63 } 64 65 func TestFromMat64(t *testing.T){ 66 assert := assert.New(t) 67 var m *mat.Dense 68 var T *Dense 69 var backing []float64 70 71 72 for i, tmt := range toMat64Tests { 73 backing = Range(Float64, 0, 6).([]float64) 74 m = mat.NewDense(2, 3, backing) 75 T = FromMat64(m) 76 conv := anyToFloat64s(tmt.data) 77 assert.Equal(conv, T.Float64s(), "test %d: []float64 from %v", i, tmt.dt) 78 assert.True(T.Shape().Eq(tmt.shape)) 79 80 T = FromMat64(m, As(tmt.dt)) 81 assert.Equal(tmt.data, T.Data()) 82 assert.True(T.Shape().Eq(tmt.shape)) 83 84 if tmt.dt == Float64{ 85 backing = Range(Float64, 0, 6).([]float64) 86 m = mat.NewDense(2, 3, backing) 87 T = FromMat64(m, UseUnsafe()) 88 assert.Equal(backing, T.Float64s()) 89 assert.True(T.Shape().Eq(tmt.shape)) 90 backing[0] = 1000 91 assert.Equal(backing, T.Float64s(), "test %d - unsafe float64", i) 92 } 93 } 94 } 95 ` 96 97 const compatArrowArrayTestsRaw = `var toArrowArrayTests = []struct{ 98 data interface{} 99 valid []bool 100 dt arrow.DataType 101 shape Shape 102 }{ 103 {{range .PrimitiveTypes -}} 104 { 105 data: Range({{.}}, 0, 6), 106 valid: []bool{true, true, true, false, true, true}, 107 dt: arrow.PrimitiveTypes.{{ . }}, 108 shape: Shape{6,1}, 109 }, 110 {{end -}} 111 } 112 func TestFromArrowArray(t *testing.T){ 113 assert := assert.New(t) 114 var T *Dense 115 pool := memory.NewGoAllocator() 116 117 for i, taat := range toArrowArrayTests { 118 var m arrowArray.Interface 119 120 switch taat.dt { 121 {{range .BinaryTypes -}} 122 case arrow.BinaryTypes.{{ . }}: 123 b := arrowArray.New{{ . }}Builder(pool) 124 defer b.Release() 125 b.AppendValues( 126 {{if eq . "String" -}} 127 []string{"0", "1", "2", "3", "4", "5"}, 128 {{else -}} 129 Range({{ . }}, 0, 6).([]{{lower . }}), 130 {{end -}} 131 taat.valid, 132 ) 133 m = b.NewArray() 134 defer m.Release() 135 {{end -}} 136 {{range .FixedWidthTypes -}} 137 case arrow.FixedWidthTypes.{{ . }}: 138 b := arrowArray.New{{ . }}Builder(pool) 139 defer b.Release() 140 b.AppendValues( 141 {{if eq . "Boolean" -}} 142 []bool{true, false, true, false, true, false}, 143 {{else -}} 144 Range({{ . }}, 0, 6).([]{{lower . }}), 145 {{end -}} 146 taat.valid, 147 ) 148 m = b.NewArray() 149 defer m.Release() 150 {{end -}} 151 {{range .PrimitiveTypes -}} 152 case arrow.PrimitiveTypes.{{ . }}: 153 b := arrowArray.New{{ . }}Builder(pool) 154 defer b.Release() 155 b.AppendValues( 156 Range({{ . }}, 0, 6).([]{{lower . }}), 157 taat.valid, 158 ) 159 m = b.NewArray() 160 defer m.Release() 161 {{end -}} 162 default: 163 t.Errorf("DataType not supported in tests: %v", taat.dt) 164 } 165 166 T = FromArrowArray(m) 167 switch taat.dt { 168 {{range .PrimitiveTypes -}} 169 case arrow.PrimitiveTypes.{{ . }}: 170 conv := taat.data.([]{{lower . }}) 171 assert.Equal(conv, T.{{ . }}s(), "test %d: []{{lower . }} from %v", i, taat.dt) 172 {{end -}} 173 default: 174 t.Errorf("DataType not supported in tests: %v", taat.dt) 175 } 176 for i, invalid := range T.Mask() { 177 assert.Equal(taat.valid[i], !invalid) 178 } 179 assert.True(T.Shape().Eq(taat.shape)) 180 } 181 } 182 ` 183 184 const compatArrowTensorTestsRaw = `var toArrowTensorTests = []struct{ 185 rowMajorData interface{} 186 colMajorData interface{} 187 rowMajorValid []bool 188 colMajorValid []bool 189 dt arrow.DataType 190 shape Shape 191 }{ 192 {{range .PrimitiveTypes -}} 193 { 194 rowMajorData: []{{lower .}}{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, 195 colMajorData: []{{lower .}}{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, 196 rowMajorValid: []bool{true, false, true, false, true, false, true, false, true, false}, 197 colMajorValid: []bool{true, false, false, true, true, false, false, true, true, false}, 198 dt: arrow.PrimitiveTypes.{{ . }}, 199 shape: Shape{2,5}, 200 }, 201 {{end -}} 202 } 203 func TestFromArrowTensor(t *testing.T){ 204 assert := assert.New(t) 205 var rowMajorT *Dense 206 var colMajorT *Dense 207 pool := memory.NewGoAllocator() 208 209 for i, taat := range toArrowTensorTests { 210 var rowMajorArr arrowArray.Interface 211 var colMajorArr arrowArray.Interface 212 var rowMajor arrowTensor.Interface 213 var colMajor arrowTensor.Interface 214 215 switch taat.dt { 216 {{range .PrimitiveTypes -}} 217 case arrow.PrimitiveTypes.{{ . }}: 218 b := arrowArray.New{{ . }}Builder(pool) 219 defer b.Release() 220 b.AppendValues( 221 []{{lower . }}{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, 222 taat.rowMajorValid, 223 ) 224 rowMajorArr = b.NewArray() 225 defer rowMajorArr.Release() 226 227 b.AppendValues( 228 []{{lower .}}{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, 229 taat.rowMajorValid, 230 ) 231 colMajorArr = b.NewArray() 232 defer colMajorArr.Release() 233 234 rowMajor = arrowTensor.New{{.}}(rowMajorArr.Data(), []int64{2, 5}, nil, []string{"x", "y"}) 235 defer rowMajor.Release() 236 colMajor = arrowTensor.New{{.}}(colMajorArr.Data(), []int64{2, 5}, []int64{int64(arrow.{{ . }}SizeBytes), int64(arrow.{{ . }}SizeBytes * 2)}, []string{"x", "y"}) 237 defer colMajor.Release() 238 {{end -}} 239 default: 240 t.Errorf("DataType not supported in tests: %v", taat.dt) 241 } 242 243 rowMajorT = FromArrowTensor(rowMajor) 244 colMajorT = FromArrowTensor(colMajor) 245 246 assert.Equal(taat.rowMajorData, rowMajorT.Data(), "test %d: row major %v", i, taat.dt) 247 assert.Equal(len(taat.rowMajorValid), len(rowMajorT.Mask()), "test %d: row major %v mask length incorrect", i, taat.dt) 248 for i, invalid := range rowMajorT.Mask() { 249 assert.Equal(taat.rowMajorValid[i], !invalid, "test %d: row major %v mask value incorrect", i, taat.dt) 250 } 251 assert.True(colMajorT.Shape().Eq(taat.shape)) 252 253 assert.Equal(taat.colMajorData, colMajorT.Data(), "test %d: column major %v", i, taat.dt) 254 assert.Equal(len(taat.colMajorValid), len(colMajorT.Mask()), "test %d: column major %v mask length incorrect", i, taat.dt) 255 for i, invalid := range colMajorT.Mask() { 256 assert.Equal(taat.colMajorValid[i], !invalid, "test %d: column major %v mask value incorrect", i, taat.dt) 257 } 258 assert.True(rowMajorT.Shape().Eq(taat.shape)) 259 } 260 } 261 ` 262 263 var ( 264 compatTests *template.Template 265 compatArrowArrayTests *template.Template 266 compatArrowTensorTests *template.Template 267 ) 268 269 func init() { 270 compatTests = template.Must(template.New("testCompat").Funcs(funcs).Parse(compatTestsRaw)) 271 compatArrowArrayTests = template.Must(template.New("testArrowArrayCompat").Funcs(funcs).Parse(compatArrowArrayTestsRaw)) 272 compatArrowTensorTests = template.Must(template.New("testArrowTensorCompat").Funcs(funcs).Parse(compatArrowTensorTestsRaw)) 273 } 274 275 func generateDenseCompatTests(f io.Writer, generic Kinds) { 276 // NOTE(poopoothegorilla): an alias is needed for the Arrow Array pkg to prevent naming 277 // collisions 278 importsArrow.Execute(f, generic) 279 compatTests.Execute(f, generic) 280 arrowData := ArrowData{ 281 BinaryTypes: arrowBinaryTypes, 282 FixedWidthTypes: arrowFixedWidthTypes, 283 PrimitiveTypes: arrowPrimitiveTypes, 284 } 285 compatArrowArrayTests.Execute(f, arrowData) 286 compatArrowTensorTests.Execute(f, arrowData) 287 }