github.com/jhump/protocompile@v0.0.0-20221021153901-4f6f732835e8/sourceinfo/source_code_info_test.go (about)

     1  package sourceinfo_test
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"io"
     8  	"io/ioutil"
     9  	"testing"
    10  
    11  	"github.com/stretchr/testify/assert"
    12  	"google.golang.org/protobuf/reflect/protodesc"
    13  	"google.golang.org/protobuf/reflect/protoreflect"
    14  	"google.golang.org/protobuf/reflect/protoregistry"
    15  	"google.golang.org/protobuf/types/descriptorpb"
    16  
    17  	"github.com/jhump/protocompile"
    18  	"github.com/jhump/protocompile/linker"
    19  )
    20  
    21  // If true, re-generates the golden output file
    22  const regenerateMode = false
    23  
    24  func TestSourceCodeInfo(t *testing.T) {
    25  	compiler := protocompile.Compiler{
    26  		Resolver: protocompile.WithStandardImports(&protocompile.SourceResolver{
    27  			ImportPaths: []string{"../internal/testprotos"},
    28  		}),
    29  		IncludeSourceInfo: true,
    30  	}
    31  	fds, err := compiler.Compile(context.Background(), "desc_test_comments.proto", "desc_test_complex.proto")
    32  	if !assert.Nil(t, err) {
    33  		return
    34  	}
    35  	// also test that imported files have source code info
    36  	// (desc_test_comments.proto imports desc_test_options.proto)
    37  	importedFd := fds[0].FindImportByPath("desc_test_options.proto")
    38  	if !assert.NotNil(t, importedFd) {
    39  		return
    40  	}
    41  
    42  	// create description of source code info
    43  	// (human readable so diffs in source control are comprehensible)
    44  	var buf bytes.Buffer
    45  	for _, fd := range fds {
    46  		printSourceCodeInfo(fd, &buf)
    47  	}
    48  	printSourceCodeInfo(importedFd, &buf)
    49  	actual := buf.String()
    50  
    51  	if regenerateMode {
    52  		// re-generate the file
    53  		err = ioutil.WriteFile("test-source-info.txt", buf.Bytes(), 0666)
    54  		if !assert.Nil(t, err) {
    55  			return
    56  		}
    57  	}
    58  
    59  	b, err := ioutil.ReadFile("test-source-info.txt")
    60  	if !assert.Nil(t, err) {
    61  		return
    62  	}
    63  	golden := string(b)
    64  
    65  	assert.Equal(t, golden, actual, "wrong source code info")
    66  }
    67  
    68  // NB: this function can be used to manually inspect the source code info for a
    69  // descriptor, in a manner that is much easier to read and check than raw
    70  // descriptor form.
    71  func printSourceCodeInfo(fd linker.File, out io.Writer) {
    72  	fmt.Fprintf(out, "---- %s ----\n", fd.Path())
    73  
    74  	var fdMsg *descriptorpb.FileDescriptorProto
    75  	if r, ok := fd.(linker.Result); ok {
    76  		fdMsg = r.Proto()
    77  	} else {
    78  		fdMsg = protodesc.ToFileDescriptorProto(fd)
    79  	}
    80  
    81  	for i := 0; i < fd.SourceLocations().Len(); i++ {
    82  		loc := fd.SourceLocations().Get(i)
    83  		var buf bytes.Buffer
    84  		findLocation(linker.ResolverFromFile(fd), fdMsg.ProtoReflect(), fdMsg.ProtoReflect().Descriptor(), loc.Path, &buf)
    85  		fmt.Fprintf(out, "\n\n%s:\n", buf.String())
    86  		fmt.Fprintf(out, "%s:%d:%d\n", fd.Path(), loc.StartLine+1, loc.StartColumn+1)
    87  		fmt.Fprintf(out, "%s:%d:%d\n", fd.Path(), loc.EndLine+1, loc.EndColumn+1)
    88  		if len(loc.LeadingDetachedComments) > 0 {
    89  			for i, comment := range loc.LeadingDetachedComments {
    90  				fmt.Fprintf(out, "    Leading detached comment [%d]:\n%s\n", i, comment)
    91  			}
    92  		}
    93  		if loc.LeadingComments != "" {
    94  			fmt.Fprintf(out, "    Leading comments:\n%s\n", loc.LeadingComments)
    95  		}
    96  		if loc.TrailingComments != "" {
    97  			fmt.Fprintf(out, "    Trailing comments:\n%s\n", loc.TrailingComments)
    98  		}
    99  	}
   100  }
   101  
   102  func findLocation(res protoregistry.ExtensionTypeResolver, msg protoreflect.Message, md protoreflect.MessageDescriptor, path []int32, buf *bytes.Buffer) {
   103  	if len(path) == 0 {
   104  		return
   105  	}
   106  
   107  	tag := protoreflect.FieldNumber(path[0])
   108  	fld := md.Fields().ByNumber(tag)
   109  	if fld == nil {
   110  		ext, err := res.FindExtensionByNumber(md.FullName(), tag)
   111  		if err != nil {
   112  			panic(fmt.Sprintf("could not find field with tag %d in message of type %s", path[0], msg.Descriptor().FullName()))
   113  		}
   114  		fld = ext.TypeDescriptor()
   115  	}
   116  
   117  	fmt.Fprintf(buf, " > %s", fld.Name())
   118  	path = path[1:]
   119  	idx := -1
   120  	if fld.Cardinality() == protoreflect.Repeated && len(path) > 0 {
   121  		idx = int(path[0])
   122  		fmt.Fprintf(buf, "[%d]", path[0])
   123  		path = path[1:]
   124  	}
   125  
   126  	if len(path) > 0 {
   127  		var next protoreflect.Message
   128  		if msg != nil {
   129  			fldVal := msg.Get(fld)
   130  			if idx >= 0 {
   131  				l := fldVal.List()
   132  				if idx < l.Len() {
   133  					next = l.Get(idx).Message()
   134  				}
   135  			} else {
   136  				next = fldVal.Message()
   137  			}
   138  		}
   139  
   140  		if next == nil && msg != nil {
   141  			buf.WriteString(" !!! ")
   142  		}
   143  
   144  		findLocation(res, next, fld.Message(), path, buf)
   145  	}
   146  }