trpc.group/trpc-go/trpc-cmdline@v1.0.9/plugin/gotag.go (about)

     1  // Tencent is pleased to support the open source community by making tRPC available.
     2  //
     3  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     4  // All rights reserved.
     5  //
     6  // If you have downloaded a copy of the tRPC source code from Tencent,
     7  // please note that tRPC source code is licensed under the  Apache 2.0 License,
     8  // A copy of the Apache 2.0 License is included in this file.
     9  
    10  package plugin
    11  
    12  import (
    13  	"fmt"
    14  	"go/ast"
    15  	"go/parser"
    16  	"go/token"
    17  	"io"
    18  	"os"
    19  	"path/filepath"
    20  	"regexp"
    21  	"strings"
    22  
    23  	"github.com/iancoleman/strcase"
    24  	"github.com/jhump/protoreflect/desc"
    25  	"google.golang.org/protobuf/proto"
    26  	"google.golang.org/protobuf/types/descriptorpb"
    27  
    28  	trpc "trpc.group/trpc/trpc-protocol/pb/go/trpc/proto"
    29  
    30  	"trpc.group/trpc-go/trpc-cmdline/descriptor"
    31  	"trpc.group/trpc-go/trpc-cmdline/params"
    32  	tparser "trpc.group/trpc-go/trpc-cmdline/parser"
    33  	"trpc.group/trpc-go/trpc-cmdline/util/fs"
    34  	"trpc.group/trpc-go/trpc-cmdline/util/log"
    35  )
    36  
    37  var (
    38  	regexpInject          = regexp.MustCompile("`.+`$")
    39  	regexpTags            = regexp.MustCompile(`[\w_]+:"[^"]+"`)
    40  	regexpProtobufTagName = regexp.MustCompile(`protobuf:"[\w_\-,]+name=([\w_\-]+)`)
    41  )
    42  
    43  // textArea records tag text position and tag info in *.pb.go file
    44  type textArea struct {
    45  	StartPos   int
    46  	EndPos     int
    47  	CurrentTag string
    48  	NewTag     string
    49  }
    50  
    51  // GoTag generates go tag by proto field options
    52  type GoTag struct {
    53  }
    54  
    55  // Name return plugin's name
    56  func (p *GoTag) Name() string {
    57  	return "gotag"
    58  }
    59  
    60  // Check only run when `--lang=go && --go_tag=true`
    61  func (p *GoTag) Check(fd *descriptor.FileDescriptor, opt *params.Option) bool {
    62  	if opt.Language == "go" && opt.Gotag {
    63  		return true
    64  	}
    65  	return false
    66  }
    67  
    68  // Run exec go tag plugin
    69  func (p *GoTag) Run(fd *descriptor.FileDescriptor, opt *params.Option) error {
    70  	tags := optTagsFromProto(fd.FD)
    71  	if len(tags) == 0 {
    72  		return nil
    73  	}
    74  
    75  	outputdir := opt.OutputDir
    76  	pbfile := ""
    77  	pbname := fs.BaseNameWithoutExt(fd.FilePath) + ".pb.go"
    78  
    79  	if opt.RPCOnly {
    80  		pbfile = filepath.Join(outputdir, pbname)
    81  	} else {
    82  		importPath, err := tparser.GetPbPackage(fd, "go_package")
    83  		if err != nil {
    84  			return err
    85  		}
    86  		pbfile = filepath.Join(outputdir, "stub", importPath, pbname)
    87  	}
    88  
    89  	return p.replaceTags(pbfile, tags)
    90  }
    91  
    92  func (p *GoTag) replaceTags(pbfile string, tags map[string]string) error {
    93  	_, err := os.Lstat(pbfile)
    94  	if err != nil {
    95  		return err
    96  	}
    97  	areas, err := tagAreasFromPBFile(pbfile, tags)
    98  	if err != nil {
    99  		return err
   100  	}
   101  	if err = injectTagsToPBFile(pbfile, areas); err != nil {
   102  		return err
   103  	}
   104  	return nil
   105  }
   106  
   107  // optTagsFromProto parses field go tag option from proto file and maps it as a kv map
   108  // map structure should be like `messageName_fieldName`
   109  func optTagsFromProto(fd descriptor.Desc) map[string]string {
   110  	tagmap := make(map[string]string)
   111  	var scanNestedMsgFunc func(*desc.MessageDescriptor, string)
   112  	scanNestedMsgFunc = func(m *desc.MessageDescriptor, prefix string) {
   113  		for _, mm := range m.GetNestedMessageTypes() {
   114  			p := fmtgotagkey(prefix, m.GetName())
   115  			scanNestedMsgFunc(mm, p)
   116  		}
   117  		for _, field := range m.GetFields() {
   118  			tags := getGoTag(field.GetFieldOptions())
   119  			if tags == "" {
   120  				continue
   121  			}
   122  			key := fmtgotagkey(prefix, m.GetName(), field.GetName())
   123  			tagmap[key] = tags
   124  		}
   125  	}
   126  	for _, msg := range fd.GetMessageTypes() {
   127  		messageDescriptor, ok := msg.(*descriptor.ProtoMessageDescriptor)
   128  		if !ok {
   129  			continue
   130  		}
   131  		md := messageDescriptor.MD
   132  		scanNestedMsgFunc(md, "")
   133  	}
   134  	return tagmap
   135  }
   136  
   137  func getGoTag(opts *descriptorpb.FieldOptions) string {
   138  	if proto.HasExtension(opts, trpc.E_GoTag) {
   139  		return proto.GetExtension(opts, trpc.E_GoTag).(string)
   140  	}
   141  	return ""
   142  }
   143  
   144  // fmtgotagkey generates the key for `protoTags` to join the struct name and
   145  // field name by `_`, nested message names would be joined too
   146  func fmtgotagkey(s ...string) string {
   147  	for k, v := range s {
   148  		if v == "" {
   149  			s = append(s[:k], s[k+1:]...)
   150  		}
   151  	}
   152  	return strcase.ToCamel(strings.Join(s, "_"))
   153  }
   154  
   155  // tagAreasFromPBFile parses *.pb.go and records tag positions which need to be replaced
   156  func tagAreasFromPBFile(fp string, newtags map[string]string) (areas []textArea, err error) {
   157  	fset := token.NewFileSet()
   158  	f, err := parser.ParseFile(fset, fp, nil, parser.ParseComments)
   159  	if err != nil {
   160  		return
   161  	}
   162  	for _, decl := range f.Decls {
   163  		// check if is generic declaration
   164  		typeSpec := genTypeSpec(decl)
   165  		// skip if can't get type spec
   166  		if typeSpec == nil {
   167  			continue
   168  		}
   169  		// not a struct, skip
   170  		structDecl, ok := typeSpec.Type.(*ast.StructType)
   171  		if !ok {
   172  			continue
   173  		}
   174  		areas = append(areas, genAreas(structDecl, typeSpec, newtags)...)
   175  	}
   176  	return
   177  }
   178  
   179  func genAreas(structDecl *ast.StructType, typeSpec *ast.TypeSpec, newtags map[string]string) []textArea {
   180  	var areas []textArea
   181  	for _, field := range structDecl.Fields.List {
   182  		if field.Tag == nil {
   183  			continue
   184  		}
   185  		fieldname := protobufTagName(field.Tag.Value)
   186  		if fieldname == "" {
   187  			continue
   188  		}
   189  		structname := typeSpec.Name.String()
   190  		// key = structName_fieldName
   191  		key := fmtgotagkey(structname, fieldname)
   192  		newtag, ok := newtags[key]
   193  		if !ok {
   194  			continue
   195  		}
   196  		currentTag := field.Tag.Value
   197  		areas = append(areas, textArea{
   198  			StartPos:   int(field.Pos()),
   199  			EndPos:     int(field.End()),
   200  			CurrentTag: currentTag[1 : len(currentTag)-1],
   201  			NewTag:     newtag,
   202  		})
   203  	}
   204  	return areas
   205  }
   206  
   207  func genTypeSpec(decl ast.Decl) *ast.TypeSpec {
   208  	genDecl, ok := decl.(*ast.GenDecl)
   209  	if !ok {
   210  		return nil
   211  	}
   212  	var typeSpec *ast.TypeSpec
   213  	for _, spec := range genDecl.Specs {
   214  		if ts, ok := spec.(*ast.TypeSpec); ok {
   215  			typeSpec = ts
   216  			break
   217  		}
   218  	}
   219  	return typeSpec
   220  }
   221  
   222  func protobufTagName(tag string) string {
   223  	matches := regexpProtobufTagName.FindStringSubmatch(tag)
   224  	if len(matches) > 1 {
   225  		return matches[1]
   226  	}
   227  	return ""
   228  }
   229  
   230  // injectTagsToPBFile replaces tags and rewrites the *.pb.go file
   231  func injectTagsToPBFile(fp string, areas []textArea) (err error) {
   232  	f, err := os.Open(fp)
   233  	if err != nil {
   234  		return
   235  	}
   236  	contents, err := io.ReadAll(f)
   237  	if err != nil {
   238  		return
   239  	}
   240  	if err = f.Close(); err != nil {
   241  		return
   242  	}
   243  	return writeTagsToFile(fp, areas, contents, err)
   244  }
   245  
   246  func writeTagsToFile(fp string, areas []textArea, contents []byte, err error) error {
   247  	// inject custom tags from tail of file first to preserve order
   248  	for i := range areas {
   249  		area := areas[len(areas)-i-1]
   250  		log.Debug("inject custom tag %q to expression %q",
   251  			area.NewTag, string(contents[area.StartPos-1:area.EndPos-1]))
   252  		contents = injectGoTag(contents, area)
   253  	}
   254  	if err = os.WriteFile(fp, contents, 0644); err != nil {
   255  		return err
   256  	}
   257  	if len(areas) > 0 {
   258  		log.Debug("file %q is injected with custom tags", fp)
   259  	}
   260  	return nil
   261  }
   262  
   263  func injectGoTag(contents []byte, area textArea) (injected []byte) {
   264  	expr := make([]byte, area.EndPos-area.StartPos)
   265  	copy(expr, contents[area.StartPos-1:area.EndPos-1])
   266  	cti := newGoTagItems(area.CurrentTag)
   267  	iti := newGoTagItems(area.NewTag)
   268  	ti := cti.override(iti)
   269  	expr = regexpInject.ReplaceAll(expr, []byte(fmt.Sprintf("`%s`", ti.format())))
   270  	injected = append(injected, contents[:area.StartPos-1]...)
   271  	injected = append(injected, expr...)
   272  	injected = append(injected, contents[area.EndPos-1:]...)
   273  	return
   274  }
   275  
   276  type goTagItem struct {
   277  	key   string
   278  	value string
   279  }
   280  
   281  type goTagItems []goTagItem
   282  
   283  func (ti goTagItems) format() string {
   284  	tags := []string{}
   285  	for _, item := range ti {
   286  		tags = append(tags, fmt.Sprintf(`%s:%s`, item.key, item.value))
   287  	}
   288  	return strings.Join(tags, " ")
   289  }
   290  
   291  func (ti goTagItems) override(nti goTagItems) goTagItems {
   292  	overridden := []goTagItem{}
   293  	for i := range ti {
   294  		var dup = -1
   295  		for j := range nti {
   296  			if ti[i].key == nti[j].key {
   297  				dup = j
   298  				break
   299  			}
   300  		}
   301  		if dup == -1 {
   302  			overridden = append(overridden, ti[i])
   303  		} else {
   304  			overridden = append(overridden, nti[dup])
   305  			nti = append(nti[:dup], nti[dup+1:]...)
   306  		}
   307  	}
   308  	return append(overridden, nti...)
   309  }
   310  
   311  func newGoTagItems(tag string) goTagItems {
   312  	var items goTagItems
   313  	split := regexpTags.FindAllString(tag, -1)
   314  	for _, t := range split {
   315  		sepPos := strings.Index(t, ":")
   316  		items = append(items, goTagItem{
   317  			key:   t[:sepPos],
   318  			value: t[sepPos+1:],
   319  		})
   320  	}
   321  	return items
   322  }