github.com/wzzhu/tensor@v0.9.24/genlib2/generic_arith.go (about)

     1  package main
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"strings"
     7  	"text/template"
     8  )
     9  
    10  type GenericVecVecArith struct {
    11  	TypedBinOp
    12  	Iter          bool
    13  	Incr          bool
    14  	WithRecv      bool      // not many BinOps have this
    15  	Check         TypeClass // can be nil
    16  	CheckTemplate string
    17  }
    18  
    19  func (fn *GenericVecVecArith) Name() string {
    20  	switch {
    21  	case fn.Iter && fn.Incr:
    22  		return fmt.Sprintf("%sIterIncr", fn.TypedBinOp.Name())
    23  	case fn.Iter && !fn.Incr:
    24  		return fmt.Sprintf("%sIter", fn.TypedBinOp.Name())
    25  	case !fn.Iter && fn.Incr:
    26  		return fmt.Sprintf("%sIncr", fn.TypedBinOp.Name())
    27  	case fn.WithRecv:
    28  		return fmt.Sprintf("%vRecv", fn.TypedBinOp.Name())
    29  	default:
    30  		return fmt.Sprintf("Vec%s", fn.TypedBinOp.Name())
    31  	}
    32  }
    33  
    34  func (fn *GenericVecVecArith) Signature() *Signature {
    35  	var paramNames []string
    36  	var paramTemplates []*template.Template
    37  	var err bool
    38  
    39  	switch {
    40  	case fn.Iter && fn.Incr:
    41  		paramNames = []string{"a", "b", "incr", "ait", "bit", "iit"}
    42  		paramTemplates = []*template.Template{sliceType, sliceType, sliceType, iteratorType, iteratorType, iteratorType}
    43  		err = true
    44  	case fn.Iter && !fn.Incr:
    45  		paramNames = []string{"a", "b", "ait", "bit"}
    46  		paramTemplates = []*template.Template{sliceType, sliceType, iteratorType, iteratorType}
    47  		err = true
    48  	case !fn.Iter && fn.Incr:
    49  		paramNames = []string{"a", "b", "incr"}
    50  		paramTemplates = []*template.Template{sliceType, sliceType, sliceType}
    51  	case fn.WithRecv:
    52  		paramNames = []string{"a", "b", "recv"}
    53  		paramTemplates = []*template.Template{sliceType, sliceType, sliceType}
    54  	default:
    55  		paramNames = []string{"a", "b"}
    56  		paramTemplates = []*template.Template{sliceType, sliceType}
    57  	}
    58  
    59  	if fn.Check != nil {
    60  		err = true
    61  	}
    62  
    63  	return &Signature{
    64  		Name:           fn.Name(),
    65  		NameTemplate:   typeAnnotatedName,
    66  		ParamNames:     paramNames,
    67  		ParamTemplates: paramTemplates,
    68  
    69  		Kind: fn.Kind(),
    70  		Err:  err,
    71  	}
    72  }
    73  
    74  func (fn *GenericVecVecArith) WriteBody(w io.Writer) {
    75  	var Range, Left, Right string
    76  	var Index0, Index1, Index2 string
    77  	var IterName0, IterName1, IterName2 string
    78  	var T *template.Template
    79  
    80  	Range = "a"
    81  	Index0 = "i"
    82  	Index1 = "j"
    83  	Left = "a[i]"
    84  	Right = "b[j]"
    85  
    86  	T = template.New(fn.Name()).Funcs(funcs)
    87  	switch {
    88  	case fn.Iter && fn.Incr:
    89  		Range = "incr"
    90  		Index2 = "k"
    91  		IterName0 = "ait"
    92  		IterName1 = "bit"
    93  		IterName2 = "iit"
    94  		T = template.Must(T.Parse(genericTernaryIterLoopRaw))
    95  		template.Must(T.New("loopbody").Parse(iterIncrLoopBody))
    96  	case fn.Iter && !fn.Incr:
    97  		IterName0 = "ait"
    98  		IterName1 = "bit"
    99  		T = template.Must(T.Parse(genericBinaryIterLoopRaw))
   100  		template.Must(T.New("loopbody").Parse(basicSet))
   101  	case !fn.Iter && fn.Incr:
   102  		Range = "incr"
   103  		Right = "b[i]"
   104  		T = template.Must(T.Parse(genericLoopRaw))
   105  		template.Must(T.New("loopbody").Parse(basicIncr))
   106  	case fn.WithRecv:
   107  		Range = "recv"
   108  		Right = "b[i]"
   109  		T = template.Must(T.Parse(genericLoopRaw))
   110  		template.Must(T.New("loopbody").Parse(basicSet))
   111  	default:
   112  		Right = "b[i]"
   113  		T = template.Must(T.Parse(genericLoopRaw))
   114  		template.Must(T.New("loopbody").Parse(basicSet))
   115  	}
   116  	template.Must(T.New("callFunc").Parse(binOpCallFunc))
   117  	template.Must(T.New("opDo").Parse(binOpDo))
   118  	template.Must(T.New("symbol").Parse(fn.SymbolTemplate()))
   119  
   120  	if fn.Check != nil && fn.Check(fn.Kind()) {
   121  		w.Write([]byte("var errs errorIndices\n"))
   122  	}
   123  	template.Must(T.New("check").Parse(fn.CheckTemplate))
   124  
   125  	lb := LoopBody{
   126  		TypedOp: fn.TypedBinOp,
   127  		Range:   Range,
   128  		Left:    Left,
   129  		Right:   Right,
   130  
   131  		Index0: Index0,
   132  		Index1: Index1,
   133  		Index2: Index2,
   134  
   135  		IterName0: IterName0,
   136  		IterName1: IterName1,
   137  		IterName2: IterName2,
   138  	}
   139  	T.Execute(w, lb)
   140  }
   141  
   142  func (fn *GenericVecVecArith) Write(w io.Writer) {
   143  	sig := fn.Signature()
   144  	if !fn.Iter && isFloat(fn.Kind()) && !fn.WithRecv {
   145  		// golinkPragma.Execute(w, fn)
   146  		w.Write([]byte("func "))
   147  		sig.Write(w)
   148  		if fn.Incr {
   149  			fmt.Fprintf(w, "{ %v%v(a, b, incr)}\n", vecPkg(fn.Kind()), getalias(fn.Name()))
   150  		} else {
   151  			fmt.Fprintf(w, "{ %v%v(a, b)}\n", vecPkg(fn.Kind()), getalias(fn.Name()))
   152  		}
   153  		return
   154  	}
   155  
   156  	w.Write([]byte("func "))
   157  	sig.Write(w)
   158  
   159  	switch {
   160  	case !fn.Iter && fn.Incr:
   161  		w.Write([]byte("{\na = a[:len(a)]; b = b[:len(a)]; incr = incr[:len(a)]\n"))
   162  	case fn.WithRecv:
   163  		w.Write([]byte("{\na = a[:len(recv)]; b = b[:len(recv)]\n"))
   164  	case !fn.Iter && !fn.Incr && !fn.WithRecv:
   165  		w.Write([]byte("{\na = a[:len(a)]; b = b[:len(a)]\n"))
   166  	default:
   167  		w.Write([]byte("{\n"))
   168  	}
   169  	fn.WriteBody(w)
   170  	if sig.Err {
   171  		if fn.Check != nil {
   172  			w.Write([]byte("\nif err != nil {\n return\n}\nif len(errs) > 0 {\n return errs }\nreturn nil"))
   173  		} else {
   174  			w.Write([]byte("\nreturn\n"))
   175  		}
   176  	}
   177  	w.Write([]byte("}\n\n"))
   178  }
   179  
   180  type GenericMixedArith struct {
   181  	GenericVecVecArith
   182  	LeftVec bool
   183  }
   184  
   185  func (fn *GenericMixedArith) Name() string {
   186  	n := fn.GenericVecVecArith.Name()
   187  	n = strings.TrimPrefix(n, "Vec")
   188  	if fn.LeftVec {
   189  		n += "VS"
   190  	} else {
   191  		n += "SV"
   192  	}
   193  	return n
   194  }
   195  
   196  func (fn *GenericMixedArith) Signature() *Signature {
   197  	var paramNames []string
   198  	var paramTemplates []*template.Template
   199  	var err bool
   200  
   201  	switch {
   202  	case fn.Iter && fn.Incr:
   203  		paramNames = []string{"a", "b", "incr", "ait", "iit"}
   204  		paramTemplates = []*template.Template{sliceType, sliceType, sliceType, iteratorType, iteratorType}
   205  		if fn.LeftVec {
   206  			paramTemplates[1] = scalarType
   207  		} else {
   208  			paramTemplates[0] = scalarType
   209  			paramNames[3] = "bit"
   210  		}
   211  		err = true
   212  	case fn.Iter && !fn.Incr:
   213  		paramNames = []string{"a", "b", "ait"}
   214  		paramTemplates = []*template.Template{sliceType, sliceType, iteratorType}
   215  		if fn.LeftVec {
   216  			paramTemplates[1] = scalarType
   217  		} else {
   218  			paramTemplates[0] = scalarType
   219  			paramNames[2] = "bit"
   220  		}
   221  
   222  		err = true
   223  	case !fn.Iter && fn.Incr:
   224  		paramNames = []string{"a", "b", "incr"}
   225  		paramTemplates = []*template.Template{sliceType, sliceType, sliceType}
   226  		if fn.LeftVec {
   227  			paramTemplates[1] = scalarType
   228  		} else {
   229  			paramTemplates[0] = scalarType
   230  		}
   231  
   232  	default:
   233  		paramNames = []string{"a", "b"}
   234  		paramTemplates = []*template.Template{sliceType, sliceType}
   235  		if fn.LeftVec {
   236  			paramTemplates[1] = scalarType
   237  		} else {
   238  			paramTemplates[0] = scalarType
   239  		}
   240  	}
   241  
   242  	if fn.Check != nil {
   243  		err = true
   244  	}
   245  
   246  	return &Signature{
   247  		Name:           fn.Name(),
   248  		NameTemplate:   typeAnnotatedName,
   249  		ParamNames:     paramNames,
   250  		ParamTemplates: paramTemplates,
   251  
   252  		Kind: fn.Kind(),
   253  		Err:  err,
   254  	}
   255  }
   256  
   257  func (fn *GenericMixedArith) WriteBody(w io.Writer) {
   258  	var Range, Left, Right string
   259  	var Index0, Index1 string
   260  	var IterName0, IterName1 string
   261  
   262  	Range = "a"
   263  	Left = "a[i]"
   264  	Right = "b[i]"
   265  	Index0 = "i"
   266  
   267  	T := template.New(fn.Name()).Funcs(funcs)
   268  	switch {
   269  	case fn.Iter && fn.Incr:
   270  		Range = "incr"
   271  		T = template.Must(T.Parse(genericBinaryIterLoopRaw))
   272  		template.Must(T.New("loopbody").Parse(iterIncrLoopBody))
   273  	case fn.Iter && !fn.Incr:
   274  		T = template.Must(T.Parse(genericUnaryIterLoopRaw))
   275  		template.Must(T.New("loopbody").Parse(basicSet))
   276  	case !fn.Iter && fn.Incr:
   277  		Range = "incr"
   278  		T = template.Must(T.Parse(genericLoopRaw))
   279  		template.Must(T.New("loopbody").Parse(basicIncr))
   280  	default:
   281  		T = template.Must(T.Parse(genericLoopRaw))
   282  		template.Must(T.New("loopbody").Parse(basicSet))
   283  	}
   284  
   285  	if fn.LeftVec {
   286  		Right = "b"
   287  	} else {
   288  		Left = "a"
   289  		if !fn.Incr {
   290  			Range = "b"
   291  		}
   292  		// Index0 = "j"
   293  	}
   294  
   295  	switch {
   296  	case fn.Iter && fn.Incr && fn.LeftVec:
   297  		IterName0 = "ait"
   298  		IterName1 = "iit"
   299  		Index1 = "k"
   300  	case fn.Iter && !fn.Incr && fn.LeftVec:
   301  		IterName0 = "ait"
   302  	case fn.Iter && fn.Incr && !fn.LeftVec:
   303  		IterName0 = "bit"
   304  		IterName1 = "iit"
   305  		Index1 = "k"
   306  	case fn.Iter && !fn.Incr && !fn.LeftVec:
   307  		IterName0 = "bit"
   308  	}
   309  
   310  	template.Must(T.New("callFunc").Parse(binOpCallFunc))
   311  	template.Must(T.New("opDo").Parse(binOpDo))
   312  	template.Must(T.New("symbol").Parse(fn.SymbolTemplate()))
   313  
   314  	if fn.Check != nil && fn.Check(fn.Kind()) {
   315  		w.Write([]byte("var errs errorIndices\n"))
   316  	}
   317  	template.Must(T.New("check").Parse(fn.CheckTemplate))
   318  
   319  	lb := LoopBody{
   320  		TypedOp: fn.TypedBinOp,
   321  		Range:   Range,
   322  		Left:    Left,
   323  		Right:   Right,
   324  
   325  		Index0:    Index0,
   326  		Index1:    Index1,
   327  		IterName0: IterName0,
   328  		IterName1: IterName1,
   329  	}
   330  	T.Execute(w, lb)
   331  }
   332  
   333  func (fn *GenericMixedArith) Write(w io.Writer) {
   334  	sig := fn.Signature()
   335  
   336  	w.Write([]byte("func "))
   337  	sig.Write(w)
   338  
   339  	w.Write([]byte("{\n"))
   340  
   341  	fn.WriteBody(w)
   342  	if sig.Err {
   343  		if fn.Check != nil {
   344  			w.Write([]byte("\nif err != nil {\n return\n}\nif len(errs) > 0 {\n return errs }\nreturn nil"))
   345  		} else {
   346  			w.Write([]byte("\nreturn\n"))
   347  		}
   348  	}
   349  	w.Write([]byte("}\n\n"))
   350  }
   351  
   352  type GenericScalarScalarArith struct {
   353  	TypedBinOp
   354  }
   355  
   356  func (fn *GenericScalarScalarArith) Signature() *Signature {
   357  	return &Signature{
   358  		Name:            fn.Name(),
   359  		NameTemplate:    typeAnnotatedName,
   360  		ParamNames:      []string{"a", "b"},
   361  		ParamTemplates:  []*template.Template{scalarType, scalarType},
   362  		RetVals:         []string{""},
   363  		RetValTemplates: []*template.Template{scalarType},
   364  		Kind:            fn.Kind(),
   365  	}
   366  }
   367  
   368  func (fn *GenericScalarScalarArith) WriteBody(w io.Writer) {
   369  	tmpl := `return {{if .IsFunc -}}
   370  			{{ template "callFunc" . -}}
   371  		{{else -}}
   372  			{{template "opDo" . -}}
   373  		{{end -}}`
   374  	opDo := `a {{template "symbol" .Kind}} b`
   375  	callFunc := `{{template "symbol" .Kind}}(a, b)`
   376  
   377  	T := template.Must(template.New(fn.Name()).Funcs(funcs).Parse(tmpl))
   378  	template.Must(T.New("opDo").Parse(opDo))
   379  	template.Must(T.New("callFunc").Parse(callFunc))
   380  	template.Must(T.New("symbol").Parse(fn.SymbolTemplate()))
   381  
   382  	T.Execute(w, fn)
   383  }
   384  
   385  func (fn *GenericScalarScalarArith) Write(w io.Writer) {
   386  	w.Write([]byte("func "))
   387  	sig := fn.Signature()
   388  	sig.Write(w)
   389  	w.Write([]byte("{"))
   390  	fn.WriteBody(w)
   391  	w.Write([]byte("}\n"))
   392  }
   393  
   394  func makeGenericVecVecAriths(tbo []TypedBinOp) (retVal []*GenericVecVecArith) {
   395  	for _, tb := range tbo {
   396  		if tc := tb.TypeClass(); tc != nil && !tc(tb.Kind()) {
   397  			continue
   398  		}
   399  		fn := &GenericVecVecArith{
   400  			TypedBinOp: tb,
   401  		}
   402  		if tb.Name() == "Div" && !isFloatCmplx(tb.Kind()) {
   403  			fn.Check = panicsDiv0
   404  			fn.CheckTemplate = check0
   405  		}
   406  
   407  		retVal = append(retVal, fn)
   408  
   409  	}
   410  
   411  	return retVal
   412  }
   413  
   414  func makeGenericMixedAriths(tbo []TypedBinOp) (retVal []*GenericMixedArith) {
   415  	for _, tb := range tbo {
   416  		if tc := tb.TypeClass(); tc != nil && !tc(tb.Kind()) {
   417  			continue
   418  		}
   419  		fn := &GenericMixedArith{
   420  			GenericVecVecArith: GenericVecVecArith{
   421  				TypedBinOp: tb,
   422  			},
   423  		}
   424  		if tb.Name() == "Div" && !isFloatCmplx(tb.Kind()) {
   425  			fn.Check = panicsDiv0
   426  			fn.CheckTemplate = check0
   427  		}
   428  		retVal = append(retVal, fn)
   429  	}
   430  	return
   431  }
   432  
   433  func makeGenericScalarScalarAriths(tbo []TypedBinOp) (retVal []*GenericScalarScalarArith) {
   434  	for _, tb := range tbo {
   435  		if tc := tb.TypeClass(); tc != nil && !tc(tb.Kind()) {
   436  			continue
   437  		}
   438  		fn := &GenericScalarScalarArith{
   439  			TypedBinOp: tb,
   440  		}
   441  		retVal = append(retVal, fn)
   442  	}
   443  	return
   444  }
   445  
   446  func generateGenericVecVecArith(f io.Writer, ak Kinds) {
   447  	gen := makeGenericVecVecAriths(typedAriths)
   448  
   449  	// importStmt := `
   450  	// import (
   451  	// 	_ "unsafe"
   452  
   453  	// _ "gorgonia.org/vecf32"
   454  	// _ "gorgonia.org/vecf64")
   455  	// `
   456  	// f.Write([]byte(importStmt))
   457  
   458  	for _, g := range gen {
   459  		g.Write(f)
   460  		g.Incr = true
   461  	}
   462  	for _, g := range gen {
   463  		g.Write(f)
   464  		g.Incr = false
   465  		g.Iter = true
   466  	}
   467  	for _, g := range gen {
   468  		g.Write(f)
   469  		g.Incr = true
   470  	}
   471  	for _, g := range gen {
   472  		g.Write(f)
   473  	}
   474  	for _, g := range gen {
   475  		g.Incr = false
   476  		g.Iter = false
   477  		g.WithRecv = true
   478  		g.Write(f)
   479  	}
   480  }
   481  
   482  func generateGenericMixedArith(f io.Writer, ak Kinds) {
   483  	gen := makeGenericMixedAriths(typedAriths)
   484  
   485  	// SV first
   486  	for _, g := range gen {
   487  		g.Write(f)
   488  		g.Incr = true
   489  	}
   490  	for _, g := range gen {
   491  		g.Write(f)
   492  		g.Incr = false
   493  		g.Iter = true
   494  	}
   495  	for _, g := range gen {
   496  		g.Write(f)
   497  		g.Incr = true
   498  	}
   499  	for _, g := range gen {
   500  		g.Write(f)
   501  
   502  		// reset
   503  		g.LeftVec = true
   504  		g.Incr = false
   505  		g.Iter = false
   506  	}
   507  
   508  	// VS
   509  	for _, g := range gen {
   510  		g.Write(f)
   511  		g.Incr = true
   512  	}
   513  	for _, g := range gen {
   514  		g.Write(f)
   515  		g.Incr = false
   516  		g.Iter = true
   517  	}
   518  	for _, g := range gen {
   519  		g.Write(f)
   520  		g.Incr = true
   521  	}
   522  	for _, g := range gen {
   523  		g.Write(f)
   524  	}
   525  }
   526  
   527  func generateGenericScalarScalarArith(f io.Writer, ak Kinds) {
   528  	gen := makeGenericScalarScalarAriths(typedAriths)
   529  	for _, g := range gen {
   530  		g.Write(f)
   531  	}
   532  }