trpc.group/trpc-go/trpc-cmdline@v1.0.9/plugin/gotag.go (about) 1 // Tencent is pleased to support the open source community by making tRPC available. 2 // 3 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 4 // All rights reserved. 5 // 6 // If you have downloaded a copy of the tRPC source code from Tencent, 7 // please note that tRPC source code is licensed under the Apache 2.0 License, 8 // A copy of the Apache 2.0 License is included in this file. 9 10 package plugin 11 12 import ( 13 "fmt" 14 "go/ast" 15 "go/parser" 16 "go/token" 17 "io" 18 "os" 19 "path/filepath" 20 "regexp" 21 "strings" 22 23 "github.com/iancoleman/strcase" 24 "github.com/jhump/protoreflect/desc" 25 "google.golang.org/protobuf/proto" 26 "google.golang.org/protobuf/types/descriptorpb" 27 28 trpc "trpc.group/trpc/trpc-protocol/pb/go/trpc/proto" 29 30 "trpc.group/trpc-go/trpc-cmdline/descriptor" 31 "trpc.group/trpc-go/trpc-cmdline/params" 32 tparser "trpc.group/trpc-go/trpc-cmdline/parser" 33 "trpc.group/trpc-go/trpc-cmdline/util/fs" 34 "trpc.group/trpc-go/trpc-cmdline/util/log" 35 ) 36 37 var ( 38 regexpInject = regexp.MustCompile("`.+`$") 39 regexpTags = regexp.MustCompile(`[\w_]+:"[^"]+"`) 40 regexpProtobufTagName = regexp.MustCompile(`protobuf:"[\w_\-,]+name=([\w_\-]+)`) 41 ) 42 43 // textArea records tag text position and tag info in *.pb.go file 44 type textArea struct { 45 StartPos int 46 EndPos int 47 CurrentTag string 48 NewTag string 49 } 50 51 // GoTag generates go tag by proto field options 52 type GoTag struct { 53 } 54 55 // Name return plugin's name 56 func (p *GoTag) Name() string { 57 return "gotag" 58 } 59 60 // Check only run when `--lang=go && --go_tag=true` 61 func (p *GoTag) Check(fd *descriptor.FileDescriptor, opt *params.Option) bool { 62 if opt.Language == "go" && opt.Gotag { 63 return true 64 } 65 return false 66 } 67 68 // Run exec go tag plugin 69 func (p *GoTag) Run(fd *descriptor.FileDescriptor, opt *params.Option) error { 70 tags := optTagsFromProto(fd.FD) 71 if len(tags) == 0 { 72 return nil 73 } 74 75 outputdir := opt.OutputDir 76 pbfile := "" 77 pbname := fs.BaseNameWithoutExt(fd.FilePath) + ".pb.go" 78 79 if opt.RPCOnly { 80 pbfile = filepath.Join(outputdir, pbname) 81 } else { 82 importPath, err := tparser.GetPbPackage(fd, "go_package") 83 if err != nil { 84 return err 85 } 86 pbfile = filepath.Join(outputdir, "stub", importPath, pbname) 87 } 88 89 return p.replaceTags(pbfile, tags) 90 } 91 92 func (p *GoTag) replaceTags(pbfile string, tags map[string]string) error { 93 _, err := os.Lstat(pbfile) 94 if err != nil { 95 return err 96 } 97 areas, err := tagAreasFromPBFile(pbfile, tags) 98 if err != nil { 99 return err 100 } 101 if err = injectTagsToPBFile(pbfile, areas); err != nil { 102 return err 103 } 104 return nil 105 } 106 107 // optTagsFromProto parses field go tag option from proto file and maps it as a kv map 108 // map structure should be like `messageName_fieldName` 109 func optTagsFromProto(fd descriptor.Desc) map[string]string { 110 tagmap := make(map[string]string) 111 var scanNestedMsgFunc func(*desc.MessageDescriptor, string) 112 scanNestedMsgFunc = func(m *desc.MessageDescriptor, prefix string) { 113 for _, mm := range m.GetNestedMessageTypes() { 114 p := fmtgotagkey(prefix, m.GetName()) 115 scanNestedMsgFunc(mm, p) 116 } 117 for _, field := range m.GetFields() { 118 tags := getGoTag(field.GetFieldOptions()) 119 if tags == "" { 120 continue 121 } 122 key := fmtgotagkey(prefix, m.GetName(), field.GetName()) 123 tagmap[key] = tags 124 } 125 } 126 for _, msg := range fd.GetMessageTypes() { 127 messageDescriptor, ok := msg.(*descriptor.ProtoMessageDescriptor) 128 if !ok { 129 continue 130 } 131 md := messageDescriptor.MD 132 scanNestedMsgFunc(md, "") 133 } 134 return tagmap 135 } 136 137 func getGoTag(opts *descriptorpb.FieldOptions) string { 138 if proto.HasExtension(opts, trpc.E_GoTag) { 139 return proto.GetExtension(opts, trpc.E_GoTag).(string) 140 } 141 return "" 142 } 143 144 // fmtgotagkey generates the key for `protoTags` to join the struct name and 145 // field name by `_`, nested message names would be joined too 146 func fmtgotagkey(s ...string) string { 147 for k, v := range s { 148 if v == "" { 149 s = append(s[:k], s[k+1:]...) 150 } 151 } 152 return strcase.ToCamel(strings.Join(s, "_")) 153 } 154 155 // tagAreasFromPBFile parses *.pb.go and records tag positions which need to be replaced 156 func tagAreasFromPBFile(fp string, newtags map[string]string) (areas []textArea, err error) { 157 fset := token.NewFileSet() 158 f, err := parser.ParseFile(fset, fp, nil, parser.ParseComments) 159 if err != nil { 160 return 161 } 162 for _, decl := range f.Decls { 163 // check if is generic declaration 164 typeSpec := genTypeSpec(decl) 165 // skip if can't get type spec 166 if typeSpec == nil { 167 continue 168 } 169 // not a struct, skip 170 structDecl, ok := typeSpec.Type.(*ast.StructType) 171 if !ok { 172 continue 173 } 174 areas = append(areas, genAreas(structDecl, typeSpec, newtags)...) 175 } 176 return 177 } 178 179 func genAreas(structDecl *ast.StructType, typeSpec *ast.TypeSpec, newtags map[string]string) []textArea { 180 var areas []textArea 181 for _, field := range structDecl.Fields.List { 182 if field.Tag == nil { 183 continue 184 } 185 fieldname := protobufTagName(field.Tag.Value) 186 if fieldname == "" { 187 continue 188 } 189 structname := typeSpec.Name.String() 190 // key = structName_fieldName 191 key := fmtgotagkey(structname, fieldname) 192 newtag, ok := newtags[key] 193 if !ok { 194 continue 195 } 196 currentTag := field.Tag.Value 197 areas = append(areas, textArea{ 198 StartPos: int(field.Pos()), 199 EndPos: int(field.End()), 200 CurrentTag: currentTag[1 : len(currentTag)-1], 201 NewTag: newtag, 202 }) 203 } 204 return areas 205 } 206 207 func genTypeSpec(decl ast.Decl) *ast.TypeSpec { 208 genDecl, ok := decl.(*ast.GenDecl) 209 if !ok { 210 return nil 211 } 212 var typeSpec *ast.TypeSpec 213 for _, spec := range genDecl.Specs { 214 if ts, ok := spec.(*ast.TypeSpec); ok { 215 typeSpec = ts 216 break 217 } 218 } 219 return typeSpec 220 } 221 222 func protobufTagName(tag string) string { 223 matches := regexpProtobufTagName.FindStringSubmatch(tag) 224 if len(matches) > 1 { 225 return matches[1] 226 } 227 return "" 228 } 229 230 // injectTagsToPBFile replaces tags and rewrites the *.pb.go file 231 func injectTagsToPBFile(fp string, areas []textArea) (err error) { 232 f, err := os.Open(fp) 233 if err != nil { 234 return 235 } 236 contents, err := io.ReadAll(f) 237 if err != nil { 238 return 239 } 240 if err = f.Close(); err != nil { 241 return 242 } 243 return writeTagsToFile(fp, areas, contents, err) 244 } 245 246 func writeTagsToFile(fp string, areas []textArea, contents []byte, err error) error { 247 // inject custom tags from tail of file first to preserve order 248 for i := range areas { 249 area := areas[len(areas)-i-1] 250 log.Debug("inject custom tag %q to expression %q", 251 area.NewTag, string(contents[area.StartPos-1:area.EndPos-1])) 252 contents = injectGoTag(contents, area) 253 } 254 if err = os.WriteFile(fp, contents, 0644); err != nil { 255 return err 256 } 257 if len(areas) > 0 { 258 log.Debug("file %q is injected with custom tags", fp) 259 } 260 return nil 261 } 262 263 func injectGoTag(contents []byte, area textArea) (injected []byte) { 264 expr := make([]byte, area.EndPos-area.StartPos) 265 copy(expr, contents[area.StartPos-1:area.EndPos-1]) 266 cti := newGoTagItems(area.CurrentTag) 267 iti := newGoTagItems(area.NewTag) 268 ti := cti.override(iti) 269 expr = regexpInject.ReplaceAll(expr, []byte(fmt.Sprintf("`%s`", ti.format()))) 270 injected = append(injected, contents[:area.StartPos-1]...) 271 injected = append(injected, expr...) 272 injected = append(injected, contents[area.EndPos-1:]...) 273 return 274 } 275 276 type goTagItem struct { 277 key string 278 value string 279 } 280 281 type goTagItems []goTagItem 282 283 func (ti goTagItems) format() string { 284 tags := []string{} 285 for _, item := range ti { 286 tags = append(tags, fmt.Sprintf(`%s:%s`, item.key, item.value)) 287 } 288 return strings.Join(tags, " ") 289 } 290 291 func (ti goTagItems) override(nti goTagItems) goTagItems { 292 overridden := []goTagItem{} 293 for i := range ti { 294 var dup = -1 295 for j := range nti { 296 if ti[i].key == nti[j].key { 297 dup = j 298 break 299 } 300 } 301 if dup == -1 { 302 overridden = append(overridden, ti[i]) 303 } else { 304 overridden = append(overridden, nti[dup]) 305 nti = append(nti[:dup], nti[dup+1:]...) 306 } 307 } 308 return append(overridden, nti...) 309 } 310 311 func newGoTagItems(tag string) goTagItems { 312 var items goTagItems 313 split := regexpTags.FindAllString(tag, -1) 314 for _, t := range split { 315 sepPos := strings.Index(t, ":") 316 items = append(items, goTagItem{ 317 key: t[:sepPos], 318 value: t[sepPos+1:], 319 }) 320 } 321 return items 322 }