github.com/wzzhu/tensor@v0.9.24/genlib2/declarations.go (about)

     1  package main
     2  
     3  import (
     4  	"reflect"
     5  	"strings"
     6  	"text/template"
     7  )
     8  
     9  var arithSymbolTemplates = [...]string{
    10  	"+",
    11  	"-",
    12  	"*",
    13  	"/",
    14  	"{{mathPkg .}}Pow",
    15  	"{{if isFloatCmplx .}}{{mathPkg .}}Mod{{else}}%{{end}}",
    16  }
    17  
    18  var cmpSymbolTemplates = [...]string{
    19  	">",
    20  	">=",
    21  	"<",
    22  	"<=",
    23  	"==",
    24  	"!=",
    25  }
    26  
    27  var nonFloatConditionalUnarySymbolTemplates = [...]string{
    28  	`{{if isFloat .Kind -}} 
    29  	{{.Range}}[{{.Index0}}] = {{mathPkg .Kind}}Abs({{.Range}}[{{.Index0}}]) {{else -}}
    30  	if {{.Range}}[{{.Index0}}] < 0 {
    31  		{{.Range}}[{{.Index0}}] = -{{.Range}}[{{.Index0}}]
    32  	}{{end -}}`, // abs
    33  
    34  	`if {{.Range}}[{{.Index0}}] < 0 {
    35  		{{.Range}}[{{.Index0}}] = -1
    36  	} else if {{.Range}}[{{.Index0}}] > 0 {
    37  		{{.Range}}[{{.Index0}}] = 1
    38  	}`, // sign
    39  }
    40  
    41  var unconditionalNumUnarySymbolTemplates = [...]string{
    42  	"-",                            // neg
    43  	"1/",                           // inv
    44  	"{{.Range}}[i]*",               // square
    45  	"{{.Range}}[i]*{{.Range}}[i]*", // cube
    46  }
    47  
    48  var unconditionalFloatUnarySymbolTemplates = [...]string{
    49  	"{{mathPkg .Kind}}Exp",
    50  	"{{mathPkg .Kind}}Tanh",
    51  	"{{mathPkg .Kind}}Log",
    52  	"{{mathPkg .Kind}}Log2",
    53  	"{{mathPkg .Kind}}Log10",
    54  	"{{mathPkg .Kind}}Sqrt",
    55  	"{{mathPkg .Kind}}Cbrt",
    56  	`{{asType .Kind}}(1)/{{mathPkg .Kind}}Sqrt`,
    57  }
    58  
    59  var funcOptUse = map[string]string{
    60  	"reuse":  ",WithReuse(reuse)",
    61  	"incr":   ",WithIncr(incr)",
    62  	"unsafe": ",UseUnsafe()",
    63  	"assame": ", AsSameType()",
    64  }
    65  
    66  var funcOptCheck = map[string]string{
    67  	"reuse": `if reuse != ret {
    68  			t.Errorf("Expected reuse to be the same as retVal")
    69  			return false
    70  	}
    71  
    72  	`,
    73  
    74  	"incr": "",
    75  
    76  	"unsafe": `if ret != a {
    77  		t.Errorf("Expected ret to be the same as a")
    78  		return false
    79  	}
    80  
    81  	`,
    82  }
    83  
    84  var funcOptDecl = map[string]string{
    85  	"reuse":  "reuse := New(Of(a.t), WithShape(a.Shape().Clone()...))\n",
    86  	"incr":   "incr := New(Of(a.t), WithShape(a.Shape().Clone()...))\n",
    87  	"unsafe": "",
    88  	"assame": `if err := typeclassCheck(q.Dtype(), {{.TypeClassName}}); err != nil {
    89  		return true // we exit early if the generated type is not something we can handle
    90  	}
    91  	`,
    92  }
    93  
    94  var funcOptCorrect = map[string]string{
    95  	"reuse": "",
    96  	"incr": `incr.Memset(identityVal(100, a.t))
    97  	correct.Add(incr, UseUnsafe())
    98  	`,
    99  	"unsafe": "",
   100  }
   101  
   102  var stdTypes = [...]string{
   103  	"Bool",
   104  	"Int",
   105  	"Int8",
   106  	"Int16",
   107  	"Int32",
   108  	"Int64",
   109  	"Uint",
   110  	"Uint8",
   111  	"Uint16",
   112  	"Uint32",
   113  	"Uint64",
   114  	"Float32",
   115  	"Float64",
   116  	"Complex64",
   117  	"Complex128",
   118  	"String",
   119  	"Uintptr",
   120  	"UnsafePointer",
   121  }
   122  
   123  var arrowBinaryTypes = []string{
   124  	"String",
   125  }
   126  
   127  var arrowFixedWidthTypes = []string{
   128  	"Boolean",
   129  }
   130  
   131  var arrowPrimitiveTypes = []string{
   132  	"Int8",
   133  	"Int16",
   134  	"Int32",
   135  	"Int64",
   136  	"Uint8",
   137  	"Uint16",
   138  	"Uint32",
   139  	"Uint64",
   140  	"Float32",
   141  	"Float64",
   142  }
   143  
   144  var parameterizedKinds = [...]reflect.Kind{
   145  	reflect.Array,
   146  	reflect.Chan,
   147  	reflect.Func,
   148  	reflect.Interface,
   149  	reflect.Map,
   150  	reflect.Ptr,
   151  	reflect.Slice,
   152  	reflect.Struct,
   153  }
   154  
   155  var number = [...]reflect.Kind{
   156  	reflect.Int,
   157  	reflect.Int8,
   158  	reflect.Int16,
   159  	reflect.Int32,
   160  	reflect.Int64,
   161  	reflect.Uint,
   162  	reflect.Uint8,
   163  	reflect.Uint16,
   164  	reflect.Uint32,
   165  	reflect.Uint64,
   166  	reflect.Float32,
   167  	reflect.Float64,
   168  	reflect.Complex64,
   169  	reflect.Complex128,
   170  }
   171  
   172  var rangeable = [...]reflect.Kind{
   173  	reflect.Int,
   174  	reflect.Int8,
   175  	reflect.Int16,
   176  	reflect.Int32,
   177  	reflect.Int64,
   178  	reflect.Uint,
   179  	reflect.Uint8,
   180  	reflect.Uint16,
   181  	reflect.Uint32,
   182  	reflect.Uint64,
   183  	reflect.Float32,
   184  	reflect.Float64,
   185  	reflect.Complex64,
   186  	reflect.Complex128,
   187  }
   188  
   189  var specialized = [...]reflect.Kind{
   190  	reflect.Bool,
   191  	reflect.Int,
   192  	reflect.Int8,
   193  	reflect.Int16,
   194  	reflect.Int32,
   195  	reflect.Int64,
   196  	reflect.Uint,
   197  	reflect.Uint8,
   198  	reflect.Uint16,
   199  	reflect.Uint32,
   200  	reflect.Uint64,
   201  	reflect.Float32,
   202  	reflect.Float64,
   203  	reflect.Complex64,
   204  	reflect.Complex128,
   205  	reflect.String,
   206  }
   207  
   208  var signedNumber = [...]reflect.Kind{
   209  	reflect.Int,
   210  	reflect.Int8,
   211  	reflect.Int16,
   212  	reflect.Int32,
   213  	reflect.Int64,
   214  	reflect.Float32,
   215  	reflect.Float64,
   216  	reflect.Complex64,
   217  	reflect.Complex128,
   218  }
   219  
   220  var nonComplexNumber = [...]reflect.Kind{
   221  	reflect.Int,
   222  	reflect.Int8,
   223  	reflect.Int16,
   224  	reflect.Int32,
   225  	reflect.Int64,
   226  	reflect.Uint,
   227  	reflect.Uint8,
   228  	reflect.Uint16,
   229  	reflect.Uint32,
   230  	reflect.Uint64,
   231  	reflect.Float32,
   232  	reflect.Float64,
   233  }
   234  
   235  var elEq = [...]reflect.Kind{
   236  	reflect.Bool,
   237  	reflect.Int,
   238  	reflect.Int8,
   239  	reflect.Int16,
   240  	reflect.Int32,
   241  	reflect.Int64,
   242  	reflect.Uint,
   243  	reflect.Uint8,
   244  	reflect.Uint16,
   245  	reflect.Uint32,
   246  	reflect.Uint64,
   247  	reflect.Uintptr,
   248  	reflect.Float32,
   249  	reflect.Float64,
   250  	reflect.Complex64,
   251  	reflect.Complex128,
   252  	reflect.String,
   253  	reflect.UnsafePointer,
   254  }
   255  
   256  var elOrd = [...]reflect.Kind{
   257  	reflect.Int,
   258  	reflect.Int8,
   259  	reflect.Int16,
   260  	reflect.Int32,
   261  	reflect.Int64,
   262  	reflect.Uint,
   263  	reflect.Uint8,
   264  	reflect.Uint16,
   265  	reflect.Uint32,
   266  	reflect.Uint64,
   267  	// reflect.Uintptr, // comparison of pointers is not that great an idea - it can technically be done but should not be encouraged
   268  	reflect.Float32,
   269  	reflect.Float64,
   270  	// reflect.Complex64,
   271  	// reflect.Complex128,
   272  	reflect.String, // strings are orderable and the assumption is lexicographic sorting
   273  }
   274  
   275  var boolRepr = [...]reflect.Kind{
   276  	reflect.Bool,
   277  	reflect.Int,
   278  	reflect.Int8,
   279  	reflect.Int16,
   280  	reflect.Int32,
   281  	reflect.Int64,
   282  	reflect.Uint,
   283  	reflect.Uint8,
   284  	reflect.Uint16,
   285  	reflect.Uint32,
   286  	reflect.Uint64,
   287  	reflect.Uintptr,
   288  	reflect.Float32,
   289  	reflect.Float64,
   290  	reflect.Complex64,
   291  	reflect.Complex128,
   292  	reflect.String,
   293  }
   294  
   295  var div0panics = [...]reflect.Kind{
   296  	reflect.Int,
   297  	reflect.Int8,
   298  	reflect.Int16,
   299  	reflect.Int32,
   300  	reflect.Int64,
   301  	reflect.Uint,
   302  	reflect.Uint8,
   303  	reflect.Uint16,
   304  	reflect.Uint32,
   305  	reflect.Uint64,
   306  }
   307  
   308  var funcs = template.FuncMap{
   309  	"lower":              strings.ToLower,
   310  	"title":              strings.Title,
   311  	"unexport":           unexport,
   312  	"hasPrefix":          strings.HasPrefix,
   313  	"hasSuffix":          strings.HasSuffix,
   314  	"isParameterized":    isParameterized,
   315  	"isRangeable":        isRangeable,
   316  	"isSpecialized":      isSpecialized,
   317  	"isNumber":           isNumber,
   318  	"isSignedNumber":     isSignedNumber,
   319  	"isNonComplexNumber": isNonComplexNumber,
   320  	"isAddable":          isAddable,
   321  	"isFloat":            isFloat,
   322  	"isFloatCmplx":       isFloatCmplx,
   323  	"isEq":               isEq,
   324  	"isOrd":              isOrd,
   325  	"isBoolRepr":         isBoolRepr,
   326  	"panicsDiv0":         panicsDiv0,
   327  
   328  	"short": short,
   329  	"clean": clean,
   330  	"strip": strip,
   331  
   332  	"reflectKind": reflectKind,
   333  	"asType":      asType,
   334  	"sliceOf":     sliceOf,
   335  	"getOne":      getOne,
   336  	"setOne":      setOne,
   337  	"trueValue":   trueValue,
   338  	"falseValue":  falseValue,
   339  
   340  	"mathPkg":       mathPkg,
   341  	"vecPkg":        vecPkg,
   342  	"bitSizeOf":     bitSizeOf,
   343  	"getalias":      getalias,
   344  	"interfaceName": interfaceName,
   345  
   346  	"isntFloat": isntFloat,
   347  }
   348  
   349  var shortNames = map[reflect.Kind]string{
   350  	reflect.Invalid:       "Invalid",
   351  	reflect.Bool:          "B",
   352  	reflect.Int:           "I",
   353  	reflect.Int8:          "I8",
   354  	reflect.Int16:         "I16",
   355  	reflect.Int32:         "I32",
   356  	reflect.Int64:         "I64",
   357  	reflect.Uint:          "U",
   358  	reflect.Uint8:         "U8",
   359  	reflect.Uint16:        "U16",
   360  	reflect.Uint32:        "U32",
   361  	reflect.Uint64:        "U64",
   362  	reflect.Uintptr:       "Uintptr",
   363  	reflect.Float32:       "F32",
   364  	reflect.Float64:       "F64",
   365  	reflect.Complex64:     "C64",
   366  	reflect.Complex128:    "C128",
   367  	reflect.Array:         "Array",
   368  	reflect.Chan:          "Chan",
   369  	reflect.Func:          "Func",
   370  	reflect.Interface:     "Interface",
   371  	reflect.Map:           "Map",
   372  	reflect.Ptr:           "Ptr",
   373  	reflect.Slice:         "Slice",
   374  	reflect.String:        "Str",
   375  	reflect.Struct:        "Struct",
   376  	reflect.UnsafePointer: "UnsafePointer",
   377  }
   378  
   379  var nameMaps = map[string]string{
   380  	"VecAdd": "Add",
   381  	"VecSub": "Sub",
   382  	"VecMul": "Mul",
   383  	"VecDiv": "Div",
   384  	"VecPow": "Pow",
   385  	"VecMod": "Mod",
   386  
   387  	"AddVS": "Trans",
   388  	"AddSV": "TransR",
   389  	"SubVS": "TransInv",
   390  	"SubSV": "TransInvR",
   391  	"MulVS": "Scale",
   392  	"MulSV": "ScaleR",
   393  	"DivVS": "ScaleInv",
   394  	"DivSV": "ScaleInvR",
   395  	"PowVS": "PowOf",
   396  	"PowSV": "PowOfR",
   397  
   398  	"AddIncr": "IncrAdd",
   399  	"SubIncr": "IncrSub",
   400  	"MulIncr": "IncrMul",
   401  	"DivIncr": "IncrDiv",
   402  	"PowIncr": "IncrPow",
   403  	"ModIncr": "IncrMod",
   404  }
   405  
   406  var arithBinOps []arithOp
   407  var cmpBinOps []cmpOp
   408  var typedAriths []TypedBinOp
   409  var typedCmps []TypedBinOp
   410  
   411  var conditionalUnaries []unaryOp
   412  var unconditionalUnaries []unaryOp
   413  var specialUnaries []UnaryOp
   414  var typedCondUnaries []TypedUnaryOp
   415  var typedUncondUnaries []TypedUnaryOp
   416  var typedSpecialUnaries []TypedUnaryOp
   417  
   418  var allKinds []reflect.Kind
   419  
   420  func init() {
   421  	// kinds
   422  
   423  	for k := reflect.Invalid + 1; k < reflect.UnsafePointer+1; k++ {
   424  		allKinds = append(allKinds, k)
   425  	}
   426  
   427  	// ops
   428  
   429  	arithBinOps = []arithOp{
   430  		{basicBinOp{"", "Add", false, isAddable}, "numberTypes", true, 0, false, "", true, false},
   431  		{basicBinOp{"", "Sub", false, isNumber}, "numberTypes", false, 0, true, "Add", false, true},
   432  		{basicBinOp{"", "Mul", false, isNumber}, "numberTypes", true, 1, false, "", true, false},
   433  		{basicBinOp{"", "Div", false, isNumber}, "numberTypes", false, 1, true, "Mul", false, false},
   434  		{basicBinOp{"", "Pow", true, isFloatCmplx}, "floatcmplxTypes", true, 1, false, "", false, false},
   435  		{basicBinOp{"", "Mod", false, isNonComplexNumber}, "nonComplexNumberTypes", false, 0, false, "", false, false},
   436  	}
   437  	for i := range arithBinOps {
   438  		arithBinOps[i].symbol = arithSymbolTemplates[i]
   439  	}
   440  
   441  	cmpBinOps = []cmpOp{
   442  		{basicBinOp{"", "Gt", false, isOrd}, "ordTypes", "Lt", true, false},
   443  		{basicBinOp{"", "Gte", false, isOrd}, "ordTypes", "Lte", true, false},
   444  		{basicBinOp{"", "Lt", false, isOrd}, "ordTypes", "Gt", true, false},
   445  		{basicBinOp{"", "Lte", false, isOrd}, "ordTypes", "Gte", true, false},
   446  		{basicBinOp{"", "Eq", false, isEq}, "eqTypes", "Eq", true, true},
   447  		{basicBinOp{"", "Ne", false, isEq}, "eqTypes", "Ne", false, true},
   448  	}
   449  	for i := range cmpBinOps {
   450  		cmpBinOps[i].symbol = cmpSymbolTemplates[i]
   451  	}
   452  
   453  	conditionalUnaries = []unaryOp{
   454  		{"", "Abs", false, isSignedNumber, "signedTypes", ""},
   455  		{"", "Sign", false, isSignedNumber, "signedTypes", ""},
   456  	}
   457  	for i := range conditionalUnaries {
   458  		conditionalUnaries[i].symbol = nonFloatConditionalUnarySymbolTemplates[i]
   459  	}
   460  
   461  	unconditionalUnaries = []unaryOp{
   462  		{"", "Neg", false, isNumber, "numberTypes", "Neg"},
   463  		{"", "Inv", false, isNumber, "numberTypes", ""},
   464  		{"", "Square", false, isNumber, "numberTypes", "Sqrt"},
   465  		{"", "Cube", false, isNumber, "numberTypes", "Cbrt"},
   466  
   467  		{"", "Exp", true, isFloatCmplx, "floatcmplxTypes", "Log"},
   468  		{"", "Tanh", true, isFloatCmplx, "floatcmplxTypes", ""},
   469  		{"", "Log", true, isFloatCmplx, "floatcmplxTypes", "Exp"},
   470  		{"", "Log2", true, isFloat, "floatTypes", ""},
   471  		{"", "Log10", true, isFloatCmplx, "floatcmplxTypes", ""},
   472  		{"", "Sqrt", true, isFloatCmplx, "floatcmplxTypes", "Square"},
   473  		{"", "Cbrt", true, isFloat, "floatTypes", "Cube"},
   474  		{"", "InvSqrt", true, isFloat, "floatTypes", ""}, // TODO: cmplx requires to much finagling to the template. Come back to it later
   475  	}
   476  	nonF := len(unconditionalNumUnarySymbolTemplates)
   477  	for i := range unconditionalNumUnarySymbolTemplates {
   478  		unconditionalUnaries[i].symbol = unconditionalNumUnarySymbolTemplates[i]
   479  	}
   480  	for i := range unconditionalFloatUnarySymbolTemplates {
   481  		unconditionalUnaries[i+nonF].symbol = unconditionalFloatUnarySymbolTemplates[i]
   482  	}
   483  
   484  	specialUnaries = []UnaryOp{
   485  		specialUnaryOp{unaryOp{clampBody, "Clamp", false, isNonComplexNumber, "nonComplexNumberTypes", ""}, []string{"min", "max"}},
   486  	}
   487  
   488  	// typed operations
   489  
   490  	for _, bo := range arithBinOps {
   491  		for _, k := range allKinds {
   492  			tb := TypedBinOp{
   493  				BinOp: bo,
   494  				k:     k,
   495  			}
   496  			typedAriths = append(typedAriths, tb)
   497  		}
   498  	}
   499  
   500  	for _, bo := range cmpBinOps {
   501  		for _, k := range allKinds {
   502  			tb := TypedBinOp{
   503  				BinOp: bo,
   504  				k:     k,
   505  			}
   506  			typedCmps = append(typedCmps, tb)
   507  		}
   508  	}
   509  
   510  	for _, uo := range conditionalUnaries {
   511  		for _, k := range allKinds {
   512  			tu := TypedUnaryOp{
   513  				UnaryOp: uo,
   514  				k:       k,
   515  			}
   516  			typedCondUnaries = append(typedCondUnaries, tu)
   517  		}
   518  	}
   519  
   520  	for _, uo := range unconditionalUnaries {
   521  		for _, k := range allKinds {
   522  			tu := TypedUnaryOp{
   523  				UnaryOp: uo,
   524  				k:       k,
   525  			}
   526  			typedUncondUnaries = append(typedUncondUnaries, tu)
   527  		}
   528  	}
   529  
   530  	for _, uo := range specialUnaries {
   531  		for _, k := range allKinds {
   532  			tu := TypedUnaryOp{
   533  				UnaryOp: uo,
   534  				k:       k,
   535  			}
   536  			typedSpecialUnaries = append(typedSpecialUnaries, tu)
   537  		}
   538  	}
   539  }