github.com/nyan233/littlerpc@v0.4.6-0.20230316182519-0c8d5c48abaf/cmd/pxtor/generator.go (about) 1 package main 2 3 import ( 4 "bytes" 5 "encoding/binary" 6 "encoding/hex" 7 "errors" 8 "fmt" 9 flag "github.com/spf13/pflag" 10 "go/ast" 11 "go/format" 12 "go/parser" 13 "go/token" 14 "math/rand" 15 "os" 16 "path" 17 "strconv" 18 "strings" 19 "text/template" 20 "time" 21 ) 22 23 type GenMethod func(receive Argument, name, service string, input []Argument, output []Argument) (getResult func() string, err error) 24 25 const ( 26 SyncStyle = "sync" 27 AsyncStyle = "async" 28 RequestsStyle = "requests" 29 ) 30 31 var ( 32 receiver = flag.StringP("receive", "r", "", "代理对象的接收器: package.RecvName") 33 dir = flag.StringP("dir", "d", "./", "解析接收器的路径: ./") 34 outName = flag.StringP("out", "o", "", "输出的文件名,默认的格式: receiver_proxy.go") 35 sourceName = flag.StringP("source", "s", "", "SourceName Example(Hello1.Hello2) SourceName == Hello1") 36 generateId = flag.BoolP("gen_id", "i", false, "生成唯一id, 多个文件在同一个包时binder/caller不会冲突, 但对于mock场景不友好") 37 // TODO: 实现不同API风格的生成函数 38 style = flag.StringP("gen", "g", SyncStyle, "生成的API风格, TODO") 39 fileSet *token.FileSet 40 ) 41 42 func main() { 43 flag.Parse() 44 if *receiver == "" { 45 panic(interface{}("no receiver specified")) 46 } 47 if *sourceName == "" { 48 *sourceName = strings.Split(*receiver, ".")[1] 49 } 50 genCode() 51 } 52 53 func genCode() { 54 tmp := strings.SplitN(*receiver, ".", 2) 55 pkgName, recvName := tmp[0], tmp[1] 56 // 输出文件名 57 if *outName == "" { 58 *outName = recvName + "_proxy.go" 59 } 60 fileSet = token.NewFileSet() 61 parseDir, err := parser.ParseDir(fileSet, *dir, nil, 0) 62 if err != nil { 63 panic(interface{}(err)) 64 } 65 // ast.Print(fileSet,parseDir[pkgName].Files["test/proxy_2.go"]) 66 // 创建 67 pkgDir := parseDir[pkgName] 68 funcStrs := make([]string, 0, 20) 69 // 要写入到文件的数据,提供这个是为了方便格式化生成的代码 70 var fileBuffer bytes.Buffer 71 fileBuffer.Grow(512) 72 var genFn GenMethod 73 switch *style { 74 case SyncStyle: 75 genFn = genSync 76 default: 77 panic("no support gen style") 78 } 79 usageImportNameAndPath := make(map[string]string) 80 ignoreSetup := ignoreSetup(pkgDir.Files, *receiver) 81 for k, v := range pkgDir.Files { 82 rawFile, err := os.Open(path.Dir(*dir) + "/" + k) 83 if err != nil { 84 panic(interface{}(err)) 85 } 86 tmp := getAllFunc(v, rawFile, usageImportNameAndPath, *sourceName, genFn, func(recvT string) bool { 87 if recvT == recvName { 88 return true 89 } 90 return false 91 }, ignoreSetup) 92 funcStrs = append(funcStrs, tmp...) 93 } 94 fileBuffer.WriteString(createBeforeCode(pkgName, recvName, *sourceName, funcStrs, usageImportNameAndPath)) 95 for _, v := range funcStrs { 96 fileBuffer.WriteString("\n\n") 97 fileBuffer.WriteString(v) 98 } 99 if string(fileBuffer.Bytes()[fileBuffer.Len()-4:]) == "}\n}\n" { 100 fmt.Println("double }") 101 } 102 fmtBytes, err := format.Source(fileBuffer.Bytes()) 103 if err != nil { 104 panic(err) 105 } 106 file, err := os.OpenFile(*dir+"/"+*outName, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0755) 107 if err != nil { 108 panic(interface{}(err)) 109 } 110 writeN, err := file.Write(fmtBytes) 111 if err != nil { 112 panic(interface{}(err)) 113 } 114 if writeN != len(fmtBytes) { 115 panic(interface{}(errors.New("write format bytes no equal"))) 116 } 117 } 118 119 func getAllFunc(file *ast.File, rawFile *os.File, usageImportNameAndPath map[string]string, sourceName string, 120 genFunc GenMethod, filter func(recvT string) bool, ignoreSetup bool) []string { 121 funcStrs := make([]string, 0) 122 importNamePathMapping := buildImportNameAndPath(file.Imports) 123 for _, v := range file.Decls { 124 funcDecl, ok := v.(*ast.FuncDecl) 125 if !ok { 126 continue 127 } 128 if funcDecl.Recv == nil { 129 continue 130 } 131 var receiver *ast.Ident 132 for _, v := range funcDecl.Recv.List { 133 // 目前只支持生成底层类型是struct的代理对象 134 sExp, ok := v.Type.(*ast.StarExpr) 135 if !ok { 136 continue 137 } 138 ident, ok := sExp.X.(*ast.Ident) 139 if !ok { 140 continue 141 } 142 receiver = ident 143 } 144 // 无接收器的函数不是正确的声明 145 if receiver == nil { 146 continue 147 } 148 if !filter(receiver.Name) { 149 continue 150 } 151 // 被代理对象的类型名 152 recvName := receiver.Name 153 // 被代理对应持有的方法名 154 funName := funcDecl.Name.Name 155 if funName == "Setup" && ignoreSetup { 156 continue 157 } 158 inputList := make([]Argument, 0, 4) 159 outputList := make([]Argument, 0, 4) 160 // 处理参数的序列化 161 for _, pv := range funcDecl.Type.Params.List { 162 for _, pvName := range pv.Names { 163 arg := Argument{ 164 Name: pvName.Name, 165 Type: handleAstType(pv.Type, rawFile), 166 } 167 inputList = append(inputList, arg) 168 // usage import ? 169 if !strings.Contains(arg.Type, ".") { 170 continue 171 } 172 typeName := strings.Trim(arg.Type, "*") 173 importName := strings.SplitN(typeName, ".", 2)[0] 174 usageImportNameAndPath[importName] = importNamePathMapping[importName] 175 } 176 } 177 // 找出所有的返回值类型 178 for _, rv := range funcDecl.Type.Results.List { 179 res := Argument{ 180 Type: handleAstType(rv.Type, rawFile), 181 } 182 outputList = append(outputList, res) 183 // usage import ? 184 if !strings.Contains(res.Type, ".") { 185 continue 186 } 187 typeName := strings.Trim(res.Type, "*") 188 importName := strings.SplitN(typeName, ".", 2)[0] 189 usageImportNameAndPath[importName] = importNamePathMapping[importName] 190 } 191 after, err := genFunc(Argument{ 192 Name: "p", 193 Type: recvName, 194 }, funName, sourceName+"."+funName, inputList, outputList) 195 if err != nil { 196 return nil 197 } 198 funcStrs = append(funcStrs, after()) 199 } 200 return funcStrs 201 } 202 203 func ignoreSetup(astFiles map[string]*ast.File, receive string) (ignore bool) { 204 for _, astFile := range astFiles { 205 importNameAndPath := buildImportNameAndPath(astFile.Imports) 206 for _, decl := range astFile.Decls { 207 genDecl, ok := decl.(*ast.GenDecl) 208 if !ok { 209 continue 210 } 211 typeSpec, ok := genDecl.Specs[0].(*ast.TypeSpec) 212 if !ok { 213 continue 214 } 215 targetTypeName := strings.Split(receive, ".")[1] 216 if typeSpec.Name.Name != targetTypeName { 217 continue 218 } 219 // 只有Struct类型才能内嵌RpcServer 220 structType, ok := typeSpec.Type.(*ast.StructType) 221 if !ok { 222 continue 223 } 224 for _, field := range structType.Fields.List { 225 se, ok := field.Type.(*ast.SelectorExpr) 226 if !ok { 227 continue 228 } 229 if se.Sel.Name != "RpcServer" { 230 continue 231 } 232 ident, ok := se.X.(*ast.Ident) 233 if !ok { 234 continue 235 } 236 importPath := importNameAndPath[ident.Name] 237 if importPath != "github.com/nyan233/littlerpc/core/server" { 238 continue 239 } 240 return true 241 } 242 } 243 } 244 return 245 } 246 247 func buildImportNameAndPath(imports []*ast.ImportSpec) map[string]string { 248 result := make(map[string]string, len(imports)) 249 for _, v := range imports { 250 // 没有别名 251 pathVal := strings.Trim(v.Path.Value, "\"") 252 if v.Name == nil { 253 tmp := strings.Split(pathVal, "/") 254 result[tmp[len(tmp)-1]] = pathVal 255 continue 256 } 257 result[v.Name.Name] = pathVal 258 } 259 return result 260 } 261 262 // 生成同步调用的Api 263 func genSync(receive Argument, name, service string, input []Argument, output []Argument) (getResult func() string, err error) { 264 receive.Type = GetTypeName(receive.Type) 265 m := Method{ 266 Receive: receive, 267 ServiceName: service, 268 Name: name, 269 InputList: input, 270 OutputList: output, 271 Statement: Statement{}, 272 } 273 return m.FormatToSync, nil 274 } 275 276 // 生成异步调用的Api 277 func genAsyncApi(recvName, source, service string, inNameList, inTypeList, outList []string) (asyncApi [2]string, err error) { 278 if len(inNameList) != len(inTypeList) { 279 return [2]string{}, errors.New("inNameList and inTypeList length not equal") 280 } 281 recvName = GetTypeName(recvName) 282 var sb strings.Builder 283 _, _ = fmt.Fprintf(&sb, "func (p %s) Async%s(", recvName, service) 284 for i := 0; i < len(inNameList); i++ { 285 _, _ = fmt.Fprintf(&sb, "%s %s,", inNameList[i], inTypeList[i]) 286 } 287 _, _ = fmt.Fprintf(&sb, ") error {return p.SyncCall(\"%s.%s\",", source, service) 288 for _, v := range inNameList { 289 sb.WriteString(v) 290 sb.WriteByte(',') 291 } 292 sb.WriteString(")}") 293 asyncApi[0] = sb.String() 294 sb.Reset() 295 _, _ = fmt.Fprintf(&sb, "func (p %sProxy) Register%sCallBack(fn func(", recvName, service) 296 for k, v := range outList { 297 _, _ = fmt.Fprintf(&sb, "r%s %s,", strconv.Itoa(k), v) 298 } 299 sb.WriteString("))") 300 _, _ = fmt.Fprintf(&sb, "{p.RegisterCallBack(\"%s.%s\",func(rep []interface{}, err error) {", recvName, service) 301 // gen error check 302 sb.WriteString("if err != nil {fn(") 303 for k, v := range outList { 304 // 关于error的生成必须独立处理,否则则会被替换为nil作为默认值 305 if k == len(outList)-1 { 306 // 一定要注入return,否则过程在出错的时候也会调用无错才会调用的回调函数 307 sb.WriteString("err);return};") 308 continue 309 } 310 str, err := writeDefaultValue(v) 311 if err != nil { 312 return [2]string{}, err 313 } 314 sb.WriteString(str) 315 sb.WriteString(",") 316 } 317 // 生成断言的代码 318 for k, v := range outList { 319 // error类型的返回值使用安全断言 320 if v == "error" { 321 _, _ = fmt.Fprintf(&sb, "r%d,_ := rep[%d].(%s);", k, k, v) 322 continue 323 } 324 _, _ = fmt.Fprintf(&sb, "r%d := rep[%d].(%s);", k, k, v) 325 } 326 // 最后生成调用的代码 327 sb.WriteString("fn(") 328 for k := range outList { 329 _, _ = fmt.Fprintf(&sb, "r%d,", k) 330 } 331 sb.WriteString(");})}") 332 asyncApi[1] = sb.String() 333 return 334 } 335 336 type ImportDesc struct { 337 Name string 338 Path string 339 } 340 341 type BeforeCodeDesc struct { 342 PackageName string 343 GeneratorName string 344 CreateTime time.Time 345 Author string 346 ImportList []ImportDesc 347 InterfaceName string 348 MethodList []string 349 SourceName string 350 TypeName string 351 RealTypeName string 352 GenId string 353 } 354 355 // 在这里生成包注释、导入、工厂函数、各种需要的类型 356 func createBeforeCode(pkgName, recvName, source string, allFunc []string, usageImportNameAndPath map[string]string) string { 357 interfaceName := recvName + "Proxy" 358 typeName := GetTypeName(recvName) 359 t, err := template.New("BeforeCodeDesc").Parse(BeforeCodeTemplate) 360 if err != nil { 361 panic(err) 362 } 363 var sb strings.Builder 364 sb.Grow(1024) 365 desc := &BeforeCodeDesc{ 366 PackageName: pkgName, 367 GeneratorName: "pxtor", 368 CreateTime: time.Now(), 369 Author: "NoAuthor", 370 ImportList: []ImportDesc{ 371 { 372 Path: "github.com/nyan233/littlerpc/core/client", 373 }, 374 }, 375 InterfaceName: interfaceName, 376 SourceName: source, 377 TypeName: typeName, 378 RealTypeName: recvName, 379 } 380 for importName, importPath := range usageImportNameAndPath { 381 // 未使用别名 382 if strings.HasSuffix(importPath, importName) { 383 desc.ImportList = append(desc.ImportList, ImportDesc{Path: importPath}) 384 continue 385 } 386 desc.ImportList = append(desc.ImportList, ImportDesc{importName, importPath}) 387 } 388 if *generateId { 389 desc.GenId = getId() 390 } 391 for _, v := range allFunc { 392 // func (x receiver) Say(i int) error {... 393 methodMeta := strings.SplitN(v, ")", 2)[1] 394 methodMeta = strings.SplitN(methodMeta, "{", 2)[0] 395 desc.MethodList = append(desc.MethodList, methodMeta) 396 } 397 err = t.Execute(&sb, desc) 398 if err != nil { 399 panic(err) 400 } 401 return sb.String() 402 } 403 404 func getId() string { 405 after := time.Now().UnixNano() 406 rand.Seed(after) 407 before := rand.Uint64() 408 bStr := hex.EncodeToString(binary.BigEndian.AppendUint64(nil, before)) 409 aStr := hex.EncodeToString(binary.BigEndian.AppendUint64(nil, uint64(after))) 410 return aStr + bStr 411 } 412 413 func GetTypeName(recvName string) string { 414 if len(recvName) == 0 { 415 return "" 416 } 417 bytes4Str := []byte(recvName) 418 lowBytes := bytes.ToLower(bytes4Str[:1]) 419 bytes4Str[0] = lowBytes[0] 420 return string(bytes4Str) + "Impl" 421 }