github.com/wzzhu/tensor@v0.9.24/genlib2/generic_unary.go (about)

     1  package main
     2  
     3  import (
     4  	"io"
     5  	"text/template"
     6  )
     7  
     8  type GenericUnary struct {
     9  	TypedUnaryOp
    10  	Iter bool
    11  	Cond bool
    12  }
    13  
    14  func (fn *GenericUnary) Name() string {
    15  	if fn.Iter {
    16  		return fn.TypedUnaryOp.Name() + "Iter"
    17  	}
    18  	return fn.TypedUnaryOp.Name()
    19  }
    20  
    21  func (fn *GenericUnary) Signature() *Signature {
    22  	paramNames := []string{"a"}
    23  	paramTemplates := []*template.Template{sliceType}
    24  	var err bool
    25  	if fn.Iter {
    26  		paramNames = append(paramNames, "ait")
    27  		paramTemplates = append(paramTemplates, iteratorType)
    28  		err = true
    29  	}
    30  	return &Signature{
    31  		Name:           fn.Name(),
    32  		NameTemplate:   typeAnnotatedName,
    33  		ParamNames:     paramNames,
    34  		ParamTemplates: paramTemplates,
    35  
    36  		Kind: fn.Kind(),
    37  		Err:  err,
    38  	}
    39  }
    40  
    41  func (fn *GenericUnary) WriteBody(w io.Writer) {
    42  	var IterName0 string
    43  	T := template.New(fn.Name()).Funcs(funcs)
    44  
    45  	if fn.Iter {
    46  		T = template.Must(T.Parse(genericUnaryIterLoopRaw))
    47  		IterName0 = "ait"
    48  	} else {
    49  		T = template.Must(T.Parse(genericLoopRaw))
    50  	}
    51  	if fn.Cond {
    52  		template.Must(T.New("loopbody").Parse(fn.SymbolTemplate()))
    53  	} else {
    54  		template.Must(T.New("loopbody").Parse(basicSet))
    55  		template.Must(T.New("symbol").Parse(fn.SymbolTemplate()))
    56  	}
    57  	template.Must(T.New("opDo").Parse(unaryOpDo))
    58  	template.Must(T.New("callFunc").Parse(unaryOpCallFunc))
    59  	template.Must(T.New("check").Parse(""))
    60  
    61  	lb := LoopBody{
    62  		TypedOp:   fn.TypedUnaryOp,
    63  		Range:     "a",
    64  		Left:      "a",
    65  		Index0:    "i",
    66  		IterName0: IterName0,
    67  	}
    68  	T.Execute(w, lb)
    69  }
    70  
    71  func (fn *GenericUnary) Write(w io.Writer) {
    72  	sig := fn.Signature()
    73  	w.Write([]byte("func "))
    74  	sig.Write(w)
    75  	w.Write([]byte("{\n"))
    76  	fn.WriteBody(w)
    77  	if sig.Err {
    78  		w.Write([]byte("\nreturn\n"))
    79  	}
    80  	w.Write([]byte("}\n\n"))
    81  }
    82  
    83  func generateGenericUncondUnary(f io.Writer, ak Kinds) {
    84  	var gen []*GenericUnary
    85  	for _, tu := range typedUncondUnaries {
    86  		if tc := tu.TypeClass(); tc != nil && !tc(tu.Kind()) {
    87  			continue
    88  		}
    89  		fn := &GenericUnary{
    90  			TypedUnaryOp: tu,
    91  		}
    92  		gen = append(gen, fn)
    93  	}
    94  
    95  	for _, g := range gen {
    96  		g.Write(f)
    97  		g.Iter = true
    98  	}
    99  	for _, g := range gen {
   100  		g.Write(f)
   101  	}
   102  }
   103  
   104  func generateGenericCondUnary(f io.Writer, ak Kinds) {
   105  	var gen []*GenericUnary
   106  	for _, tu := range typedCondUnaries {
   107  		if tc := tu.TypeClass(); tc != nil && !tc(tu.Kind()) {
   108  			continue
   109  		}
   110  		// special case for cmplx
   111  		if isComplex(tu.Kind()) {
   112  			continue
   113  		}
   114  
   115  		fn := &GenericUnary{
   116  			TypedUnaryOp: tu,
   117  			Cond:         true,
   118  		}
   119  		gen = append(gen, fn)
   120  	}
   121  	for _, g := range gen {
   122  		g.Write(f)
   123  		g.Iter = true
   124  	}
   125  	for _, g := range gen {
   126  		g.Write(f)
   127  	}
   128  }
   129  
   130  /*
   131  SPECIAL CASES
   132  */
   133  
   134  type GenericUnarySpecial struct {
   135  	*GenericUnary
   136  	AdditionalParams         []string
   137  	AdditionalParamTemplates []*template.Template
   138  }
   139  
   140  func (fn *GenericUnarySpecial) Signature() *Signature {
   141  	sig := fn.GenericUnary.Signature()
   142  	sig.ParamNames = append(sig.ParamNames, fn.AdditionalParams...)
   143  	sig.ParamTemplates = append(sig.ParamTemplates, fn.AdditionalParamTemplates...)
   144  	return sig
   145  }
   146  
   147  func (fn *GenericUnarySpecial) Write(w io.Writer) {
   148  	sig := fn.Signature()
   149  	w.Write([]byte("func "))
   150  	sig.Write(w)
   151  	w.Write([]byte("{\n"))
   152  	fn.WriteBody(w)
   153  	if sig.Err {
   154  		w.Write([]byte("\nreturn\n"))
   155  	}
   156  	w.Write([]byte("}\n\n"))
   157  }
   158  
   159  func (fn *GenericUnarySpecial) WriteBody(w io.Writer) {
   160  	var IterName0 string
   161  	T := template.New(fn.Name()).Funcs(funcs)
   162  
   163  	if fn.Iter {
   164  		T = template.Must(T.Parse(genericUnaryIterLoopRaw))
   165  		IterName0 = "ait"
   166  	} else {
   167  		T = template.Must(T.Parse(genericLoopRaw))
   168  	}
   169  	template.Must(T.New("loopbody").Parse(clampBody))
   170  	template.Must(T.New("opDo").Parse(unaryOpDo))
   171  	template.Must(T.New("callFunc").Parse(unaryOpCallFunc))
   172  	template.Must(T.New("check").Parse(""))
   173  
   174  	lb := LoopBody{
   175  		TypedOp:   fn.TypedUnaryOp,
   176  		Range:     "a",
   177  		Left:      "a",
   178  		Index0:    "i",
   179  		IterName0: IterName0,
   180  	}
   181  	T.Execute(w, lb)
   182  }
   183  
   184  func generateSpecialGenericUnaries(f io.Writer, ak Kinds) {
   185  	var gen []*GenericUnarySpecial
   186  	for _, tu := range typedSpecialUnaries {
   187  		if tc := tu.TypeClass(); tc != nil && !tc(tu.Kind()) {
   188  			continue
   189  		}
   190  
   191  		additional := tu.UnaryOp.(specialUnaryOp).additionalParams
   192  		tmpls := make([]*template.Template, len(additional))
   193  		for i := range tmpls {
   194  			tmpls[i] = scalarType
   195  		}
   196  		fn := &GenericUnarySpecial{
   197  			GenericUnary: &GenericUnary{
   198  				TypedUnaryOp: tu,
   199  			},
   200  			AdditionalParams:         additional,
   201  			AdditionalParamTemplates: tmpls,
   202  		}
   203  		gen = append(gen, fn)
   204  	}
   205  
   206  	for _, fn := range gen {
   207  		fn.Write(f)
   208  		fn.Iter = true
   209  	}
   210  
   211  	for _, fn := range gen {
   212  		fn.Write(f)
   213  	}
   214  }