github.com/hoveychen/protoreflect@v1.4.7-0.20221103114119-0b4b3385ec76/desc/protoparse/source_code_info_test.go (about)

     1  package protoparse
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io"
     7  	"io/ioutil"
     8  	"testing"
     9  
    10  	"github.com/golang/protobuf/proto"
    11  
    12  	"github.com/hoveychen/protoreflect/desc"
    13  	"github.com/hoveychen/protoreflect/dynamic"
    14  	"github.com/hoveychen/protoreflect/internal/testutil"
    15  )
    16  
    17  // If true, re-generates the golden output file
    18  const regenerateMode = false
    19  
    20  func TestSourceCodeInfo(t *testing.T) {
    21  	p := Parser{ImportPaths: []string{"../../internal/testprotos"}, IncludeSourceCodeInfo: true}
    22  	fds, err := p.ParseFiles("desc_test_comments.proto")
    23  	testutil.Ok(t, err)
    24  	fd := fds[0]
    25  	// also test that imported files have source code info
    26  	// (desc_test_comments.proto imports desc_test_options.proto)
    27  	var importedFd *desc.FileDescriptor
    28  	for _, dep := range fd.GetDependencies() {
    29  		if dep.GetName() == "desc_test_options.proto" {
    30  			importedFd = dep
    31  			break
    32  		}
    33  	}
    34  	testutil.Require(t, importedFd != nil)
    35  
    36  	// create description of source code info
    37  	// (human readable so diffs in source control are comprehensible)
    38  	var buf bytes.Buffer
    39  	printSourceCodeInfo(t, fd, &buf)
    40  	printSourceCodeInfo(t, importedFd, &buf)
    41  	actual := buf.String()
    42  
    43  	if regenerateMode {
    44  		// re-generate the file
    45  		err = ioutil.WriteFile("test-source-info.txt", buf.Bytes(), 0666)
    46  		testutil.Ok(t, err)
    47  	}
    48  
    49  	b, err := ioutil.ReadFile("test-source-info.txt")
    50  	testutil.Ok(t, err)
    51  	golden := string(b)
    52  
    53  	testutil.Eq(t, golden, actual, "wrong source code info")
    54  }
    55  
    56  // NB: this function can be used to manually inspect the source code info for a
    57  // descriptor, in a manner that is much easier to read and check than raw
    58  // descriptor form.
    59  func printSourceCodeInfo(t *testing.T, fd *desc.FileDescriptor, out io.Writer) {
    60  	fmt.Fprintf(out, "---- %s ----\n", fd.GetName())
    61  	md, err := desc.LoadMessageDescriptorForMessage(fd.AsProto())
    62  	testutil.Ok(t, err)
    63  	er := &dynamic.ExtensionRegistry{}
    64  	er.AddExtensionsFromFileRecursively(fd)
    65  	mf := dynamic.NewMessageFactoryWithExtensionRegistry(er)
    66  	dfd := mf.NewDynamicMessage(md)
    67  	err = dfd.ConvertFrom(fd.AsProto())
    68  	testutil.Ok(t, err)
    69  
    70  	for _, loc := range fd.AsFileDescriptorProto().GetSourceCodeInfo().GetLocation() {
    71  		var buf bytes.Buffer
    72  		findLocation(mf, dfd, md, loc.Path, &buf)
    73  		fmt.Fprintf(out, "\n\n%s:\n", buf.String())
    74  		if len(loc.Span) == 3 {
    75  			fmt.Fprintf(out, "%s:%d:%d\n", fd.GetName(), loc.Span[0]+1, loc.Span[1]+1)
    76  			fmt.Fprintf(out, "%s:%d:%d\n", fd.GetName(), loc.Span[0]+1, loc.Span[2]+1)
    77  		} else {
    78  			fmt.Fprintf(out, "%s:%d:%d\n", fd.GetName(), loc.Span[0]+1, loc.Span[1]+1)
    79  			fmt.Fprintf(out, "%s:%d:%d\n", fd.GetName(), loc.Span[2]+1, loc.Span[3]+1)
    80  		}
    81  		if len(loc.LeadingDetachedComments) > 0 {
    82  			for i, comment := range loc.LeadingDetachedComments {
    83  				fmt.Fprintf(out, "    Leading detached comment [%d]:\n%s\n", i, comment)
    84  			}
    85  		}
    86  		if loc.LeadingComments != nil {
    87  			fmt.Fprintf(out, "    Leading comments:\n%s\n", loc.GetLeadingComments())
    88  		}
    89  		if loc.TrailingComments != nil {
    90  			fmt.Fprintf(out, "    Trailing comments:\n%s\n", loc.GetTrailingComments())
    91  		}
    92  	}
    93  }
    94  
    95  func findLocation(mf *dynamic.MessageFactory, msg *dynamic.Message, md *desc.MessageDescriptor, path []int32, buf *bytes.Buffer) {
    96  	if len(path) == 0 {
    97  		return
    98  	}
    99  
   100  	var fld *desc.FieldDescriptor
   101  	if msg != nil {
   102  		fld = msg.FindFieldDescriptor(path[0])
   103  	} else {
   104  		fld = md.FindFieldByNumber(path[0])
   105  		if fld == nil {
   106  			fld = mf.GetExtensionRegistry().FindExtension(md.GetFullyQualifiedName(), path[0])
   107  		}
   108  	}
   109  	if fld == nil {
   110  		panic(fmt.Sprintf("could not find field with tag %d in message of type %s", path[0], msg.XXX_MessageName()))
   111  	}
   112  
   113  	fmt.Fprintf(buf, " > %s", fld.GetName())
   114  	path = path[1:]
   115  	idx := -1
   116  	if fld.IsRepeated() && len(path) > 0 {
   117  		idx = int(path[0])
   118  		fmt.Fprintf(buf, "[%d]", path[0])
   119  		path = path[1:]
   120  	}
   121  
   122  	if len(path) > 0 {
   123  		var next proto.Message
   124  		if msg != nil {
   125  			if idx >= 0 {
   126  				if idx < msg.FieldLength(fld) {
   127  					next = msg.GetRepeatedField(fld, idx).(proto.Message)
   128  				}
   129  			} else {
   130  				if m, ok := msg.GetField(fld).(proto.Message); ok {
   131  					next = m
   132  				} else {
   133  					panic(fmt.Sprintf("path traverses into non-message type %T: %s -> %v", msg.GetField(fld), buf.String(), path))
   134  				}
   135  			}
   136  		}
   137  
   138  		if next == nil && msg != nil {
   139  			buf.WriteString(" !!! ")
   140  		}
   141  
   142  		if dm, ok := next.(*dynamic.Message); ok || next == nil {
   143  			findLocation(mf, dm, fld.GetMessageType(), path, buf)
   144  		} else {
   145  			dm := mf.NewDynamicMessage(fld.GetMessageType())
   146  			err := dm.ConvertFrom(next)
   147  			if err != nil {
   148  				panic(err.Error())
   149  			}
   150  			findLocation(mf, dm, fld.GetMessageType(), path, buf)
   151  		}
   152  	}
   153  }