github.com/jhump/protocompile@v0.0.0-20221021153901-4f6f732835e8/parser/parser_test.go (about)

     1  package parser
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io"
     7  	"os"
     8  	"path/filepath"
     9  	"strings"
    10  	"testing"
    11  
    12  	"github.com/stretchr/testify/assert"
    13  	"github.com/stretchr/testify/require"
    14  	"google.golang.org/protobuf/types/descriptorpb"
    15  
    16  	"github.com/jhump/protocompile/reporter"
    17  )
    18  
    19  func TestEmptyParse(t *testing.T) {
    20  	errHandler := reporter.NewHandler(nil)
    21  	ast, err := Parse("foo.proto", bytes.NewReader(nil), errHandler)
    22  	assert.Nil(t, err)
    23  	result, err := ResultFromAST(ast, true, errHandler)
    24  	assert.Nil(t, err)
    25  	fd := result.Proto()
    26  	assert.Equal(t, "foo.proto", fd.GetName())
    27  	assert.Equal(t, 0, len(fd.GetDependency()))
    28  	assert.Equal(t, 0, len(fd.GetMessageType()))
    29  	assert.Equal(t, 0, len(fd.GetEnumType()))
    30  	assert.Equal(t, 0, len(fd.GetExtension()))
    31  	assert.Equal(t, 0, len(fd.GetService()))
    32  }
    33  
    34  func TestJunkParse(t *testing.T) {
    35  	errHandler := reporter.NewHandler(nil)
    36  	// inputs that have been found in the past to cause panics by oss-fuzz
    37  	inputs := map[string]string{
    38  		"case-34232": `'';`,
    39  		"case-34238": `.`,
    40  	}
    41  	for name, input := range inputs {
    42  		protoName := fmt.Sprintf("%s.proto", name)
    43  		_, err := Parse(protoName, strings.NewReader(input), errHandler)
    44  		// we expect this to error... but we don't want it to panic
    45  		assert.NotNil(t, err, "junk input should have returned error")
    46  		t.Logf("error from parse: %v", err)
    47  	}
    48  }
    49  
    50  func TestSimpleParse(t *testing.T) {
    51  	protos := map[string]Result{}
    52  
    53  	// Just verify that we can successfully parse the same files we use for
    54  	// testing. We do a *very* shallow check of what was parsed because we know
    55  	// it won't be fully correct until after linking. (So that will be tested
    56  	// below, where we parse *and* link.)
    57  	res, err := parseFileForTest("../internal/testprotos/desc_test1.proto")
    58  	if assert.Nil(t, err, "%v", err) {
    59  		fd := res.Proto()
    60  		assert.Equal(t, "../internal/testprotos/desc_test1.proto", fd.GetName())
    61  		assert.Equal(t, "testprotos", fd.GetPackage())
    62  		assert.True(t, hasExtension(fd, "xtm"))
    63  		assert.True(t, hasMessage(fd, "TestMessage"))
    64  		protos[fd.GetName()] = res
    65  	}
    66  
    67  	res, err = parseFileForTest("../internal/testprotos/desc_test2.proto")
    68  	if assert.Nil(t, err, "%v", err) {
    69  		fd := res.Proto()
    70  		assert.Equal(t, "../internal/testprotos/desc_test2.proto", fd.GetName())
    71  		assert.Equal(t, "testprotos", fd.GetPackage())
    72  		assert.True(t, hasExtension(fd, "groupx"))
    73  		assert.True(t, hasMessage(fd, "GroupX"))
    74  		assert.True(t, hasMessage(fd, "Frobnitz"))
    75  		protos[fd.GetName()] = res
    76  	}
    77  
    78  	res, err = parseFileForTest("../internal/testprotos/desc_test_defaults.proto")
    79  	if assert.Nil(t, err, "%v", err) {
    80  		fd := res.Proto()
    81  		assert.Equal(t, "../internal/testprotos/desc_test_defaults.proto", fd.GetName())
    82  		assert.Equal(t, "testprotos", fd.GetPackage())
    83  		assert.True(t, hasMessage(fd, "PrimitiveDefaults"))
    84  		protos[fd.GetName()] = res
    85  	}
    86  
    87  	res, err = parseFileForTest("../internal/testprotos/desc_test_field_types.proto")
    88  	if assert.Nil(t, err, "%v", err) {
    89  		fd := res.Proto()
    90  		assert.Equal(t, "../internal/testprotos/desc_test_field_types.proto", fd.GetName())
    91  		assert.Equal(t, "testprotos", fd.GetPackage())
    92  		assert.True(t, hasEnum(fd, "TestEnum"))
    93  		assert.True(t, hasMessage(fd, "UnaryFields"))
    94  		protos[fd.GetName()] = res
    95  	}
    96  
    97  	res, err = parseFileForTest("../internal/testprotos/desc_test_options.proto")
    98  	if assert.Nil(t, err, "%v", err) {
    99  		fd := res.Proto()
   100  		assert.Equal(t, "../internal/testprotos/desc_test_options.proto", fd.GetName())
   101  		assert.Equal(t, "testprotos", fd.GetPackage())
   102  		assert.True(t, hasExtension(fd, "mfubar"))
   103  		assert.True(t, hasEnum(fd, "ReallySimpleEnum"))
   104  		assert.True(t, hasMessage(fd, "ReallySimpleMessage"))
   105  		protos[fd.GetName()] = res
   106  	}
   107  
   108  	res, err = parseFileForTest("../internal/testprotos/desc_test_proto3.proto")
   109  	if assert.Nil(t, err, "%v", err) {
   110  		fd := res.Proto()
   111  		assert.Equal(t, "../internal/testprotos/desc_test_proto3.proto", fd.GetName())
   112  		assert.Equal(t, "testprotos", fd.GetPackage())
   113  		assert.True(t, hasEnum(fd, "Proto3Enum"))
   114  		assert.True(t, hasService(fd, "TestService"))
   115  		protos[fd.GetName()] = res
   116  	}
   117  
   118  	res, err = parseFileForTest("../internal/testprotos/desc_test_wellknowntypes.proto")
   119  	if assert.Nil(t, err, "%v", err) {
   120  		fd := res.Proto()
   121  		assert.Equal(t, "../internal/testprotos/desc_test_wellknowntypes.proto", fd.GetName())
   122  		assert.Equal(t, "testprotos", fd.GetPackage())
   123  		assert.True(t, hasMessage(fd, "TestWellKnownTypes"))
   124  		protos[fd.GetName()] = res
   125  	}
   126  
   127  	res, err = parseFileForTest("../internal/testprotos/nopkg/desc_test_nopkg.proto")
   128  	if assert.Nil(t, err, "%v", err) {
   129  		fd := res.Proto()
   130  		assert.Equal(t, "../internal/testprotos/nopkg/desc_test_nopkg.proto", fd.GetName())
   131  		assert.Equal(t, "", fd.GetPackage())
   132  		protos[fd.GetName()] = res
   133  	}
   134  
   135  	res, err = parseFileForTest("../internal/testprotos/nopkg/desc_test_nopkg_new.proto")
   136  	if assert.Nil(t, err, "%v", err) {
   137  		fd := res.Proto()
   138  		assert.Equal(t, "../internal/testprotos/nopkg/desc_test_nopkg_new.proto", fd.GetName())
   139  		assert.Equal(t, "", fd.GetPackage())
   140  		assert.True(t, hasMessage(fd, "TopLevel"))
   141  		protos[fd.GetName()] = res
   142  	}
   143  
   144  	res, err = parseFileForTest("../internal/testprotos/pkg/desc_test_pkg.proto")
   145  	if assert.Nil(t, err, "%v", err) {
   146  		fd := res.Proto()
   147  		assert.Equal(t, "../internal/testprotos/pkg/desc_test_pkg.proto", fd.GetName())
   148  		assert.Equal(t, "jhump.protocompile.test", fd.GetPackage())
   149  		assert.True(t, hasEnum(fd, "Foo"))
   150  		assert.True(t, hasMessage(fd, "Bar"))
   151  		protos[fd.GetName()] = res
   152  	}
   153  }
   154  
   155  func parseFileForTest(filename string) (Result, error) {
   156  	f, err := os.Open(filename)
   157  	if err != nil {
   158  		return nil, err
   159  	}
   160  	defer func() {
   161  		_ = f.Close()
   162  	}()
   163  	errHandler := reporter.NewHandler(nil)
   164  	res, err := Parse(filename, f, errHandler)
   165  	if err != nil {
   166  		return nil, err
   167  	}
   168  	return ResultFromAST(res, true, errHandler)
   169  }
   170  
   171  func hasExtension(fd *descriptorpb.FileDescriptorProto, name string) bool {
   172  	for _, ext := range fd.Extension {
   173  		if ext.GetName() == name {
   174  			return true
   175  		}
   176  	}
   177  	return false
   178  }
   179  
   180  func hasMessage(fd *descriptorpb.FileDescriptorProto, name string) bool {
   181  	for _, md := range fd.MessageType {
   182  		if md.GetName() == name {
   183  			return true
   184  		}
   185  	}
   186  	return false
   187  }
   188  
   189  func hasEnum(fd *descriptorpb.FileDescriptorProto, name string) bool {
   190  	for _, ed := range fd.EnumType {
   191  		if ed.GetName() == name {
   192  			return true
   193  		}
   194  	}
   195  	return false
   196  }
   197  
   198  func hasService(fd *descriptorpb.FileDescriptorProto, name string) bool {
   199  	for _, sd := range fd.Service {
   200  		if sd.GetName() == name {
   201  			return true
   202  		}
   203  	}
   204  	return false
   205  }
   206  
   207  func TestAggregateValueInUninterpretedOptions(t *testing.T) {
   208  	res, err := parseFileForTest("../internal/testprotos/desc_test_complex.proto")
   209  	if !assert.Nil(t, err) {
   210  		t.FailNow()
   211  	}
   212  	fd := res.Proto()
   213  
   214  	// service TestTestService, method UserAuth; first option
   215  	aggregateValue1 := *fd.Service[0].Method[0].Options.UninterpretedOption[0].AggregateValue
   216  	assert.Equal(t, "authenticated : true permission : { action : LOGIN entity : \"client\" }", aggregateValue1)
   217  
   218  	// service TestTestService, method Get; first option
   219  	aggregateValue2 := *fd.Service[0].Method[1].Options.UninterpretedOption[0].AggregateValue
   220  	assert.Equal(t, "authenticated : true permission : { action : READ entity : \"user\" }", aggregateValue2)
   221  
   222  	// message Another; first option
   223  	aggregateValue3 := *fd.MessageType[4].Options.UninterpretedOption[0].AggregateValue
   224  	assert.Equal(t, "foo : \"abc\" s < name : \"foo\" , id : 123 > , array : [ 1 , 2 , 3 ] , r : [ < name : \"f\" > , { name : \"s\" } , { id : 456 } ] ,", aggregateValue3)
   225  
   226  	// message Test.Nested._NestedNested; second option (rept)
   227  	//  (Test.Nested is at index 1 instead of 0 because of implicit nested message from map field m)
   228  	aggregateValue4 := *fd.MessageType[1].NestedType[1].NestedType[0].Options.UninterpretedOption[1].AggregateValue
   229  	assert.Equal(t, "foo : \"goo\" [ foo . bar . Test . Nested . _NestedNested . _garblez ] : \"boo\"", aggregateValue4)
   230  }
   231  
   232  func TestBasicSuccess(t *testing.T) {
   233  	r := readerForTestdata(t, "largeproto.proto")
   234  	handler := reporter.NewHandler(nil)
   235  
   236  	fileNode, err := Parse("largeproto.proto", r, handler)
   237  	require.NoError(t, err)
   238  
   239  	result, err := ResultFromAST(fileNode, true, handler)
   240  	require.NoError(t, err)
   241  	require.NoError(t, handler.Error())
   242  
   243  	assert.Equal(t, "proto3", result.AST().Syntax.Syntax.AsString())
   244  }
   245  
   246  func BenchmarkBasicSuccess(b *testing.B) {
   247  	r := readerForTestdata(b, "largeproto.proto")
   248  	bs, err := io.ReadAll(r)
   249  	require.NoError(b, err)
   250  
   251  	b.ResetTimer()
   252  	for i := 0; i < b.N; i++ {
   253  		b.ReportAllocs()
   254  		byteReader := bytes.NewReader(bs)
   255  		handler := reporter.NewHandler(nil)
   256  
   257  		fileNode, err := Parse("largeproto.proto", byteReader, handler)
   258  		require.NoError(b, err)
   259  
   260  		result, err := ResultFromAST(fileNode, true, handler)
   261  		require.NoError(b, err)
   262  		require.NoError(b, handler.Error())
   263  
   264  		assert.Equal(b, "proto3", result.AST().Syntax.Syntax.AsString())
   265  	}
   266  }
   267  
   268  func readerForTestdata(t testing.TB, filename string) io.Reader {
   269  	file, err := os.Open(filepath.Join("testdata", filename))
   270  	require.NoError(t, err)
   271  
   272  	return file
   273  }