github.com/wzzhu/tensor@v0.9.24/genlib2/engine.go (about)

     1  package main
     2  
     3  import (
     4  	"io"
     5  	"reflect"
     6  	"text/template"
     7  )
     8  
     9  type EngineArith struct {
    10  	Name           string
    11  	VecVar         string
    12  	PrepData       string
    13  	TypeClassCheck string
    14  	IsCommutative  bool
    15  
    16  	VV      bool
    17  	LeftVec bool
    18  }
    19  
    20  func (fn *EngineArith) methName() string {
    21  	switch {
    22  	case fn.VV:
    23  		return fn.Name
    24  	default:
    25  		return fn.Name + "Scalar"
    26  	}
    27  }
    28  
    29  func (fn *EngineArith) Signature() *Signature {
    30  	var paramNames []string
    31  	var paramTemplates []*template.Template
    32  
    33  	switch {
    34  	case fn.VV:
    35  		paramNames = []string{"a", "b", "opts"}
    36  		paramTemplates = []*template.Template{tensorType, tensorType, splatFuncOptType}
    37  	default:
    38  		paramNames = []string{"t", "s", "leftTensor", "opts"}
    39  		paramTemplates = []*template.Template{tensorType, interfaceType, boolType, splatFuncOptType}
    40  	}
    41  	return &Signature{
    42  		Name:           fn.methName(),
    43  		NameTemplate:   plainName,
    44  		ParamNames:     paramNames,
    45  		ParamTemplates: paramTemplates,
    46  		Err:            false,
    47  	}
    48  }
    49  
    50  func (fn *EngineArith) WriteBody(w io.Writer) {
    51  	var prep *template.Template
    52  	switch {
    53  	case fn.VV:
    54  		prep = prepVV
    55  		fn.VecVar = "a"
    56  	case !fn.VV && fn.LeftVec:
    57  		fn.VecVar = "t"
    58  		fn.PrepData = "prepDataVS"
    59  		prep = prepMixed
    60  	default:
    61  		fn.VecVar = "t"
    62  		fn.PrepData = "prepDataSV"
    63  		prep = prepMixed
    64  	}
    65  	template.Must(prep.New("prep").Parse(arithPrepRaw))
    66  	prep.Execute(w, fn)
    67  	agg2Body.Execute(w, fn)
    68  }
    69  
    70  func (fn *EngineArith) Write(w io.Writer) {
    71  	if tmpl, ok := arithDocStrings[fn.methName()]; ok {
    72  		type tmp struct {
    73  			Left, Right string
    74  		}
    75  		var ds tmp
    76  		if fn.VV {
    77  			ds.Left = "a"
    78  			ds.Right = "b"
    79  		} else {
    80  			ds.Left = "t"
    81  			ds.Right = "s"
    82  		}
    83  		tmpl.Execute(w, ds)
    84  	}
    85  
    86  	sig := fn.Signature()
    87  	w.Write([]byte("func (e StdEng) "))
    88  	sig.Write(w)
    89  	w.Write([]byte("(retVal Tensor, err error) {\n"))
    90  	fn.WriteBody(w)
    91  	w.Write([]byte("}\n\n"))
    92  }
    93  
    94  func generateStdEngArith(f io.Writer, ak Kinds) {
    95  	var methods []*EngineArith
    96  	for _, abo := range arithBinOps {
    97  		meth := &EngineArith{
    98  			Name:           abo.Name(),
    99  			VV:             true,
   100  			TypeClassCheck: "Number",
   101  			IsCommutative:  abo.IsCommutative,
   102  		}
   103  		methods = append(methods, meth)
   104  	}
   105  
   106  	// VV
   107  	for _, meth := range methods {
   108  		meth.Write(f)
   109  		meth.VV = false
   110  	}
   111  
   112  	// Scalar
   113  	for _, meth := range methods {
   114  		meth.Write(f)
   115  		meth.LeftVec = true
   116  	}
   117  
   118  }
   119  
   120  type EngineCmp struct {
   121  	Name           string
   122  	VecVar         string
   123  	PrepData       string
   124  	TypeClassCheck string
   125  	Inv            string
   126  
   127  	VV      bool
   128  	LeftVec bool
   129  }
   130  
   131  func (fn *EngineCmp) methName() string {
   132  	switch {
   133  	case fn.VV:
   134  		if fn.Name == "Eq" || fn.Name == "Ne" {
   135  			return "El" + fn.Name
   136  		}
   137  		return fn.Name
   138  	default:
   139  		return fn.Name + "Scalar"
   140  	}
   141  }
   142  
   143  func (fn *EngineCmp) Signature() *Signature {
   144  	var paramNames []string
   145  	var paramTemplates []*template.Template
   146  
   147  	switch {
   148  	case fn.VV:
   149  		paramNames = []string{"a", "b", "opts"}
   150  		paramTemplates = []*template.Template{tensorType, tensorType, splatFuncOptType}
   151  	default:
   152  		paramNames = []string{"t", "s", "leftTensor", "opts"}
   153  		paramTemplates = []*template.Template{tensorType, interfaceType, boolType, splatFuncOptType}
   154  	}
   155  	return &Signature{
   156  		Name:           fn.methName(),
   157  		NameTemplate:   plainName,
   158  		ParamNames:     paramNames,
   159  		ParamTemplates: paramTemplates,
   160  		Err:            false,
   161  	}
   162  }
   163  
   164  func (fn *EngineCmp) WriteBody(w io.Writer) {
   165  	var prep *template.Template
   166  	switch {
   167  	case fn.VV:
   168  		prep = prepVV
   169  		fn.VecVar = "a"
   170  	case !fn.VV && fn.LeftVec:
   171  		fn.VecVar = "t"
   172  		fn.PrepData = "prepDataVS"
   173  		prep = prepMixed
   174  	default:
   175  		fn.VecVar = "t"
   176  		fn.PrepData = "prepDataSV"
   177  		prep = prepMixed
   178  	}
   179  	template.Must(prep.New("prep").Parse(cmpPrepRaw))
   180  	prep.Execute(w, fn)
   181  	agg2CmpBody.Execute(w, fn)
   182  }
   183  
   184  func (fn *EngineCmp) Write(w io.Writer) {
   185  	if tmpl, ok := cmpDocStrings[fn.methName()]; ok {
   186  		type tmp struct {
   187  			Left, Right string
   188  		}
   189  		var ds tmp
   190  		if fn.VV {
   191  			ds.Left = "a"
   192  			ds.Right = "b"
   193  		} else {
   194  			ds.Left = "t"
   195  			ds.Right = "s"
   196  		}
   197  		tmpl.Execute(w, ds)
   198  	}
   199  	sig := fn.Signature()
   200  	w.Write([]byte("func (e StdEng) "))
   201  	sig.Write(w)
   202  	w.Write([]byte("(retVal Tensor, err error) {\n"))
   203  	fn.WriteBody(w)
   204  	w.Write([]byte("}\n\n"))
   205  }
   206  
   207  func generateStdEngCmp(f io.Writer, ak Kinds) {
   208  	var methods []*EngineCmp
   209  
   210  	for _, abo := range cmpBinOps {
   211  		var tc string
   212  		if abo.Name() == "Eq" || abo.Name() == "Ne" {
   213  			tc = "Eq"
   214  		} else {
   215  			tc = "Ord"
   216  		}
   217  		meth := &EngineCmp{
   218  			Name:           abo.Name(),
   219  			Inv:            abo.Inv,
   220  			VV:             true,
   221  			TypeClassCheck: tc,
   222  		}
   223  		methods = append(methods, meth)
   224  	}
   225  
   226  	// VV
   227  	for _, meth := range methods {
   228  		meth.Write(f)
   229  		meth.VV = false
   230  	}
   231  
   232  	// Scalar
   233  	for _, meth := range methods {
   234  		meth.Write(f)
   235  		meth.LeftVec = true
   236  	}
   237  }
   238  
   239  type EngineMinMax struct {
   240  	Name           string
   241  	VecVar         string
   242  	PrepData       string
   243  	TypeClassCheck string
   244  	Kinds          []reflect.Kind
   245  
   246  	VV      bool
   247  	LeftVec bool
   248  }
   249  
   250  func (fn *EngineMinMax) methName() string {
   251  	switch {
   252  	case fn.VV:
   253  		return fn.Name
   254  	default:
   255  		return fn.Name + "Scalar"
   256  	}
   257  }
   258  
   259  func (fn *EngineMinMax) Signature() *Signature {
   260  	var paramNames []string
   261  	var paramTemplates []*template.Template
   262  
   263  	switch {
   264  	case fn.VV:
   265  		paramNames = []string{"a", "b", "opts"}
   266  		paramTemplates = []*template.Template{tensorType, tensorType, splatFuncOptType}
   267  	default:
   268  		paramNames = []string{"t", "s", "leftTensor", "opts"}
   269  		paramTemplates = []*template.Template{tensorType, interfaceType, boolType, splatFuncOptType}
   270  	}
   271  	return &Signature{
   272  		Name:           fn.methName(),
   273  		NameTemplate:   plainName,
   274  		ParamNames:     paramNames,
   275  		ParamTemplates: paramTemplates,
   276  		Err:            false,
   277  	}
   278  }
   279  
   280  func (fn *EngineMinMax) WriteBody(w io.Writer) {
   281  	var prep *template.Template
   282  	switch {
   283  	case fn.VV:
   284  		prep = prepVV
   285  		fn.VecVar = "a"
   286  	case !fn.VV && fn.LeftVec:
   287  		fn.VecVar = "t"
   288  		fn.PrepData = "prepDataVS"
   289  		prep = prepMixed
   290  	default:
   291  		fn.VecVar = "t"
   292  		fn.PrepData = "prepDataSV"
   293  		prep = prepMixed
   294  	}
   295  	template.Must(prep.New("prep").Parse(minmaxPrepRaw))
   296  	prep.Execute(w, fn)
   297  	agg2MinMaxBody.Execute(w, fn)
   298  }
   299  
   300  func (fn *EngineMinMax) Write(w io.Writer) {
   301  	if tmpl, ok := cmpDocStrings[fn.methName()]; ok {
   302  		type tmp struct {
   303  			Left, Right string
   304  		}
   305  		var ds tmp
   306  		if fn.VV {
   307  			ds.Left = "a"
   308  			ds.Right = "b"
   309  		} else {
   310  			ds.Left = "t"
   311  			ds.Right = "s"
   312  		}
   313  		tmpl.Execute(w, ds)
   314  	}
   315  	sig := fn.Signature()
   316  	w.Write([]byte("func (e StdEng) "))
   317  	sig.Write(w)
   318  	w.Write([]byte("(retVal Tensor, err error) {\n"))
   319  	fn.WriteBody(w)
   320  	w.Write([]byte("}\n\n"))
   321  }
   322  
   323  func generateStdEngMinMax(f io.Writer, ak Kinds) {
   324  	methods := []*EngineMinMax{
   325  		&EngineMinMax{
   326  			Name:           "MinBetween",
   327  			VV:             true,
   328  			TypeClassCheck: "Ord",
   329  		},
   330  		&EngineMinMax{
   331  			Name:           "MaxBetween",
   332  			VV:             true,
   333  			TypeClassCheck: "Ord",
   334  		},
   335  	}
   336  	f.Write([]byte(`var (
   337  	_ MinBetweener = StdEng{}
   338  	_ MaxBetweener = StdEng{}
   339  )
   340  `))
   341  	// VV
   342  	for _, meth := range methods {
   343  		meth.Write(f)
   344  		meth.VV = false
   345  	}
   346  
   347  	// Scalar-Vector
   348  	for _, meth := range methods {
   349  		meth.Write(f)
   350  		meth.LeftVec = true
   351  	}
   352  }
   353  
   354  /* UNARY METHODS */
   355  
   356  type EngineUnary struct {
   357  	Name           string
   358  	TypeClassCheck string
   359  	Kinds          []reflect.Kind
   360  }
   361  
   362  func (fn *EngineUnary) Signature() *Signature {
   363  	return &Signature{
   364  		Name:            fn.Name,
   365  		NameTemplate:    plainName,
   366  		ParamNames:      []string{"a", "opts"},
   367  		ParamTemplates:  []*template.Template{tensorType, splatFuncOptType},
   368  		RetVals:         []string{"retVal"},
   369  		RetValTemplates: []*template.Template{tensorType},
   370  
   371  		Err: true,
   372  	}
   373  }
   374  
   375  func (fn *EngineUnary) WriteBody(w io.Writer) {
   376  	prepUnary.Execute(w, fn)
   377  	agg2UnaryBody.Execute(w, fn)
   378  }
   379  
   380  func (fn *EngineUnary) Write(w io.Writer) {
   381  	sig := fn.Signature()
   382  	w.Write([]byte("func (e StdEng) "))
   383  	sig.Write(w)
   384  	w.Write([]byte("{\n"))
   385  	fn.WriteBody(w)
   386  	w.Write([]byte("\n}\n"))
   387  }
   388  
   389  func generateStdEngUncondUnary(f io.Writer, ak Kinds) {
   390  	tcc := []string{
   391  		"Number",     // Neg
   392  		"Number",     // Inv
   393  		"Number",     // Square
   394  		"Number",     // Cube
   395  		"FloatCmplx", // Exp
   396  		"FloatCmplx", // Tanhh
   397  		"FloatCmplx", // Log
   398  		"Float",      // Log2
   399  		"FloatCmplx", // Log10
   400  		"FloatCmplx", // Sqrt
   401  		"Float",      // Cbrt
   402  		"Float",      // InvSqrt
   403  	}
   404  	var gen []*EngineUnary
   405  	for i, u := range unconditionalUnaries {
   406  		var ks []reflect.Kind
   407  		for _, k := range ak.Kinds {
   408  			if tc := u.TypeClass(); tc != nil && !tc(k) {
   409  				continue
   410  			}
   411  			ks = append(ks, k)
   412  		}
   413  		fn := &EngineUnary{
   414  			Name:           u.Name(),
   415  			TypeClassCheck: tcc[i],
   416  			Kinds:          ks,
   417  		}
   418  		gen = append(gen, fn)
   419  	}
   420  
   421  	for _, fn := range gen {
   422  		fn.Write(f)
   423  	}
   424  }
   425  
   426  func generateStdEngCondUnary(f io.Writer, ak Kinds) {
   427  	tcc := []string{
   428  		"Signed", // Abs
   429  		"Signed", // Sign
   430  	}
   431  	var gen []*EngineUnary
   432  	for i, u := range conditionalUnaries {
   433  		var ks []reflect.Kind
   434  		for _, k := range ak.Kinds {
   435  			if tc := u.TypeClass(); tc != nil && !tc(k) {
   436  				continue
   437  			}
   438  			ks = append(ks, k)
   439  		}
   440  		fn := &EngineUnary{
   441  			Name:           u.Name(),
   442  			TypeClassCheck: tcc[i],
   443  			Kinds:          ks,
   444  		}
   445  		gen = append(gen, fn)
   446  	}
   447  
   448  	for _, fn := range gen {
   449  		fn.Write(f)
   450  	}
   451  }