github.com/jhump/protocompile@v0.0.0-20221021153901-4f6f732835e8/parser/parser_test.go (about) 1 package parser 2 3 import ( 4 "bytes" 5 "fmt" 6 "io" 7 "os" 8 "path/filepath" 9 "strings" 10 "testing" 11 12 "github.com/stretchr/testify/assert" 13 "github.com/stretchr/testify/require" 14 "google.golang.org/protobuf/types/descriptorpb" 15 16 "github.com/jhump/protocompile/reporter" 17 ) 18 19 func TestEmptyParse(t *testing.T) { 20 errHandler := reporter.NewHandler(nil) 21 ast, err := Parse("foo.proto", bytes.NewReader(nil), errHandler) 22 assert.Nil(t, err) 23 result, err := ResultFromAST(ast, true, errHandler) 24 assert.Nil(t, err) 25 fd := result.Proto() 26 assert.Equal(t, "foo.proto", fd.GetName()) 27 assert.Equal(t, 0, len(fd.GetDependency())) 28 assert.Equal(t, 0, len(fd.GetMessageType())) 29 assert.Equal(t, 0, len(fd.GetEnumType())) 30 assert.Equal(t, 0, len(fd.GetExtension())) 31 assert.Equal(t, 0, len(fd.GetService())) 32 } 33 34 func TestJunkParse(t *testing.T) { 35 errHandler := reporter.NewHandler(nil) 36 // inputs that have been found in the past to cause panics by oss-fuzz 37 inputs := map[string]string{ 38 "case-34232": `'';`, 39 "case-34238": `.`, 40 } 41 for name, input := range inputs { 42 protoName := fmt.Sprintf("%s.proto", name) 43 _, err := Parse(protoName, strings.NewReader(input), errHandler) 44 // we expect this to error... but we don't want it to panic 45 assert.NotNil(t, err, "junk input should have returned error") 46 t.Logf("error from parse: %v", err) 47 } 48 } 49 50 func TestSimpleParse(t *testing.T) { 51 protos := map[string]Result{} 52 53 // Just verify that we can successfully parse the same files we use for 54 // testing. We do a *very* shallow check of what was parsed because we know 55 // it won't be fully correct until after linking. (So that will be tested 56 // below, where we parse *and* link.) 57 res, err := parseFileForTest("../internal/testprotos/desc_test1.proto") 58 if assert.Nil(t, err, "%v", err) { 59 fd := res.Proto() 60 assert.Equal(t, "../internal/testprotos/desc_test1.proto", fd.GetName()) 61 assert.Equal(t, "testprotos", fd.GetPackage()) 62 assert.True(t, hasExtension(fd, "xtm")) 63 assert.True(t, hasMessage(fd, "TestMessage")) 64 protos[fd.GetName()] = res 65 } 66 67 res, err = parseFileForTest("../internal/testprotos/desc_test2.proto") 68 if assert.Nil(t, err, "%v", err) { 69 fd := res.Proto() 70 assert.Equal(t, "../internal/testprotos/desc_test2.proto", fd.GetName()) 71 assert.Equal(t, "testprotos", fd.GetPackage()) 72 assert.True(t, hasExtension(fd, "groupx")) 73 assert.True(t, hasMessage(fd, "GroupX")) 74 assert.True(t, hasMessage(fd, "Frobnitz")) 75 protos[fd.GetName()] = res 76 } 77 78 res, err = parseFileForTest("../internal/testprotos/desc_test_defaults.proto") 79 if assert.Nil(t, err, "%v", err) { 80 fd := res.Proto() 81 assert.Equal(t, "../internal/testprotos/desc_test_defaults.proto", fd.GetName()) 82 assert.Equal(t, "testprotos", fd.GetPackage()) 83 assert.True(t, hasMessage(fd, "PrimitiveDefaults")) 84 protos[fd.GetName()] = res 85 } 86 87 res, err = parseFileForTest("../internal/testprotos/desc_test_field_types.proto") 88 if assert.Nil(t, err, "%v", err) { 89 fd := res.Proto() 90 assert.Equal(t, "../internal/testprotos/desc_test_field_types.proto", fd.GetName()) 91 assert.Equal(t, "testprotos", fd.GetPackage()) 92 assert.True(t, hasEnum(fd, "TestEnum")) 93 assert.True(t, hasMessage(fd, "UnaryFields")) 94 protos[fd.GetName()] = res 95 } 96 97 res, err = parseFileForTest("../internal/testprotos/desc_test_options.proto") 98 if assert.Nil(t, err, "%v", err) { 99 fd := res.Proto() 100 assert.Equal(t, "../internal/testprotos/desc_test_options.proto", fd.GetName()) 101 assert.Equal(t, "testprotos", fd.GetPackage()) 102 assert.True(t, hasExtension(fd, "mfubar")) 103 assert.True(t, hasEnum(fd, "ReallySimpleEnum")) 104 assert.True(t, hasMessage(fd, "ReallySimpleMessage")) 105 protos[fd.GetName()] = res 106 } 107 108 res, err = parseFileForTest("../internal/testprotos/desc_test_proto3.proto") 109 if assert.Nil(t, err, "%v", err) { 110 fd := res.Proto() 111 assert.Equal(t, "../internal/testprotos/desc_test_proto3.proto", fd.GetName()) 112 assert.Equal(t, "testprotos", fd.GetPackage()) 113 assert.True(t, hasEnum(fd, "Proto3Enum")) 114 assert.True(t, hasService(fd, "TestService")) 115 protos[fd.GetName()] = res 116 } 117 118 res, err = parseFileForTest("../internal/testprotos/desc_test_wellknowntypes.proto") 119 if assert.Nil(t, err, "%v", err) { 120 fd := res.Proto() 121 assert.Equal(t, "../internal/testprotos/desc_test_wellknowntypes.proto", fd.GetName()) 122 assert.Equal(t, "testprotos", fd.GetPackage()) 123 assert.True(t, hasMessage(fd, "TestWellKnownTypes")) 124 protos[fd.GetName()] = res 125 } 126 127 res, err = parseFileForTest("../internal/testprotos/nopkg/desc_test_nopkg.proto") 128 if assert.Nil(t, err, "%v", err) { 129 fd := res.Proto() 130 assert.Equal(t, "../internal/testprotos/nopkg/desc_test_nopkg.proto", fd.GetName()) 131 assert.Equal(t, "", fd.GetPackage()) 132 protos[fd.GetName()] = res 133 } 134 135 res, err = parseFileForTest("../internal/testprotos/nopkg/desc_test_nopkg_new.proto") 136 if assert.Nil(t, err, "%v", err) { 137 fd := res.Proto() 138 assert.Equal(t, "../internal/testprotos/nopkg/desc_test_nopkg_new.proto", fd.GetName()) 139 assert.Equal(t, "", fd.GetPackage()) 140 assert.True(t, hasMessage(fd, "TopLevel")) 141 protos[fd.GetName()] = res 142 } 143 144 res, err = parseFileForTest("../internal/testprotos/pkg/desc_test_pkg.proto") 145 if assert.Nil(t, err, "%v", err) { 146 fd := res.Proto() 147 assert.Equal(t, "../internal/testprotos/pkg/desc_test_pkg.proto", fd.GetName()) 148 assert.Equal(t, "jhump.protocompile.test", fd.GetPackage()) 149 assert.True(t, hasEnum(fd, "Foo")) 150 assert.True(t, hasMessage(fd, "Bar")) 151 protos[fd.GetName()] = res 152 } 153 } 154 155 func parseFileForTest(filename string) (Result, error) { 156 f, err := os.Open(filename) 157 if err != nil { 158 return nil, err 159 } 160 defer func() { 161 _ = f.Close() 162 }() 163 errHandler := reporter.NewHandler(nil) 164 res, err := Parse(filename, f, errHandler) 165 if err != nil { 166 return nil, err 167 } 168 return ResultFromAST(res, true, errHandler) 169 } 170 171 func hasExtension(fd *descriptorpb.FileDescriptorProto, name string) bool { 172 for _, ext := range fd.Extension { 173 if ext.GetName() == name { 174 return true 175 } 176 } 177 return false 178 } 179 180 func hasMessage(fd *descriptorpb.FileDescriptorProto, name string) bool { 181 for _, md := range fd.MessageType { 182 if md.GetName() == name { 183 return true 184 } 185 } 186 return false 187 } 188 189 func hasEnum(fd *descriptorpb.FileDescriptorProto, name string) bool { 190 for _, ed := range fd.EnumType { 191 if ed.GetName() == name { 192 return true 193 } 194 } 195 return false 196 } 197 198 func hasService(fd *descriptorpb.FileDescriptorProto, name string) bool { 199 for _, sd := range fd.Service { 200 if sd.GetName() == name { 201 return true 202 } 203 } 204 return false 205 } 206 207 func TestAggregateValueInUninterpretedOptions(t *testing.T) { 208 res, err := parseFileForTest("../internal/testprotos/desc_test_complex.proto") 209 if !assert.Nil(t, err) { 210 t.FailNow() 211 } 212 fd := res.Proto() 213 214 // service TestTestService, method UserAuth; first option 215 aggregateValue1 := *fd.Service[0].Method[0].Options.UninterpretedOption[0].AggregateValue 216 assert.Equal(t, "authenticated : true permission : { action : LOGIN entity : \"client\" }", aggregateValue1) 217 218 // service TestTestService, method Get; first option 219 aggregateValue2 := *fd.Service[0].Method[1].Options.UninterpretedOption[0].AggregateValue 220 assert.Equal(t, "authenticated : true permission : { action : READ entity : \"user\" }", aggregateValue2) 221 222 // message Another; first option 223 aggregateValue3 := *fd.MessageType[4].Options.UninterpretedOption[0].AggregateValue 224 assert.Equal(t, "foo : \"abc\" s < name : \"foo\" , id : 123 > , array : [ 1 , 2 , 3 ] , r : [ < name : \"f\" > , { name : \"s\" } , { id : 456 } ] ,", aggregateValue3) 225 226 // message Test.Nested._NestedNested; second option (rept) 227 // (Test.Nested is at index 1 instead of 0 because of implicit nested message from map field m) 228 aggregateValue4 := *fd.MessageType[1].NestedType[1].NestedType[0].Options.UninterpretedOption[1].AggregateValue 229 assert.Equal(t, "foo : \"goo\" [ foo . bar . Test . Nested . _NestedNested . _garblez ] : \"boo\"", aggregateValue4) 230 } 231 232 func TestBasicSuccess(t *testing.T) { 233 r := readerForTestdata(t, "largeproto.proto") 234 handler := reporter.NewHandler(nil) 235 236 fileNode, err := Parse("largeproto.proto", r, handler) 237 require.NoError(t, err) 238 239 result, err := ResultFromAST(fileNode, true, handler) 240 require.NoError(t, err) 241 require.NoError(t, handler.Error()) 242 243 assert.Equal(t, "proto3", result.AST().Syntax.Syntax.AsString()) 244 } 245 246 func BenchmarkBasicSuccess(b *testing.B) { 247 r := readerForTestdata(b, "largeproto.proto") 248 bs, err := io.ReadAll(r) 249 require.NoError(b, err) 250 251 b.ResetTimer() 252 for i := 0; i < b.N; i++ { 253 b.ReportAllocs() 254 byteReader := bytes.NewReader(bs) 255 handler := reporter.NewHandler(nil) 256 257 fileNode, err := Parse("largeproto.proto", byteReader, handler) 258 require.NoError(b, err) 259 260 result, err := ResultFromAST(fileNode, true, handler) 261 require.NoError(b, err) 262 require.NoError(b, handler.Error()) 263 264 assert.Equal(b, "proto3", result.AST().Syntax.Syntax.AsString()) 265 } 266 } 267 268 func readerForTestdata(t testing.TB, filename string) io.Reader { 269 file, err := os.Open(filepath.Join("testdata", filename)) 270 require.NoError(t, err) 271 272 return file 273 }