github.com/jhump/protocompile@v0.0.0-20221021153901-4f6f732835e8/options/options_test.go (about)

     1  package options
     2  
     3  import (
     4  	"fmt"
     5  	"strings"
     6  	"testing"
     7  
     8  	"github.com/stretchr/testify/assert"
     9  	"google.golang.org/protobuf/types/descriptorpb"
    10  
    11  	"github.com/jhump/protocompile/parser"
    12  	"github.com/jhump/protocompile/reporter"
    13  )
    14  
    15  type ident string
    16  type aggregate string
    17  
    18  func TestOptionsInUnlinkedFiles(t *testing.T) {
    19  	testCases := []struct {
    20  		contents         string
    21  		uninterpreted    map[string]interface{}
    22  		checkInterpreted func(*testing.T, *descriptorpb.FileDescriptorProto)
    23  	}{
    24  		{
    25  			// file options
    26  			contents: `option go_package = "foo.bar"; option (must.link) = "FOO";`,
    27  			uninterpreted: map[string]interface{}{
    28  				"test.proto:(must.link)": "FOO",
    29  			},
    30  			checkInterpreted: func(t *testing.T, fd *descriptorpb.FileDescriptorProto) {
    31  				assert.Equal(t, "foo.bar", fd.GetOptions().GetGoPackage())
    32  			},
    33  		},
    34  		{
    35  			// message options
    36  			contents: `message Test { option (must.link) = 1.234; option deprecated = true; }`,
    37  			uninterpreted: map[string]interface{}{
    38  				"Test:(must.link)": 1.234,
    39  			},
    40  			checkInterpreted: func(t *testing.T, fd *descriptorpb.FileDescriptorProto) {
    41  				assert.True(t, fd.GetMessageType()[0].GetOptions().GetDeprecated())
    42  			},
    43  		},
    44  		{
    45  			// field options and pseudo-options
    46  			contents: `message Test { optional string uid = 1 [(must.link) = 10101, (must.link) = 20202, default = "fubar", json_name = "UID", deprecated = true]; }`,
    47  			uninterpreted: map[string]interface{}{
    48  				"Test.uid:(must.link)":   10101,
    49  				"Test.uid:(must.link)#1": 20202,
    50  			},
    51  			checkInterpreted: func(t *testing.T, fd *descriptorpb.FileDescriptorProto) {
    52  				assert.Equal(t, "fubar", fd.GetMessageType()[0].GetField()[0].GetDefaultValue())
    53  				assert.Equal(t, "UID", fd.GetMessageType()[0].GetField()[0].GetJsonName())
    54  				assert.True(t, fd.GetMessageType()[0].GetField()[0].GetOptions().GetDeprecated())
    55  			},
    56  		},
    57  		{
    58  			// field where default is uninterpretable
    59  			contents: `enum TestEnum{ ZERO = 0; ONE = 1; } message Test { optional TestEnum uid = 1 [(must.link) = {foo: bar}, default = ONE, json_name = "UID", deprecated = true]; }`,
    60  			uninterpreted: map[string]interface{}{
    61  				"Test.uid:(must.link)": aggregate("foo : bar"),
    62  				"Test.uid:default":     ident("ONE"),
    63  			},
    64  			checkInterpreted: func(t *testing.T, fd *descriptorpb.FileDescriptorProto) {
    65  				assert.Equal(t, "UID", fd.GetMessageType()[0].GetField()[0].GetJsonName())
    66  				assert.True(t, fd.GetMessageType()[0].GetField()[0].GetOptions().GetDeprecated())
    67  			},
    68  		},
    69  		{
    70  			// one-of options
    71  			contents: `message Test { oneof x { option (must.link) = true; option deprecated = true; string uid = 1; uint64 nnn = 2; } }`,
    72  			uninterpreted: map[string]interface{}{
    73  				"Test.x:(must.link)": ident("true"),
    74  				"Test.x:deprecated":  ident("true"), // one-ofs do not have deprecated option :/
    75  			},
    76  		},
    77  		{
    78  			// extension range options
    79  			contents: `message Test { extensions 100 to 200 [(must.link) = "foo", deprecated = true]; }`,
    80  			uninterpreted: map[string]interface{}{
    81  				"Test.100-200:(must.link)": "foo",
    82  				"Test.100-200:deprecated":  ident("true"), // extension ranges do not have deprecated option :/
    83  			},
    84  		},
    85  		{
    86  			// enum options
    87  			contents: `enum Test { option allow_alias = true; option deprecated = true; option (must.link) = 123.456; ZERO = 0; ZILCH = 0; }`,
    88  			uninterpreted: map[string]interface{}{
    89  				"Test:(must.link)": 123.456,
    90  			},
    91  			checkInterpreted: func(t *testing.T, fd *descriptorpb.FileDescriptorProto) {
    92  				assert.True(t, fd.GetEnumType()[0].GetOptions().GetDeprecated())
    93  				assert.True(t, fd.GetEnumType()[0].GetOptions().GetAllowAlias())
    94  			},
    95  		},
    96  		{
    97  			// enum value options
    98  			contents: `enum Test { ZERO = 0 [deprecated = true, (must.link) = -222]; }`,
    99  			uninterpreted: map[string]interface{}{
   100  				"Test.ZERO:(must.link)": -222,
   101  			},
   102  			checkInterpreted: func(t *testing.T, fd *descriptorpb.FileDescriptorProto) {
   103  				assert.True(t, fd.GetEnumType()[0].GetValue()[0].GetOptions().GetDeprecated())
   104  			},
   105  		},
   106  		{
   107  			// service options
   108  			contents: `service Test { option deprecated = true; option (must.link) = {foo:1, foo:2, bar:3}; }`,
   109  			uninterpreted: map[string]interface{}{
   110  				"Test:(must.link)": aggregate("foo : 1 , foo : 2 , bar : 3"),
   111  			},
   112  			checkInterpreted: func(t *testing.T, fd *descriptorpb.FileDescriptorProto) {
   113  				assert.True(t, fd.GetService()[0].GetOptions().GetDeprecated())
   114  			},
   115  		},
   116  		{
   117  			// method options
   118  			contents: `import "google/protobuf/empty.proto"; service Test { rpc Foo (google.protobuf.Empty) returns (google.protobuf.Empty) { option deprecated = true; option (must.link) = FOO; } }`,
   119  			uninterpreted: map[string]interface{}{
   120  				"Test.Foo:(must.link)": ident("FOO"),
   121  			},
   122  			checkInterpreted: func(t *testing.T, fd *descriptorpb.FileDescriptorProto) {
   123  				assert.True(t, fd.GetService()[0].GetMethod()[0].GetOptions().GetDeprecated())
   124  			},
   125  		},
   126  	}
   127  
   128  	for i, tc := range testCases {
   129  		h := reporter.NewHandler(nil)
   130  		ast, err := parser.Parse("test.proto", strings.NewReader(tc.contents), h)
   131  		if !assert.Nil(t, err, "case #%d failed to parse", i) {
   132  			continue
   133  		}
   134  		res, err := parser.ResultFromAST(ast, true, h)
   135  		if !assert.Nil(t, err, "case #%d failed to produce descriptor proto", i) {
   136  			continue
   137  		}
   138  		_, err = InterpretUnlinkedOptions(res)
   139  		if !assert.Nil(t, err, "case #%d failed to interpret options", i) {
   140  			continue
   141  		}
   142  		actual := map[string]interface{}{}
   143  		buildUninterpretedMapForFile(res.Proto(), actual)
   144  		assert.Equal(t, tc.uninterpreted, actual, "case #%d resulted in wrong uninterpreted options", i)
   145  		if tc.checkInterpreted != nil {
   146  			tc.checkInterpreted(t, res.Proto())
   147  		}
   148  	}
   149  }
   150  
   151  func buildUninterpretedMapForFile(fd *descriptorpb.FileDescriptorProto, opts map[string]interface{}) {
   152  	buildUninterpretedMap(fd.GetName(), fd.GetOptions().GetUninterpretedOption(), opts)
   153  	for _, md := range fd.GetMessageType() {
   154  		buildUninterpretedMapForMessage(fd.GetPackage(), md, opts)
   155  	}
   156  	for _, extd := range fd.GetExtension() {
   157  		buildUninterpretedMap(qualify(fd.GetPackage(), extd.GetName()), extd.GetOptions().GetUninterpretedOption(), opts)
   158  	}
   159  	for _, ed := range fd.GetEnumType() {
   160  		buildUninterpretedMapForEnum(fd.GetPackage(), ed, opts)
   161  	}
   162  	for _, sd := range fd.GetService() {
   163  		svcFqn := qualify(fd.GetPackage(), sd.GetName())
   164  		buildUninterpretedMap(svcFqn, sd.GetOptions().GetUninterpretedOption(), opts)
   165  		for _, mtd := range sd.GetMethod() {
   166  			buildUninterpretedMap(qualify(svcFqn, mtd.GetName()), mtd.GetOptions().GetUninterpretedOption(), opts)
   167  		}
   168  	}
   169  }
   170  
   171  func buildUninterpretedMapForMessage(qual string, md *descriptorpb.DescriptorProto, opts map[string]interface{}) {
   172  	fqn := qualify(qual, md.GetName())
   173  	buildUninterpretedMap(fqn, md.GetOptions().GetUninterpretedOption(), opts)
   174  	for _, fld := range md.GetField() {
   175  		buildUninterpretedMap(qualify(fqn, fld.GetName()), fld.GetOptions().GetUninterpretedOption(), opts)
   176  	}
   177  	for _, ood := range md.GetOneofDecl() {
   178  		buildUninterpretedMap(qualify(fqn, ood.GetName()), ood.GetOptions().GetUninterpretedOption(), opts)
   179  	}
   180  	for _, extr := range md.GetExtensionRange() {
   181  		buildUninterpretedMap(qualify(fqn, fmt.Sprintf("%d-%d", extr.GetStart(), extr.GetEnd()-1)), extr.GetOptions().GetUninterpretedOption(), opts)
   182  	}
   183  	for _, nmd := range md.GetNestedType() {
   184  		buildUninterpretedMapForMessage(fqn, nmd, opts)
   185  	}
   186  	for _, extd := range md.GetExtension() {
   187  		buildUninterpretedMap(qualify(fqn, extd.GetName()), extd.GetOptions().GetUninterpretedOption(), opts)
   188  	}
   189  	for _, ed := range md.GetEnumType() {
   190  		buildUninterpretedMapForEnum(fqn, ed, opts)
   191  	}
   192  }
   193  
   194  func buildUninterpretedMapForEnum(qual string, ed *descriptorpb.EnumDescriptorProto, opts map[string]interface{}) {
   195  	fqn := qualify(qual, ed.GetName())
   196  	buildUninterpretedMap(fqn, ed.GetOptions().GetUninterpretedOption(), opts)
   197  	for _, evd := range ed.GetValue() {
   198  		buildUninterpretedMap(qualify(fqn, evd.GetName()), evd.GetOptions().GetUninterpretedOption(), opts)
   199  	}
   200  }
   201  
   202  func buildUninterpretedMap(prefix string, uos []*descriptorpb.UninterpretedOption, opts map[string]interface{}) {
   203  	for _, uo := range uos {
   204  		parts := make([]string, len(uo.GetName()))
   205  		for i, np := range uo.GetName() {
   206  			if np.GetIsExtension() {
   207  				parts[i] = fmt.Sprintf("(%s)", np.GetNamePart())
   208  			} else {
   209  				parts[i] = np.GetNamePart()
   210  			}
   211  		}
   212  		uoName := fmt.Sprintf("%s:%s", prefix, strings.Join(parts, "."))
   213  		key := uoName
   214  		i := 0
   215  		for {
   216  			if _, ok := opts[key]; !ok {
   217  				break
   218  			}
   219  			i++
   220  			key = fmt.Sprintf("%s#%d", uoName, i)
   221  		}
   222  		var val interface{}
   223  		switch {
   224  		case uo.AggregateValue != nil:
   225  			val = aggregate(uo.GetAggregateValue())
   226  		case uo.IdentifierValue != nil:
   227  			val = ident(uo.GetIdentifierValue())
   228  		case uo.DoubleValue != nil:
   229  			val = uo.GetDoubleValue()
   230  		case uo.PositiveIntValue != nil:
   231  			val = int(uo.GetPositiveIntValue())
   232  		case uo.NegativeIntValue != nil:
   233  			val = int(uo.GetNegativeIntValue())
   234  		default:
   235  			val = string(uo.GetStringValue())
   236  		}
   237  		opts[key] = val
   238  	}
   239  }
   240  
   241  func qualify(qualifier, name string) string {
   242  	if qualifier == "" {
   243  		return name
   244  	}
   245  	return qualifier + "." + name
   246  }