github.com/wzzhu/tensor@v0.9.24/genlib2/dense_maskedmethods_tests.go (about)

     1  package main
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"reflect"
     7  	"text/template"
     8  )
     9  
    10  type MaskCmpMethodTest struct {
    11  	Kind reflect.Kind
    12  	Name string
    13  }
    14  
    15  const testMaskCmpMethodRaw = `func TestDense_{{title .Name}}_{{short .Kind}}(t *testing.T){
    16      assert := assert.New(t)
    17      T := New(Of({{reflectKind .Kind}}), WithShape(2, 3, 4, 5))
    18      assert.False(T.IsMasked())
    19      data := T.{{sliceOf .Kind}}
    20      for i := range data {
    21  {{if eq "string" (asType .Kind) -}}
    22  		data[i] = fmt.Sprint(i)
    23  {{else -}}
    24  		data[i] = {{asType .Kind}}(i)
    25  {{end -}}
    26  	}
    27  {{if eq "string" (asType .Kind) -}}
    28      T.MaskedEqual(fmt.Sprint(0))
    29  {{else -}}
    30      T.MaskedEqual({{asType .Kind}}(0))
    31  {{end -}}
    32  	assert.True(T.IsMasked())
    33  {{if eq "string" (asType .Kind) -}}
    34  	T.MaskedEqual(fmt.Sprint(1))
    35  {{else -}}
    36  	T.MaskedEqual({{asType .Kind}}(1))
    37  {{end -}}
    38  	assert.True(T.mask[0] && T.mask[1])
    39  {{if eq "string" (asType .Kind) -}}
    40  	T.MaskedNotEqual(fmt.Sprint(2))
    41  {{else -}}
    42  	T.MaskedNotEqual({{asType .Kind}}(2))
    43  {{end -}}
    44  	assert.False(T.mask[2] && !(T.mask[0]))
    45  
    46      T.ResetMask()
    47  {{if eq "string" (asType .Kind) -}}
    48  	T.MaskedInside(fmt.Sprint(1), fmt.Sprint(22))
    49  {{else -}}
    50  	T.MaskedInside({{asType .Kind}}(1), {{asType .Kind}}(22))
    51  {{end -}}
    52  	assert.True(!T.mask[0] && !T.mask[23] && T.mask[1] && T.mask[22])
    53  
    54  	T.ResetMask()
    55  {{if eq "string" (asType .Kind) -}}
    56  	T.MaskedOutside(fmt.Sprint(1), fmt.Sprint(22))
    57  {{else -}}
    58  	T.MaskedOutside({{asType .Kind}}(1), {{asType .Kind}}(22))
    59  {{end -}}
    60  	assert.True(T.mask[0] && T.mask[23] && !T.mask[1] && !T.mask[22])
    61  
    62      T.ResetMask()
    63      for i := 0; i < 5; i++ {
    64  {{if eq "string" (asType .Kind) -}}
    65  		T.MaskedEqual(fmt.Sprint(i*10))
    66  {{else -}}
    67  		T.MaskedEqual({{asType .Kind}}(i*10))
    68  {{end -}}
    69  	}
    70      it := IteratorFromDense(T)
    71  
    72      j := 0
    73  	for _, err := it.Next(); err == nil; _, err = it.Next() {
    74  		j++
    75  	}
    76  
    77  	it.Reset()
    78  	assert.Equal(120, j)
    79  	j = 0
    80  	for _, _, err := it.NextValid(); err == nil; _, _, err = it.NextValid() {
    81  		j++
    82  	}
    83  	it.Reset()
    84  	assert.Equal(115, j)
    85  	j = 0
    86  	for _, _, err := it.NextInvalid(); err == nil; _, _, err = it.NextInvalid() {
    87  		j++
    88  	}
    89  	it.Reset()
    90  	assert.Equal(5,j)
    91      }
    92      `
    93  
    94  var (
    95  	testMaskCmpMethod *template.Template
    96  )
    97  
    98  func init() {
    99  	testMaskCmpMethod = template.Must(template.New("testmaskcmpmethod").Funcs(funcs).Parse(testMaskCmpMethodRaw))
   100  }
   101  
   102  func generateMaskCmpMethodsTests(f io.Writer, generic Kinds) {
   103  	for _, mm := range maskcmpMethods {
   104  		fmt.Fprintf(f, "/* %s */ \n\n", mm.Name)
   105  		for _, k := range generic.Kinds {
   106  			if isOrd(k) {
   107  				if mm.ReqFloat && isntFloat(k) {
   108  
   109  				} else {
   110  					op := MaskCmpMethodTest{k, mm.Name}
   111  					testMaskCmpMethod.Execute(f, op)
   112  				}
   113  			}
   114  		}
   115  	}
   116  }