github.com/jhump/protocompile@v0.0.0-20221021153901-4f6f732835e8/sourceinfo/source_code_info_test.go (about) 1 package sourceinfo_test 2 3 import ( 4 "bytes" 5 "context" 6 "fmt" 7 "io" 8 "io/ioutil" 9 "testing" 10 11 "github.com/stretchr/testify/assert" 12 "google.golang.org/protobuf/reflect/protodesc" 13 "google.golang.org/protobuf/reflect/protoreflect" 14 "google.golang.org/protobuf/reflect/protoregistry" 15 "google.golang.org/protobuf/types/descriptorpb" 16 17 "github.com/jhump/protocompile" 18 "github.com/jhump/protocompile/linker" 19 ) 20 21 // If true, re-generates the golden output file 22 const regenerateMode = false 23 24 func TestSourceCodeInfo(t *testing.T) { 25 compiler := protocompile.Compiler{ 26 Resolver: protocompile.WithStandardImports(&protocompile.SourceResolver{ 27 ImportPaths: []string{"../internal/testprotos"}, 28 }), 29 IncludeSourceInfo: true, 30 } 31 fds, err := compiler.Compile(context.Background(), "desc_test_comments.proto", "desc_test_complex.proto") 32 if !assert.Nil(t, err) { 33 return 34 } 35 // also test that imported files have source code info 36 // (desc_test_comments.proto imports desc_test_options.proto) 37 importedFd := fds[0].FindImportByPath("desc_test_options.proto") 38 if !assert.NotNil(t, importedFd) { 39 return 40 } 41 42 // create description of source code info 43 // (human readable so diffs in source control are comprehensible) 44 var buf bytes.Buffer 45 for _, fd := range fds { 46 printSourceCodeInfo(fd, &buf) 47 } 48 printSourceCodeInfo(importedFd, &buf) 49 actual := buf.String() 50 51 if regenerateMode { 52 // re-generate the file 53 err = ioutil.WriteFile("test-source-info.txt", buf.Bytes(), 0666) 54 if !assert.Nil(t, err) { 55 return 56 } 57 } 58 59 b, err := ioutil.ReadFile("test-source-info.txt") 60 if !assert.Nil(t, err) { 61 return 62 } 63 golden := string(b) 64 65 assert.Equal(t, golden, actual, "wrong source code info") 66 } 67 68 // NB: this function can be used to manually inspect the source code info for a 69 // descriptor, in a manner that is much easier to read and check than raw 70 // descriptor form. 71 func printSourceCodeInfo(fd linker.File, out io.Writer) { 72 fmt.Fprintf(out, "---- %s ----\n", fd.Path()) 73 74 var fdMsg *descriptorpb.FileDescriptorProto 75 if r, ok := fd.(linker.Result); ok { 76 fdMsg = r.Proto() 77 } else { 78 fdMsg = protodesc.ToFileDescriptorProto(fd) 79 } 80 81 for i := 0; i < fd.SourceLocations().Len(); i++ { 82 loc := fd.SourceLocations().Get(i) 83 var buf bytes.Buffer 84 findLocation(linker.ResolverFromFile(fd), fdMsg.ProtoReflect(), fdMsg.ProtoReflect().Descriptor(), loc.Path, &buf) 85 fmt.Fprintf(out, "\n\n%s:\n", buf.String()) 86 fmt.Fprintf(out, "%s:%d:%d\n", fd.Path(), loc.StartLine+1, loc.StartColumn+1) 87 fmt.Fprintf(out, "%s:%d:%d\n", fd.Path(), loc.EndLine+1, loc.EndColumn+1) 88 if len(loc.LeadingDetachedComments) > 0 { 89 for i, comment := range loc.LeadingDetachedComments { 90 fmt.Fprintf(out, " Leading detached comment [%d]:\n%s\n", i, comment) 91 } 92 } 93 if loc.LeadingComments != "" { 94 fmt.Fprintf(out, " Leading comments:\n%s\n", loc.LeadingComments) 95 } 96 if loc.TrailingComments != "" { 97 fmt.Fprintf(out, " Trailing comments:\n%s\n", loc.TrailingComments) 98 } 99 } 100 } 101 102 func findLocation(res protoregistry.ExtensionTypeResolver, msg protoreflect.Message, md protoreflect.MessageDescriptor, path []int32, buf *bytes.Buffer) { 103 if len(path) == 0 { 104 return 105 } 106 107 tag := protoreflect.FieldNumber(path[0]) 108 fld := md.Fields().ByNumber(tag) 109 if fld == nil { 110 ext, err := res.FindExtensionByNumber(md.FullName(), tag) 111 if err != nil { 112 panic(fmt.Sprintf("could not find field with tag %d in message of type %s", path[0], msg.Descriptor().FullName())) 113 } 114 fld = ext.TypeDescriptor() 115 } 116 117 fmt.Fprintf(buf, " > %s", fld.Name()) 118 path = path[1:] 119 idx := -1 120 if fld.Cardinality() == protoreflect.Repeated && len(path) > 0 { 121 idx = int(path[0]) 122 fmt.Fprintf(buf, "[%d]", path[0]) 123 path = path[1:] 124 } 125 126 if len(path) > 0 { 127 var next protoreflect.Message 128 if msg != nil { 129 fldVal := msg.Get(fld) 130 if idx >= 0 { 131 l := fldVal.List() 132 if idx < l.Len() { 133 next = l.Get(idx).Message() 134 } 135 } else { 136 next = fldVal.Message() 137 } 138 } 139 140 if next == nil && msg != nil { 141 buf.WriteString(" !!! ") 142 } 143 144 findLocation(res, next, fld.Message(), path, buf) 145 } 146 }