github.com/wzzhu/tensor@v0.9.24/genlib2/arith_tests.go (about)

     1  package main
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"text/template"
     7  )
     8  
     9  const (
    10  	APICallVVxRaw        = `correct, err := {{.Name}}(a, b)` // no funcopt
    11  	APICallVVReuseMutRaw = `ret, err = {{.Name}}(a, b`
    12  	APICallVVRaw         = `ret, err := {{.Name}}(a, b {{template "funcoptuse"}})`
    13  	APICallVSRaw         = `ret, err := {{.Name}}(a, b {{template "funcoptuse"}})`
    14  	APICallSVRaw         = `ret, err := {{.Name}}(b, a {{template "funcoptuse"}})`
    15  
    16  	APIInvVVRaw = `ret, err = {{.Inv}}(ret, b, UseUnsafe())`
    17  	APIInvVSRaw = `ret, err = {{.Inv}}(ret, b, UseUnsafe())`
    18  	APIInvSVRaw = `ret, err = {{.Name}}(b, ret, UseUnsafe())`
    19  
    20  	DenseMethodCallVVxRaw        = `correct, err := a.{{.Name}}(b)` // no funcopt
    21  	DenseMethodCallVVReuseMutRaw = `ret, err = a.{{.Name}}(b`
    22  	DenseMethodCallVVRaw         = `ret, err := a.{{.Name}}(b {{template "funcoptuse"}})`
    23  	DenseMethodCallVSRaw         = `ret, err := a.{{.Name}}Scalar(b, true {{template "funcoptuse"}})`
    24  	DenseMethodCallSVRaw         = `ret, err := a.{{.Name}}Scalar(b, false {{template "funcoptuse"}})`
    25  
    26  	DenseMethodInvVVRaw = `ret, err = ret.{{.Inv}}(b, UseUnsafe())`
    27  	DenseMethodInvVSRaw = `ret, err = ret.{{.Inv}}Scalar(b, true, UseUnsafe())`
    28  	DenseMethodInvSVRaw = `ret, err = ret.{{.Name}}Scalar(b, false, UseUnsafe())`
    29  
    30  	APIRetType   = `Tensor`
    31  	DenseRetType = `*Dense`
    32  )
    33  
    34  type ArithTest struct {
    35  	arithOp
    36  	scalars             bool
    37  	lvl                 Level
    38  	FuncOpt             string
    39  	EqFailTypeClassName string
    40  }
    41  
    42  func (fn *ArithTest) Signature() *Signature {
    43  	var name string
    44  	switch fn.lvl {
    45  	case API:
    46  		name = fmt.Sprintf("Test%s", fn.Name())
    47  
    48  	case Dense:
    49  		name = fmt.Sprintf("TestDense_%s", fn.Name())
    50  	}
    51  	if fn.scalars {
    52  		name += "Scalar"
    53  	}
    54  	if fn.FuncOpt != "" {
    55  		name += "_" + fn.FuncOpt
    56  	}
    57  	return &Signature{
    58  		Name:           name,
    59  		NameTemplate:   plainName,
    60  		ParamNames:     []string{"t"},
    61  		ParamTemplates: []*template.Template{testingType},
    62  	}
    63  }
    64  
    65  func (fn *ArithTest) WriteBody(w io.Writer) {
    66  	if fn.HasIdentity {
    67  		fn.writeIdentity(w)
    68  		fmt.Fprintf(w, "\n")
    69  	}
    70  
    71  	if fn.IsInv {
    72  		fn.writeInv(w)
    73  	}
    74  	fn.WriteScalarWrongType(w)
    75  
    76  	if fn.FuncOpt == "reuse" && fn.arithOp.Name() != "Pow" {
    77  		fn.writeReuseMutate(w)
    78  	}
    79  }
    80  
    81  func (fn *ArithTest) canWrite() bool {
    82  	if fn.HasIdentity || fn.IsInv {
    83  		return true
    84  	}
    85  	return false
    86  }
    87  
    88  func (fn *ArithTest) writeIdentity(w io.Writer) {
    89  	var t *template.Template
    90  	if fn.scalars {
    91  		t = template.Must(template.New("dense identity test").Funcs(funcs).Parse(denseIdentityArithScalarTestRaw))
    92  	} else {
    93  		t = template.Must(template.New("dense identity test").Funcs(funcs).Parse(denseIdentityArithTestBodyRaw))
    94  	}
    95  	switch fn.lvl {
    96  	case API:
    97  		if fn.scalars {
    98  			template.Must(t.New("call0").Parse(APICallVSRaw))
    99  			template.Must(t.New("call1").Parse(APICallSVRaw))
   100  		} else {
   101  			template.Must(t.New("call0").Parse(APICallVVRaw))
   102  		}
   103  	case Dense:
   104  		if fn.scalars {
   105  			template.Must(t.New("call0").Parse(DenseMethodCallVSRaw))
   106  			template.Must(t.New("call1").Parse(DenseMethodCallSVRaw))
   107  		} else {
   108  			template.Must(t.New("call0").Parse(DenseMethodCallVVRaw))
   109  		}
   110  
   111  	}
   112  	template.Must(t.New("funcoptdecl").Parse(funcOptDecl[fn.FuncOpt]))
   113  	template.Must(t.New("funcoptcorrect").Parse(funcOptCorrect[fn.FuncOpt]))
   114  	template.Must(t.New("funcoptuse").Parse(funcOptUse[fn.FuncOpt]))
   115  	template.Must(t.New("funcoptcheck").Parse(funcOptCheck[fn.FuncOpt]))
   116  
   117  	t.Execute(w, fn)
   118  }
   119  
   120  func (fn *ArithTest) writeInv(w io.Writer) {
   121  	var t *template.Template
   122  	if fn.scalars {
   123  		t = template.Must(template.New("dense involution test").Funcs(funcs).Parse(denseInvArithScalarTestRaw))
   124  	} else {
   125  		t = template.Must(template.New("dense involution test").Funcs(funcs).Parse(denseInvArithTestBodyRaw))
   126  	}
   127  	switch fn.lvl {
   128  	case API:
   129  		if fn.scalars {
   130  			template.Must(t.New("call0").Parse(APICallVSRaw))
   131  			template.Must(t.New("call1").Parse(APICallSVRaw))
   132  			template.Must(t.New("callInv0").Parse(APIInvVSRaw))
   133  			template.Must(t.New("callInv1").Parse(APIInvSVRaw))
   134  		} else {
   135  			template.Must(t.New("call0").Parse(APICallVVRaw))
   136  			template.Must(t.New("callInv").Parse(APIInvVVRaw))
   137  		}
   138  	case Dense:
   139  		if fn.scalars {
   140  			template.Must(t.New("call0").Parse(DenseMethodCallVSRaw))
   141  			template.Must(t.New("call1").Parse(DenseMethodCallSVRaw))
   142  			template.Must(t.New("callInv0").Parse(DenseMethodInvVSRaw))
   143  			template.Must(t.New("callInv1").Parse(DenseMethodInvSVRaw))
   144  		} else {
   145  			template.Must(t.New("call0").Parse(DenseMethodCallVVRaw))
   146  			template.Must(t.New("callInv").Parse(DenseMethodInvVVRaw))
   147  		}
   148  	}
   149  
   150  	template.Must(t.New("funcoptdecl").Parse(funcOptDecl[fn.FuncOpt]))
   151  	template.Must(t.New("funcoptcorrect").Parse(funcOptCorrect[fn.FuncOpt]))
   152  	template.Must(t.New("funcoptuse").Parse(funcOptUse[fn.FuncOpt]))
   153  	template.Must(t.New("funcoptcheck").Parse(funcOptCheck[fn.FuncOpt]))
   154  
   155  	t.Execute(w, fn)
   156  }
   157  
   158  func (fn *ArithTest) writeReuseMutate(w io.Writer) {
   159  	t := template.Must(template.New("Reuse mutation test").Funcs(funcs).Parse(denseArithReuseMutationTestRaw))
   160  	switch fn.lvl {
   161  	case API:
   162  		return // tmp
   163  	case Dense:
   164  		template.Must(t.New("callVanilla").Parse(DenseMethodCallVVxRaw))
   165  		template.Must(t.New("retType").Parse(DenseRetType))
   166  		template.Must(t.New("call0").Parse(DenseMethodCallVVReuseMutRaw))
   167  
   168  	}
   169  	template.Must(t.New("funcoptdecl").Parse(funcOptDecl[fn.FuncOpt]))
   170  	template.Must(t.New("funcoptuse").Parse(funcOptUse[fn.FuncOpt]))
   171  	template.Must(t.New("funcoptcheck").Parse(funcOptCheck[fn.FuncOpt]))
   172  	t.Execute(w, fn)
   173  }
   174  
   175  func (fn *ArithTest) WriteScalarWrongType(w io.Writer) {
   176  	if !fn.scalars {
   177  		return
   178  	}
   179  	if fn.FuncOpt != "" {
   180  		return
   181  	}
   182  	t := template.Must(template.New("dense scalar wrongtype test").Funcs(funcs).Parse(denseArithScalarWrongTypeTestRaw))
   183  	template.Must(t.New("call0").Parse(APICallVSRaw))
   184  	template.Must(t.New("call1").Parse(APICallSVRaw))
   185  	template.Must(t.New("funcoptdecl").Parse(funcOptDecl[fn.FuncOpt]))
   186  	template.Must(t.New("funcoptcorrect").Parse(funcOptCorrect[fn.FuncOpt]))
   187  	template.Must(t.New("funcoptuse").Parse(funcOptUse[fn.FuncOpt]))
   188  	template.Must(t.New("funcoptcheck").Parse(funcOptCheck[fn.FuncOpt]))
   189  
   190  	t.Execute(w, fn)
   191  }
   192  
   193  func (fn *ArithTest) Write(w io.Writer) {
   194  	sig := fn.Signature()
   195  	w.Write([]byte("func "))
   196  	sig.Write(w)
   197  	w.Write([]byte("{\n"))
   198  	fn.WriteBody(w)
   199  	w.Write([]byte("}\n"))
   200  }
   201  
   202  func generateAPIArithTests(f io.Writer, ak Kinds) {
   203  	var tests []*ArithTest
   204  	for _, op := range arithBinOps {
   205  		t := &ArithTest{
   206  			arithOp:             op,
   207  			lvl:                 API,
   208  			EqFailTypeClassName: "nil",
   209  		}
   210  		if t.name == "Pow" {
   211  			t.EqFailTypeClassName = "complexTypes"
   212  		}
   213  		tests = append(tests, t)
   214  	}
   215  
   216  	for _, fn := range tests {
   217  		if fn.canWrite() {
   218  			fn.Write(f)
   219  		}
   220  		fn.FuncOpt = "unsafe"
   221  	}
   222  
   223  	for _, fn := range tests {
   224  		if fn.canWrite() {
   225  			fn.Write(f)
   226  		}
   227  		fn.FuncOpt = "reuse"
   228  	}
   229  
   230  	for _, fn := range tests {
   231  		if fn.canWrite() {
   232  			fn.Write(f)
   233  		}
   234  		fn.FuncOpt = "incr"
   235  	}
   236  
   237  	for _, fn := range tests {
   238  		if fn.canWrite() {
   239  			fn.Write(f)
   240  		}
   241  	}
   242  }
   243  
   244  func generateAPIArithScalarTests(f io.Writer, ak Kinds) {
   245  	var tests []*ArithTest
   246  	for _, op := range arithBinOps {
   247  		t := &ArithTest{
   248  			arithOp:             op,
   249  			scalars:             true,
   250  			lvl:                 API,
   251  			EqFailTypeClassName: "nil",
   252  		}
   253  		switch t.name {
   254  		case "Pow":
   255  			t.EqFailTypeClassName = "complexTypes"
   256  		case "Sub":
   257  			t.EqFailTypeClassName = "unsignedTypes"
   258  		}
   259  		tests = append(tests, t)
   260  	}
   261  
   262  	for _, fn := range tests {
   263  		if fn.canWrite() {
   264  			fn.Write(f)
   265  		}
   266  		fn.FuncOpt = "unsafe"
   267  	}
   268  
   269  	for _, fn := range tests {
   270  		if fn.canWrite() {
   271  			fn.Write(f)
   272  		}
   273  		fn.FuncOpt = "reuse"
   274  	}
   275  
   276  	for _, fn := range tests {
   277  		if fn.canWrite() {
   278  			fn.Write(f)
   279  		}
   280  		fn.FuncOpt = "incr"
   281  	}
   282  
   283  	for _, fn := range tests {
   284  		if fn.canWrite() {
   285  			fn.Write(f)
   286  		}
   287  	}
   288  }
   289  
   290  func generateDenseMethodArithTests(f io.Writer, ak Kinds) {
   291  	var tests []*ArithTest
   292  	for _, op := range arithBinOps {
   293  		t := &ArithTest{
   294  			arithOp:             op,
   295  			lvl:                 Dense,
   296  			EqFailTypeClassName: "nil",
   297  		}
   298  		if t.name == "Pow" {
   299  			t.EqFailTypeClassName = "complexTypes"
   300  		}
   301  		tests = append(tests, t)
   302  	}
   303  
   304  	for _, fn := range tests {
   305  		if fn.canWrite() {
   306  			fn.Write(f)
   307  		}
   308  		fn.FuncOpt = "unsafe"
   309  	}
   310  
   311  	for _, fn := range tests {
   312  		if fn.canWrite() {
   313  			fn.Write(f)
   314  		}
   315  		fn.FuncOpt = "reuse"
   316  	}
   317  
   318  	for _, fn := range tests {
   319  		if fn.canWrite() {
   320  			fn.Write(f)
   321  		}
   322  		fn.FuncOpt = "incr"
   323  	}
   324  
   325  	for _, fn := range tests {
   326  		if fn.canWrite() {
   327  			fn.Write(f)
   328  		}
   329  	}
   330  }
   331  
   332  func generateDenseMethodScalarTests(f io.Writer, ak Kinds) {
   333  	var tests []*ArithTest
   334  	for _, op := range arithBinOps {
   335  		t := &ArithTest{
   336  			arithOp:             op,
   337  			scalars:             true,
   338  			lvl:                 Dense,
   339  			EqFailTypeClassName: "nil",
   340  		}
   341  		switch t.name {
   342  		case "Pow":
   343  			t.EqFailTypeClassName = "complexTypes"
   344  		case "Sub":
   345  			t.EqFailTypeClassName = "unsignedTypes"
   346  		}
   347  		tests = append(tests, t)
   348  	}
   349  
   350  	for _, fn := range tests {
   351  		if fn.canWrite() {
   352  			fn.Write(f)
   353  		}
   354  		fn.FuncOpt = "unsafe"
   355  	}
   356  
   357  	for _, fn := range tests {
   358  		if fn.canWrite() {
   359  			fn.Write(f)
   360  		}
   361  		fn.FuncOpt = "reuse"
   362  	}
   363  
   364  	for _, fn := range tests {
   365  		if fn.canWrite() {
   366  			fn.Write(f)
   367  		}
   368  		fn.FuncOpt = "incr"
   369  	}
   370  
   371  	for _, fn := range tests {
   372  		if fn.canWrite() {
   373  			fn.Write(f)
   374  		}
   375  	}
   376  }