github.com/wzzhu/tensor@v0.9.24/genlib2/generic_cmp.go (about)

     1  package main
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"text/template"
     7  )
     8  
     9  type GenericVecVecCmp struct {
    10  	TypedBinOp
    11  	RetSame bool
    12  	Iter    bool
    13  }
    14  
    15  func (fn *GenericVecVecCmp) Name() string {
    16  	switch {
    17  	case fn.Iter && fn.RetSame:
    18  		return fmt.Sprintf("%sSameIter", fn.TypedBinOp.Name())
    19  	case fn.Iter && !fn.RetSame:
    20  		return fmt.Sprintf("%sIter", fn.TypedBinOp.Name())
    21  	case !fn.Iter && fn.RetSame:
    22  		return fmt.Sprintf("%sSame", fn.TypedBinOp.Name())
    23  	default:
    24  		return fn.TypedBinOp.Name()
    25  	}
    26  }
    27  
    28  func (fn *GenericVecVecCmp) Signature() *Signature {
    29  	var paramNames []string
    30  	var paramTemplates []*template.Template
    31  	var err bool
    32  
    33  	switch {
    34  	case fn.Iter && fn.RetSame:
    35  		paramNames = []string{"a", "b", "ait", "bit"}
    36  		paramTemplates = []*template.Template{sliceType, sliceType, iteratorType, iteratorType}
    37  		err = true
    38  	case fn.Iter && !fn.RetSame:
    39  		paramNames = []string{"a", "b", "retVal", "ait", "bit", "rit"}
    40  		paramTemplates = []*template.Template{sliceType, sliceType, boolsType, iteratorType, iteratorType, iteratorType}
    41  		err = true
    42  	case !fn.Iter && fn.RetSame:
    43  		paramNames = []string{"a", "b"}
    44  		paramTemplates = []*template.Template{sliceType, sliceType}
    45  	default:
    46  		paramNames = []string{"a", "b", "retVal"}
    47  		paramTemplates = []*template.Template{sliceType, sliceType, boolsType}
    48  	}
    49  
    50  	return &Signature{
    51  		Name:           fn.Name(),
    52  		NameTemplate:   typeAnnotatedName,
    53  		ParamNames:     paramNames,
    54  		ParamTemplates: paramTemplates,
    55  
    56  		Kind: fn.Kind(),
    57  		Err:  err,
    58  	}
    59  }
    60  
    61  func (fn *GenericVecVecCmp) WriteBody(w io.Writer) {
    62  	var Range, Left, Right string
    63  	var Index0, Index1, Index2 string
    64  	var IterName0, IterName1, IterName2 string
    65  	var T *template.Template
    66  
    67  	Range = "a"
    68  	Left = "a[i]"
    69  	Right = "b[j]"
    70  	Index0 = "i"
    71  	Index1 = "j"
    72  	switch {
    73  	case fn.Iter && fn.RetSame:
    74  		IterName0 = "ait"
    75  		IterName1 = "bit"
    76  		T = template.Must(template.New(fn.Name()).Funcs(funcs).Parse(genericBinaryIterLoopRaw))
    77  		template.Must(T.New("loopbody").Funcs(funcs).Parse(sameSet))
    78  	case fn.Iter && !fn.RetSame:
    79  		Range = "retVal"
    80  		Index2 = "k"
    81  		IterName0 = "ait"
    82  		IterName1 = "bit"
    83  		IterName2 = "rit"
    84  		T = template.Must(template.New(fn.Name()).Funcs(funcs).Parse(genericTernaryIterLoopRaw))
    85  		template.Must(T.New("loopbody").Funcs(funcs).Parse(ternaryIterSet))
    86  	case !fn.Iter && fn.RetSame:
    87  		Right = "b[i]"
    88  		T = template.Must(template.New(fn.Name()).Funcs(funcs).Parse(genericLoopRaw))
    89  		template.Must(T.New("loopbody").Funcs(funcs).Parse(sameSet))
    90  	default:
    91  		Range = "retVal"
    92  		Right = "b[i]"
    93  		T = template.Must(template.New(fn.Name()).Funcs(funcs).Parse(genericLoopRaw))
    94  		template.Must(T.New("loopbody").Funcs(funcs).Parse(basicSet))
    95  	}
    96  	template.Must(T.New("opDo").Funcs(funcs).Parse(binOpDo))
    97  	template.Must(T.New("callFunc").Funcs(funcs).Parse(""))
    98  	template.Must(T.New("check").Funcs(funcs).Parse(""))
    99  	template.Must(T.New("symbol").Funcs(funcs).Parse(fn.SymbolTemplate()))
   100  
   101  	lb := LoopBody{
   102  		TypedOp: fn.TypedBinOp,
   103  		Range:   Range,
   104  		Left:    Left,
   105  		Right:   Right,
   106  
   107  		Index0: Index0,
   108  		Index1: Index1,
   109  		Index2: Index2,
   110  
   111  		IterName0: IterName0,
   112  		IterName1: IterName1,
   113  		IterName2: IterName2,
   114  	}
   115  	T.Execute(w, lb)
   116  }
   117  
   118  func (fn *GenericVecVecCmp) Write(w io.Writer) {
   119  	sig := fn.Signature()
   120  	w.Write([]byte("func "))
   121  	sig.Write(w)
   122  	switch {
   123  	case !fn.Iter && !fn.RetSame:
   124  		w.Write([]byte("{\na = a[:len(a)]; b = b[:len(a)]; retVal=retVal[:len(a)]\n"))
   125  	case !fn.Iter && fn.RetSame:
   126  		w.Write([]byte("{\na = a[:len(a)]; b = b[:len(a)]\n"))
   127  	default:
   128  		w.Write([]byte("{"))
   129  	}
   130  	fn.WriteBody(w)
   131  	if sig.Err {
   132  		w.Write([]byte("\n return\n"))
   133  	}
   134  	w.Write([]byte("}\n\n"))
   135  }
   136  
   137  type GenericMixedCmp struct {
   138  	GenericVecVecCmp
   139  	LeftVec bool
   140  }
   141  
   142  func (fn *GenericMixedCmp) Name() string {
   143  	n := fn.GenericVecVecCmp.Name()
   144  	if fn.LeftVec {
   145  		n += "VS"
   146  	} else {
   147  		n += "SV"
   148  	}
   149  	return n
   150  }
   151  
   152  func (fn *GenericMixedCmp) Signature() *Signature {
   153  	var paramNames []string
   154  	var paramTemplates []*template.Template
   155  	var err bool
   156  
   157  	switch {
   158  	case fn.Iter && !fn.RetSame:
   159  		paramNames = []string{"a", "b", "retVal", "ait", "rit"}
   160  		paramTemplates = []*template.Template{sliceType, sliceType, boolsType, iteratorType, iteratorType}
   161  		err = true
   162  	case fn.Iter && fn.RetSame:
   163  		paramNames = []string{"a", "b", "ait"}
   164  		paramTemplates = []*template.Template{sliceType, sliceType, iteratorType}
   165  		err = true
   166  	case !fn.Iter && fn.RetSame:
   167  		paramNames = []string{"a", "b"}
   168  		paramTemplates = []*template.Template{sliceType, sliceType}
   169  	default:
   170  		paramNames = []string{"a", "b", "retVal"}
   171  		paramTemplates = []*template.Template{sliceType, sliceType, boolsType}
   172  	}
   173  	if fn.LeftVec {
   174  		paramTemplates[1] = scalarType
   175  	} else {
   176  		paramTemplates[0] = scalarType
   177  		if fn.Iter && !fn.RetSame {
   178  			paramNames[3] = "bit"
   179  		} else if fn.Iter && fn.RetSame {
   180  			paramNames[2] = "bit"
   181  		}
   182  	}
   183  	return &Signature{
   184  		Name:           fn.Name(),
   185  		NameTemplate:   typeAnnotatedName,
   186  		ParamNames:     paramNames,
   187  		ParamTemplates: paramTemplates,
   188  
   189  		Kind: fn.Kind(),
   190  		Err:  err,
   191  	}
   192  }
   193  
   194  func (fn *GenericMixedCmp) WriteBody(w io.Writer) {
   195  	var Range, Left, Right string
   196  	var Index0, Index1 string
   197  	var IterName0, IterName1 string
   198  	var T *template.Template
   199  
   200  	Range = "a"
   201  	Left = "a[i]"
   202  	Right = "b[i]"
   203  	Index0 = "i"
   204  
   205  	T = template.New(fn.Name()).Funcs(funcs)
   206  	switch {
   207  	case fn.Iter && !fn.RetSame:
   208  		Range = "retVal"
   209  		T = template.Must(T.Parse(genericBinaryIterLoopRaw))
   210  		template.Must(T.New("loopbody").Parse(ternaryIterSet))
   211  	case fn.Iter && fn.RetSame:
   212  		T = template.Must(T.Parse(genericUnaryIterLoopRaw))
   213  		template.Must(T.New("loopbody").Parse(sameSet))
   214  	case !fn.Iter && fn.RetSame:
   215  		T = template.Must(T.Parse(genericLoopRaw))
   216  		template.Must(T.New("loopbody").Parse(sameSet))
   217  	default:
   218  		T = template.Must(T.Parse(genericLoopRaw))
   219  		template.Must(T.New("loopbody").Parse(basicSet))
   220  	}
   221  
   222  	if fn.LeftVec {
   223  		Right = "b"
   224  	} else {
   225  		Left = "a"
   226  	}
   227  	if !fn.RetSame {
   228  		Range = "retVal"
   229  	} else {
   230  		if !fn.LeftVec {
   231  			Range = "b"
   232  		}
   233  	}
   234  
   235  	switch {
   236  	case fn.Iter && !fn.RetSame && fn.LeftVec:
   237  		IterName0 = "ait"
   238  		IterName1 = "rit"
   239  		Index1 = "k"
   240  	case fn.Iter && fn.RetSame && fn.LeftVec:
   241  		IterName0 = "ait"
   242  	case fn.Iter && !fn.RetSame && !fn.LeftVec:
   243  		IterName0 = "bit"
   244  		IterName1 = "rit"
   245  		Index1 = "k"
   246  	case fn.Iter && fn.RetSame && !fn.LeftVec:
   247  		IterName0 = "bit"
   248  	}
   249  
   250  	template.Must(T.New("callFunc").Parse(""))
   251  	template.Must(T.New("opDo").Parse(binOpDo))
   252  	template.Must(T.New("symbol").Parse(fn.SymbolTemplate()))
   253  	template.Must(T.New("check").Parse(""))
   254  
   255  	lb := LoopBody{
   256  		TypedOp: fn.TypedBinOp,
   257  		Range:   Range,
   258  		Left:    Left,
   259  		Right:   Right,
   260  
   261  		Index0:    Index0,
   262  		Index1:    Index1,
   263  		IterName0: IterName0,
   264  		IterName1: IterName1,
   265  	}
   266  	T.Execute(w, lb)
   267  }
   268  
   269  func (fn *GenericMixedCmp) Write(w io.Writer) {
   270  	sig := fn.Signature()
   271  	w.Write([]byte("func "))
   272  	sig.Write(w)
   273  	w.Write([]byte("{ \n"))
   274  	fn.WriteBody(w)
   275  	if sig.Err {
   276  		w.Write([]byte("\nreturn\n"))
   277  	}
   278  	w.Write([]byte("}\n\n"))
   279  }
   280  
   281  func makeGenericVecVecCmps(tbo []TypedBinOp) (retVal []*GenericVecVecCmp) {
   282  	for _, tb := range tbo {
   283  		if tc := tb.TypeClass(); tc != nil && !tc(tb.Kind()) {
   284  			continue
   285  		}
   286  		fn := &GenericVecVecCmp{
   287  			TypedBinOp: tb,
   288  		}
   289  		retVal = append(retVal, fn)
   290  	}
   291  	return
   292  }
   293  
   294  func makeGenericMixedCmps(tbo []TypedBinOp) (retVal []*GenericMixedCmp) {
   295  	for _, tb := range tbo {
   296  		if tc := tb.TypeClass(); tc != nil && !tc(tb.Kind()) {
   297  			continue
   298  		}
   299  		fn := &GenericMixedCmp{
   300  			GenericVecVecCmp: GenericVecVecCmp{
   301  				TypedBinOp: tb,
   302  			},
   303  		}
   304  		retVal = append(retVal, fn)
   305  	}
   306  	return
   307  }
   308  
   309  func generateGenericVecVecCmp(f io.Writer, ak Kinds) {
   310  	gen := makeGenericVecVecCmps(typedCmps)
   311  	for _, g := range gen {
   312  		g.Write(f)
   313  		g.RetSame = true
   314  
   315  	}
   316  	for _, g := range gen {
   317  		if isBoolRepr(g.Kind()) {
   318  			g.Write(f)
   319  		}
   320  		g.RetSame = false
   321  		g.Iter = true
   322  	}
   323  	for _, g := range gen {
   324  		g.Write(f)
   325  		g.RetSame = true
   326  	}
   327  	for _, g := range gen {
   328  		if isBoolRepr(g.Kind()) {
   329  			g.Write(f)
   330  		}
   331  	}
   332  }
   333  
   334  func generateGenericMixedCmp(f io.Writer, ak Kinds) {
   335  	gen := makeGenericMixedCmps(typedCmps)
   336  	for _, g := range gen {
   337  		g.Write(f)
   338  		g.RetSame = true
   339  	}
   340  	for _, g := range gen {
   341  		if isBoolRepr(g.Kind()) {
   342  			g.Write(f)
   343  		}
   344  		g.RetSame = false
   345  		g.Iter = true
   346  	}
   347  	for _, g := range gen {
   348  		g.Write(f)
   349  		g.RetSame = true
   350  	}
   351  	for _, g := range gen {
   352  		if isBoolRepr(g.Kind()) {
   353  			g.Write(f)
   354  		}
   355  		g.LeftVec = true
   356  		g.RetSame = false
   357  		g.Iter = false
   358  	}
   359  
   360  	// VS
   361  
   362  	for _, g := range gen {
   363  		g.Write(f)
   364  		g.RetSame = true
   365  	}
   366  	for _, g := range gen {
   367  		if isBoolRepr(g.Kind()) {
   368  			g.Write(f)
   369  		}
   370  		g.RetSame = false
   371  		g.Iter = true
   372  	}
   373  	for _, g := range gen {
   374  		g.Write(f)
   375  		g.RetSame = true
   376  	}
   377  	for _, g := range gen {
   378  		if isBoolRepr(g.Kind()) {
   379  			g.Write(f)
   380  		}
   381  	}
   382  }
   383  
   384  /* OTHER */
   385  
   386  // element wise Min/Max
   387  const genericElMinMaxRaw = `func VecMin{{short . | title}}(a, b []{{asType .}}) {
   388  	a = a[:len(a)]
   389  	b = b[:len(a)]
   390  	for i, v := range a {
   391  		bv := b[i]
   392  		if bv < v {
   393  			a[i] = bv
   394  		}
   395  	}
   396  }
   397  
   398  func MinSV{{short . | title}}(a {{asType .}}, b []{{asType .}}){
   399  	for i := range b {
   400  		if a < b[i]{
   401  			b[i] = a
   402  		}
   403  	}
   404  }
   405  
   406  func MinVS{{short . | title}}(a []{{asType .}}, b {{asType .}}){
   407  	for i := range a {
   408  		if b < a[i]{
   409  			a[i] = b
   410  		}
   411  	}
   412  }
   413  
   414  func VecMax{{short . | title}}(a, b []{{asType .}}) {
   415  	a = a[:len(a)]
   416  	b = b[:len(a)]
   417  	for i, v := range a {
   418  		bv := b[i]
   419  		if bv > v {
   420  			a[i] = bv
   421  		}
   422  	}
   423  }
   424  
   425  
   426  
   427  func MaxSV{{short . | title}}(a {{asType .}}, b []{{asType .}}){
   428  	for i := range b {
   429  		if a > b[i]{
   430  			b[i] = a
   431  		}
   432  	}
   433  }
   434  
   435  func MaxVS{{short . | title}}(a []{{asType .}}, b {{asType .}}){
   436  	for i := range a {
   437  		if b > a[i]{
   438  			a[i] = b
   439  		}
   440  	}
   441  }
   442  `
   443  
   444  // Iter Min/Max
   445  const genericIterMinMaxRaw = `func MinIterSV{{short . | title}}(a {{asType .}}, b []{{asType .}}, bit Iterator) (err error){
   446  	var i int
   447  	var validi bool
   448  	for {
   449  		if i, validi, err = bit.NextValidity(); err != nil{
   450  			err = handleNoOp(err)
   451  			break
   452  		}
   453  		if validi {
   454  			if a < b[i] {
   455  				b[i] = a
   456  			}
   457  		}
   458  	}
   459  	return
   460  }
   461  
   462  func MinIterVS{{short . | title}}(a []{{asType .}}, b {{asType .}}, ait Iterator) (err error){
   463  	var i int
   464  	var validi bool
   465  	for {
   466  		if i, validi, err = ait.NextValidity(); err != nil{
   467  			err = handleNoOp(err)
   468  			break
   469  		}
   470  		if validi {
   471  			if b < a[i] {
   472  				a[i] = b
   473  			}
   474  		}
   475  	}
   476  	return
   477  }
   478  
   479  func VecMinIter{{short . | title}}(a , b []{{asType .}}, ait, bit Iterator) (err error){
   480  	var i,j int
   481  	var validi ,validj bool
   482  	for {
   483  		if i, validi, err = ait.NextValidity(); err != nil{
   484  			err = handleNoOp(err)
   485  			break
   486  		}
   487  		if j, validj, err = bit.NextValidity(); err != nil{
   488  			err = handleNoOp(err)
   489  			break
   490  		}
   491  		if validi && validj {
   492  			if b[j] < a[i] {
   493  				a[i] = b[j]
   494  			}
   495  		}
   496  	}
   497  	return
   498  }
   499  
   500  
   501  func MaxIterSV{{short . | title}}(a {{asType .}}, b []{{asType .}}, bit Iterator) (err error){
   502  	var i int
   503  	var validi bool
   504  	for {
   505  		if i, validi, err = bit.NextValidity(); err != nil{
   506  			err = handleNoOp(err)
   507  			break
   508  		}
   509  		if validi {
   510  			if a > b[i] {
   511  				b[i] = a
   512  			}
   513  		}
   514  	}
   515  	return
   516  }
   517  
   518  func MaxIterVS{{short . | title}}(a []{{asType .}}, b {{asType .}}, ait Iterator) (err error){
   519  	var i int
   520  	var validi bool
   521  	for {
   522  		if i, validi, err = ait.NextValidity(); err != nil{
   523  			err = handleNoOp(err)
   524  			break
   525  		}
   526  		if validi {
   527  			if b > a[i] {
   528  				a[i] = b
   529  			}
   530  		}
   531  	}
   532  	return
   533  }
   534  
   535  func VecMaxIter{{short . | title}}(a , b []{{asType .}}, ait, bit Iterator) (err error){
   536  	var i,j int
   537  	var validi, validj bool
   538  	for {
   539  		if i, validi, err = ait.NextValidity(); err != nil{
   540  			err = handleNoOp(err)
   541  			break
   542  		}
   543  		if j, validj, err = bit.NextValidity(); err != nil{
   544  			err = handleNoOp(err)
   545  			break
   546  		}
   547  		if validi && validj {
   548  			if b[j] > a[i] {
   549  				a[i] = b[j]
   550  			}
   551  		}
   552  	}
   553  	return
   554  }
   555  
   556  `
   557  
   558  // scalar Min/Max
   559  const genericScalarMinMaxRaw = `func Min{{short .}}(a, b {{asType .}}) (c {{asType .}}) {if a < b {
   560  	return a
   561  	}
   562  	return b
   563  }
   564  
   565  
   566  func Max{{short .}}(a, b {{asType .}}) (c {{asType .}}) {if a > b {
   567  	return a
   568  	}
   569  	return b
   570  }
   571  `
   572  
   573  var (
   574  	genericElMinMax     *template.Template
   575  	genericMinMax       *template.Template
   576  	genericElMinMaxIter *template.Template
   577  )
   578  
   579  func init() {
   580  	genericElMinMax = template.Must(template.New("genericVecVecMinMax").Funcs(funcs).Parse(genericElMinMaxRaw))
   581  	genericMinMax = template.Must(template.New("genericMinMax").Funcs(funcs).Parse(genericScalarMinMaxRaw))
   582  	genericElMinMaxIter = template.Must(template.New("genericIterMinMax").Funcs(funcs).Parse(genericIterMinMaxRaw))
   583  }
   584  
   585  func generateMinMax(f io.Writer, ak Kinds) {
   586  	for _, k := range filter(ak.Kinds, isOrd) {
   587  		genericElMinMax.Execute(f, k)
   588  	}
   589  
   590  	for _, k := range filter(ak.Kinds, isOrd) {
   591  		genericMinMax.Execute(f, k)
   592  	}
   593  
   594  	for _, k := range filter(ak.Kinds, isOrd) {
   595  		genericElMinMaxIter.Execute(f, k)
   596  	}
   597  }