github.com/artisanhe/tools@v1.0.1-0.20210607022958-19a8fef2eb04/courier/status_error/gen/status_error_generator.go (about) 1 package gen 2 3 import ( 4 "bytes" 5 "fmt" 6 "go/build" 7 "go/parser" 8 "go/types" 9 "path" 10 "path/filepath" 11 "reflect" 12 "sort" 13 "strconv" 14 "strings" 15 16 "golang.org/x/tools/go/loader" 17 18 "github.com/artisanhe/tools/codegen" 19 "github.com/artisanhe/tools/codegen/loaderx" 20 "github.com/artisanhe/tools/courier/status_error" 21 ) 22 23 type StatusErrorGenerator struct { 24 pkgImportPath string 25 program *loader.Program 26 statusErrorCodes map[*types.Package]status_error.StatusErrorCodeMap 27 } 28 29 func (g *StatusErrorGenerator) Load(cwd string) { 30 ldr := loader.Config{ 31 AllowErrors: true, 32 ParserMode: parser.ParseComments, 33 } 34 35 pkgImportPath := codegen.GetPackageImportPath(cwd) 36 ldr.Import(pkgImportPath) 37 38 p, err := ldr.Load() 39 if err != nil { 40 panic(err) 41 } 42 43 g.program = p 44 g.pkgImportPath = pkgImportPath 45 g.statusErrorCodes = map[*types.Package]status_error.StatusErrorCodeMap{} 46 } 47 48 func (g *StatusErrorGenerator) Pick() { 49 statusErrorCodeType := reflect.TypeOf(status_error.StatusErrorCode(0)) 50 statusErrorCodeTypeFullName := fmt.Sprintf("%s.%s", statusErrorCodeType.PkgPath(), statusErrorCodeType.Name()) 51 52 for pkg, pkgInfo := range g.program.AllPackages { 53 if pkg.Path() != g.pkgImportPath { 54 continue 55 } 56 for ident, obj := range pkgInfo.Defs { 57 if constObj, ok := obj.(*types.Const); ok { 58 if strings.HasSuffix(constObj.Type().String(), statusErrorCodeTypeFullName) { 59 key := constObj.Name() 60 if key == "_" { 61 continue 62 } 63 64 doc := loaderx.CommentsOf(g.program.Fset, ident, pkgInfo.Files...) 65 code, _ := strconv.ParseInt(constObj.Val().String(), 10, 64) 66 msg, desc, canBeErrTalk := ParseStatusErrorDesc(doc) 67 68 if g.statusErrorCodes[pkg] == nil { 69 g.statusErrorCodes[pkg] = status_error.StatusErrorCodeMap{} 70 } 71 g.statusErrorCodes[pkg].Register(key, code, msg, desc, canBeErrTalk) 72 } 73 } 74 } 75 } 76 } 77 78 func (g *StatusErrorGenerator) Output(cwd string) codegen.Outputs { 79 statusErrorCodeType := reflect.TypeOf(status_error.StatusErrorCode(0)) 80 outputs := codegen.Outputs{} 81 for pkg, statusErrorCodeMap := range g.statusErrorCodes { 82 p, _ := build.Import(pkg.Path(), "", build.FindOnly) 83 dir, _ := filepath.Rel(cwd, p.Dir) 84 content := fmt.Sprintf(` 85 package %s 86 87 import( 88 %s 89 ) 90 91 %s `, 92 pkg.Name(), 93 strconv.Quote(statusErrorCodeType.PkgPath()), 94 g.toRegisterInit(statusErrorCodeMap), 95 ) 96 outputs.Add(codegen.GeneratedSuffix(path.Join(dir, "status_err_codes.go")), content) 97 } 98 return outputs 99 } 100 101 func (g *StatusErrorGenerator) toRegisterInit(statusErrorCodeMap status_error.StatusErrorCodeMap) string { 102 buffer := bytes.Buffer{} 103 buffer.WriteString("func init () {") 104 105 registerMethod := "status_error.StatusErrorCodes" 106 pkgs := strings.Split(g.pkgImportPath, "/") 107 if strings.HasPrefix(registerMethod, pkgs[len(pkgs)-1]) { 108 registerMethod = "StatusErrorCodes" 109 } 110 111 statusErrorCodeList := []int{} 112 for _, statusErrorCode := range statusErrorCodeMap { 113 statusErrorCodeList = append(statusErrorCodeList, int(statusErrorCode.Code)) 114 } 115 sort.Ints(statusErrorCodeList) 116 117 for _, code := range statusErrorCodeList { 118 statusErrorCode := statusErrorCodeMap[int64(code)] 119 120 buffer.WriteString(fmt.Sprintf( 121 "%s.Register(%s, %d, %s, %s, %s)\n", 122 registerMethod, 123 strconv.Quote(statusErrorCode.Key), 124 statusErrorCode.Code, 125 strconv.Quote(statusErrorCode.Msg), 126 strconv.Quote(statusErrorCode.Desc), 127 strconv.FormatBool(statusErrorCode.CanBeErrorTalk), 128 )) 129 } 130 131 buffer.WriteString("}") 132 return buffer.String() 133 } 134 135 func ParseStatusErrorDesc(str string) (msg string, desc string, canBeErrTalk bool) { 136 lines := strings.Split(str, "\n") 137 firstLine := strings.Split(lines[0], "@errTalk") 138 139 if len(firstLine) > 1 { 140 canBeErrTalk = true 141 msg = strings.TrimSpace(firstLine[1]) 142 } else { 143 canBeErrTalk = false 144 msg = strings.TrimSpace(firstLine[0]) 145 } 146 147 if len(lines) > 1 { 148 desc = strings.TrimSpace(strings.Join(lines[1:], "\n")) 149 } 150 return 151 }