github.com/wzzhu/tensor@v0.9.24/genlib2/agg3_body.go (about)

     1  package main
     2  
     3  import "text/template"
     4  
     5  // 3rd level function aggregation templates
     6  
     7  const denseArithBodyRaw = `{{$elne := eq .Name "Ne"}}
     8  {{$eleq := eq .Name "Eq"}}
     9  {{$eleqne := or $eleq $elne}}
    10  var ret Tensor
    11  if t.oe != nil {
    12  	if ret, err = t.oe.{{if $eleqne}}El{{end}}{{.Name}}(t, other, opts...); err != nil {
    13  		return nil, errors.Wrapf(err, "Unable to do {{.Name}}()")
    14  	}
    15  	if retVal, err = assertDense(ret); err != nil {
    16  		return nil, errors.Wrapf(err, opFail, "{{.Name}}")
    17  	}
    18  	return
    19  }
    20  
    21  if {{interfaceName .Name | lower}}, ok := t.e.({{interfaceName .Name}}); ok {
    22  	if ret, err = {{interfaceName .Name | lower}}.{{if $eleqne}}El{{end}}{{.Name}}(t, other, opts...); err != nil {
    23  		return nil, errors.Wrapf(err, "Unable to do {{.Name}}()")
    24  	}
    25  	if retVal, err = assertDense(ret); err != nil {
    26  		return nil, errors.Wrapf(err, opFail, "{{.Name}}")
    27  	}
    28  	return
    29  }
    30  return  nil, errors.Errorf("Engine does not support {{.Name}}()")
    31  `
    32  
    33  const denseArithScalarBodyRaw = `var ret Tensor
    34  if t.oe != nil {
    35  	if ret, err = t.oe.{{.Name}}Scalar(t, other, leftTensor, opts...); err != nil{
    36  		return nil, errors.Wrapf(err, "Unable to do {{.Name}}Scalar()")
    37  	}
    38  	if retVal, err = assertDense(ret); err != nil {
    39  			return nil, errors.Wrapf(err, opFail, "{{.Name}}Scalar")
    40  		}
    41  		return
    42  }
    43  
    44  	if {{interfaceName .Name | lower}}, ok := t.e.({{interfaceName .Name}}); ok {
    45  		if ret, err = {{interfaceName .Name | lower}}.{{.Name}}Scalar(t, other, leftTensor, opts...); err != nil {
    46  			return nil, errors.Wrapf(err, "Unable to do {{.Name}}Scalar()")
    47  		}
    48  		if retVal, err = assertDense(ret); err != nil {
    49  			return nil, errors.Wrapf(err, opFail, "{{.Name}}Scalar")
    50  		}
    51  		return
    52  	}
    53  	return nil, errors.Errorf("Engine does not support {{.Name}}Scalar()")
    54  `
    55  
    56  const denseIdentityArithTestBodyRaw = `iden := func(a *Dense) bool {
    57  	b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine()))
    58  	{{if ne .Identity 0 -}}
    59  			b.Memset(identityVal({{.Identity}}, a.t))
    60  	{{end -}}
    61  	{{template "funcoptdecl" -}}
    62  	correct := a.Clone().(*Dense)
    63  	{{template "funcoptcorrect" -}}
    64  
    65  	we, willFailEq := willerr(a, {{.TypeClassName}}, {{.EqFailTypeClassName}})
    66  	_, ok := a.Engine().({{interfaceName .Name}}); we = we || !ok
    67  
    68  	{{template "call0" . }}
    69  	if err, retEarly := qcErrCheck(t, "{{.Name}}", a, b, we, err); retEarly{
    70  		if err != nil {
    71  			return false
    72  		}
    73  		return true
    74  	}
    75  
    76  	if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) {
    77  		return false
    78  	}
    79  	{{template "funcoptcheck" -}}
    80  
    81  	return true
    82  }
    83  		if err := quick.Check(iden, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil{
    84  			t.Errorf("Identity test for {{.Name}} failed: %v", err)
    85  		}
    86  `
    87  
    88  const denseIdentityArithScalarTestRaw = `iden1 := func(q *Dense) bool {
    89  	a := q.Clone().(*Dense)
    90  	b := identityVal({{.Identity}}, q.t)
    91  	{{template "funcoptdecl"}}
    92  	correct := a.Clone().(*Dense)
    93  	{{template "funcoptcorrect" -}}
    94  
    95  	we, willFailEq := willerr(a, {{.TypeClassName}}, {{.EqFailTypeClassName}})
    96  	_, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok
    97  
    98  	{{template "call0" . }}
    99  	if err, retEarly := qcErrCheck(t, "{{.Name}}", a, b, we, err); retEarly{
   100  		if err != nil {
   101  			return false
   102  		}
   103  		return true
   104  	}
   105  
   106  	if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) {
   107  		return false
   108  	}
   109  	{{template "funcoptcheck" -}}
   110  
   111  	return true
   112  }
   113  
   114  if err := quick.Check(iden1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil {
   115  	t.Errorf("Identity test for {{.Name}} (tensor as left, scalar as right) failed: %v", err)
   116  }
   117  
   118  {{if .IsCommutative -}}
   119  iden2 := func(q *Dense) bool {
   120  	a := q.Clone().(*Dense)
   121  	b := identityVal({{.Identity}}, q.t)
   122  	{{template "funcoptdecl" -}}
   123  	correct := a.Clone().(*Dense)
   124  	{{template "funcoptcorrect" -}}
   125  
   126  	we, willFailEq := willerr(a, {{.TypeClassName}}, {{.EqFailTypeClassName}})
   127  	_, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok
   128  
   129  	{{template "call1" . }}
   130  	if err, retEarly := qcErrCheck(t, "{{.Name}}", a, b, we, err); retEarly{
   131  		if err != nil {
   132  			return false
   133  		}
   134  		return true
   135  	}
   136  
   137  	if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) {
   138  		return false
   139  	}
   140  	{{template "funcoptcheck" -}}
   141  
   142  	return true
   143  }
   144  if err := quick.Check(iden2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil {
   145  	t.Errorf("Identity test for {{.Name}} (scalar as left, tensor as right) failed: %v", err)
   146  }
   147  {{end -}}
   148  `
   149  
   150  const denseInvArithTestBodyRaw = `inv := func(a *Dense) bool {
   151  	b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine()))
   152  	{{if ne .Identity 0 -}}
   153  			b.Memset(identityVal({{.Identity}}, a.t))
   154  	{{end -}}
   155  	{{template "funcoptdecl" -}}
   156  	correct := a.Clone().(*Dense)
   157  	{{template "funcoptcorrect" -}}
   158  
   159  	we, willFailEq := willerr(a,  {{.TypeClassName}}, {{.EqFailTypeClassName}})
   160  	_, ok := a.Engine().({{interfaceName .Name}}); we = we || !ok
   161  
   162  	{{template "call0" . }}
   163  	if err, retEarly := qcErrCheck(t, "{{.Name}}", a, b, we, err); retEarly{
   164  		if err != nil {
   165  			return false
   166  		}
   167  		return true
   168  	}
   169  	{{template "callInv" .}}
   170  
   171  	if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) {
   172  		return false
   173  	}
   174  	{{template "funcoptcheck" -}}
   175  
   176  	return true
   177  }
   178  		if err := quick.Check(inv, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil{
   179  			t.Errorf("Inv test for {{.Name}} failed: %v", err)
   180  		}
   181  `
   182  
   183  const denseInvArithScalarTestRaw = `inv1 := func(q *Dense) bool {
   184  	a := q.Clone().(*Dense)
   185  	b := identityVal({{.Identity}}, q.t)
   186  	{{template "funcoptdecl"}}
   187  	correct := a.Clone().(*Dense)
   188  	{{template "funcoptcorrect" -}}
   189  
   190  	we, willFailEq := willerr(a, {{.TypeClassName}}, {{.EqFailTypeClassName}})
   191  	_, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok
   192  
   193  	{{template "call0" . }}
   194  	if err, retEarly := qcErrCheck(t, "{{.Name}}VS", a, b, we, err); retEarly{
   195  		if err != nil {
   196  			return false
   197  		}
   198  		return true
   199  	}
   200  	{{template "callInv0" .}}
   201  
   202  	if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) {
   203  		return false
   204  	}
   205  	{{template "funcoptcheck" -}}
   206  
   207  	return true
   208  }
   209  if err := quick.Check(inv1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil {
   210  	t.Errorf("Inv test for {{.Name}} (tensor as left, scalar as right) failed: %v", err)
   211  }
   212  
   213  {{if .IsInvolutionary -}}
   214  {{if eq .FuncOpt "incr" -}}
   215  {{else -}}
   216  inv2 := func(q *Dense) bool {
   217  	a := q.Clone().(*Dense)
   218  	b := identityVal({{.Identity}}, q.t)
   219  	{{template "funcoptdecl" -}}
   220  	correct := a.Clone().(*Dense)
   221  	{{template "funcoptcorrect" -}}
   222  
   223  	we, willFailEq := willerr(a, {{.TypeClassName}}, {{.EqFailTypeClassName}})
   224  	_, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok
   225  
   226  	{{template "call1" . }}
   227  	if err, retEarly := qcErrCheck(t, "{{.Name}}SV", a, b, we, err); retEarly{
   228  		if err != nil {
   229  			return false
   230  		}
   231  		return true
   232  	}
   233  	{{template "callInv1" .}}
   234  
   235  	if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) {
   236  		return false
   237  	}
   238  	{{template "funcoptcheck" -}}
   239  
   240  	return true
   241  }
   242  if err := quick.Check(inv2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil {
   243  	t.Errorf("Inv test for {{.Name}} (scalar as left, tensor as right) failed: %v", err)
   244  }
   245  {{end -}}
   246  {{end -}}
   247  `
   248  
   249  const denseArithScalarWrongTypeTestRaw = `type Foo int
   250  wt1 := func(a *Dense) bool{
   251  	b := Foo(0)
   252  	{{template "call0" .}}
   253  	if err == nil {
   254  		return false
   255  	}
   256  	_ = ret
   257  	return true
   258  }
   259  if err := quick.Check(wt1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil {
   260  	t.Errorf("WrongType test for {{.Name}} (tensor as left, scalar as right) failed: %v", err)
   261  }
   262  
   263  wt2 := func(a *Dense) bool{
   264  	b := Foo(0)
   265  	{{template "call1" .}}
   266  	if err == nil {
   267  		return false
   268  	}
   269  	_ = ret
   270  	return true
   271  }
   272  if err := quick.Check(wt2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil {
   273  	t.Errorf("WrongType test for {{.Name}} (tensor as right, scalar as left) failed: %v", err)
   274  }
   275  `
   276  
   277  const denseArithReuseMutationTestRaw = `mut := func(a, b *Dense, reuseA bool) bool {
   278  	// req because we're only testing on one kind of tensor/engine combo
   279  	a.e = StdEng{}
   280  	a.oe = StdEng{}
   281  	a.flag = 0
   282  	b.e = StdEng{}
   283  	b.oe = StdEng{}
   284  	b.flag = 0
   285  
   286  	if a.Dtype() != b.Dtype(){
   287  	return true
   288  	}
   289  	if !a.Shape().Eq(b.Shape()){
   290  	return true
   291  	}
   292  
   293  
   294  
   295  	{{template "callVanilla" .}}
   296  	we, willFailEq := willerr(a, {{.TypeClassName}}, {{.EqFailTypeClassName}})
   297  	_, ok := a.Engine().({{interfaceName .Name}}); we = we || !ok
   298  
   299  
   300  
   301  	var ret, reuse {{template "retType" .}}
   302  	if reuseA {
   303  		{{template "call0" .}}, WithReuse(a))
   304  		reuse = a
   305  	} else {
   306  		{{template "call0" .}}, WithReuse(b))
   307  		reuse = b
   308  	}
   309  
   310  
   311  	if err, retEarly := qcErrCheck(t, "{{.Name}}", a, b, we, err); retEarly{
   312  		if err != nil {
   313  			return false
   314  		}
   315  		return true
   316  	}
   317  
   318  	if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) {
   319  		return false
   320  	}
   321  
   322  	{{template "funcoptcheck" -}}
   323  
   324  	return true
   325  }
   326  if err := quick.Check(mut, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil {
   327  	t.Errorf("Reuse Mutation test for {{.Name}} failed: %v", err)
   328  }
   329  
   330  `
   331  
   332  var (
   333  	denseArithBody       *template.Template
   334  	denseArithScalarBody *template.Template
   335  
   336  	denseIdentityArithTest       *template.Template
   337  	denseIdentityArithScalarTest *template.Template
   338  
   339  	denseArithScalarWrongTypeTest *template.Template
   340  )
   341  
   342  func init() {
   343  	denseArithBody = template.Must(template.New("dense arith body").Funcs(funcs).Parse(denseArithBodyRaw))
   344  	denseArithScalarBody = template.Must(template.New("dense arith body").Funcs(funcs).Parse(denseArithScalarBodyRaw))
   345  
   346  	denseIdentityArithTest = template.Must(template.New("dense identity test").Funcs(funcs).Parse(denseIdentityArithTestBodyRaw))
   347  	denseIdentityArithScalarTest = template.Must(template.New("dense scalar identity test").Funcs(funcs).Parse(denseIdentityArithScalarTestRaw))
   348  
   349  	denseArithScalarWrongTypeTest = template.Must(template.New("dense scalar wrongtype test").Funcs(funcs).Parse(denseArithScalarWrongTypeTestRaw))
   350  }