github.com/bakjos/protoreflect@v1.9.2/dynamic/extension_registry_test.go (about)

     1  package dynamic
     2  
     3  import (
     4  	"sort"
     5  	"testing"
     6  
     7  	"github.com/bakjos/protoreflect/desc"
     8  	"github.com/bakjos/protoreflect/internal/testprotos"
     9  	"github.com/bakjos/protoreflect/internal/testutil"
    10  )
    11  
    12  func TestExtensionRegistry_AddExtension(t *testing.T) {
    13  	er := &ExtensionRegistry{}
    14  	file, err := desc.LoadFileDescriptor("desc_test1.proto")
    15  	testutil.Ok(t, err)
    16  
    17  	err = er.AddExtension(file.GetExtensions()...)
    18  	testutil.Ok(t, err)
    19  
    20  	fds := er.AllExtensionsForType("testprotos.AnotherTestMessage")
    21  	sort.Sort(fields(fds))
    22  
    23  	testutil.Eq(t, []desc.Descriptor{
    24  		file.FindSymbol("testprotos.xtm"),
    25  		file.FindSymbol("testprotos.xs"),
    26  		file.FindSymbol("testprotos.xi"),
    27  		file.FindSymbol("testprotos.xui"),
    28  	}, fds)
    29  
    30  	checkFindExtension(t, er, fds)
    31  }
    32  
    33  func TestExtensionRegistry_AddExtensionDesc(t *testing.T) {
    34  	er := &ExtensionRegistry{}
    35  
    36  	err := er.AddExtensionDesc(testprotos.E_Xtm, testprotos.E_Xs, testprotos.E_Xi)
    37  	testutil.Ok(t, err)
    38  
    39  	fds := er.AllExtensionsForType("testprotos.AnotherTestMessage")
    40  	sort.Sort(fields(fds))
    41  
    42  	file, err := desc.LoadFileDescriptor("desc_test1.proto")
    43  	testutil.Ok(t, err)
    44  
    45  	testutil.Eq(t, 3, len(fds))
    46  	testutil.Eq(t, file.FindSymbol("testprotos.xtm"), fds[0])
    47  	testutil.Eq(t, file.FindSymbol("testprotos.xs"), fds[1])
    48  	testutil.Eq(t, file.FindSymbol("testprotos.xi"), fds[2])
    49  
    50  	checkFindExtension(t, er, fds)
    51  }
    52  
    53  func TestExtensionRegistry_AddExtensionsFromFile(t *testing.T) {
    54  	er := &ExtensionRegistry{}
    55  	file, err := desc.LoadFileDescriptor("desc_test1.proto")
    56  	testutil.Ok(t, err)
    57  
    58  	er.AddExtensionsFromFile(file)
    59  
    60  	fds := er.AllExtensionsForType("testprotos.AnotherTestMessage")
    61  	sort.Sort(fields(fds))
    62  
    63  	testutil.Eq(t, 5, len(fds))
    64  	testutil.Eq(t, file.FindSymbol("testprotos.xtm"), fds[0])
    65  	testutil.Eq(t, file.FindSymbol("testprotos.xs"), fds[1])
    66  	testutil.Eq(t, file.FindSymbol("testprotos.xi"), fds[2])
    67  	testutil.Eq(t, file.FindSymbol("testprotos.xui"), fds[3])
    68  	testutil.Eq(t, file.FindSymbol("testprotos.TestMessage.NestedMessage.AnotherNestedMessage.flags"), fds[4])
    69  
    70  	checkFindExtension(t, er, fds)
    71  }
    72  
    73  func TestExtensionRegistry_Empty(t *testing.T) {
    74  	er := ExtensionRegistry{}
    75  	fds := er.AllExtensionsForType("testprotos.AnotherTestMessage")
    76  	testutil.Eq(t, 0, len(fds))
    77  }
    78  
    79  func TestExtensionRegistry_Defaults(t *testing.T) {
    80  	er := NewExtensionRegistryWithDefaults()
    81  
    82  	fds := er.AllExtensionsForType("testprotos.AnotherTestMessage")
    83  	sort.Sort(fields(fds))
    84  
    85  	file, err := desc.LoadFileDescriptor("desc_test1.proto")
    86  	testutil.Ok(t, err)
    87  
    88  	testutil.Eq(t, 5, len(fds))
    89  	testutil.Eq(t, file.FindSymbol("testprotos.xtm").AsProto(), fds[0].AsProto())
    90  	testutil.Eq(t, file.FindSymbol("testprotos.xs").AsProto(), fds[1].AsProto())
    91  	testutil.Eq(t, file.FindSymbol("testprotos.xi").AsProto(), fds[2].AsProto())
    92  	testutil.Eq(t, file.FindSymbol("testprotos.xui").AsProto(), fds[3].AsProto())
    93  	testutil.Eq(t, file.FindSymbol("testprotos.TestMessage.NestedMessage.AnotherNestedMessage.flags").AsProto(), fds[4].AsProto())
    94  
    95  	checkFindExtension(t, er, fds)
    96  }
    97  
    98  func checkFindExtension(t *testing.T, er *ExtensionRegistry, fds []*desc.FieldDescriptor) {
    99  	for _, fd := range fds {
   100  		testutil.Eq(t, fd, er.FindExtension(fd.GetOwner().GetFullyQualifiedName(), fd.GetNumber()))
   101  		testutil.Eq(t, fd, er.FindExtensionByName(fd.GetOwner().GetFullyQualifiedName(), fd.GetFullyQualifiedName()))
   102  	}
   103  }
   104  
   105  type fields []*desc.FieldDescriptor
   106  
   107  func (f fields) Len() int {
   108  	return len(f)
   109  }
   110  
   111  func (f fields) Less(i, j int) bool {
   112  	return f[i].GetNumber() < f[j].GetNumber()
   113  }
   114  
   115  func (f fields) Swap(i, j int) {
   116  	f[i], f[j] = f[j], f[i]
   117  }