github.com/wzzhu/tensor@v0.9.24/genlib2/dense_maskedmethods_tests.go (about) 1 package main 2 3 import ( 4 "fmt" 5 "io" 6 "reflect" 7 "text/template" 8 ) 9 10 type MaskCmpMethodTest struct { 11 Kind reflect.Kind 12 Name string 13 } 14 15 const testMaskCmpMethodRaw = `func TestDense_{{title .Name}}_{{short .Kind}}(t *testing.T){ 16 assert := assert.New(t) 17 T := New(Of({{reflectKind .Kind}}), WithShape(2, 3, 4, 5)) 18 assert.False(T.IsMasked()) 19 data := T.{{sliceOf .Kind}} 20 for i := range data { 21 {{if eq "string" (asType .Kind) -}} 22 data[i] = fmt.Sprint(i) 23 {{else -}} 24 data[i] = {{asType .Kind}}(i) 25 {{end -}} 26 } 27 {{if eq "string" (asType .Kind) -}} 28 T.MaskedEqual(fmt.Sprint(0)) 29 {{else -}} 30 T.MaskedEqual({{asType .Kind}}(0)) 31 {{end -}} 32 assert.True(T.IsMasked()) 33 {{if eq "string" (asType .Kind) -}} 34 T.MaskedEqual(fmt.Sprint(1)) 35 {{else -}} 36 T.MaskedEqual({{asType .Kind}}(1)) 37 {{end -}} 38 assert.True(T.mask[0] && T.mask[1]) 39 {{if eq "string" (asType .Kind) -}} 40 T.MaskedNotEqual(fmt.Sprint(2)) 41 {{else -}} 42 T.MaskedNotEqual({{asType .Kind}}(2)) 43 {{end -}} 44 assert.False(T.mask[2] && !(T.mask[0])) 45 46 T.ResetMask() 47 {{if eq "string" (asType .Kind) -}} 48 T.MaskedInside(fmt.Sprint(1), fmt.Sprint(22)) 49 {{else -}} 50 T.MaskedInside({{asType .Kind}}(1), {{asType .Kind}}(22)) 51 {{end -}} 52 assert.True(!T.mask[0] && !T.mask[23] && T.mask[1] && T.mask[22]) 53 54 T.ResetMask() 55 {{if eq "string" (asType .Kind) -}} 56 T.MaskedOutside(fmt.Sprint(1), fmt.Sprint(22)) 57 {{else -}} 58 T.MaskedOutside({{asType .Kind}}(1), {{asType .Kind}}(22)) 59 {{end -}} 60 assert.True(T.mask[0] && T.mask[23] && !T.mask[1] && !T.mask[22]) 61 62 T.ResetMask() 63 for i := 0; i < 5; i++ { 64 {{if eq "string" (asType .Kind) -}} 65 T.MaskedEqual(fmt.Sprint(i*10)) 66 {{else -}} 67 T.MaskedEqual({{asType .Kind}}(i*10)) 68 {{end -}} 69 } 70 it := IteratorFromDense(T) 71 72 j := 0 73 for _, err := it.Next(); err == nil; _, err = it.Next() { 74 j++ 75 } 76 77 it.Reset() 78 assert.Equal(120, j) 79 j = 0 80 for _, _, err := it.NextValid(); err == nil; _, _, err = it.NextValid() { 81 j++ 82 } 83 it.Reset() 84 assert.Equal(115, j) 85 j = 0 86 for _, _, err := it.NextInvalid(); err == nil; _, _, err = it.NextInvalid() { 87 j++ 88 } 89 it.Reset() 90 assert.Equal(5,j) 91 } 92 ` 93 94 var ( 95 testMaskCmpMethod *template.Template 96 ) 97 98 func init() { 99 testMaskCmpMethod = template.Must(template.New("testmaskcmpmethod").Funcs(funcs).Parse(testMaskCmpMethodRaw)) 100 } 101 102 func generateMaskCmpMethodsTests(f io.Writer, generic Kinds) { 103 for _, mm := range maskcmpMethods { 104 fmt.Fprintf(f, "/* %s */ \n\n", mm.Name) 105 for _, k := range generic.Kinds { 106 if isOrd(k) { 107 if mm.ReqFloat && isntFloat(k) { 108 109 } else { 110 op := MaskCmpMethodTest{k, mm.Name} 111 testMaskCmpMethod.Execute(f, op) 112 } 113 } 114 } 115 } 116 }