github.com/wzzhu/tensor@v0.9.24/genlib2/generic_argmethods.go (about) 1 package main 2 3 import ( 4 "fmt" 5 "io" 6 "reflect" 7 "text/template" 8 ) 9 10 const argMethodLoopBody = `v := a[i] 11 if !set { 12 f = v 13 {{.ArgX}} = i 14 set = true 15 continue 16 } 17 {{if isFloat .Kind -}} 18 if {{mathPkg .Kind}}IsNaN(v) || {{mathPkg .Kind}}IsInf(v, {{if eq .ArgX "min"}}-{{end}}1) { 19 {{.ArgX}} = i 20 return {{.ArgX}} 21 } 22 {{end -}} 23 if v {{if eq .ArgX "max"}}>{{else}}<{{end}} f { 24 {{.ArgX}} = i 25 f = v 26 } 27 ` 28 29 const argMethodIter = `data := t.{{sliceOf .}} 30 tmp := make([]{{asType .}}, 0, lastSize) 31 for next, err = it.Next(); err == nil; next; err = it.Next() { 32 tmp = append(tmp, data[next]) 33 if len(tmp) == lastSize { 34 am := {{.ArgX | title}}(tmp) 35 indices = append(indices, am) 36 tmp = tmp[:0] 37 } 38 } 39 return 40 ` 41 42 type GenericArgMethod struct { 43 ArgX string 44 Masked bool 45 Range string 46 47 Kind reflect.Kind 48 } 49 50 func (fn *GenericArgMethod) Name() string { 51 switch { 52 case fn.ArgX == "max" && fn.Masked: 53 return "ArgmaxMasked" 54 case fn.ArgX == "min" && fn.Masked: 55 return "ArgminMasked" 56 case fn.ArgX == "max" && !fn.Masked: 57 return "Argmax" 58 case fn.ArgX == "min" && !fn.Masked: 59 return "Argmin" 60 } 61 panic("Unreachable") 62 } 63 64 func (fn *GenericArgMethod) Signature() *Signature { 65 paramNames := []string{"a"} 66 paramTemplates := []*template.Template{sliceType} 67 68 if fn.Masked { 69 paramNames = append(paramNames, "mask") 70 paramTemplates = append(paramTemplates, boolsType) 71 } 72 return &Signature{ 73 Name: fn.Name(), 74 NameTemplate: typeAnnotatedName, 75 ParamNames: paramNames, 76 ParamTemplates: paramTemplates, 77 Kind: fn.Kind, 78 } 79 } 80 81 func (fn *GenericArgMethod) WriteBody(w io.Writer) { 82 T := template.Must(template.New(fn.Name()).Funcs(funcs).Parse(genericLoopRaw)) 83 template.Must(T.New("loopbody").Parse(argMethodLoopBody)) 84 if fn.Masked { 85 template.Must(T.New("check").Parse(maskCheck)) 86 } else { 87 template.Must(T.New("check").Parse("")) 88 } 89 genericArgmaxVarDecl.Execute(w, fn) 90 T.Execute(w, fn) 91 fmt.Fprintf(w, "\nreturn %s\n", fn.ArgX) 92 } 93 94 func (fn *GenericArgMethod) Write(w io.Writer) { 95 sig := fn.Signature() 96 w.Write([]byte("func ")) 97 sig.Write(w) 98 w.Write([]byte("int {\n")) 99 fn.WriteBody(w) 100 w.Write([]byte("}\n\n")) 101 } 102 103 func generateGenericArgMethods(f io.Writer, ak Kinds) { 104 var argMethods []*GenericArgMethod 105 for _, k := range ak.Kinds { 106 if !isOrd(k) { 107 continue 108 } 109 m := &GenericArgMethod{ 110 ArgX: "max", 111 Kind: k, 112 Range: "a", 113 } 114 argMethods = append(argMethods, m) 115 } 116 117 // argmax 118 for _, m := range argMethods { 119 m.Write(f) 120 m.Masked = true 121 } 122 123 for _, m := range argMethods { 124 m.Write(f) 125 m.Masked = false 126 m.ArgX = "min" 127 } 128 // argmin 129 for _, m := range argMethods { 130 m.Write(f) 131 m.Masked = true 132 } 133 134 for _, m := range argMethods { 135 m.Write(f) 136 } 137 }