github.com/wzzhu/tensor@v0.9.24/genlib2/generic_map.go (about)

     1  package main
     2  
     3  import (
     4  	"io"
     5  	"reflect"
     6  	"text/template"
     7  )
     8  
     9  const (
    10  	fnErrSet = `if {{.Range}}[i], err = {{template "callFunc" .}}; handleNoOp(err) != nil {
    11  		return
    12  	}`
    13  	fnErrIncr = `var x {{asType .Kind}}
    14  	if x, err = {{template "callFunc" .}}; err != nil {
    15  		if err = handleNoOp(err);err != nil {
    16  			return
    17  		}
    18  	}
    19  	{{.Range}}[i] = x
    20  	`
    21  	simpleUnaryCallFunc = `{{template "symbol" .}}({{.Left}}[{{.Index0}}])`
    22  )
    23  
    24  type Map struct {
    25  	k    reflect.Kind
    26  	Iter bool
    27  	Incr bool
    28  	Err  bool
    29  }
    30  
    31  func (fn *Map) Name() string {
    32  	switch {
    33  	case fn.Iter && fn.Incr && fn.Err:
    34  		return "MapIterIncrErr"
    35  	case fn.Iter && fn.Incr && !fn.Err:
    36  		return "MapIterIncr"
    37  	case fn.Iter && !fn.Incr && fn.Err:
    38  		return "MapIterErr"
    39  	case fn.Iter && !fn.Incr && !fn.Err:
    40  		return "MapIter"
    41  	case !fn.Iter && fn.Incr && fn.Err:
    42  		return "MapIncrErr"
    43  	case !fn.Iter && fn.Incr && !fn.Err:
    44  		return "MapIncr"
    45  	case !fn.Iter && !fn.Incr && fn.Err:
    46  		return "MapErr"
    47  	default:
    48  		return "Map"
    49  	}
    50  }
    51  
    52  func (fn *Map) Arity() int             { return 1 }
    53  func (fn *Map) SymbolTemplate() string { return "fn" }
    54  func (fn *Map) TypeClass() TypeClass   { return nil }
    55  func (fn *Map) IsFunc() bool           { return true }
    56  func (fn *Map) Kind() reflect.Kind     { return fn.k }
    57  
    58  func (fn *Map) Signature() *Signature {
    59  	var retErr bool
    60  	paramNames := []string{"fn", "a"}
    61  	paramTemplates := []*template.Template{unaryFuncType, sliceType}
    62  	if fn.Iter {
    63  		paramNames = append(paramNames, "ait")
    64  		paramTemplates = append(paramTemplates, iteratorType)
    65  		retErr = true
    66  	}
    67  	if fn.Err {
    68  		paramTemplates[0] = unaryFuncErrType
    69  		retErr = true
    70  	}
    71  
    72  	return &Signature{
    73  		Name:           fn.Name(),
    74  		NameTemplate:   typeAnnotatedName,
    75  		ParamNames:     paramNames,
    76  		ParamTemplates: paramTemplates,
    77  		Kind:           fn.Kind(),
    78  		Err:            retErr,
    79  	}
    80  }
    81  
    82  func (fn *Map) WriteBody(w io.Writer) {
    83  	Range := "a"
    84  	Left := "a"
    85  
    86  	var T *template.Template
    87  	var IterName0 string
    88  	if fn.Iter {
    89  		T = template.Must(template.New(fn.Name()).Funcs(funcs).Parse(genericUnaryIterLoopRaw))
    90  		IterName0 = "ait"
    91  	} else {
    92  		T = template.Must(template.New(fn.Name()).Funcs(funcs).Parse(genericLoopRaw))
    93  	}
    94  
    95  	switch {
    96  	case fn.Incr && fn.Err:
    97  		template.Must(T.New("loopbody").Funcs(funcs).Parse(fnErrIncr))
    98  	case fn.Incr && !fn.Err:
    99  		template.Must(T.New("loopbody").Funcs(funcs).Parse(basicIncr))
   100  	case !fn.Incr && fn.Err:
   101  		template.Must(T.New("loopbody").Funcs(funcs).Parse(fnErrSet))
   102  	default:
   103  		template.Must(T.New("loopbody").Funcs(funcs).Parse(basicSet))
   104  	}
   105  	template.Must(T.New("callFunc").Funcs(funcs).Parse(simpleUnaryCallFunc))
   106  	template.Must(T.New("symbol").Funcs(funcs).Parse("fn"))
   107  	template.Must(T.New("opDo").Funcs(funcs).Parse(""))
   108  	template.Must(T.New("check").Funcs(funcs).Parse(""))
   109  
   110  	lb := LoopBody{
   111  		TypedOp:   fn,
   112  		Range:     Range,
   113  		Left:      Left,
   114  		Index0:    "i",
   115  		IterName0: IterName0,
   116  	}
   117  	T.Execute(w, lb)
   118  }
   119  
   120  func (fn *Map) Write(w io.Writer) {
   121  	sig := fn.Signature()
   122  	w.Write([]byte("func "))
   123  	sig.Write(w)
   124  	w.Write([]byte("{\n"))
   125  	fn.WriteBody(w)
   126  	w.Write([]byte("\nreturn \n"))
   127  	w.Write([]byte("}\n\n"))
   128  }
   129  
   130  func makeGenericMaps(incr bool) (retVal []*Map) {
   131  	for _, k := range allKinds {
   132  		if incr {
   133  			if !isAddable(k) {
   134  				continue
   135  			}
   136  		}
   137  		if isParameterized(k) {
   138  			continue
   139  		}
   140  
   141  		m := &Map{k: k}
   142  		if incr {
   143  			m.Incr = true
   144  		}
   145  		retVal = append(retVal, m)
   146  	}
   147  	return
   148  }
   149  
   150  func generateGenericMap(f io.Writer, ak Kinds) {
   151  	gen0 := makeGenericMaps(false)
   152  	for _, m := range gen0 {
   153  		m.Write(f)
   154  		m.Err = true
   155  	}
   156  	for _, m := range gen0 {
   157  		m.Write(f)
   158  		m.Err = false
   159  		m.Iter = true
   160  	}
   161  	for _, m := range gen0 {
   162  		m.Write(f)
   163  		m.Err = true
   164  	}
   165  	for _, m := range gen0 {
   166  		m.Write(f)
   167  	}
   168  
   169  	gen1 := makeGenericMaps(true)
   170  	for _, m := range gen1 {
   171  		m.Write(f)
   172  		m.Err = true
   173  	}
   174  	for _, m := range gen1 {
   175  		m.Write(f)
   176  		m.Err = false
   177  		m.Iter = true
   178  	}
   179  	for _, m := range gen1 {
   180  		m.Write(f)
   181  		m.Err = true
   182  	}
   183  	for _, m := range gen1 {
   184  		m.Write(f)
   185  	}
   186  }