github.com/wzzhu/tensor@v0.9.24/genlib2/dense_argmethods_tests.go (about) 1 package main 2 3 import ( 4 "fmt" 5 "io" 6 "reflect" 7 "text/template" 8 ) 9 10 type ArgMethodTestData struct { 11 Kind reflect.Kind 12 Data []int 13 } 14 15 var data = []int{ 16 3, 4, 2, 4, 3, 8, 3, 9, 7, 4, 3, 0, 3, 9, 9, 0, 6, 7, 3, 9, 4, 8, 5, 17 1, 1, 9, 4, 0, 4, 1, 6, 6, 4, 9, 3, 8, 1, 7, 0, 7, 4, 0, 6, 8, 2, 8, 18 0, 6, 1, 6, 2, 3, 7, 5, 7, 3, 0, 8, 6, 5, 6, 9, 7, 5, 6, 8, 7, 9, 5, 19 0, 8, 1, 4, 0, 6, 6, 3, 3, 8, 1, 1, 3, 2, 5, 9, 0, 4, 5, 3, 1, 9, 1, 20 9, 3, 9, 3, 3, 4, 5, 9, 4, 2, 2, 7, 9, 8, 1, 6, 9, 4, 4, 1, 8, 9, 8, 21 0, 9, 9, 4, 6, 7, 5, 9, 9, 4, 8, 5, 8, 2, 4, 8, 2, 7, 2, 8, 7, 2, 3, 22 7, 0, 9, 9, 8, 9, 2, 1, 7, 0, 7, 9, 0, 2, 4, 8, 7, 9, 6, 8, 3, 3, 7, 23 2, 9, 2, 8, 2, 3, 6, 0, 8, 7, 7, 0, 9, 0, 9, 3, 2, 6, 9, 5, 8, 6, 9, 24 5, 6, 1, 8, 7, 8, 1, 9, 9, 3, 7, 7, 6, 8, 2, 1, 1, 5, 1, 4, 0, 5, 1, 25 7, 9, 5, 6, 6, 8, 7, 5, 1, 3, 4, 0, 1, 8, 0, 2, 6, 9, 1, 4, 8, 0, 5, 26 6, 2, 9, 4, 4, 2, 4, 4, 4, 3, 27 } 28 29 const argMethodsDataRaw = `var basicDense{{short .Kind}} = New(WithShape(2,3,4,5,2), WithBacking([]{{asType .Kind}}{ {{range .Data -}}{{.}}, {{end -}} })) 30 ` 31 32 const argmaxCorrect = `var argmaxCorrect = []struct { 33 shape Shape 34 data []int 35 }{ 36 {Shape{3,4,5,2}, []int{ 37 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 38 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 39 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 40 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 41 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 42 1, 0, 0, 0, 0, 43 }}, 44 {Shape{2,4,5,2}, []int{ 45 1, 0, 1, 1, 2, 0, 2, 0, 0, 1, 2, 1, 2, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 46 2, 2, 0, 1, 1, 2, 2, 1, 0, 2, 0, 2, 0, 2, 2, 1, 0, 0, 0, 0, 0, 1, 0, 47 0, 0, 2, 1, 0, 1, 2, 1, 0, 1, 1, 2, 0, 1, 0, 0, 0, 0, 2, 1, 0, 1, 0, 48 0, 2, 1, 1, 0, 0, 0, 0, 0, 2, 0, 49 }}, 50 {Shape{2,3,5,2}, []int{ 51 3, 2, 2, 1, 1, 2, 1, 0, 0, 1, 3, 2, 1, 0, 1, 0, 2, 2, 3, 0, 1, 0, 1, 52 3, 0, 2, 3, 3, 2, 1, 2, 2, 0, 0, 1, 3, 2, 0, 1, 2, 0, 3, 0, 1, 0, 1, 53 3, 2, 2, 1, 2, 1, 3, 1, 2, 0, 2, 2, 0, 0, 54 }}, 55 {Shape{2,3,4,2}, []int{ 56 4, 3, 2, 1, 1, 2, 0, 1, 1, 1, 1, 3, 1, 0, 0, 2, 2, 1, 0, 4, 2, 2, 3, 57 1, 1, 1, 0, 2, 0, 0, 2, 2, 1, 4, 0, 1, 4, 1, 1, 0, 4, 3, 1, 1, 2, 3, 58 1, 1, 59 }}, 60 {Shape{2,3,4,5}, []int{ 61 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 62 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 63 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 64 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 65 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 66 0, 0, 0, 0, 0, 67 }}, 68 } 69 ` 70 71 const argminCorrect = `var argminCorrect = []struct { 72 shape Shape 73 data []int 74 }{ 75 {Shape{3,4,5,2}, []int{ 76 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 77 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 78 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 79 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 80 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 81 0, 1, 1, 0, 1, 82 }}, 83 {Shape{2,4,5,2}, []int{ 84 2, 1, 0, 0, 1, 2, 1, 2, 1, 2, 1, 0, 0, 2, 1, 0, 1, 2, 0, 1, 0, 2, 2, 85 0, 0, 1, 2, 0, 0, 1, 2, 1, 0, 1, 0, 2, 0, 1, 0, 1, 2, 1, 2, 1, 2, 1, 86 2, 1, 1, 0, 2, 0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 2, 2, 0, 0, 1, 0, 2, 87 2, 0, 0, 0, 1, 2, 2, 2, 2, 1, 1, 88 }}, 89 {Shape{2,3,5,2}, []int{ 90 0, 1, 0, 2, 2, 1, 3, 2, 3, 2, 1, 0, 3, 3, 0, 1, 0, 3, 0, 2, 0, 1, 0, 91 1, 3, 0, 2, 1, 0, 0, 3, 1, 3, 1, 2, 2, 1, 2, 0, 1, 3, 0, 1, 0, 1, 0, 92 2, 1, 0, 3, 0, 2, 0, 0, 0, 1, 0, 1, 1, 1, 93 }}, 94 {Shape{2,3,4,2}, []int{ 95 1, 0, 0, 0, 2, 3, 4, 0, 3, 0, 3, 0, 4, 4, 3, 1, 0, 2, 3, 0, 3, 0, 0, 96 2, 4, 4, 3, 4, 2, 3, 0, 0, 4, 0, 1, 3, 3, 2, 0, 4, 2, 1, 4, 2, 4, 0, 97 2, 0, 98 }}, 99 {Shape{2,3,4,5}, []int{ 100 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 101 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 102 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 103 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 104 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 105 1, 1, 1, 0, 1, 106 }}, 107 } 108 ` 109 110 type ArgMethodTest struct { 111 Kind reflect.Kind 112 ArgMethod string 113 ArgAllAxes int 114 } 115 116 const testArgMethodsRaw = `func TestDense_{{title .ArgMethod}}_{{short .Kind}}(t *testing.T){ 117 assert := assert.New(t) 118 var T, {{.ArgMethod}} *Dense 119 var err error 120 T = basicDense{{short .Kind}}.Clone().(*Dense) 121 for i:= 0; i < T.Dims(); i++ { 122 if {{.ArgMethod}}, err = T.{{title .ArgMethod}}(i); err != nil { 123 t.Error(err) 124 continue 125 } 126 127 assert.True({{.ArgMethod}}Correct[i].shape.Eq({{.ArgMethod}}.Shape()), "{{title .ArgMethod}}(%d) error. Want shape %v. Got %v", i, {{.ArgMethod}}Correct[i].shape) 128 assert.Equal({{.ArgMethod}}Correct[i].data, {{.ArgMethod}}.Data(), "{{title .ArgMethod}}(%d) error. ", i) 129 } 130 // test all axes 131 if {{.ArgMethod}}, err = T.{{title .ArgMethod}}(AllAxes); err != nil { 132 t.Error(err) 133 return 134 } 135 assert.True({{.ArgMethod}}.IsScalar()) 136 assert.Equal({{.ArgAllAxes}}, {{.ArgMethod}}.ScalarValue()) 137 138 {{if hasPrefix .Kind.String "float" -}} 139 // test with NaN 140 T = New(WithShape(4), WithBacking([]{{asType .Kind}}{1,2,{{mathPkg .Kind}}NaN(), 4})) 141 if {{.ArgMethod}}, err = T.{{title .ArgMethod}}(AllAxes); err != nil { 142 t.Errorf("Failed test with NaN: %v", err) 143 } 144 assert.True({{.ArgMethod}}.IsScalar()) 145 assert.Equal(2, {{.ArgMethod}}.ScalarValue(), "NaN test") 146 147 // test with Mask and Nan 148 T = New(WithShape(4), WithBacking([]{{asType .Kind}}{1,{{if eq .ArgMethod "argmax"}}9{{else}}-9{{end}},{{mathPkg .Kind}}NaN(), 4}, []bool{false,true,true,false})) 149 if {{.ArgMethod}}, err = T.{{title .ArgMethod}}(AllAxes); err != nil { 150 t.Errorf("Failed test with NaN: %v", err) 151 } 152 assert.True({{.ArgMethod}}.IsScalar()) 153 assert.Equal({{if eq .ArgMethod "argmin"}}0{{else}}3{{end}}, {{.ArgMethod}}.ScalarValue(), "Masked NaN test") 154 155 // test with +Inf 156 T = New(WithShape(4), WithBacking([]{{asType .Kind}}{1,2,{{mathPkg .Kind}}Inf(1),4})) 157 if {{.ArgMethod}}, err = T.{{title .ArgMethod}}(AllAxes); err != nil { 158 t.Errorf("Failed test with +Inf: %v", err) 159 } 160 assert.True({{.ArgMethod}}.IsScalar()) 161 assert.Equal({{if eq .ArgMethod "argmax"}}2{{else}}0{{end}}, {{.ArgMethod}}.ScalarValue(), "+Inf test") 162 163 // test with Mask and +Inf 164 T = New(WithShape(4), WithBacking([]{{asType .Kind}}{1,{{if eq .ArgMethod "argmax"}}9{{else}}-9{{end}},{{mathPkg .Kind}}Inf(1), 4}, []bool{false,true,true,false})) 165 if {{.ArgMethod}}, err = T.{{title .ArgMethod}}(AllAxes); err != nil { 166 t.Errorf("Failed test with NaN: %v", err) 167 } 168 assert.True({{.ArgMethod}}.IsScalar()) 169 assert.Equal({{if eq .ArgMethod "argmin"}}0{{else}}3{{end}}, {{.ArgMethod}}.ScalarValue(), "Masked NaN test") 170 171 // test with -Inf 172 T = New(WithShape(4), WithBacking([]{{asType .Kind}}{1,2,{{mathPkg .Kind}}Inf(-1),4 })) 173 if {{.ArgMethod}}, err = T.{{title .ArgMethod}}(AllAxes); err != nil { 174 t.Errorf("Failed test with -Inf: %v", err) 175 } 176 assert.True({{.ArgMethod}}.IsScalar()) 177 assert.Equal({{if eq .ArgMethod "argmin"}}2{{else}}3{{end}}, {{.ArgMethod}}.ScalarValue(), "+Inf test") 178 179 // test with Mask and -Inf 180 T = New(WithShape(4), WithBacking([]{{asType .Kind}}{1,{{if eq .ArgMethod "argmax"}}9{{else}}-9{{end}},{{mathPkg .Kind}}Inf(-1), 4}, []bool{false,true,true,false})) 181 if {{.ArgMethod}}, err = T.{{title .ArgMethod}}(AllAxes); err != nil { 182 t.Errorf("Failed test with NaN: %v", err) 183 } 184 assert.True({{.ArgMethod}}.IsScalar()) 185 assert.Equal({{if eq .ArgMethod "argmin"}}0{{else}}3{{end}}, {{.ArgMethod}}.ScalarValue(), "Masked -Inf test") 186 187 {{end -}} 188 189 // with different engine 190 T = basicDense{{short .Kind}}.Clone().(*Dense) 191 WithEngine(dummyEngine2{})(T) 192 for i:= 0; i < T.Dims(); i++ { 193 if {{.ArgMethod}}, err = T.{{title .ArgMethod}}(i); err != nil { 194 t.Error(err) 195 continue 196 } 197 198 assert.True({{.ArgMethod}}Correct[i].shape.Eq({{.ArgMethod}}.Shape()), "{{title .ArgMethod}}(%d) error. Want shape %v. Got %v", i, {{.ArgMethod}}Correct[i].shape) 199 assert.Equal({{.ArgMethod}}Correct[i].data, {{.ArgMethod}}.Data(), "{{title .ArgMethod}}(%d) error. ", i) 200 } 201 202 203 204 // idiotsville 205 _, err = T.{{title .ArgMethod}}(10000) 206 assert.NotNil(err) 207 208 } 209 ` 210 211 var ( 212 argMethodsData *template.Template 213 testArgMethods *template.Template 214 ) 215 216 func init() { 217 argMethodsData = template.Must(template.New("argmethodsData").Funcs(funcs).Parse(argMethodsDataRaw)) 218 testArgMethods = template.Must(template.New("testArgMethod").Funcs(funcs).Parse(testArgMethodsRaw)) 219 } 220 221 func generateArgmethodsTests(f io.Writer, generic Kinds) { 222 fmt.Fprintf(f, "/* Test data */\n\n") 223 for _, k := range generic.Kinds { 224 if isNumber(k) && isOrd(k) { 225 op := ArgMethodTestData{k, data} 226 argMethodsData.Execute(f, op) 227 } 228 } 229 fmt.Fprintf(f, "\n%s\n%s\n", argmaxCorrect, argminCorrect) 230 for _, k := range generic.Kinds { 231 if isNumber(k) && isOrd(k) { 232 op := ArgMethodTest{k, "argmax", 7} 233 testArgMethods.Execute(f, op) 234 op = ArgMethodTest{k, "argmin", 11} 235 testArgMethods.Execute(f, op) 236 } 237 } 238 }