github.com/Big-big-orange/protoreflect@v0.0.0-20240408141420-285cedfdf6a4/desc/protoparse/source_code_info_test.go (about)

     1  package protoparse
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io"
     7  	"io/ioutil"
     8  	"testing"
     9  
    10  	"google.golang.org/protobuf/reflect/protoreflect"
    11  	"google.golang.org/protobuf/reflect/protoregistry"
    12  
    13  	"github.com/Big-big-orange/protoreflect/desc"
    14  	"github.com/Big-big-orange/protoreflect/desc/internal"
    15  	"github.com/Big-big-orange/protoreflect/internal/testutil"
    16  )
    17  
    18  // If true, re-generates the golden output file
    19  const regenerateMode = false
    20  
    21  func TestSourceCodeInfo(t *testing.T) {
    22  	p := Parser{ImportPaths: []string{"../../internal/testprotos"}, IncludeSourceCodeInfo: true}
    23  	fds, err := p.ParseFiles("desc_test_comments.proto", "desc_test_complex.proto")
    24  	testutil.Ok(t, err)
    25  	// also test that imported files have source code info
    26  	// (desc_test_comments.proto imports desc_test_options.proto)
    27  	var importedFd *desc.FileDescriptor
    28  	for _, dep := range fds[0].GetDependencies() {
    29  		if dep.GetName() == "desc_test_options.proto" {
    30  			importedFd = dep
    31  			break
    32  		}
    33  	}
    34  	testutil.Require(t, importedFd != nil)
    35  
    36  	// create description of source code info
    37  	// (human readable so diffs in source control are comprehensible)
    38  	var buf bytes.Buffer
    39  	for _, fd := range fds {
    40  		printSourceCodeInfo(fd, &buf)
    41  	}
    42  	printSourceCodeInfo(importedFd, &buf)
    43  	actual := buf.String()
    44  
    45  	if regenerateMode {
    46  		// re-generate the file
    47  		err = ioutil.WriteFile("test-source-info.txt", buf.Bytes(), 0666)
    48  		testutil.Ok(t, err)
    49  	}
    50  
    51  	b, err := ioutil.ReadFile("test-source-info.txt")
    52  	testutil.Ok(t, err)
    53  	golden := string(b)
    54  
    55  	testutil.Eq(t, golden, actual, "wrong source code info")
    56  }
    57  
    58  // NB: this function can be used to manually inspect the source code info for a
    59  // descriptor, in a manner that is much easier to read and check than raw
    60  // descriptor form.
    61  func printSourceCodeInfo(fd *desc.FileDescriptor, out io.Writer) {
    62  	_, _ = fmt.Fprintf(out, "---- %s ----\n", fd.GetName())
    63  	msg := fd.AsFileDescriptorProto().ProtoReflect()
    64  	var reg protoregistry.Types
    65  	internal.RegisterExtensionsVisibleToFile(&reg, fd.UnwrapFile())
    66  
    67  	for _, loc := range fd.AsFileDescriptorProto().GetSourceCodeInfo().GetLocation() {
    68  		var buf bytes.Buffer
    69  		findLocation(msg, &reg, loc.Path, &buf)
    70  		_, _ = fmt.Fprintf(out, "\n\n%s:\n", buf.String())
    71  		if len(loc.Span) == 3 {
    72  			_, _ = fmt.Fprintf(out, "%s:%d:%d\n", fd.GetName(), loc.Span[0]+1, loc.Span[1]+1)
    73  			_, _ = fmt.Fprintf(out, "%s:%d:%d\n", fd.GetName(), loc.Span[0]+1, loc.Span[2]+1)
    74  		} else {
    75  			_, _ = fmt.Fprintf(out, "%s:%d:%d\n", fd.GetName(), loc.Span[0]+1, loc.Span[1]+1)
    76  			_, _ = fmt.Fprintf(out, "%s:%d:%d\n", fd.GetName(), loc.Span[2]+1, loc.Span[3]+1)
    77  		}
    78  		if len(loc.LeadingDetachedComments) > 0 {
    79  			for i, comment := range loc.LeadingDetachedComments {
    80  				_, _ = fmt.Fprintf(out, "    Leading detached comment [%d]:\n%s\n", i, comment)
    81  			}
    82  		}
    83  		if loc.LeadingComments != nil {
    84  			_, _ = fmt.Fprintf(out, "    Leading comments:\n%s\n", loc.GetLeadingComments())
    85  		}
    86  		if loc.TrailingComments != nil {
    87  			_, _ = fmt.Fprintf(out, "    Trailing comments:\n%s\n", loc.GetTrailingComments())
    88  		}
    89  	}
    90  }
    91  
    92  func findLocation(msg protoreflect.Message, reg protoregistry.ExtensionTypeResolver, path []int32, buf *bytes.Buffer) {
    93  	if len(path) == 0 {
    94  		return
    95  	}
    96  
    97  	fieldNumber := protoreflect.FieldNumber(path[0])
    98  	md := msg.Descriptor()
    99  	fld := md.Fields().ByNumber(fieldNumber)
   100  	if fld == nil {
   101  		xt, err := reg.FindExtensionByNumber(md.FullName(), fieldNumber)
   102  		if err == nil {
   103  			fld = xt.TypeDescriptor()
   104  		}
   105  	}
   106  	if fld == nil {
   107  		panic(fmt.Sprintf("could not find field with tag %d in message of type %s", path[0], md.FullName()))
   108  	}
   109  
   110  	var name string
   111  	if fld.IsExtension() {
   112  		name = "(" + string(fld.FullName()) + ")"
   113  	} else {
   114  		name = string(fld.Name())
   115  	}
   116  	_, _ = fmt.Fprintf(buf, " > %s", name)
   117  	path = path[1:]
   118  	idx := -1
   119  	if fld.Cardinality() == protoreflect.Repeated && len(path) > 0 {
   120  		idx = int(path[0])
   121  		_, _ = fmt.Fprintf(buf, "[%d]", path[0])
   122  		path = path[1:]
   123  	}
   124  
   125  	if len(path) > 0 {
   126  		if fld.Kind() != protoreflect.MessageKind && fld.Kind() != protoreflect.GroupKind {
   127  			panic(fmt.Sprintf("path indicates tag %d, but field %v is %v, not a message", path[0], name, fld.Kind()))
   128  		}
   129  		var present bool
   130  		var next protoreflect.Message
   131  		if idx == -1 {
   132  			present = msg.Has(fld)
   133  			next = msg.Get(fld).Message()
   134  		} else {
   135  			list := msg.Get(fld).List()
   136  			present = idx < list.Len()
   137  			if present {
   138  				next = list.Get(idx).Message()
   139  			} else {
   140  				next = list.NewElement().Message()
   141  			}
   142  		}
   143  
   144  		if !present {
   145  			buf.WriteString(" !!! ")
   146  		}
   147  
   148  		findLocation(next, reg, path, buf)
   149  	}
   150  }