github.com/wzzhu/tensor@v0.9.24/genlib2/reduction_specialization.go (about)

     1  package main
     2  
     3  import (
     4  	"io"
     5  	"reflect"
     6  	"text/template"
     7  )
     8  
     9  type ReductionOp struct {
    10  	OpName      string
    11  	VecVec      string // sum(a, b []T)
    12  	OpOfVec     string // sum([]T)
    13  	GenericName string // sum(T, T) T
    14  	Kinds       []reflect.Kind
    15  	Typeclass   TypeClass
    16  }
    17  
    18  var reductionOps = []ReductionOp{
    19  	{OpName: "Sum", VecVec: "VecAdd", OpOfVec: "Sum", GenericName: "Add", Typeclass: isNumber},
    20  	{OpName: "Max", VecVec: "VecMax", OpOfVec: "SliceMax", GenericName: "Max", Typeclass: isNonComplexNumber},
    21  	{OpName: "Min", VecVec: "VecMin", OpOfVec: "SliceMin", GenericName: "Min", Typeclass: isNonComplexNumber},
    22  }
    23  
    24  const reductionSpecializationRaw = `func Monotonic{{.OpName | title}}(t reflect.Type, a *storage.Header) (retVal interface{}, err error) {
    25  	switch t {
    26  		{{$opOfVec := .OpOfVec -}}
    27  		{{range .Kinds -}}
    28  		{{if isNumber . -}}
    29  	case {{reflectKind .}}:
    30  		retVal = {{$opOfVec}}{{short .}}(a.{{sliceOf .}})
    31  		return
    32  		{{end -}}
    33  		{{end -}}
    34  	default:
    35  		err = errors.Errorf("Cannot perform {{.OpName}} on %v", t)
    36  		return
    37  	}
    38  }
    39  
    40  func {{.OpName | title}}Methods(t reflect.Type)(firstFn, lasFn, defaultFn interface{}, err error) {
    41  	{{$vecVec := .VecVec -}}
    42  	{{$opOfVec := .OpOfVec -}}
    43  	{{$genericName := .GenericName -}}
    44  	switch t {
    45  		{{range .Kinds -}}
    46  		{{if isNumber . -}}
    47  	case {{reflectKind .}}:
    48  		return {{$vecVec}}{{short .}}, {{$opOfVec}}{{short .}}, {{$genericName}}{{short .}}, nil
    49  		{{end -}}
    50  		{{end -}}
    51  	default:
    52  		return nil, nil, nil, errors.Errorf("No methods found for {{.OpName}} for %v", t)
    53  	}
    54  }
    55  
    56  `
    57  
    58  var reductionSpecialization *template.Template
    59  
    60  func init() {
    61  	reductionSpecialization = template.Must(template.New("reduction specialization").Funcs(funcs).Parse(reductionSpecializationRaw))
    62  }
    63  
    64  func generateReductionSpecialization(f io.Writer, ak Kinds) {
    65  	for _, op := range reductionOps {
    66  		for _, k := range ak.Kinds {
    67  			if !op.Typeclass(k) {
    68  				continue
    69  			}
    70  			op.Kinds = append(op.Kinds, k)
    71  		}
    72  		reductionSpecialization.Execute(f, op)
    73  	}
    74  }