github.com/wzzhu/tensor@v0.9.24/genlib2/dense_maskedmethods.go (about) 1 package main 2 3 import ( 4 "fmt" 5 "io" 6 "reflect" 7 "text/template" 8 ) 9 10 var maskcmpMethods = []struct { 11 Name string 12 Desc string 13 NumArgs int 14 CmpFn string 15 ReqFloat bool 16 Kinds []reflect.Kind 17 }{ 18 {"MaskedEqual", "equal to ", 1, "a == x", false, nil}, 19 {"MaskedNotEqual", "not equal to ", 1, "a != x", false, nil}, 20 {"MaskedValues", " equal to ", 3, "math.Abs(float64(a-x)) <= delta", true, nil}, 21 {"MaskedGreater", " greater than ", 1, "a > x", false, nil}, 22 {"MaskedGreaterEqual", " greater than or equal to ", 1, "a >= x", false, nil}, 23 {"MaskedLess", " less than ", 1, "a < x", false, nil}, 24 {"MaskedLessEqual", " less than or equal to ", 1, "a <= x", false, nil}, 25 {"MaskedInside", " inside range of ", 2, "(a >= x) && (a <= y)", false, nil}, 26 {"MaskedOutside", " outside range of ", 2, "(a < x) || (a > y)", false, nil}, 27 } 28 29 const maskCmpMethodRaw = `// {{.Name}} sets the mask to true where the corresponding data is {{.Desc}} val 30 // Any values must be the same type as the tensor 31 func (t *Dense) {{.Name}}({{if ge .NumArgs 1 -}} val1 interface{} {{end}} {{if ge .NumArgs 2 -}} , val2 interface{} {{end}} {{if ge .NumArgs 3 -}} , val3 ...interface{}{{end}})(err error){ 32 {{if .ReqFloat}} 33 if !isFloat(t.t) { 34 err = errors.Errorf("Can only do {{.Name}} with floating point types") 35 return 36 } 37 {{end}} 38 39 if !t.IsMasked() { 40 t.makeMask() 41 } 42 43 {{$numargs := .NumArgs}} 44 {{$name := .Name}} 45 {{$fn := .CmpFn}} 46 {{$reqFloat := .ReqFloat}} 47 switch t.t.Kind(){ 48 {{range .Kinds -}} 49 {{if isParameterized . -}} 50 {{else -}} 51 {{if or (not (isOrd .)) (and $reqFloat (isntFloat .)) -}} 52 {{else -}} 53 case reflect.{{reflectKind .}}: 54 data := t.{{sliceOf .}} 55 mask := t.mask 56 {{if ge $numargs 1 -}} x := val1.({{asType .}}) {{end}} 57 {{if ge $numargs 2 -}} y := val2.({{asType .}}){{end}} 58 {{if ge $numargs 3 -}} 59 {{if eq $name "MaskedValues"}} 60 delta := float64(1.0e-8) 61 if len(val3) > 0 { 62 delta = float64(val3[0].({{asType .}})) + float64(y)*math.Abs(float64(x)) 63 } 64 {{else}} 65 z := val3.({{asType .}}) 66 {{end}} 67 {{end}} 68 if t.maskIsSoft{ 69 for i := range data { 70 a := data[i] 71 mask[i] = ({{$fn}}) 72 } 73 } else { 74 for i := range data { 75 a := data[i] 76 mask[i] = mask[i] || ({{$fn}}) 77 } 78 } 79 80 {{end}} 81 {{end}} 82 {{end}} 83 } 84 return nil 85 } 86 ` 87 88 var ( 89 maskCmpMethod *template.Template 90 ) 91 92 func init() { 93 maskCmpMethod = template.Must(template.New("maskcmpmethod").Funcs(funcs).Parse(maskCmpMethodRaw)) 94 } 95 96 func generateDenseMaskedMethods(f io.Writer, generic Kinds) { 97 for _, mm := range maskcmpMethods { 98 mm.Kinds = generic.Kinds 99 fmt.Fprintf(f, "/* %s */ \n\n", mm.Name) 100 maskCmpMethod.Execute(f, mm) 101 102 } 103 }