github.com/wzzhu/tensor@v0.9.24/genlib2/signature.go (about)

     1  package main
     2  
     3  import (
     4  	"io"
     5  	"reflect"
     6  	"text/template"
     7  )
     8  
     9  type Signature struct {
    10  	Name            string
    11  	NameTemplate    *template.Template
    12  	ParamNames      []string
    13  	ParamTemplates  []*template.Template
    14  	RetVals         []string
    15  	RetValTemplates []*template.Template
    16  
    17  	Kind reflect.Kind
    18  	Err  bool
    19  }
    20  
    21  func (s *Signature) Write(w io.Writer) {
    22  	s.NameTemplate.Execute(w, s)
    23  	w.Write([]byte("("))
    24  	for i, p := range s.ParamTemplates {
    25  		w.Write([]byte(s.ParamNames[i]))
    26  		w.Write([]byte(" "))
    27  		p.Execute(w, s.Kind)
    28  
    29  		if i < len(s.ParamNames) {
    30  			w.Write([]byte(", "))
    31  		}
    32  	}
    33  	w.Write([]byte(")"))
    34  	if len(s.RetVals) > 0 {
    35  		w.Write([]byte("("))
    36  		for i, r := range s.RetValTemplates {
    37  			w.Write([]byte(s.RetVals[i]))
    38  			w.Write([]byte(" "))
    39  			r.Execute(w, s.Kind)
    40  
    41  			if i < len(s.RetVals) {
    42  				w.Write([]byte(", "))
    43  			}
    44  		}
    45  
    46  		if s.Err {
    47  			w.Write([]byte("err error"))
    48  		}
    49  		w.Write([]byte(")"))
    50  		return
    51  	}
    52  
    53  	if s.Err {
    54  		w.Write([]byte("(err error)"))
    55  	}
    56  }
    57  
    58  const (
    59  	golinkPragmaRaw = "//go:linkname {{.Name}}{{short .Kind}} github.com/chewxy/{{vecPkg .Kind}}{{getalias .Name}}\n"
    60  
    61  	typeAnnotatedNameRaw = `{{.Name}}{{short .Kind}}`
    62  	plainNameRaw         = `{{.Name}}`
    63  )
    64  
    65  const (
    66  	scalarTypeRaw    = `{{asType .}}`
    67  	sliceTypeRaw     = `[]{{asType .}}`
    68  	iteratorTypeRaw  = `Iterator`
    69  	interfaceTypeRaw = "interface{}"
    70  	boolsTypeRaw     = `[]bool`
    71  	boolTypeRaw      = `bool`
    72  	intTypeRaw       = `int`
    73  	intsTypeRaw      = `[]int`
    74  	reflectTypeRaw   = `reflect.Type`
    75  
    76  	// arrayTypeRaw        = `Array`
    77  	arrayTypeRaw            = `*storage.Header`
    78  	unaryFuncTypeRaw        = `func({{asType .}}){{asType .}} `
    79  	unaryFuncErrTypeRaw     = `func({{asType .}}) ({{asType .}}, error)`
    80  	reductionFuncTypeRaw    = `func(a, b {{asType .}}) {{asType .}}`
    81  	reductionFuncTypeErrRaw = `func(a, b {{asType .}}) ({{asType .}}, error)`
    82  	tensorTypeRaw           = `Tensor`
    83  	splatFuncOptTypeRaw     = `...FuncOpt`
    84  	denseTypeRaw            = `*Dense`
    85  
    86  	testingTypeRaw = `*testing.T`
    87  )
    88  
    89  var (
    90  	golinkPragma      *template.Template
    91  	typeAnnotatedName *template.Template
    92  	plainName         *template.Template
    93  
    94  	scalarType       *template.Template
    95  	sliceType        *template.Template
    96  	iteratorType     *template.Template
    97  	interfaceType    *template.Template
    98  	boolsType        *template.Template
    99  	boolType         *template.Template
   100  	intType          *template.Template
   101  	intsType         *template.Template
   102  	reflectType      *template.Template
   103  	arrayType        *template.Template
   104  	unaryFuncType    *template.Template
   105  	unaryFuncErrType *template.Template
   106  	tensorType       *template.Template
   107  	splatFuncOptType *template.Template
   108  	denseType        *template.Template
   109  	testingType      *template.Template
   110  )
   111  
   112  func init() {
   113  	golinkPragma = template.Must(template.New("golinkPragma").Funcs(funcs).Parse(golinkPragmaRaw))
   114  	typeAnnotatedName = template.Must(template.New("type annotated name").Funcs(funcs).Parse(typeAnnotatedNameRaw))
   115  	plainName = template.Must(template.New("plainName").Funcs(funcs).Parse(plainNameRaw))
   116  
   117  	scalarType = template.Must(template.New("scalarType").Funcs(funcs).Parse(scalarTypeRaw))
   118  	sliceType = template.Must(template.New("sliceType").Funcs(funcs).Parse(sliceTypeRaw))
   119  	iteratorType = template.Must(template.New("iteratorType").Funcs(funcs).Parse(iteratorTypeRaw))
   120  	interfaceType = template.Must(template.New("interfaceType").Funcs(funcs).Parse(interfaceTypeRaw))
   121  	boolsType = template.Must(template.New("boolsType").Funcs(funcs).Parse(boolsTypeRaw))
   122  	boolType = template.Must(template.New("boolType").Funcs(funcs).Parse(boolTypeRaw))
   123  	intType = template.Must(template.New("intTYpe").Funcs(funcs).Parse(intTypeRaw))
   124  	intsType = template.Must(template.New("intsType").Funcs(funcs).Parse(intsTypeRaw))
   125  	reflectType = template.Must(template.New("reflectType").Funcs(funcs).Parse(reflectTypeRaw))
   126  	arrayType = template.Must(template.New("arrayType").Funcs(funcs).Parse(arrayTypeRaw))
   127  	unaryFuncType = template.Must(template.New("unaryFuncType").Funcs(funcs).Parse(unaryFuncTypeRaw))
   128  	unaryFuncErrType = template.Must(template.New("unaryFuncErrType").Funcs(funcs).Parse(unaryFuncErrTypeRaw))
   129  	tensorType = template.Must(template.New("tensorType").Funcs(funcs).Parse(tensorTypeRaw))
   130  	splatFuncOptType = template.Must(template.New("splatFuncOpt").Funcs(funcs).Parse(splatFuncOptTypeRaw))
   131  	denseType = template.Must(template.New("*Dense").Funcs(funcs).Parse(denseTypeRaw))
   132  	testingType = template.Must(template.New("*testing.T").Funcs(funcs).Parse(testingTypeRaw))
   133  }