github.com/wzzhu/tensor@v0.9.24/genlib2/dense_maskedmethods.go (about)

     1  package main
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"reflect"
     7  	"text/template"
     8  )
     9  
    10  var maskcmpMethods = []struct {
    11  	Name     string
    12  	Desc     string
    13  	NumArgs  int
    14  	CmpFn    string
    15  	ReqFloat bool
    16  	Kinds    []reflect.Kind
    17  }{
    18  	{"MaskedEqual", "equal to ", 1, "a == x", false, nil},
    19  	{"MaskedNotEqual", "not equal to ", 1, "a != x", false, nil},
    20  	{"MaskedValues", " equal to ", 3, "math.Abs(float64(a-x)) <= delta", true, nil},
    21  	{"MaskedGreater", " greater than ", 1, "a > x", false, nil},
    22  	{"MaskedGreaterEqual", " greater than or equal to ", 1, "a >= x", false, nil},
    23  	{"MaskedLess", " less than ", 1, "a < x", false, nil},
    24  	{"MaskedLessEqual", " less than or equal to ", 1, "a <= x", false, nil},
    25  	{"MaskedInside", " inside range of ", 2, "(a >= x) && (a <= y)", false, nil},
    26  	{"MaskedOutside", " outside range of ", 2, "(a < x) || (a > y)", false, nil},
    27  }
    28  
    29  const maskCmpMethodRaw = `// {{.Name}} sets the mask to true where the corresponding data is {{.Desc}} val
    30  // Any values must be the same type as the tensor
    31  func (t *Dense) {{.Name}}({{if ge .NumArgs 1 -}} val1 interface{} {{end}} {{if ge .NumArgs 2 -}} , val2 interface{} {{end}}  {{if ge .NumArgs 3 -}} , val3 ...interface{}{{end}})(err error){
    32  	{{if .ReqFloat}}
    33  	if !isFloat(t.t) {
    34  			err = errors.Errorf("Can only do {{.Name}} with floating point types")
    35  			return
    36  	}
    37  	{{end}}
    38  	
    39  	if !t.IsMasked() {
    40  		t.makeMask()		
    41  	}	
    42      	
    43      {{$numargs := .NumArgs}}
    44  	{{$name := .Name}}
    45      {{$fn := .CmpFn}}	
    46  	{{$reqFloat := .ReqFloat}}	
    47  	switch t.t.Kind(){
    48  	{{range .Kinds -}}
    49  	{{if isParameterized . -}}
    50  	{{else -}}
    51  	{{if or (not (isOrd .)) (and $reqFloat (isntFloat .)) -}}
    52  	{{else -}}
    53  		case reflect.{{reflectKind .}}:
    54  			data := t.{{sliceOf .}}
    55  			mask := t.mask
    56  			{{if ge $numargs 1 -}} x := val1.({{asType .}}) {{end}}
    57  			{{if ge $numargs 2 -}} y := val2.({{asType .}}){{end}}
    58  			{{if ge $numargs 3 -}} 
    59  				{{if eq $name "MaskedValues"}} 
    60  					delta := float64(1.0e-8)
    61  					if len(val3) > 0 {
    62  					delta = float64(val3[0].({{asType .}})) + float64(y)*math.Abs(float64(x))
    63  					}
    64  				{{else}}
    65  					z := val3.({{asType .}})
    66  				{{end}}
    67  			{{end}}			
    68  			if t.maskIsSoft{
    69  					for i := range data {					
    70  					a := data[i]
    71  					mask[i] = ({{$fn}})
    72  				}
    73  			} else {
    74  				for i := range data {					
    75  					a := data[i]					
    76  					mask[i] = mask[i] || ({{$fn}})
    77  				}
    78  			}
    79  						
    80      {{end}}
    81      {{end}}
    82  	{{end}}
    83  }
    84  return nil
    85  }
    86  `
    87  
    88  var (
    89  	maskCmpMethod *template.Template
    90  )
    91  
    92  func init() {
    93  	maskCmpMethod = template.Must(template.New("maskcmpmethod").Funcs(funcs).Parse(maskCmpMethodRaw))
    94  }
    95  
    96  func generateDenseMaskedMethods(f io.Writer, generic Kinds) {
    97  	for _, mm := range maskcmpMethods {
    98  		mm.Kinds = generic.Kinds
    99  		fmt.Fprintf(f, "/* %s */ \n\n", mm.Name)
   100  		maskCmpMethod.Execute(f, mm)
   101  
   102  	}
   103  }