github.com/wzzhu/tensor@v0.9.24/genlib2/generic_argmethods.go (about)

     1  package main
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"reflect"
     7  	"text/template"
     8  )
     9  
    10  const argMethodLoopBody = `v := a[i]
    11  if !set {
    12  	f = v
    13  	{{.ArgX}} = i 
    14  	set = true
    15  	continue
    16  }
    17  {{if isFloat .Kind -}}
    18  if {{mathPkg .Kind}}IsNaN(v) || {{mathPkg .Kind}}IsInf(v, {{if eq .ArgX "min"}}-{{end}}1) {
    19  	{{.ArgX}} = i
    20  	return {{.ArgX}}
    21  }
    22  {{end -}}
    23  if v {{if eq .ArgX "max"}}>{{else}}<{{end}} f {
    24  	{{.ArgX}} = i
    25  	f = v
    26  }
    27  `
    28  
    29  const argMethodIter = `data := t.{{sliceOf .}}
    30  tmp := make([]{{asType .}}, 0, lastSize)
    31  for next, err = it.Next(); err == nil; next; err = it.Next() {
    32  	tmp = append(tmp, data[next])
    33  	if len(tmp) == lastSize {
    34  		am := {{.ArgX | title}}(tmp)
    35  		indices = append(indices, am)
    36  		tmp = tmp[:0]
    37  	}
    38  }
    39  return
    40  `
    41  
    42  type GenericArgMethod struct {
    43  	ArgX   string
    44  	Masked bool
    45  	Range  string
    46  
    47  	Kind reflect.Kind
    48  }
    49  
    50  func (fn *GenericArgMethod) Name() string {
    51  	switch {
    52  	case fn.ArgX == "max" && fn.Masked:
    53  		return "ArgmaxMasked"
    54  	case fn.ArgX == "min" && fn.Masked:
    55  		return "ArgminMasked"
    56  	case fn.ArgX == "max" && !fn.Masked:
    57  		return "Argmax"
    58  	case fn.ArgX == "min" && !fn.Masked:
    59  		return "Argmin"
    60  	}
    61  	panic("Unreachable")
    62  }
    63  
    64  func (fn *GenericArgMethod) Signature() *Signature {
    65  	paramNames := []string{"a"}
    66  	paramTemplates := []*template.Template{sliceType}
    67  
    68  	if fn.Masked {
    69  		paramNames = append(paramNames, "mask")
    70  		paramTemplates = append(paramTemplates, boolsType)
    71  	}
    72  	return &Signature{
    73  		Name:           fn.Name(),
    74  		NameTemplate:   typeAnnotatedName,
    75  		ParamNames:     paramNames,
    76  		ParamTemplates: paramTemplates,
    77  		Kind:           fn.Kind,
    78  	}
    79  }
    80  
    81  func (fn *GenericArgMethod) WriteBody(w io.Writer) {
    82  	T := template.Must(template.New(fn.Name()).Funcs(funcs).Parse(genericLoopRaw))
    83  	template.Must(T.New("loopbody").Parse(argMethodLoopBody))
    84  	if fn.Masked {
    85  		template.Must(T.New("check").Parse(maskCheck))
    86  	} else {
    87  		template.Must(T.New("check").Parse(""))
    88  	}
    89  	genericArgmaxVarDecl.Execute(w, fn)
    90  	T.Execute(w, fn)
    91  	fmt.Fprintf(w, "\nreturn %s\n", fn.ArgX)
    92  }
    93  
    94  func (fn *GenericArgMethod) Write(w io.Writer) {
    95  	sig := fn.Signature()
    96  	w.Write([]byte("func "))
    97  	sig.Write(w)
    98  	w.Write([]byte("int {\n"))
    99  	fn.WriteBody(w)
   100  	w.Write([]byte("}\n\n"))
   101  }
   102  
   103  func generateGenericArgMethods(f io.Writer, ak Kinds) {
   104  	var argMethods []*GenericArgMethod
   105  	for _, k := range ak.Kinds {
   106  		if !isOrd(k) {
   107  			continue
   108  		}
   109  		m := &GenericArgMethod{
   110  			ArgX:  "max",
   111  			Kind:  k,
   112  			Range: "a",
   113  		}
   114  		argMethods = append(argMethods, m)
   115  	}
   116  
   117  	// argmax
   118  	for _, m := range argMethods {
   119  		m.Write(f)
   120  		m.Masked = true
   121  	}
   122  
   123  	for _, m := range argMethods {
   124  		m.Write(f)
   125  		m.Masked = false
   126  		m.ArgX = "min"
   127  	}
   128  	// argmin
   129  	for _, m := range argMethods {
   130  		m.Write(f)
   131  		m.Masked = true
   132  	}
   133  
   134  	for _, m := range argMethods {
   135  		m.Write(f)
   136  	}
   137  }