github.com/wzzhu/tensor@v0.9.24/genlib2/signature.go (about) 1 package main 2 3 import ( 4 "io" 5 "reflect" 6 "text/template" 7 ) 8 9 type Signature struct { 10 Name string 11 NameTemplate *template.Template 12 ParamNames []string 13 ParamTemplates []*template.Template 14 RetVals []string 15 RetValTemplates []*template.Template 16 17 Kind reflect.Kind 18 Err bool 19 } 20 21 func (s *Signature) Write(w io.Writer) { 22 s.NameTemplate.Execute(w, s) 23 w.Write([]byte("(")) 24 for i, p := range s.ParamTemplates { 25 w.Write([]byte(s.ParamNames[i])) 26 w.Write([]byte(" ")) 27 p.Execute(w, s.Kind) 28 29 if i < len(s.ParamNames) { 30 w.Write([]byte(", ")) 31 } 32 } 33 w.Write([]byte(")")) 34 if len(s.RetVals) > 0 { 35 w.Write([]byte("(")) 36 for i, r := range s.RetValTemplates { 37 w.Write([]byte(s.RetVals[i])) 38 w.Write([]byte(" ")) 39 r.Execute(w, s.Kind) 40 41 if i < len(s.RetVals) { 42 w.Write([]byte(", ")) 43 } 44 } 45 46 if s.Err { 47 w.Write([]byte("err error")) 48 } 49 w.Write([]byte(")")) 50 return 51 } 52 53 if s.Err { 54 w.Write([]byte("(err error)")) 55 } 56 } 57 58 const ( 59 golinkPragmaRaw = "//go:linkname {{.Name}}{{short .Kind}} github.com/chewxy/{{vecPkg .Kind}}{{getalias .Name}}\n" 60 61 typeAnnotatedNameRaw = `{{.Name}}{{short .Kind}}` 62 plainNameRaw = `{{.Name}}` 63 ) 64 65 const ( 66 scalarTypeRaw = `{{asType .}}` 67 sliceTypeRaw = `[]{{asType .}}` 68 iteratorTypeRaw = `Iterator` 69 interfaceTypeRaw = "interface{}" 70 boolsTypeRaw = `[]bool` 71 boolTypeRaw = `bool` 72 intTypeRaw = `int` 73 intsTypeRaw = `[]int` 74 reflectTypeRaw = `reflect.Type` 75 76 // arrayTypeRaw = `Array` 77 arrayTypeRaw = `*storage.Header` 78 unaryFuncTypeRaw = `func({{asType .}}){{asType .}} ` 79 unaryFuncErrTypeRaw = `func({{asType .}}) ({{asType .}}, error)` 80 reductionFuncTypeRaw = `func(a, b {{asType .}}) {{asType .}}` 81 reductionFuncTypeErrRaw = `func(a, b {{asType .}}) ({{asType .}}, error)` 82 tensorTypeRaw = `Tensor` 83 splatFuncOptTypeRaw = `...FuncOpt` 84 denseTypeRaw = `*Dense` 85 86 testingTypeRaw = `*testing.T` 87 ) 88 89 var ( 90 golinkPragma *template.Template 91 typeAnnotatedName *template.Template 92 plainName *template.Template 93 94 scalarType *template.Template 95 sliceType *template.Template 96 iteratorType *template.Template 97 interfaceType *template.Template 98 boolsType *template.Template 99 boolType *template.Template 100 intType *template.Template 101 intsType *template.Template 102 reflectType *template.Template 103 arrayType *template.Template 104 unaryFuncType *template.Template 105 unaryFuncErrType *template.Template 106 tensorType *template.Template 107 splatFuncOptType *template.Template 108 denseType *template.Template 109 testingType *template.Template 110 ) 111 112 func init() { 113 golinkPragma = template.Must(template.New("golinkPragma").Funcs(funcs).Parse(golinkPragmaRaw)) 114 typeAnnotatedName = template.Must(template.New("type annotated name").Funcs(funcs).Parse(typeAnnotatedNameRaw)) 115 plainName = template.Must(template.New("plainName").Funcs(funcs).Parse(plainNameRaw)) 116 117 scalarType = template.Must(template.New("scalarType").Funcs(funcs).Parse(scalarTypeRaw)) 118 sliceType = template.Must(template.New("sliceType").Funcs(funcs).Parse(sliceTypeRaw)) 119 iteratorType = template.Must(template.New("iteratorType").Funcs(funcs).Parse(iteratorTypeRaw)) 120 interfaceType = template.Must(template.New("interfaceType").Funcs(funcs).Parse(interfaceTypeRaw)) 121 boolsType = template.Must(template.New("boolsType").Funcs(funcs).Parse(boolsTypeRaw)) 122 boolType = template.Must(template.New("boolType").Funcs(funcs).Parse(boolTypeRaw)) 123 intType = template.Must(template.New("intTYpe").Funcs(funcs).Parse(intTypeRaw)) 124 intsType = template.Must(template.New("intsType").Funcs(funcs).Parse(intsTypeRaw)) 125 reflectType = template.Must(template.New("reflectType").Funcs(funcs).Parse(reflectTypeRaw)) 126 arrayType = template.Must(template.New("arrayType").Funcs(funcs).Parse(arrayTypeRaw)) 127 unaryFuncType = template.Must(template.New("unaryFuncType").Funcs(funcs).Parse(unaryFuncTypeRaw)) 128 unaryFuncErrType = template.Must(template.New("unaryFuncErrType").Funcs(funcs).Parse(unaryFuncErrTypeRaw)) 129 tensorType = template.Must(template.New("tensorType").Funcs(funcs).Parse(tensorTypeRaw)) 130 splatFuncOptType = template.Must(template.New("splatFuncOpt").Funcs(funcs).Parse(splatFuncOptTypeRaw)) 131 denseType = template.Must(template.New("*Dense").Funcs(funcs).Parse(denseTypeRaw)) 132 testingType = template.Must(template.New("*testing.T").Funcs(funcs).Parse(testingTypeRaw)) 133 }