github.com/wzzhu/tensor@v0.9.24/genlib2/declarations.go (about) 1 package main 2 3 import ( 4 "reflect" 5 "strings" 6 "text/template" 7 ) 8 9 var arithSymbolTemplates = [...]string{ 10 "+", 11 "-", 12 "*", 13 "/", 14 "{{mathPkg .}}Pow", 15 "{{if isFloatCmplx .}}{{mathPkg .}}Mod{{else}}%{{end}}", 16 } 17 18 var cmpSymbolTemplates = [...]string{ 19 ">", 20 ">=", 21 "<", 22 "<=", 23 "==", 24 "!=", 25 } 26 27 var nonFloatConditionalUnarySymbolTemplates = [...]string{ 28 `{{if isFloat .Kind -}} 29 {{.Range}}[{{.Index0}}] = {{mathPkg .Kind}}Abs({{.Range}}[{{.Index0}}]) {{else -}} 30 if {{.Range}}[{{.Index0}}] < 0 { 31 {{.Range}}[{{.Index0}}] = -{{.Range}}[{{.Index0}}] 32 }{{end -}}`, // abs 33 34 `if {{.Range}}[{{.Index0}}] < 0 { 35 {{.Range}}[{{.Index0}}] = -1 36 } else if {{.Range}}[{{.Index0}}] > 0 { 37 {{.Range}}[{{.Index0}}] = 1 38 }`, // sign 39 } 40 41 var unconditionalNumUnarySymbolTemplates = [...]string{ 42 "-", // neg 43 "1/", // inv 44 "{{.Range}}[i]*", // square 45 "{{.Range}}[i]*{{.Range}}[i]*", // cube 46 } 47 48 var unconditionalFloatUnarySymbolTemplates = [...]string{ 49 "{{mathPkg .Kind}}Exp", 50 "{{mathPkg .Kind}}Tanh", 51 "{{mathPkg .Kind}}Log", 52 "{{mathPkg .Kind}}Log2", 53 "{{mathPkg .Kind}}Log10", 54 "{{mathPkg .Kind}}Sqrt", 55 "{{mathPkg .Kind}}Cbrt", 56 `{{asType .Kind}}(1)/{{mathPkg .Kind}}Sqrt`, 57 } 58 59 var funcOptUse = map[string]string{ 60 "reuse": ",WithReuse(reuse)", 61 "incr": ",WithIncr(incr)", 62 "unsafe": ",UseUnsafe()", 63 "assame": ", AsSameType()", 64 } 65 66 var funcOptCheck = map[string]string{ 67 "reuse": `if reuse != ret { 68 t.Errorf("Expected reuse to be the same as retVal") 69 return false 70 } 71 72 `, 73 74 "incr": "", 75 76 "unsafe": `if ret != a { 77 t.Errorf("Expected ret to be the same as a") 78 return false 79 } 80 81 `, 82 } 83 84 var funcOptDecl = map[string]string{ 85 "reuse": "reuse := New(Of(a.t), WithShape(a.Shape().Clone()...))\n", 86 "incr": "incr := New(Of(a.t), WithShape(a.Shape().Clone()...))\n", 87 "unsafe": "", 88 "assame": `if err := typeclassCheck(q.Dtype(), {{.TypeClassName}}); err != nil { 89 return true // we exit early if the generated type is not something we can handle 90 } 91 `, 92 } 93 94 var funcOptCorrect = map[string]string{ 95 "reuse": "", 96 "incr": `incr.Memset(identityVal(100, a.t)) 97 correct.Add(incr, UseUnsafe()) 98 `, 99 "unsafe": "", 100 } 101 102 var stdTypes = [...]string{ 103 "Bool", 104 "Int", 105 "Int8", 106 "Int16", 107 "Int32", 108 "Int64", 109 "Uint", 110 "Uint8", 111 "Uint16", 112 "Uint32", 113 "Uint64", 114 "Float32", 115 "Float64", 116 "Complex64", 117 "Complex128", 118 "String", 119 "Uintptr", 120 "UnsafePointer", 121 } 122 123 var arrowBinaryTypes = []string{ 124 "String", 125 } 126 127 var arrowFixedWidthTypes = []string{ 128 "Boolean", 129 } 130 131 var arrowPrimitiveTypes = []string{ 132 "Int8", 133 "Int16", 134 "Int32", 135 "Int64", 136 "Uint8", 137 "Uint16", 138 "Uint32", 139 "Uint64", 140 "Float32", 141 "Float64", 142 } 143 144 var parameterizedKinds = [...]reflect.Kind{ 145 reflect.Array, 146 reflect.Chan, 147 reflect.Func, 148 reflect.Interface, 149 reflect.Map, 150 reflect.Ptr, 151 reflect.Slice, 152 reflect.Struct, 153 } 154 155 var number = [...]reflect.Kind{ 156 reflect.Int, 157 reflect.Int8, 158 reflect.Int16, 159 reflect.Int32, 160 reflect.Int64, 161 reflect.Uint, 162 reflect.Uint8, 163 reflect.Uint16, 164 reflect.Uint32, 165 reflect.Uint64, 166 reflect.Float32, 167 reflect.Float64, 168 reflect.Complex64, 169 reflect.Complex128, 170 } 171 172 var rangeable = [...]reflect.Kind{ 173 reflect.Int, 174 reflect.Int8, 175 reflect.Int16, 176 reflect.Int32, 177 reflect.Int64, 178 reflect.Uint, 179 reflect.Uint8, 180 reflect.Uint16, 181 reflect.Uint32, 182 reflect.Uint64, 183 reflect.Float32, 184 reflect.Float64, 185 reflect.Complex64, 186 reflect.Complex128, 187 } 188 189 var specialized = [...]reflect.Kind{ 190 reflect.Bool, 191 reflect.Int, 192 reflect.Int8, 193 reflect.Int16, 194 reflect.Int32, 195 reflect.Int64, 196 reflect.Uint, 197 reflect.Uint8, 198 reflect.Uint16, 199 reflect.Uint32, 200 reflect.Uint64, 201 reflect.Float32, 202 reflect.Float64, 203 reflect.Complex64, 204 reflect.Complex128, 205 reflect.String, 206 } 207 208 var signedNumber = [...]reflect.Kind{ 209 reflect.Int, 210 reflect.Int8, 211 reflect.Int16, 212 reflect.Int32, 213 reflect.Int64, 214 reflect.Float32, 215 reflect.Float64, 216 reflect.Complex64, 217 reflect.Complex128, 218 } 219 220 var nonComplexNumber = [...]reflect.Kind{ 221 reflect.Int, 222 reflect.Int8, 223 reflect.Int16, 224 reflect.Int32, 225 reflect.Int64, 226 reflect.Uint, 227 reflect.Uint8, 228 reflect.Uint16, 229 reflect.Uint32, 230 reflect.Uint64, 231 reflect.Float32, 232 reflect.Float64, 233 } 234 235 var elEq = [...]reflect.Kind{ 236 reflect.Bool, 237 reflect.Int, 238 reflect.Int8, 239 reflect.Int16, 240 reflect.Int32, 241 reflect.Int64, 242 reflect.Uint, 243 reflect.Uint8, 244 reflect.Uint16, 245 reflect.Uint32, 246 reflect.Uint64, 247 reflect.Uintptr, 248 reflect.Float32, 249 reflect.Float64, 250 reflect.Complex64, 251 reflect.Complex128, 252 reflect.String, 253 reflect.UnsafePointer, 254 } 255 256 var elOrd = [...]reflect.Kind{ 257 reflect.Int, 258 reflect.Int8, 259 reflect.Int16, 260 reflect.Int32, 261 reflect.Int64, 262 reflect.Uint, 263 reflect.Uint8, 264 reflect.Uint16, 265 reflect.Uint32, 266 reflect.Uint64, 267 // reflect.Uintptr, // comparison of pointers is not that great an idea - it can technically be done but should not be encouraged 268 reflect.Float32, 269 reflect.Float64, 270 // reflect.Complex64, 271 // reflect.Complex128, 272 reflect.String, // strings are orderable and the assumption is lexicographic sorting 273 } 274 275 var boolRepr = [...]reflect.Kind{ 276 reflect.Bool, 277 reflect.Int, 278 reflect.Int8, 279 reflect.Int16, 280 reflect.Int32, 281 reflect.Int64, 282 reflect.Uint, 283 reflect.Uint8, 284 reflect.Uint16, 285 reflect.Uint32, 286 reflect.Uint64, 287 reflect.Uintptr, 288 reflect.Float32, 289 reflect.Float64, 290 reflect.Complex64, 291 reflect.Complex128, 292 reflect.String, 293 } 294 295 var div0panics = [...]reflect.Kind{ 296 reflect.Int, 297 reflect.Int8, 298 reflect.Int16, 299 reflect.Int32, 300 reflect.Int64, 301 reflect.Uint, 302 reflect.Uint8, 303 reflect.Uint16, 304 reflect.Uint32, 305 reflect.Uint64, 306 } 307 308 var funcs = template.FuncMap{ 309 "lower": strings.ToLower, 310 "title": strings.Title, 311 "unexport": unexport, 312 "hasPrefix": strings.HasPrefix, 313 "hasSuffix": strings.HasSuffix, 314 "isParameterized": isParameterized, 315 "isRangeable": isRangeable, 316 "isSpecialized": isSpecialized, 317 "isNumber": isNumber, 318 "isSignedNumber": isSignedNumber, 319 "isNonComplexNumber": isNonComplexNumber, 320 "isAddable": isAddable, 321 "isFloat": isFloat, 322 "isFloatCmplx": isFloatCmplx, 323 "isEq": isEq, 324 "isOrd": isOrd, 325 "isBoolRepr": isBoolRepr, 326 "panicsDiv0": panicsDiv0, 327 328 "short": short, 329 "clean": clean, 330 "strip": strip, 331 332 "reflectKind": reflectKind, 333 "asType": asType, 334 "sliceOf": sliceOf, 335 "getOne": getOne, 336 "setOne": setOne, 337 "trueValue": trueValue, 338 "falseValue": falseValue, 339 340 "mathPkg": mathPkg, 341 "vecPkg": vecPkg, 342 "bitSizeOf": bitSizeOf, 343 "getalias": getalias, 344 "interfaceName": interfaceName, 345 346 "isntFloat": isntFloat, 347 } 348 349 var shortNames = map[reflect.Kind]string{ 350 reflect.Invalid: "Invalid", 351 reflect.Bool: "B", 352 reflect.Int: "I", 353 reflect.Int8: "I8", 354 reflect.Int16: "I16", 355 reflect.Int32: "I32", 356 reflect.Int64: "I64", 357 reflect.Uint: "U", 358 reflect.Uint8: "U8", 359 reflect.Uint16: "U16", 360 reflect.Uint32: "U32", 361 reflect.Uint64: "U64", 362 reflect.Uintptr: "Uintptr", 363 reflect.Float32: "F32", 364 reflect.Float64: "F64", 365 reflect.Complex64: "C64", 366 reflect.Complex128: "C128", 367 reflect.Array: "Array", 368 reflect.Chan: "Chan", 369 reflect.Func: "Func", 370 reflect.Interface: "Interface", 371 reflect.Map: "Map", 372 reflect.Ptr: "Ptr", 373 reflect.Slice: "Slice", 374 reflect.String: "Str", 375 reflect.Struct: "Struct", 376 reflect.UnsafePointer: "UnsafePointer", 377 } 378 379 var nameMaps = map[string]string{ 380 "VecAdd": "Add", 381 "VecSub": "Sub", 382 "VecMul": "Mul", 383 "VecDiv": "Div", 384 "VecPow": "Pow", 385 "VecMod": "Mod", 386 387 "AddVS": "Trans", 388 "AddSV": "TransR", 389 "SubVS": "TransInv", 390 "SubSV": "TransInvR", 391 "MulVS": "Scale", 392 "MulSV": "ScaleR", 393 "DivVS": "ScaleInv", 394 "DivSV": "ScaleInvR", 395 "PowVS": "PowOf", 396 "PowSV": "PowOfR", 397 398 "AddIncr": "IncrAdd", 399 "SubIncr": "IncrSub", 400 "MulIncr": "IncrMul", 401 "DivIncr": "IncrDiv", 402 "PowIncr": "IncrPow", 403 "ModIncr": "IncrMod", 404 } 405 406 var arithBinOps []arithOp 407 var cmpBinOps []cmpOp 408 var typedAriths []TypedBinOp 409 var typedCmps []TypedBinOp 410 411 var conditionalUnaries []unaryOp 412 var unconditionalUnaries []unaryOp 413 var specialUnaries []UnaryOp 414 var typedCondUnaries []TypedUnaryOp 415 var typedUncondUnaries []TypedUnaryOp 416 var typedSpecialUnaries []TypedUnaryOp 417 418 var allKinds []reflect.Kind 419 420 func init() { 421 // kinds 422 423 for k := reflect.Invalid + 1; k < reflect.UnsafePointer+1; k++ { 424 allKinds = append(allKinds, k) 425 } 426 427 // ops 428 429 arithBinOps = []arithOp{ 430 {basicBinOp{"", "Add", false, isAddable}, "numberTypes", true, 0, false, "", true, false}, 431 {basicBinOp{"", "Sub", false, isNumber}, "numberTypes", false, 0, true, "Add", false, true}, 432 {basicBinOp{"", "Mul", false, isNumber}, "numberTypes", true, 1, false, "", true, false}, 433 {basicBinOp{"", "Div", false, isNumber}, "numberTypes", false, 1, true, "Mul", false, false}, 434 {basicBinOp{"", "Pow", true, isFloatCmplx}, "floatcmplxTypes", true, 1, false, "", false, false}, 435 {basicBinOp{"", "Mod", false, isNonComplexNumber}, "nonComplexNumberTypes", false, 0, false, "", false, false}, 436 } 437 for i := range arithBinOps { 438 arithBinOps[i].symbol = arithSymbolTemplates[i] 439 } 440 441 cmpBinOps = []cmpOp{ 442 {basicBinOp{"", "Gt", false, isOrd}, "ordTypes", "Lt", true, false}, 443 {basicBinOp{"", "Gte", false, isOrd}, "ordTypes", "Lte", true, false}, 444 {basicBinOp{"", "Lt", false, isOrd}, "ordTypes", "Gt", true, false}, 445 {basicBinOp{"", "Lte", false, isOrd}, "ordTypes", "Gte", true, false}, 446 {basicBinOp{"", "Eq", false, isEq}, "eqTypes", "Eq", true, true}, 447 {basicBinOp{"", "Ne", false, isEq}, "eqTypes", "Ne", false, true}, 448 } 449 for i := range cmpBinOps { 450 cmpBinOps[i].symbol = cmpSymbolTemplates[i] 451 } 452 453 conditionalUnaries = []unaryOp{ 454 {"", "Abs", false, isSignedNumber, "signedTypes", ""}, 455 {"", "Sign", false, isSignedNumber, "signedTypes", ""}, 456 } 457 for i := range conditionalUnaries { 458 conditionalUnaries[i].symbol = nonFloatConditionalUnarySymbolTemplates[i] 459 } 460 461 unconditionalUnaries = []unaryOp{ 462 {"", "Neg", false, isNumber, "numberTypes", "Neg"}, 463 {"", "Inv", false, isNumber, "numberTypes", ""}, 464 {"", "Square", false, isNumber, "numberTypes", "Sqrt"}, 465 {"", "Cube", false, isNumber, "numberTypes", "Cbrt"}, 466 467 {"", "Exp", true, isFloatCmplx, "floatcmplxTypes", "Log"}, 468 {"", "Tanh", true, isFloatCmplx, "floatcmplxTypes", ""}, 469 {"", "Log", true, isFloatCmplx, "floatcmplxTypes", "Exp"}, 470 {"", "Log2", true, isFloat, "floatTypes", ""}, 471 {"", "Log10", true, isFloatCmplx, "floatcmplxTypes", ""}, 472 {"", "Sqrt", true, isFloatCmplx, "floatcmplxTypes", "Square"}, 473 {"", "Cbrt", true, isFloat, "floatTypes", "Cube"}, 474 {"", "InvSqrt", true, isFloat, "floatTypes", ""}, // TODO: cmplx requires to much finagling to the template. Come back to it later 475 } 476 nonF := len(unconditionalNumUnarySymbolTemplates) 477 for i := range unconditionalNumUnarySymbolTemplates { 478 unconditionalUnaries[i].symbol = unconditionalNumUnarySymbolTemplates[i] 479 } 480 for i := range unconditionalFloatUnarySymbolTemplates { 481 unconditionalUnaries[i+nonF].symbol = unconditionalFloatUnarySymbolTemplates[i] 482 } 483 484 specialUnaries = []UnaryOp{ 485 specialUnaryOp{unaryOp{clampBody, "Clamp", false, isNonComplexNumber, "nonComplexNumberTypes", ""}, []string{"min", "max"}}, 486 } 487 488 // typed operations 489 490 for _, bo := range arithBinOps { 491 for _, k := range allKinds { 492 tb := TypedBinOp{ 493 BinOp: bo, 494 k: k, 495 } 496 typedAriths = append(typedAriths, tb) 497 } 498 } 499 500 for _, bo := range cmpBinOps { 501 for _, k := range allKinds { 502 tb := TypedBinOp{ 503 BinOp: bo, 504 k: k, 505 } 506 typedCmps = append(typedCmps, tb) 507 } 508 } 509 510 for _, uo := range conditionalUnaries { 511 for _, k := range allKinds { 512 tu := TypedUnaryOp{ 513 UnaryOp: uo, 514 k: k, 515 } 516 typedCondUnaries = append(typedCondUnaries, tu) 517 } 518 } 519 520 for _, uo := range unconditionalUnaries { 521 for _, k := range allKinds { 522 tu := TypedUnaryOp{ 523 UnaryOp: uo, 524 k: k, 525 } 526 typedUncondUnaries = append(typedUncondUnaries, tu) 527 } 528 } 529 530 for _, uo := range specialUnaries { 531 for _, k := range allKinds { 532 tu := TypedUnaryOp{ 533 UnaryOp: uo, 534 k: k, 535 } 536 typedSpecialUnaries = append(typedSpecialUnaries, tu) 537 } 538 } 539 }