github.com/nyan233/littlerpc@v0.4.6-0.20230316182519-0c8d5c48abaf/cmd/pxtor/generator.go (about)

     1  package main
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"encoding/hex"
     7  	"errors"
     8  	"fmt"
     9  	flag "github.com/spf13/pflag"
    10  	"go/ast"
    11  	"go/format"
    12  	"go/parser"
    13  	"go/token"
    14  	"math/rand"
    15  	"os"
    16  	"path"
    17  	"strconv"
    18  	"strings"
    19  	"text/template"
    20  	"time"
    21  )
    22  
    23  type GenMethod func(receive Argument, name, service string, input []Argument, output []Argument) (getResult func() string, err error)
    24  
    25  const (
    26  	SyncStyle     = "sync"
    27  	AsyncStyle    = "async"
    28  	RequestsStyle = "requests"
    29  )
    30  
    31  var (
    32  	receiver   = flag.StringP("receive", "r", "", "代理对象的接收器: package.RecvName")
    33  	dir        = flag.StringP("dir", "d", "./", "解析接收器的路径: ./")
    34  	outName    = flag.StringP("out", "o", "", "输出的文件名,默认的格式: receiver_proxy.go")
    35  	sourceName = flag.StringP("source", "s", "", "SourceName Example(Hello1.Hello2) SourceName == Hello1")
    36  	generateId = flag.BoolP("gen_id", "i", false, "生成唯一id, 多个文件在同一个包时binder/caller不会冲突, 但对于mock场景不友好")
    37  	// TODO: 实现不同API风格的生成函数
    38  	style   = flag.StringP("gen", "g", SyncStyle, "生成的API风格, TODO")
    39  	fileSet *token.FileSet
    40  )
    41  
    42  func main() {
    43  	flag.Parse()
    44  	if *receiver == "" {
    45  		panic(interface{}("no receiver specified"))
    46  	}
    47  	if *sourceName == "" {
    48  		*sourceName = strings.Split(*receiver, ".")[1]
    49  	}
    50  	genCode()
    51  }
    52  
    53  func genCode() {
    54  	tmp := strings.SplitN(*receiver, ".", 2)
    55  	pkgName, recvName := tmp[0], tmp[1]
    56  	// 输出文件名
    57  	if *outName == "" {
    58  		*outName = recvName + "_proxy.go"
    59  	}
    60  	fileSet = token.NewFileSet()
    61  	parseDir, err := parser.ParseDir(fileSet, *dir, nil, 0)
    62  	if err != nil {
    63  		panic(interface{}(err))
    64  	}
    65  	// ast.Print(fileSet,parseDir[pkgName].Files["test/proxy_2.go"])
    66  	// 创建
    67  	pkgDir := parseDir[pkgName]
    68  	funcStrs := make([]string, 0, 20)
    69  	// 要写入到文件的数据,提供这个是为了方便格式化生成的代码
    70  	var fileBuffer bytes.Buffer
    71  	fileBuffer.Grow(512)
    72  	var genFn GenMethod
    73  	switch *style {
    74  	case SyncStyle:
    75  		genFn = genSync
    76  	default:
    77  		panic("no support gen style")
    78  	}
    79  	usageImportNameAndPath := make(map[string]string)
    80  	ignoreSetup := ignoreSetup(pkgDir.Files, *receiver)
    81  	for k, v := range pkgDir.Files {
    82  		rawFile, err := os.Open(path.Dir(*dir) + "/" + k)
    83  		if err != nil {
    84  			panic(interface{}(err))
    85  		}
    86  		tmp := getAllFunc(v, rawFile, usageImportNameAndPath, *sourceName, genFn, func(recvT string) bool {
    87  			if recvT == recvName {
    88  				return true
    89  			}
    90  			return false
    91  		}, ignoreSetup)
    92  		funcStrs = append(funcStrs, tmp...)
    93  	}
    94  	fileBuffer.WriteString(createBeforeCode(pkgName, recvName, *sourceName, funcStrs, usageImportNameAndPath))
    95  	for _, v := range funcStrs {
    96  		fileBuffer.WriteString("\n\n")
    97  		fileBuffer.WriteString(v)
    98  	}
    99  	if string(fileBuffer.Bytes()[fileBuffer.Len()-4:]) == "}\n}\n" {
   100  		fmt.Println("double }")
   101  	}
   102  	fmtBytes, err := format.Source(fileBuffer.Bytes())
   103  	if err != nil {
   104  		panic(err)
   105  	}
   106  	file, err := os.OpenFile(*dir+"/"+*outName, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0755)
   107  	if err != nil {
   108  		panic(interface{}(err))
   109  	}
   110  	writeN, err := file.Write(fmtBytes)
   111  	if err != nil {
   112  		panic(interface{}(err))
   113  	}
   114  	if writeN != len(fmtBytes) {
   115  		panic(interface{}(errors.New("write format bytes no equal")))
   116  	}
   117  }
   118  
   119  func getAllFunc(file *ast.File, rawFile *os.File, usageImportNameAndPath map[string]string, sourceName string,
   120  	genFunc GenMethod, filter func(recvT string) bool, ignoreSetup bool) []string {
   121  	funcStrs := make([]string, 0)
   122  	importNamePathMapping := buildImportNameAndPath(file.Imports)
   123  	for _, v := range file.Decls {
   124  		funcDecl, ok := v.(*ast.FuncDecl)
   125  		if !ok {
   126  			continue
   127  		}
   128  		if funcDecl.Recv == nil {
   129  			continue
   130  		}
   131  		var receiver *ast.Ident
   132  		for _, v := range funcDecl.Recv.List {
   133  			// 目前只支持生成底层类型是struct的代理对象
   134  			sExp, ok := v.Type.(*ast.StarExpr)
   135  			if !ok {
   136  				continue
   137  			}
   138  			ident, ok := sExp.X.(*ast.Ident)
   139  			if !ok {
   140  				continue
   141  			}
   142  			receiver = ident
   143  		}
   144  		// 无接收器的函数不是正确的声明
   145  		if receiver == nil {
   146  			continue
   147  		}
   148  		if !filter(receiver.Name) {
   149  			continue
   150  		}
   151  		// 被代理对象的类型名
   152  		recvName := receiver.Name
   153  		// 被代理对应持有的方法名
   154  		funName := funcDecl.Name.Name
   155  		if funName == "Setup" && ignoreSetup {
   156  			continue
   157  		}
   158  		inputList := make([]Argument, 0, 4)
   159  		outputList := make([]Argument, 0, 4)
   160  		// 处理参数的序列化
   161  		for _, pv := range funcDecl.Type.Params.List {
   162  			for _, pvName := range pv.Names {
   163  				arg := Argument{
   164  					Name: pvName.Name,
   165  					Type: handleAstType(pv.Type, rawFile),
   166  				}
   167  				inputList = append(inputList, arg)
   168  				// usage import ?
   169  				if !strings.Contains(arg.Type, ".") {
   170  					continue
   171  				}
   172  				typeName := strings.Trim(arg.Type, "*")
   173  				importName := strings.SplitN(typeName, ".", 2)[0]
   174  				usageImportNameAndPath[importName] = importNamePathMapping[importName]
   175  			}
   176  		}
   177  		// 找出所有的返回值类型
   178  		for _, rv := range funcDecl.Type.Results.List {
   179  			res := Argument{
   180  				Type: handleAstType(rv.Type, rawFile),
   181  			}
   182  			outputList = append(outputList, res)
   183  			// usage import ?
   184  			if !strings.Contains(res.Type, ".") {
   185  				continue
   186  			}
   187  			typeName := strings.Trim(res.Type, "*")
   188  			importName := strings.SplitN(typeName, ".", 2)[0]
   189  			usageImportNameAndPath[importName] = importNamePathMapping[importName]
   190  		}
   191  		after, err := genFunc(Argument{
   192  			Name: "p",
   193  			Type: recvName,
   194  		}, funName, sourceName+"."+funName, inputList, outputList)
   195  		if err != nil {
   196  			return nil
   197  		}
   198  		funcStrs = append(funcStrs, after())
   199  	}
   200  	return funcStrs
   201  }
   202  
   203  func ignoreSetup(astFiles map[string]*ast.File, receive string) (ignore bool) {
   204  	for _, astFile := range astFiles {
   205  		importNameAndPath := buildImportNameAndPath(astFile.Imports)
   206  		for _, decl := range astFile.Decls {
   207  			genDecl, ok := decl.(*ast.GenDecl)
   208  			if !ok {
   209  				continue
   210  			}
   211  			typeSpec, ok := genDecl.Specs[0].(*ast.TypeSpec)
   212  			if !ok {
   213  				continue
   214  			}
   215  			targetTypeName := strings.Split(receive, ".")[1]
   216  			if typeSpec.Name.Name != targetTypeName {
   217  				continue
   218  			}
   219  			// 只有Struct类型才能内嵌RpcServer
   220  			structType, ok := typeSpec.Type.(*ast.StructType)
   221  			if !ok {
   222  				continue
   223  			}
   224  			for _, field := range structType.Fields.List {
   225  				se, ok := field.Type.(*ast.SelectorExpr)
   226  				if !ok {
   227  					continue
   228  				}
   229  				if se.Sel.Name != "RpcServer" {
   230  					continue
   231  				}
   232  				ident, ok := se.X.(*ast.Ident)
   233  				if !ok {
   234  					continue
   235  				}
   236  				importPath := importNameAndPath[ident.Name]
   237  				if importPath != "github.com/nyan233/littlerpc/core/server" {
   238  					continue
   239  				}
   240  				return true
   241  			}
   242  		}
   243  	}
   244  	return
   245  }
   246  
   247  func buildImportNameAndPath(imports []*ast.ImportSpec) map[string]string {
   248  	result := make(map[string]string, len(imports))
   249  	for _, v := range imports {
   250  		// 没有别名
   251  		pathVal := strings.Trim(v.Path.Value, "\"")
   252  		if v.Name == nil {
   253  			tmp := strings.Split(pathVal, "/")
   254  			result[tmp[len(tmp)-1]] = pathVal
   255  			continue
   256  		}
   257  		result[v.Name.Name] = pathVal
   258  	}
   259  	return result
   260  }
   261  
   262  // 生成同步调用的Api
   263  func genSync(receive Argument, name, service string, input []Argument, output []Argument) (getResult func() string, err error) {
   264  	receive.Type = GetTypeName(receive.Type)
   265  	m := Method{
   266  		Receive:     receive,
   267  		ServiceName: service,
   268  		Name:        name,
   269  		InputList:   input,
   270  		OutputList:  output,
   271  		Statement:   Statement{},
   272  	}
   273  	return m.FormatToSync, nil
   274  }
   275  
   276  // 生成异步调用的Api
   277  func genAsyncApi(recvName, source, service string, inNameList, inTypeList, outList []string) (asyncApi [2]string, err error) {
   278  	if len(inNameList) != len(inTypeList) {
   279  		return [2]string{}, errors.New("inNameList and inTypeList length not equal")
   280  	}
   281  	recvName = GetTypeName(recvName)
   282  	var sb strings.Builder
   283  	_, _ = fmt.Fprintf(&sb, "func (p %s) Async%s(", recvName, service)
   284  	for i := 0; i < len(inNameList); i++ {
   285  		_, _ = fmt.Fprintf(&sb, "%s %s,", inNameList[i], inTypeList[i])
   286  	}
   287  	_, _ = fmt.Fprintf(&sb, ") error {return p.SyncCall(\"%s.%s\",", source, service)
   288  	for _, v := range inNameList {
   289  		sb.WriteString(v)
   290  		sb.WriteByte(',')
   291  	}
   292  	sb.WriteString(")}")
   293  	asyncApi[0] = sb.String()
   294  	sb.Reset()
   295  	_, _ = fmt.Fprintf(&sb, "func (p %sProxy) Register%sCallBack(fn func(", recvName, service)
   296  	for k, v := range outList {
   297  		_, _ = fmt.Fprintf(&sb, "r%s %s,", strconv.Itoa(k), v)
   298  	}
   299  	sb.WriteString("))")
   300  	_, _ = fmt.Fprintf(&sb, "{p.RegisterCallBack(\"%s.%s\",func(rep []interface{}, err error) {", recvName, service)
   301  	// gen error check
   302  	sb.WriteString("if err != nil {fn(")
   303  	for k, v := range outList {
   304  		// 关于error的生成必须独立处理,否则则会被替换为nil作为默认值
   305  		if k == len(outList)-1 {
   306  			// 一定要注入return,否则过程在出错的时候也会调用无错才会调用的回调函数
   307  			sb.WriteString("err);return};")
   308  			continue
   309  		}
   310  		str, err := writeDefaultValue(v)
   311  		if err != nil {
   312  			return [2]string{}, err
   313  		}
   314  		sb.WriteString(str)
   315  		sb.WriteString(",")
   316  	}
   317  	// 生成断言的代码
   318  	for k, v := range outList {
   319  		// error类型的返回值使用安全断言
   320  		if v == "error" {
   321  			_, _ = fmt.Fprintf(&sb, "r%d,_ := rep[%d].(%s);", k, k, v)
   322  			continue
   323  		}
   324  		_, _ = fmt.Fprintf(&sb, "r%d := rep[%d].(%s);", k, k, v)
   325  	}
   326  	// 最后生成调用的代码
   327  	sb.WriteString("fn(")
   328  	for k := range outList {
   329  		_, _ = fmt.Fprintf(&sb, "r%d,", k)
   330  	}
   331  	sb.WriteString(");})}")
   332  	asyncApi[1] = sb.String()
   333  	return
   334  }
   335  
   336  type ImportDesc struct {
   337  	Name string
   338  	Path string
   339  }
   340  
   341  type BeforeCodeDesc struct {
   342  	PackageName   string
   343  	GeneratorName string
   344  	CreateTime    time.Time
   345  	Author        string
   346  	ImportList    []ImportDesc
   347  	InterfaceName string
   348  	MethodList    []string
   349  	SourceName    string
   350  	TypeName      string
   351  	RealTypeName  string
   352  	GenId         string
   353  }
   354  
   355  // 在这里生成包注释、导入、工厂函数、各种需要的类型
   356  func createBeforeCode(pkgName, recvName, source string, allFunc []string, usageImportNameAndPath map[string]string) string {
   357  	interfaceName := recvName + "Proxy"
   358  	typeName := GetTypeName(recvName)
   359  	t, err := template.New("BeforeCodeDesc").Parse(BeforeCodeTemplate)
   360  	if err != nil {
   361  		panic(err)
   362  	}
   363  	var sb strings.Builder
   364  	sb.Grow(1024)
   365  	desc := &BeforeCodeDesc{
   366  		PackageName:   pkgName,
   367  		GeneratorName: "pxtor",
   368  		CreateTime:    time.Now(),
   369  		Author:        "NoAuthor",
   370  		ImportList: []ImportDesc{
   371  			{
   372  				Path: "github.com/nyan233/littlerpc/core/client",
   373  			},
   374  		},
   375  		InterfaceName: interfaceName,
   376  		SourceName:    source,
   377  		TypeName:      typeName,
   378  		RealTypeName:  recvName,
   379  	}
   380  	for importName, importPath := range usageImportNameAndPath {
   381  		// 未使用别名
   382  		if strings.HasSuffix(importPath, importName) {
   383  			desc.ImportList = append(desc.ImportList, ImportDesc{Path: importPath})
   384  			continue
   385  		}
   386  		desc.ImportList = append(desc.ImportList, ImportDesc{importName, importPath})
   387  	}
   388  	if *generateId {
   389  		desc.GenId = getId()
   390  	}
   391  	for _, v := range allFunc {
   392  		// func (x receiver) Say(i int) error {...
   393  		methodMeta := strings.SplitN(v, ")", 2)[1]
   394  		methodMeta = strings.SplitN(methodMeta, "{", 2)[0]
   395  		desc.MethodList = append(desc.MethodList, methodMeta)
   396  	}
   397  	err = t.Execute(&sb, desc)
   398  	if err != nil {
   399  		panic(err)
   400  	}
   401  	return sb.String()
   402  }
   403  
   404  func getId() string {
   405  	after := time.Now().UnixNano()
   406  	rand.Seed(after)
   407  	before := rand.Uint64()
   408  	bStr := hex.EncodeToString(binary.BigEndian.AppendUint64(nil, before))
   409  	aStr := hex.EncodeToString(binary.BigEndian.AppendUint64(nil, uint64(after)))
   410  	return aStr + bStr
   411  }
   412  
   413  func GetTypeName(recvName string) string {
   414  	if len(recvName) == 0 {
   415  		return ""
   416  	}
   417  	bytes4Str := []byte(recvName)
   418  	lowBytes := bytes.ToLower(bytes4Str[:1])
   419  	bytes4Str[0] = lowBytes[0]
   420  	return string(bytes4Str) + "Impl"
   421  }