github.com/wzzhu/tensor@v0.9.24/genlib2/engine.go (about) 1 package main 2 3 import ( 4 "io" 5 "reflect" 6 "text/template" 7 ) 8 9 type EngineArith struct { 10 Name string 11 VecVar string 12 PrepData string 13 TypeClassCheck string 14 IsCommutative bool 15 16 VV bool 17 LeftVec bool 18 } 19 20 func (fn *EngineArith) methName() string { 21 switch { 22 case fn.VV: 23 return fn.Name 24 default: 25 return fn.Name + "Scalar" 26 } 27 } 28 29 func (fn *EngineArith) Signature() *Signature { 30 var paramNames []string 31 var paramTemplates []*template.Template 32 33 switch { 34 case fn.VV: 35 paramNames = []string{"a", "b", "opts"} 36 paramTemplates = []*template.Template{tensorType, tensorType, splatFuncOptType} 37 default: 38 paramNames = []string{"t", "s", "leftTensor", "opts"} 39 paramTemplates = []*template.Template{tensorType, interfaceType, boolType, splatFuncOptType} 40 } 41 return &Signature{ 42 Name: fn.methName(), 43 NameTemplate: plainName, 44 ParamNames: paramNames, 45 ParamTemplates: paramTemplates, 46 Err: false, 47 } 48 } 49 50 func (fn *EngineArith) WriteBody(w io.Writer) { 51 var prep *template.Template 52 switch { 53 case fn.VV: 54 prep = prepVV 55 fn.VecVar = "a" 56 case !fn.VV && fn.LeftVec: 57 fn.VecVar = "t" 58 fn.PrepData = "prepDataVS" 59 prep = prepMixed 60 default: 61 fn.VecVar = "t" 62 fn.PrepData = "prepDataSV" 63 prep = prepMixed 64 } 65 template.Must(prep.New("prep").Parse(arithPrepRaw)) 66 prep.Execute(w, fn) 67 agg2Body.Execute(w, fn) 68 } 69 70 func (fn *EngineArith) Write(w io.Writer) { 71 if tmpl, ok := arithDocStrings[fn.methName()]; ok { 72 type tmp struct { 73 Left, Right string 74 } 75 var ds tmp 76 if fn.VV { 77 ds.Left = "a" 78 ds.Right = "b" 79 } else { 80 ds.Left = "t" 81 ds.Right = "s" 82 } 83 tmpl.Execute(w, ds) 84 } 85 86 sig := fn.Signature() 87 w.Write([]byte("func (e StdEng) ")) 88 sig.Write(w) 89 w.Write([]byte("(retVal Tensor, err error) {\n")) 90 fn.WriteBody(w) 91 w.Write([]byte("}\n\n")) 92 } 93 94 func generateStdEngArith(f io.Writer, ak Kinds) { 95 var methods []*EngineArith 96 for _, abo := range arithBinOps { 97 meth := &EngineArith{ 98 Name: abo.Name(), 99 VV: true, 100 TypeClassCheck: "Number", 101 IsCommutative: abo.IsCommutative, 102 } 103 methods = append(methods, meth) 104 } 105 106 // VV 107 for _, meth := range methods { 108 meth.Write(f) 109 meth.VV = false 110 } 111 112 // Scalar 113 for _, meth := range methods { 114 meth.Write(f) 115 meth.LeftVec = true 116 } 117 118 } 119 120 type EngineCmp struct { 121 Name string 122 VecVar string 123 PrepData string 124 TypeClassCheck string 125 Inv string 126 127 VV bool 128 LeftVec bool 129 } 130 131 func (fn *EngineCmp) methName() string { 132 switch { 133 case fn.VV: 134 if fn.Name == "Eq" || fn.Name == "Ne" { 135 return "El" + fn.Name 136 } 137 return fn.Name 138 default: 139 return fn.Name + "Scalar" 140 } 141 } 142 143 func (fn *EngineCmp) Signature() *Signature { 144 var paramNames []string 145 var paramTemplates []*template.Template 146 147 switch { 148 case fn.VV: 149 paramNames = []string{"a", "b", "opts"} 150 paramTemplates = []*template.Template{tensorType, tensorType, splatFuncOptType} 151 default: 152 paramNames = []string{"t", "s", "leftTensor", "opts"} 153 paramTemplates = []*template.Template{tensorType, interfaceType, boolType, splatFuncOptType} 154 } 155 return &Signature{ 156 Name: fn.methName(), 157 NameTemplate: plainName, 158 ParamNames: paramNames, 159 ParamTemplates: paramTemplates, 160 Err: false, 161 } 162 } 163 164 func (fn *EngineCmp) WriteBody(w io.Writer) { 165 var prep *template.Template 166 switch { 167 case fn.VV: 168 prep = prepVV 169 fn.VecVar = "a" 170 case !fn.VV && fn.LeftVec: 171 fn.VecVar = "t" 172 fn.PrepData = "prepDataVS" 173 prep = prepMixed 174 default: 175 fn.VecVar = "t" 176 fn.PrepData = "prepDataSV" 177 prep = prepMixed 178 } 179 template.Must(prep.New("prep").Parse(cmpPrepRaw)) 180 prep.Execute(w, fn) 181 agg2CmpBody.Execute(w, fn) 182 } 183 184 func (fn *EngineCmp) Write(w io.Writer) { 185 if tmpl, ok := cmpDocStrings[fn.methName()]; ok { 186 type tmp struct { 187 Left, Right string 188 } 189 var ds tmp 190 if fn.VV { 191 ds.Left = "a" 192 ds.Right = "b" 193 } else { 194 ds.Left = "t" 195 ds.Right = "s" 196 } 197 tmpl.Execute(w, ds) 198 } 199 sig := fn.Signature() 200 w.Write([]byte("func (e StdEng) ")) 201 sig.Write(w) 202 w.Write([]byte("(retVal Tensor, err error) {\n")) 203 fn.WriteBody(w) 204 w.Write([]byte("}\n\n")) 205 } 206 207 func generateStdEngCmp(f io.Writer, ak Kinds) { 208 var methods []*EngineCmp 209 210 for _, abo := range cmpBinOps { 211 var tc string 212 if abo.Name() == "Eq" || abo.Name() == "Ne" { 213 tc = "Eq" 214 } else { 215 tc = "Ord" 216 } 217 meth := &EngineCmp{ 218 Name: abo.Name(), 219 Inv: abo.Inv, 220 VV: true, 221 TypeClassCheck: tc, 222 } 223 methods = append(methods, meth) 224 } 225 226 // VV 227 for _, meth := range methods { 228 meth.Write(f) 229 meth.VV = false 230 } 231 232 // Scalar 233 for _, meth := range methods { 234 meth.Write(f) 235 meth.LeftVec = true 236 } 237 } 238 239 type EngineMinMax struct { 240 Name string 241 VecVar string 242 PrepData string 243 TypeClassCheck string 244 Kinds []reflect.Kind 245 246 VV bool 247 LeftVec bool 248 } 249 250 func (fn *EngineMinMax) methName() string { 251 switch { 252 case fn.VV: 253 return fn.Name 254 default: 255 return fn.Name + "Scalar" 256 } 257 } 258 259 func (fn *EngineMinMax) Signature() *Signature { 260 var paramNames []string 261 var paramTemplates []*template.Template 262 263 switch { 264 case fn.VV: 265 paramNames = []string{"a", "b", "opts"} 266 paramTemplates = []*template.Template{tensorType, tensorType, splatFuncOptType} 267 default: 268 paramNames = []string{"t", "s", "leftTensor", "opts"} 269 paramTemplates = []*template.Template{tensorType, interfaceType, boolType, splatFuncOptType} 270 } 271 return &Signature{ 272 Name: fn.methName(), 273 NameTemplate: plainName, 274 ParamNames: paramNames, 275 ParamTemplates: paramTemplates, 276 Err: false, 277 } 278 } 279 280 func (fn *EngineMinMax) WriteBody(w io.Writer) { 281 var prep *template.Template 282 switch { 283 case fn.VV: 284 prep = prepVV 285 fn.VecVar = "a" 286 case !fn.VV && fn.LeftVec: 287 fn.VecVar = "t" 288 fn.PrepData = "prepDataVS" 289 prep = prepMixed 290 default: 291 fn.VecVar = "t" 292 fn.PrepData = "prepDataSV" 293 prep = prepMixed 294 } 295 template.Must(prep.New("prep").Parse(minmaxPrepRaw)) 296 prep.Execute(w, fn) 297 agg2MinMaxBody.Execute(w, fn) 298 } 299 300 func (fn *EngineMinMax) Write(w io.Writer) { 301 if tmpl, ok := cmpDocStrings[fn.methName()]; ok { 302 type tmp struct { 303 Left, Right string 304 } 305 var ds tmp 306 if fn.VV { 307 ds.Left = "a" 308 ds.Right = "b" 309 } else { 310 ds.Left = "t" 311 ds.Right = "s" 312 } 313 tmpl.Execute(w, ds) 314 } 315 sig := fn.Signature() 316 w.Write([]byte("func (e StdEng) ")) 317 sig.Write(w) 318 w.Write([]byte("(retVal Tensor, err error) {\n")) 319 fn.WriteBody(w) 320 w.Write([]byte("}\n\n")) 321 } 322 323 func generateStdEngMinMax(f io.Writer, ak Kinds) { 324 methods := []*EngineMinMax{ 325 &EngineMinMax{ 326 Name: "MinBetween", 327 VV: true, 328 TypeClassCheck: "Ord", 329 }, 330 &EngineMinMax{ 331 Name: "MaxBetween", 332 VV: true, 333 TypeClassCheck: "Ord", 334 }, 335 } 336 f.Write([]byte(`var ( 337 _ MinBetweener = StdEng{} 338 _ MaxBetweener = StdEng{} 339 ) 340 `)) 341 // VV 342 for _, meth := range methods { 343 meth.Write(f) 344 meth.VV = false 345 } 346 347 // Scalar-Vector 348 for _, meth := range methods { 349 meth.Write(f) 350 meth.LeftVec = true 351 } 352 } 353 354 /* UNARY METHODS */ 355 356 type EngineUnary struct { 357 Name string 358 TypeClassCheck string 359 Kinds []reflect.Kind 360 } 361 362 func (fn *EngineUnary) Signature() *Signature { 363 return &Signature{ 364 Name: fn.Name, 365 NameTemplate: plainName, 366 ParamNames: []string{"a", "opts"}, 367 ParamTemplates: []*template.Template{tensorType, splatFuncOptType}, 368 RetVals: []string{"retVal"}, 369 RetValTemplates: []*template.Template{tensorType}, 370 371 Err: true, 372 } 373 } 374 375 func (fn *EngineUnary) WriteBody(w io.Writer) { 376 prepUnary.Execute(w, fn) 377 agg2UnaryBody.Execute(w, fn) 378 } 379 380 func (fn *EngineUnary) Write(w io.Writer) { 381 sig := fn.Signature() 382 w.Write([]byte("func (e StdEng) ")) 383 sig.Write(w) 384 w.Write([]byte("{\n")) 385 fn.WriteBody(w) 386 w.Write([]byte("\n}\n")) 387 } 388 389 func generateStdEngUncondUnary(f io.Writer, ak Kinds) { 390 tcc := []string{ 391 "Number", // Neg 392 "Number", // Inv 393 "Number", // Square 394 "Number", // Cube 395 "FloatCmplx", // Exp 396 "FloatCmplx", // Tanhh 397 "FloatCmplx", // Log 398 "Float", // Log2 399 "FloatCmplx", // Log10 400 "FloatCmplx", // Sqrt 401 "Float", // Cbrt 402 "Float", // InvSqrt 403 } 404 var gen []*EngineUnary 405 for i, u := range unconditionalUnaries { 406 var ks []reflect.Kind 407 for _, k := range ak.Kinds { 408 if tc := u.TypeClass(); tc != nil && !tc(k) { 409 continue 410 } 411 ks = append(ks, k) 412 } 413 fn := &EngineUnary{ 414 Name: u.Name(), 415 TypeClassCheck: tcc[i], 416 Kinds: ks, 417 } 418 gen = append(gen, fn) 419 } 420 421 for _, fn := range gen { 422 fn.Write(f) 423 } 424 } 425 426 func generateStdEngCondUnary(f io.Writer, ak Kinds) { 427 tcc := []string{ 428 "Signed", // Abs 429 "Signed", // Sign 430 } 431 var gen []*EngineUnary 432 for i, u := range conditionalUnaries { 433 var ks []reflect.Kind 434 for _, k := range ak.Kinds { 435 if tc := u.TypeClass(); tc != nil && !tc(k) { 436 continue 437 } 438 ks = append(ks, k) 439 } 440 fn := &EngineUnary{ 441 Name: u.Name(), 442 TypeClassCheck: tcc[i], 443 Kinds: ks, 444 } 445 gen = append(gen, fn) 446 } 447 448 for _, fn := range gen { 449 fn.Write(f) 450 } 451 }