github.com/syumai/protoreflect@v1.7.1-0.20200810020253-2ac7e3b3a321/desc/protoparse/source_code_info_test.go (about)

     1  package protoparse
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io"
     7  	"io/ioutil"
     8  	"testing"
     9  
    10  	"github.com/golang/protobuf/proto"
    11  
    12  	"github.com/syumai/protoreflect/desc"
    13  	"github.com/syumai/protoreflect/dynamic"
    14  	"github.com/syumai/protoreflect/internal/testutil"
    15  )
    16  
    17  // If true, re-generates the golden output file
    18  const regenerateMode = false
    19  
    20  func TestSourceCodeInfo(t *testing.T) {
    21  	p := Parser{ImportPaths: []string{"../../internal/testprotos"}, IncludeSourceCodeInfo: true}
    22  	fds, err := p.ParseFiles("desc_test_comments.proto", "desc_test_complex.proto")
    23  	testutil.Ok(t, err)
    24  	// also test that imported files have source code info
    25  	// (desc_test_comments.proto imports desc_test_options.proto)
    26  	var importedFd *desc.FileDescriptor
    27  	for _, dep := range fds[0].GetDependencies() {
    28  		if dep.GetName() == "desc_test_options.proto" {
    29  			importedFd = dep
    30  			break
    31  		}
    32  	}
    33  	testutil.Require(t, importedFd != nil)
    34  
    35  	// create description of source code info
    36  	// (human readable so diffs in source control are comprehensible)
    37  	var buf bytes.Buffer
    38  	for _, fd := range fds {
    39  		printSourceCodeInfo(t, fd, &buf)
    40  	}
    41  	printSourceCodeInfo(t, importedFd, &buf)
    42  	actual := buf.String()
    43  
    44  	if regenerateMode {
    45  		// re-generate the file
    46  		err = ioutil.WriteFile("test-source-info.txt", buf.Bytes(), 0666)
    47  		testutil.Ok(t, err)
    48  	}
    49  
    50  	b, err := ioutil.ReadFile("test-source-info.txt")
    51  	testutil.Ok(t, err)
    52  	golden := string(b)
    53  
    54  	testutil.Eq(t, golden, actual, "wrong source code info")
    55  }
    56  
    57  // NB: this function can be used to manually inspect the source code info for a
    58  // descriptor, in a manner that is much easier to read and check than raw
    59  // descriptor form.
    60  func printSourceCodeInfo(t *testing.T, fd *desc.FileDescriptor, out io.Writer) {
    61  	fmt.Fprintf(out, "---- %s ----\n", fd.GetName())
    62  	md, err := desc.LoadMessageDescriptorForMessage(fd.AsProto())
    63  	testutil.Ok(t, err)
    64  	er := &dynamic.ExtensionRegistry{}
    65  	er.AddExtensionsFromFileRecursively(fd)
    66  	mf := dynamic.NewMessageFactoryWithExtensionRegistry(er)
    67  	dfd := mf.NewDynamicMessage(md)
    68  	err = dfd.ConvertFrom(fd.AsProto())
    69  	testutil.Ok(t, err)
    70  
    71  	for _, loc := range fd.AsFileDescriptorProto().GetSourceCodeInfo().GetLocation() {
    72  		var buf bytes.Buffer
    73  		findLocation(mf, dfd, md, loc.Path, &buf)
    74  		fmt.Fprintf(out, "\n\n%s:\n", buf.String())
    75  		if len(loc.Span) == 3 {
    76  			fmt.Fprintf(out, "%s:%d:%d\n", fd.GetName(), loc.Span[0]+1, loc.Span[1]+1)
    77  			fmt.Fprintf(out, "%s:%d:%d\n", fd.GetName(), loc.Span[0]+1, loc.Span[2]+1)
    78  		} else {
    79  			fmt.Fprintf(out, "%s:%d:%d\n", fd.GetName(), loc.Span[0]+1, loc.Span[1]+1)
    80  			fmt.Fprintf(out, "%s:%d:%d\n", fd.GetName(), loc.Span[2]+1, loc.Span[3]+1)
    81  		}
    82  		if len(loc.LeadingDetachedComments) > 0 {
    83  			for i, comment := range loc.LeadingDetachedComments {
    84  				fmt.Fprintf(out, "    Leading detached comment [%d]:\n%s\n", i, comment)
    85  			}
    86  		}
    87  		if loc.LeadingComments != nil {
    88  			fmt.Fprintf(out, "    Leading comments:\n%s\n", loc.GetLeadingComments())
    89  		}
    90  		if loc.TrailingComments != nil {
    91  			fmt.Fprintf(out, "    Trailing comments:\n%s\n", loc.GetTrailingComments())
    92  		}
    93  	}
    94  }
    95  
    96  func findLocation(mf *dynamic.MessageFactory, msg *dynamic.Message, md *desc.MessageDescriptor, path []int32, buf *bytes.Buffer) {
    97  	if len(path) == 0 {
    98  		return
    99  	}
   100  
   101  	var fld *desc.FieldDescriptor
   102  	if msg != nil {
   103  		fld = msg.FindFieldDescriptor(path[0])
   104  	} else {
   105  		fld = md.FindFieldByNumber(path[0])
   106  		if fld == nil {
   107  			fld = mf.GetExtensionRegistry().FindExtension(md.GetFullyQualifiedName(), path[0])
   108  		}
   109  	}
   110  	if fld == nil {
   111  		panic(fmt.Sprintf("could not find field with tag %d in message of type %s", path[0], msg.XXX_MessageName()))
   112  	}
   113  
   114  	fmt.Fprintf(buf, " > %s", fld.GetName())
   115  	path = path[1:]
   116  	idx := -1
   117  	if fld.IsRepeated() && len(path) > 0 {
   118  		idx = int(path[0])
   119  		fmt.Fprintf(buf, "[%d]", path[0])
   120  		path = path[1:]
   121  	}
   122  
   123  	if len(path) > 0 {
   124  		var next proto.Message
   125  		if msg != nil {
   126  			if idx >= 0 {
   127  				if idx < msg.FieldLength(fld) {
   128  					next = msg.GetRepeatedField(fld, idx).(proto.Message)
   129  				}
   130  			} else {
   131  				if m, ok := msg.GetField(fld).(proto.Message); ok {
   132  					next = m
   133  				} else {
   134  					panic(fmt.Sprintf("path traverses into non-message type %T: %s -> %v", msg.GetField(fld), buf.String(), path))
   135  				}
   136  			}
   137  		}
   138  
   139  		if next == nil && msg != nil {
   140  			buf.WriteString(" !!! ")
   141  		}
   142  
   143  		if dm, ok := next.(*dynamic.Message); ok || next == nil {
   144  			findLocation(mf, dm, fld.GetMessageType(), path, buf)
   145  		} else {
   146  			dm := mf.NewDynamicMessage(fld.GetMessageType())
   147  			err := dm.ConvertFrom(next)
   148  			if err != nil {
   149  				panic(err.Error())
   150  			}
   151  			findLocation(mf, dm, fld.GetMessageType(), path, buf)
   152  		}
   153  	}
   154  }