github.com/wzzhu/tensor@v0.9.24/genlib2/agg3_body.go (about) 1 package main 2 3 import "text/template" 4 5 // 3rd level function aggregation templates 6 7 const denseArithBodyRaw = `{{$elne := eq .Name "Ne"}} 8 {{$eleq := eq .Name "Eq"}} 9 {{$eleqne := or $eleq $elne}} 10 var ret Tensor 11 if t.oe != nil { 12 if ret, err = t.oe.{{if $eleqne}}El{{end}}{{.Name}}(t, other, opts...); err != nil { 13 return nil, errors.Wrapf(err, "Unable to do {{.Name}}()") 14 } 15 if retVal, err = assertDense(ret); err != nil { 16 return nil, errors.Wrapf(err, opFail, "{{.Name}}") 17 } 18 return 19 } 20 21 if {{interfaceName .Name | lower}}, ok := t.e.({{interfaceName .Name}}); ok { 22 if ret, err = {{interfaceName .Name | lower}}.{{if $eleqne}}El{{end}}{{.Name}}(t, other, opts...); err != nil { 23 return nil, errors.Wrapf(err, "Unable to do {{.Name}}()") 24 } 25 if retVal, err = assertDense(ret); err != nil { 26 return nil, errors.Wrapf(err, opFail, "{{.Name}}") 27 } 28 return 29 } 30 return nil, errors.Errorf("Engine does not support {{.Name}}()") 31 ` 32 33 const denseArithScalarBodyRaw = `var ret Tensor 34 if t.oe != nil { 35 if ret, err = t.oe.{{.Name}}Scalar(t, other, leftTensor, opts...); err != nil{ 36 return nil, errors.Wrapf(err, "Unable to do {{.Name}}Scalar()") 37 } 38 if retVal, err = assertDense(ret); err != nil { 39 return nil, errors.Wrapf(err, opFail, "{{.Name}}Scalar") 40 } 41 return 42 } 43 44 if {{interfaceName .Name | lower}}, ok := t.e.({{interfaceName .Name}}); ok { 45 if ret, err = {{interfaceName .Name | lower}}.{{.Name}}Scalar(t, other, leftTensor, opts...); err != nil { 46 return nil, errors.Wrapf(err, "Unable to do {{.Name}}Scalar()") 47 } 48 if retVal, err = assertDense(ret); err != nil { 49 return nil, errors.Wrapf(err, opFail, "{{.Name}}Scalar") 50 } 51 return 52 } 53 return nil, errors.Errorf("Engine does not support {{.Name}}Scalar()") 54 ` 55 56 const denseIdentityArithTestBodyRaw = `iden := func(a *Dense) bool { 57 b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) 58 {{if ne .Identity 0 -}} 59 b.Memset(identityVal({{.Identity}}, a.t)) 60 {{end -}} 61 {{template "funcoptdecl" -}} 62 correct := a.Clone().(*Dense) 63 {{template "funcoptcorrect" -}} 64 65 we, willFailEq := willerr(a, {{.TypeClassName}}, {{.EqFailTypeClassName}}) 66 _, ok := a.Engine().({{interfaceName .Name}}); we = we || !ok 67 68 {{template "call0" . }} 69 if err, retEarly := qcErrCheck(t, "{{.Name}}", a, b, we, err); retEarly{ 70 if err != nil { 71 return false 72 } 73 return true 74 } 75 76 if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { 77 return false 78 } 79 {{template "funcoptcheck" -}} 80 81 return true 82 } 83 if err := quick.Check(iden, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil{ 84 t.Errorf("Identity test for {{.Name}} failed: %v", err) 85 } 86 ` 87 88 const denseIdentityArithScalarTestRaw = `iden1 := func(q *Dense) bool { 89 a := q.Clone().(*Dense) 90 b := identityVal({{.Identity}}, q.t) 91 {{template "funcoptdecl"}} 92 correct := a.Clone().(*Dense) 93 {{template "funcoptcorrect" -}} 94 95 we, willFailEq := willerr(a, {{.TypeClassName}}, {{.EqFailTypeClassName}}) 96 _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok 97 98 {{template "call0" . }} 99 if err, retEarly := qcErrCheck(t, "{{.Name}}", a, b, we, err); retEarly{ 100 if err != nil { 101 return false 102 } 103 return true 104 } 105 106 if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { 107 return false 108 } 109 {{template "funcoptcheck" -}} 110 111 return true 112 } 113 114 if err := quick.Check(iden1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { 115 t.Errorf("Identity test for {{.Name}} (tensor as left, scalar as right) failed: %v", err) 116 } 117 118 {{if .IsCommutative -}} 119 iden2 := func(q *Dense) bool { 120 a := q.Clone().(*Dense) 121 b := identityVal({{.Identity}}, q.t) 122 {{template "funcoptdecl" -}} 123 correct := a.Clone().(*Dense) 124 {{template "funcoptcorrect" -}} 125 126 we, willFailEq := willerr(a, {{.TypeClassName}}, {{.EqFailTypeClassName}}) 127 _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok 128 129 {{template "call1" . }} 130 if err, retEarly := qcErrCheck(t, "{{.Name}}", a, b, we, err); retEarly{ 131 if err != nil { 132 return false 133 } 134 return true 135 } 136 137 if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { 138 return false 139 } 140 {{template "funcoptcheck" -}} 141 142 return true 143 } 144 if err := quick.Check(iden2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { 145 t.Errorf("Identity test for {{.Name}} (scalar as left, tensor as right) failed: %v", err) 146 } 147 {{end -}} 148 ` 149 150 const denseInvArithTestBodyRaw = `inv := func(a *Dense) bool { 151 b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) 152 {{if ne .Identity 0 -}} 153 b.Memset(identityVal({{.Identity}}, a.t)) 154 {{end -}} 155 {{template "funcoptdecl" -}} 156 correct := a.Clone().(*Dense) 157 {{template "funcoptcorrect" -}} 158 159 we, willFailEq := willerr(a, {{.TypeClassName}}, {{.EqFailTypeClassName}}) 160 _, ok := a.Engine().({{interfaceName .Name}}); we = we || !ok 161 162 {{template "call0" . }} 163 if err, retEarly := qcErrCheck(t, "{{.Name}}", a, b, we, err); retEarly{ 164 if err != nil { 165 return false 166 } 167 return true 168 } 169 {{template "callInv" .}} 170 171 if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { 172 return false 173 } 174 {{template "funcoptcheck" -}} 175 176 return true 177 } 178 if err := quick.Check(inv, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil{ 179 t.Errorf("Inv test for {{.Name}} failed: %v", err) 180 } 181 ` 182 183 const denseInvArithScalarTestRaw = `inv1 := func(q *Dense) bool { 184 a := q.Clone().(*Dense) 185 b := identityVal({{.Identity}}, q.t) 186 {{template "funcoptdecl"}} 187 correct := a.Clone().(*Dense) 188 {{template "funcoptcorrect" -}} 189 190 we, willFailEq := willerr(a, {{.TypeClassName}}, {{.EqFailTypeClassName}}) 191 _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok 192 193 {{template "call0" . }} 194 if err, retEarly := qcErrCheck(t, "{{.Name}}VS", a, b, we, err); retEarly{ 195 if err != nil { 196 return false 197 } 198 return true 199 } 200 {{template "callInv0" .}} 201 202 if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { 203 return false 204 } 205 {{template "funcoptcheck" -}} 206 207 return true 208 } 209 if err := quick.Check(inv1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { 210 t.Errorf("Inv test for {{.Name}} (tensor as left, scalar as right) failed: %v", err) 211 } 212 213 {{if .IsInvolutionary -}} 214 {{if eq .FuncOpt "incr" -}} 215 {{else -}} 216 inv2 := func(q *Dense) bool { 217 a := q.Clone().(*Dense) 218 b := identityVal({{.Identity}}, q.t) 219 {{template "funcoptdecl" -}} 220 correct := a.Clone().(*Dense) 221 {{template "funcoptcorrect" -}} 222 223 we, willFailEq := willerr(a, {{.TypeClassName}}, {{.EqFailTypeClassName}}) 224 _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok 225 226 {{template "call1" . }} 227 if err, retEarly := qcErrCheck(t, "{{.Name}}SV", a, b, we, err); retEarly{ 228 if err != nil { 229 return false 230 } 231 return true 232 } 233 {{template "callInv1" .}} 234 235 if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { 236 return false 237 } 238 {{template "funcoptcheck" -}} 239 240 return true 241 } 242 if err := quick.Check(inv2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { 243 t.Errorf("Inv test for {{.Name}} (scalar as left, tensor as right) failed: %v", err) 244 } 245 {{end -}} 246 {{end -}} 247 ` 248 249 const denseArithScalarWrongTypeTestRaw = `type Foo int 250 wt1 := func(a *Dense) bool{ 251 b := Foo(0) 252 {{template "call0" .}} 253 if err == nil { 254 return false 255 } 256 _ = ret 257 return true 258 } 259 if err := quick.Check(wt1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { 260 t.Errorf("WrongType test for {{.Name}} (tensor as left, scalar as right) failed: %v", err) 261 } 262 263 wt2 := func(a *Dense) bool{ 264 b := Foo(0) 265 {{template "call1" .}} 266 if err == nil { 267 return false 268 } 269 _ = ret 270 return true 271 } 272 if err := quick.Check(wt2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { 273 t.Errorf("WrongType test for {{.Name}} (tensor as right, scalar as left) failed: %v", err) 274 } 275 ` 276 277 const denseArithReuseMutationTestRaw = `mut := func(a, b *Dense, reuseA bool) bool { 278 // req because we're only testing on one kind of tensor/engine combo 279 a.e = StdEng{} 280 a.oe = StdEng{} 281 a.flag = 0 282 b.e = StdEng{} 283 b.oe = StdEng{} 284 b.flag = 0 285 286 if a.Dtype() != b.Dtype(){ 287 return true 288 } 289 if !a.Shape().Eq(b.Shape()){ 290 return true 291 } 292 293 294 295 {{template "callVanilla" .}} 296 we, willFailEq := willerr(a, {{.TypeClassName}}, {{.EqFailTypeClassName}}) 297 _, ok := a.Engine().({{interfaceName .Name}}); we = we || !ok 298 299 300 301 var ret, reuse {{template "retType" .}} 302 if reuseA { 303 {{template "call0" .}}, WithReuse(a)) 304 reuse = a 305 } else { 306 {{template "call0" .}}, WithReuse(b)) 307 reuse = b 308 } 309 310 311 if err, retEarly := qcErrCheck(t, "{{.Name}}", a, b, we, err); retEarly{ 312 if err != nil { 313 return false 314 } 315 return true 316 } 317 318 if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { 319 return false 320 } 321 322 {{template "funcoptcheck" -}} 323 324 return true 325 } 326 if err := quick.Check(mut, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { 327 t.Errorf("Reuse Mutation test for {{.Name}} failed: %v", err) 328 } 329 330 ` 331 332 var ( 333 denseArithBody *template.Template 334 denseArithScalarBody *template.Template 335 336 denseIdentityArithTest *template.Template 337 denseIdentityArithScalarTest *template.Template 338 339 denseArithScalarWrongTypeTest *template.Template 340 ) 341 342 func init() { 343 denseArithBody = template.Must(template.New("dense arith body").Funcs(funcs).Parse(denseArithBodyRaw)) 344 denseArithScalarBody = template.Must(template.New("dense arith body").Funcs(funcs).Parse(denseArithScalarBodyRaw)) 345 346 denseIdentityArithTest = template.Must(template.New("dense identity test").Funcs(funcs).Parse(denseIdentityArithTestBodyRaw)) 347 denseIdentityArithScalarTest = template.Must(template.New("dense scalar identity test").Funcs(funcs).Parse(denseIdentityArithScalarTestRaw)) 348 349 denseArithScalarWrongTypeTest = template.Must(template.New("dense scalar wrongtype test").Funcs(funcs).Parse(denseArithScalarWrongTypeTestRaw)) 350 }