github.com/jhump/protoreflect@v1.16.0/grpcreflect/client_test.go (about)

     1  package grpcreflect
     2  
     3  import (
     4  	"context"
     5  	"encoding/base64"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"net"
    10  	"os"
    11  	"sort"
    12  	"sync"
    13  	"sync/atomic"
    14  	"testing"
    15  	"time"
    16  
    17  	"github.com/golang/protobuf/proto"
    18  	"google.golang.org/grpc"
    19  	"google.golang.org/grpc/codes"
    20  	"google.golang.org/grpc/credentials/insecure"
    21  	"google.golang.org/grpc/reflection"
    22  	reflectv1alpha "google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
    23  	"google.golang.org/grpc/status"
    24  	"google.golang.org/protobuf/reflect/protodesc"
    25  	"google.golang.org/protobuf/reflect/protoreflect"
    26  	"google.golang.org/protobuf/reflect/protoregistry"
    27  	"google.golang.org/protobuf/types/descriptorpb"
    28  	"google.golang.org/protobuf/types/dynamicpb"
    29  	_ "google.golang.org/protobuf/types/known/apipb"
    30  	_ "google.golang.org/protobuf/types/known/emptypb"
    31  	_ "google.golang.org/protobuf/types/known/fieldmaskpb"
    32  	_ "google.golang.org/protobuf/types/known/sourcecontextpb"
    33  	_ "google.golang.org/protobuf/types/known/typepb"
    34  	_ "google.golang.org/protobuf/types/pluginpb"
    35  
    36  	"github.com/jhump/protoreflect/desc"
    37  	"github.com/jhump/protoreflect/internal"
    38  	testprotosgrpc "github.com/jhump/protoreflect/internal/testprotos/grpc"
    39  	"github.com/jhump/protoreflect/internal/testutil"
    40  )
    41  
    42  var client *Client
    43  
    44  func TestMain(m *testing.M) {
    45  	code := 1
    46  	defer func() {
    47  		p := recover()
    48  		if p != nil {
    49  			_, _ = fmt.Fprintf(os.Stderr, "PANIC: %v\n", p)
    50  		}
    51  		os.Exit(code)
    52  	}()
    53  
    54  	svr := grpc.NewServer()
    55  	testprotosgrpc.RegisterDummyServiceServer(svr, testService{})
    56  	reflection.Register(svr)
    57  	l, err := net.Listen("tcp", "127.0.0.1:0")
    58  	if err != nil {
    59  		panic(fmt.Sprintf("Failed to open server socket: %s", err.Error()))
    60  	}
    61  	go func() {
    62  		_ = svr.Serve(l)
    63  	}()
    64  	defer svr.Stop()
    65  
    66  	// create grpc client
    67  	addr := l.Addr().String()
    68  	cconn, err := grpc.Dial(addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
    69  	if err != nil {
    70  		panic(fmt.Sprintf("Failed to create grpc client: %s", err.Error()))
    71  	}
    72  	defer func() {
    73  		_ = cconn.Close()
    74  	}()
    75  
    76  	stub := reflectv1alpha.NewServerReflectionClient(cconn)
    77  	client = NewClientV1Alpha(context.Background(), stub)
    78  
    79  	code = m.Run()
    80  }
    81  
    82  func TestFileByFileName(t *testing.T) {
    83  	fd, err := client.FileByFilename("desc_test1.proto")
    84  	testutil.Ok(t, err)
    85  	// shallow check that the descriptor appears correct and complete
    86  	testutil.Eq(t, "desc_test1.proto", fd.GetName())
    87  	testutil.Eq(t, "testprotos", fd.GetPackage())
    88  	md := fd.GetMessageTypes()[0]
    89  	testutil.Eq(t, "TestMessage", md.GetName())
    90  	md = md.GetNestedMessageTypes()[0]
    91  	testutil.Eq(t, "NestedMessage", md.GetName())
    92  	md = md.GetNestedMessageTypes()[0]
    93  	testutil.Eq(t, "AnotherNestedMessage", md.GetName())
    94  	md = md.GetNestedMessageTypes()[0]
    95  	testutil.Eq(t, "YetAnotherNestedMessage", md.GetName())
    96  	ed := md.GetNestedEnumTypes()[0]
    97  	testutil.Eq(t, "DeeplyNestedEnum", ed.GetName())
    98  
    99  	_, err = client.FileByFilename("does not exist")
   100  	testutil.Require(t, IsElementNotFoundError(err))
   101  }
   102  
   103  func TestFileByFileNameForWellKnownProtos(t *testing.T) {
   104  	wellKnownProtos := map[string][]string{
   105  		"google/protobuf/any.proto":             {"google.protobuf.Any"},
   106  		"google/protobuf/api.proto":             {"google.protobuf.Api", "google.protobuf.Method", "google.protobuf.Mixin"},
   107  		"google/protobuf/descriptor.proto":      {"google.protobuf.FileDescriptorSet", "google.protobuf.DescriptorProto"},
   108  		"google/protobuf/duration.proto":        {"google.protobuf.Duration"},
   109  		"google/protobuf/empty.proto":           {"google.protobuf.Empty"},
   110  		"google/protobuf/field_mask.proto":      {"google.protobuf.FieldMask"},
   111  		"google/protobuf/source_context.proto":  {"google.protobuf.SourceContext"},
   112  		"google/protobuf/struct.proto":          {"google.protobuf.Struct", "google.protobuf.Value", "google.protobuf.NullValue"},
   113  		"google/protobuf/timestamp.proto":       {"google.protobuf.Timestamp"},
   114  		"google/protobuf/type.proto":            {"google.protobuf.Type", "google.protobuf.Field", "google.protobuf.Syntax"},
   115  		"google/protobuf/wrappers.proto":        {"google.protobuf.DoubleValue", "google.protobuf.Int32Value", "google.protobuf.StringValue"},
   116  		"google/protobuf/compiler/plugin.proto": {"google.protobuf.compiler.CodeGeneratorRequest"},
   117  	}
   118  
   119  	for file, types := range wellKnownProtos {
   120  		fd, err := client.FileByFilename(file)
   121  		testutil.Ok(t, err)
   122  		testutil.Eq(t, file, fd.GetName())
   123  		for _, typ := range types {
   124  			d := fd.FindSymbol(typ)
   125  			testutil.Require(t, d != nil)
   126  		}
   127  
   128  		// also try loading via alternate name
   129  		file = internal.StdFileAliases[file]
   130  		if file == "" {
   131  			// not a file that has a known alternate, so nothing else to check...
   132  			continue
   133  		}
   134  		fd, err = client.FileByFilename(file)
   135  		testutil.Ok(t, err)
   136  		testutil.Eq(t, file, fd.GetName())
   137  		for _, typ := range types {
   138  			d := fd.FindSymbol(typ)
   139  			testutil.Require(t, d != nil)
   140  		}
   141  	}
   142  }
   143  
   144  func TestFileContainingSymbol(t *testing.T) {
   145  	fd, err := client.FileContainingSymbol("TopLevel")
   146  	testutil.Ok(t, err)
   147  	// shallow check that the descriptor appears correct and complete
   148  	testutil.Eq(t, "nopkg/desc_test_nopkg_new.proto", fd.GetName())
   149  	testutil.Eq(t, "", fd.GetPackage())
   150  	md := fd.GetMessageTypes()[0]
   151  	testutil.Eq(t, "TopLevel", md.GetName())
   152  	testutil.Eq(t, "i", md.GetFields()[0].GetName())
   153  	testutil.Eq(t, "j", md.GetFields()[1].GetName())
   154  	testutil.Eq(t, "k", md.GetFields()[2].GetName())
   155  	testutil.Eq(t, "l", md.GetFields()[3].GetName())
   156  	testutil.Eq(t, "m", md.GetFields()[4].GetName())
   157  	testutil.Eq(t, "n", md.GetFields()[5].GetName())
   158  	testutil.Eq(t, "o", md.GetFields()[6].GetName())
   159  	testutil.Eq(t, "p", md.GetFields()[7].GetName())
   160  	testutil.Eq(t, "q", md.GetFields()[8].GetName())
   161  	testutil.Eq(t, "r", md.GetFields()[9].GetName())
   162  	testutil.Eq(t, "s", md.GetFields()[10].GetName())
   163  	testutil.Eq(t, "t", md.GetFields()[11].GetName())
   164  
   165  	_, err = client.FileContainingSymbol("does not exist")
   166  	testutil.Require(t, IsElementNotFoundError(err))
   167  }
   168  
   169  func TestFileContainingExtension(t *testing.T) {
   170  	fd, err := client.FileContainingExtension("TopLevel", 100)
   171  	testutil.Ok(t, err)
   172  	// shallow check that the descriptor appears correct and complete
   173  	testutil.Eq(t, "desc_test2.proto", fd.GetName())
   174  	testutil.Eq(t, "testprotos", fd.GetPackage())
   175  	testutil.Eq(t, 4, len(fd.GetMessageTypes()))
   176  	testutil.Eq(t, "Frobnitz", fd.GetMessageTypes()[0].GetName())
   177  	testutil.Eq(t, "Whatchamacallit", fd.GetMessageTypes()[1].GetName())
   178  	testutil.Eq(t, "Whatzit", fd.GetMessageTypes()[2].GetName())
   179  	testutil.Eq(t, "GroupX", fd.GetMessageTypes()[3].GetName())
   180  
   181  	testutil.Eq(t, "desc_test1.proto", fd.GetDependencies()[0].GetName())
   182  	testutil.Eq(t, "pkg/desc_test_pkg.proto", fd.GetDependencies()[1].GetName())
   183  	testutil.Eq(t, "nopkg/desc_test_nopkg.proto", fd.GetDependencies()[2].GetName())
   184  
   185  	_, err = client.FileContainingExtension("does not exist", 100)
   186  	testutil.Require(t, IsElementNotFoundError(err))
   187  	_, err = client.FileContainingExtension("TopLevel", -9)
   188  	testutil.Require(t, IsElementNotFoundError(err))
   189  }
   190  
   191  func TestAllExtensionNumbersForType(t *testing.T) {
   192  	nums, err := client.AllExtensionNumbersForType("TopLevel")
   193  	testutil.Ok(t, err)
   194  	inums := make([]int, len(nums))
   195  	for idx, v := range nums {
   196  		inums[idx] = int(v)
   197  	}
   198  	sort.Ints(inums)
   199  	testutil.Eq(t, []int{100, 104}, inums)
   200  
   201  	nums, err = client.AllExtensionNumbersForType("testprotos.AnotherTestMessage")
   202  	testutil.Ok(t, err)
   203  	testutil.Eq(t, 5, len(nums))
   204  	inums = make([]int, len(nums))
   205  	for idx, v := range nums {
   206  		inums[idx] = int(v)
   207  	}
   208  	sort.Ints(inums)
   209  	testutil.Eq(t, []int{100, 101, 102, 103, 200}, inums)
   210  
   211  	_, err = client.AllExtensionNumbersForType("does not exist")
   212  	testutil.Require(t, IsElementNotFoundError(err))
   213  }
   214  
   215  func TestListServices(t *testing.T) {
   216  	s, err := client.ListServices()
   217  	testutil.Ok(t, err)
   218  
   219  	sort.Strings(s)
   220  	testutil.Eq(t, []string{
   221  		"grpc.reflection.v1.ServerReflection",
   222  		"grpc.reflection.v1alpha.ServerReflection",
   223  		"testprotos.DummyService",
   224  	}, s)
   225  }
   226  
   227  func TestReset(t *testing.T) {
   228  	_, err := client.ListServices()
   229  	testutil.Ok(t, err)
   230  
   231  	// save the current stream
   232  	stream := client.stream
   233  	// intercept cancellation
   234  	cancel := client.cancel
   235  	var cancelled atomic.Bool
   236  	client.cancel = func() {
   237  		cancelled.Store(true)
   238  		cancel()
   239  	}
   240  
   241  	client.Reset()
   242  	testutil.Eq(t, true, cancelled.Load())
   243  	testutil.Eq(t, nil, client.stream)
   244  
   245  	_, err = client.ListServices()
   246  	testutil.Ok(t, err)
   247  
   248  	// stream was re-created
   249  	testutil.Eq(t, true, client.stream != nil && client.stream != stream)
   250  }
   251  
   252  func TestRecover(t *testing.T) {
   253  	_, err := client.ListServices()
   254  	testutil.Ok(t, err)
   255  
   256  	// kill the stream
   257  	stream := client.stream
   258  	err = client.stream.CloseSend()
   259  	testutil.Ok(t, err)
   260  
   261  	// it should auto-recover and re-create stream
   262  	_, err = client.ListServices()
   263  	testutil.Ok(t, err)
   264  	testutil.Eq(t, true, client.stream != nil && client.stream != stream)
   265  }
   266  
   267  func TestMultipleFiles(t *testing.T) {
   268  	svr := grpc.NewServer()
   269  	reflectv1alpha.RegisterServerReflectionServer(svr, testReflectionServer{})
   270  
   271  	l, err := net.Listen("tcp", "127.0.0.1:0")
   272  	testutil.Ok(t, err, "failed to listen")
   273  	ctx, cancel := context.WithCancel(context.Background())
   274  	defer cancel()
   275  	go func() {
   276  		defer cancel()
   277  		if err := svr.Serve(l); err != nil {
   278  			t.Logf("serve returned error: %v", err)
   279  		}
   280  	}()
   281  	time.Sleep(100 * time.Millisecond) // give server a chance to start
   282  	testutil.Ok(t, ctx.Err(), "failed to start server")
   283  	defer func() {
   284  		svr.Stop()
   285  	}()
   286  
   287  	dialCtx, dialCancel := context.WithTimeout(ctx, 3*time.Second)
   288  	defer dialCancel()
   289  	cc, err := grpc.DialContext(dialCtx, l.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock())
   290  	testutil.Ok(t, err, "failed ot dial %v", l.Addr().String())
   291  	cl := reflectv1alpha.NewServerReflectionClient(cc)
   292  
   293  	client := NewClientV1Alpha(ctx, cl)
   294  	defer client.Reset()
   295  	svcs, err := client.ListServices()
   296  	testutil.Ok(t, err, "failed to list services")
   297  	for _, svc := range svcs {
   298  		fd, err := client.FileContainingSymbol(svc)
   299  		testutil.Ok(t, err, "failed to file for service %v", svc)
   300  		sd := fd.FindSymbol(svc)
   301  		_, ok := sd.(*desc.ServiceDescriptor)
   302  		testutil.Require(t, ok, "symbol for %s is not a service descriptor, instead is %T", svc, sd)
   303  	}
   304  }
   305  
   306  func TestAllowMissingFileDescriptors(t *testing.T) {
   307  	svr := grpc.NewServer()
   308  	files := createFilesWithMissingDeps(t)
   309  	reflectionSvc := reflection.NewServer(reflection.ServerOptions{
   310  		DescriptorResolver: files,
   311  		ExtensionResolver:  files,
   312  	})
   313  	reflectv1alpha.RegisterServerReflectionServer(svr, reflectionSvc)
   314  
   315  	l, err := net.Listen("tcp", "127.0.0.1:0")
   316  	testutil.Ok(t, err, "failed to listen")
   317  	ctx, cancel := context.WithCancel(context.Background())
   318  	defer cancel()
   319  	go func() {
   320  		defer cancel()
   321  		if err := svr.Serve(l); err != nil {
   322  			t.Logf("serve returned error: %v", err)
   323  		}
   324  	}()
   325  	time.Sleep(100 * time.Millisecond) // give server a chance to start
   326  	testutil.Ok(t, ctx.Err(), "failed to start server")
   327  	defer func() {
   328  		svr.Stop()
   329  	}()
   330  
   331  	dialCtx, dialCancel := context.WithTimeout(ctx, 3*time.Second)
   332  	defer dialCancel()
   333  	cc, err := grpc.DialContext(dialCtx, l.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock())
   334  	testutil.Ok(t, err, "failed ot dial %v", l.Addr().String())
   335  	cl := reflectv1alpha.NewServerReflectionClient(cc)
   336  
   337  	client := NewClientV1Alpha(ctx, cl)
   338  	defer client.Reset()
   339  
   340  	// First we try some things that should fail due to missing descriptors.
   341  	_, err = client.FileByFilename("foo/bar/this.proto")
   342  	testutil.Nok(t, err)
   343  	_, err = client.FileContainingSymbol("foo.bar.Bar")
   344  	testutil.Nok(t, err)
   345  	_, err = client.FileContainingExtension("google.protobuf.MessageOptions", 10101)
   346  	testutil.Nok(t, err)
   347  
   348  	client.AllowMissingFileDescriptors()
   349  	// Now the above queries should succeed.
   350  	file, err := client.FileByFilename("foo/bar/this.proto")
   351  	testutil.Ok(t, err)
   352  	testutil.Require(t, file != nil)
   353  	testutil.Eq(t, "foo/bar/this.proto", file.GetName())
   354  	_, err = client.FileContainingSymbol("foo.bar.Bar")
   355  	testutil.Ok(t, err)
   356  	testutil.Require(t, file != nil)
   357  	testutil.Eq(t, "foo/bar/this.proto", file.GetName())
   358  	_, err = client.FileContainingExtension("google.protobuf.MessageOptions", 10101)
   359  	testutil.Ok(t, err)
   360  	testutil.Require(t, file != nil)
   361  	testutil.Eq(t, "foo/bar/this.proto", file.GetName())
   362  }
   363  
   364  func TestFileWithoutDeps(t *testing.T) {
   365  	fd := &descriptorpb.FileDescriptorProto{
   366  		Dependency: []string{
   367  			"foo/bar.proto",
   368  			"foo/public/bar.proto", // missing
   369  			"foo/weak/bar.proto",
   370  			"foo/baz.proto", // missing
   371  			"foo/public/baz.proto",
   372  			"foo/weak/baz.proto", // missing
   373  			"foo/fizz.proto",
   374  			"foo/public/fizz.proto", // missing
   375  			"foo/weak/fizz.proto",
   376  			"foo/buzz.proto", // missing
   377  			"foo/public/buzz.proto",
   378  			"foo/weak/buzz.proto", // missing
   379  		},
   380  		PublicDependency: []int32{1, 4, 7, 10},
   381  		WeakDependency:   []int32{2, 5, 8, 11},
   382  	}
   383  	fd = fileWithoutDeps(fd, []int{1, 3, 5, 7, 9, 11})
   384  	testutil.Eq(t,
   385  		[]string{
   386  			"foo/bar.proto",
   387  			"foo/weak/bar.proto",
   388  			"foo/public/baz.proto",
   389  			"foo/fizz.proto",
   390  			"foo/weak/fizz.proto",
   391  			"foo/public/buzz.proto",
   392  		},
   393  		fd.Dependency)
   394  	testutil.Eq(t, []int32{2, 5}, fd.PublicDependency)
   395  	testutil.Eq(t, []int32{1, 4}, fd.WeakDependency)
   396  }
   397  
   398  type testReflectionServer struct{}
   399  
   400  func (t testReflectionServer) ServerReflectionInfo(server reflectv1alpha.ServerReflection_ServerReflectionInfoServer) error {
   401  	const svcA_file = "ChdzYW5kYm94L3NlcnZpY2VfQS5wcm90bxIHc2FuZGJveCIWCghSZXF1ZXN0QRIKCgJpZBgBIAEoBSIYCglSZXNwb25zZUESCwoDc3RyGAEgASgJMj0KCVNlcnZpY2VfQRIwCgdFeGVjdXRlEhEuc2FuZGJveC5SZXF1ZXN0QRoSLnNhbmRib3guUmVzcG9uc2VBYgZwcm90bzM="
   402  	const svcB_file = "ChdzYW5kYm94L1NlcnZpY2VfQi5wcm90bxIHc2FuZGJveCIWCghSZXF1ZXN0QhIKCgJpZBgBIAEoBSIYCglSZXNwb25zZUISCwoDc3RyGAEgASgJMj0KCVNlcnZpY2VfQhIwCgdFeGVjdXRlEhEuc2FuZGJveC5SZXF1ZXN0QhoSLnNhbmRib3guUmVzcG9uc2VCYgZwcm90bzM="
   403  
   404  	for {
   405  		req, err := server.Recv()
   406  		if err == io.EOF {
   407  			return nil
   408  		} else if err != nil {
   409  			return err
   410  		}
   411  		var resp reflectv1alpha.ServerReflectionResponse
   412  		resp.OriginalRequest = req
   413  		switch req := req.MessageRequest.(type) {
   414  		case *reflectv1alpha.ServerReflectionRequest_FileByFilename:
   415  			switch req.FileByFilename {
   416  			case "sandbox/service_A.proto":
   417  				resp.MessageResponse = msgResponseForFiles(svcA_file)
   418  			case "sandbox/service_B.proto":
   419  				resp.MessageResponse = msgResponseForFiles(svcB_file)
   420  			default:
   421  				resp.MessageResponse = &reflectv1alpha.ServerReflectionResponse_ErrorResponse{
   422  					ErrorResponse: &reflectv1alpha.ErrorResponse{
   423  						ErrorCode:    int32(codes.NotFound),
   424  						ErrorMessage: "not found",
   425  					},
   426  				}
   427  			}
   428  		case *reflectv1alpha.ServerReflectionRequest_FileContainingSymbol:
   429  			switch req.FileContainingSymbol {
   430  			case "sandbox.Service_A":
   431  				resp.MessageResponse = msgResponseForFiles(svcA_file)
   432  			case "sandbox.Service_B":
   433  				// HERE is where we return two files instead of one
   434  				resp.MessageResponse = msgResponseForFiles(svcA_file, svcB_file)
   435  			default:
   436  				resp.MessageResponse = &reflectv1alpha.ServerReflectionResponse_ErrorResponse{
   437  					ErrorResponse: &reflectv1alpha.ErrorResponse{
   438  						ErrorCode:    int32(codes.NotFound),
   439  						ErrorMessage: "not found",
   440  					},
   441  				}
   442  			}
   443  		case *reflectv1alpha.ServerReflectionRequest_ListServices:
   444  			resp.MessageResponse = &reflectv1alpha.ServerReflectionResponse_ListServicesResponse{
   445  				ListServicesResponse: &reflectv1alpha.ListServiceResponse{
   446  					Service: []*reflectv1alpha.ServiceResponse{
   447  						{Name: "sandbox.Service_A"},
   448  						{Name: "sandbox.Service_B"},
   449  					},
   450  				},
   451  			}
   452  		default:
   453  			resp.MessageResponse = &reflectv1alpha.ServerReflectionResponse_ErrorResponse{
   454  				ErrorResponse: &reflectv1alpha.ErrorResponse{
   455  					ErrorCode:    int32(codes.NotFound),
   456  					ErrorMessage: "not found",
   457  				},
   458  			}
   459  		}
   460  		if err := server.Send(&resp); err != nil {
   461  			return err
   462  		}
   463  	}
   464  }
   465  
   466  func msgResponseForFiles(files ...string) *reflectv1alpha.ServerReflectionResponse_FileDescriptorResponse {
   467  	descs := make([][]byte, len(files))
   468  	for i, f := range files {
   469  		b, err := base64.StdEncoding.DecodeString(f)
   470  		if err != nil {
   471  			panic(err)
   472  		}
   473  		descs[i] = b
   474  	}
   475  	return &reflectv1alpha.ServerReflectionResponse_FileDescriptorResponse{
   476  		FileDescriptorResponse: &reflectv1alpha.FileDescriptorResponse{
   477  			FileDescriptorProto: descs,
   478  		},
   479  	}
   480  }
   481  
   482  func TestAutoVersion(t *testing.T) {
   483  	t.Run("v1", func(t *testing.T) {
   484  		testClientAuto(t,
   485  			func(s *grpc.Server) {
   486  				reflection.RegisterV1(s)
   487  				testprotosgrpc.RegisterDummyServiceServer(s, testService{})
   488  			},
   489  			[]string{
   490  				"grpc.reflection.v1.ServerReflection",
   491  				"testprotos.DummyService",
   492  			},
   493  			[]string{
   494  				"/grpc.reflection.v1.ServerReflection/ServerReflectionInfo",
   495  				"/grpc.reflection.v1.ServerReflection/ServerReflectionInfo",
   496  				"/grpc.reflection.v1.ServerReflection/ServerReflectionInfo",
   497  				"/grpc.reflection.v1.ServerReflection/ServerReflectionInfo",
   498  			})
   499  	})
   500  
   501  	t.Run("v1alpha", func(t *testing.T) {
   502  		testClientAuto(t,
   503  			func(s *grpc.Server) {
   504  				impl := reflection.NewServer(reflection.ServerOptions{Services: s})
   505  				reflectv1alpha.RegisterServerReflectionServer(s, impl)
   506  				testprotosgrpc.RegisterDummyServiceServer(s, testService{})
   507  			},
   508  			[]string{
   509  				"grpc.reflection.v1alpha.ServerReflection",
   510  				"testprotos.DummyService",
   511  			},
   512  			[]string{
   513  				// first one fails, so falls back to v1alpha
   514  				"/grpc.reflection.v1.ServerReflection/ServerReflectionInfo",
   515  				"/grpc.reflection.v1alpha.ServerReflection/ServerReflectionInfo",
   516  				// next two use v1alpha
   517  				"/grpc.reflection.v1alpha.ServerReflection/ServerReflectionInfo",
   518  				"/grpc.reflection.v1alpha.ServerReflection/ServerReflectionInfo",
   519  				// final one retries v1
   520  				"/grpc.reflection.v1.ServerReflection/ServerReflectionInfo",
   521  				"/grpc.reflection.v1alpha.ServerReflection/ServerReflectionInfo",
   522  			})
   523  	})
   524  
   525  	t.Run("both", func(t *testing.T) {
   526  		testClientAuto(t,
   527  			func(s *grpc.Server) {
   528  				reflection.Register(s)
   529  				testprotosgrpc.RegisterDummyServiceServer(s, testService{})
   530  			},
   531  			[]string{
   532  				"grpc.reflection.v1.ServerReflection",
   533  				"grpc.reflection.v1alpha.ServerReflection",
   534  				"testprotos.DummyService",
   535  			},
   536  			[]string{
   537  				// never uses v1alpha since v1 works
   538  				"/grpc.reflection.v1.ServerReflection/ServerReflectionInfo",
   539  				"/grpc.reflection.v1.ServerReflection/ServerReflectionInfo",
   540  				"/grpc.reflection.v1.ServerReflection/ServerReflectionInfo",
   541  				"/grpc.reflection.v1.ServerReflection/ServerReflectionInfo",
   542  			})
   543  	})
   544  
   545  	t.Run("fallback-on-unavailable", testClientAutoOnUnavailable)
   546  }
   547  
   548  func testClientAuto(t *testing.T, register func(*grpc.Server), expectedServices []string, expectedLog []string) {
   549  	var capture captureStreamNames
   550  	svr := grpc.NewServer(grpc.StreamInterceptor(capture.intercept), grpc.UnknownServiceHandler(capture.handleUnknown))
   551  	register(svr)
   552  	l, err := net.Listen("tcp", "127.0.0.1:0")
   553  	if err != nil {
   554  		panic(fmt.Sprintf("Failed to open server socket: %s", err.Error()))
   555  	}
   556  	go func() {
   557  		err := svr.Serve(l)
   558  		testutil.Ok(t, err)
   559  	}()
   560  	defer svr.Stop()
   561  
   562  	cconn, err := grpc.Dial(l.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()))
   563  	if err != nil {
   564  		panic(fmt.Sprintf("Failed to create grpc client: %s", err.Error()))
   565  	}
   566  	defer func() {
   567  		err := cconn.Close()
   568  		testutil.Ok(t, err)
   569  	}()
   570  	client := NewClientAuto(context.Background(), cconn)
   571  	now := time.Now()
   572  	client.now = func() time.Time {
   573  		return now
   574  	}
   575  
   576  	svcs, err := client.ListServices()
   577  	testutil.Ok(t, err)
   578  	sort.Strings(svcs)
   579  	testutil.Eq(t, expectedServices, svcs)
   580  	client.Reset()
   581  
   582  	_, err = client.FileContainingSymbol(svcs[0])
   583  	testutil.Ok(t, err)
   584  	client.Reset()
   585  
   586  	// at the threshold, but not quite enough to retry
   587  	now = now.Add(time.Hour)
   588  	_, err = client.ListServices()
   589  	testutil.Ok(t, err)
   590  	client.Reset()
   591  
   592  	// 1 ns more, and we've crossed threshold and will retry
   593  	now = now.Add(1)
   594  	_, err = client.ListServices()
   595  	testutil.Ok(t, err)
   596  	client.Reset()
   597  
   598  	actualLog := capture.names()
   599  	testutil.Eq(t, expectedLog, actualLog)
   600  }
   601  
   602  type captureStreamNames struct {
   603  	mu  sync.Mutex
   604  	log []string
   605  }
   606  
   607  func (c *captureStreamNames) names() []string {
   608  	c.mu.Lock()
   609  	defer c.mu.Unlock()
   610  	ret := make([]string, len(c.log))
   611  	copy(ret, c.log)
   612  	return ret
   613  }
   614  
   615  func (c *captureStreamNames) intercept(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
   616  	c.mu.Lock()
   617  	c.log = append(c.log, info.FullMethod)
   618  	c.mu.Unlock()
   619  	return handler(srv, ss)
   620  }
   621  
   622  func (c *captureStreamNames) handleUnknown(_ interface{}, _ grpc.ServerStream) error {
   623  	return status.Errorf(codes.Unimplemented, "WTF?")
   624  }
   625  
   626  func testClientAutoOnUnavailable(t *testing.T) {
   627  	l, err := net.Listen("tcp", "127.0.0.1:0")
   628  	if err != nil {
   629  		panic(fmt.Sprintf("Failed to open server socket: %s", err.Error()))
   630  	}
   631  	captureConn := &captureListener{Listener: l}
   632  
   633  	var capture captureStreamNames
   634  	svr := grpc.NewServer(
   635  		grpc.StreamInterceptor(capture.intercept),
   636  		grpc.UnknownServiceHandler(func(_ interface{}, _ grpc.ServerStream) error {
   637  			// On unknown method, forcibly close the net.Conn, without sending
   638  			// back any reply, which should result in an "unavailable" error.
   639  			return captureConn.latest().Close()
   640  		}),
   641  	)
   642  	impl := reflection.NewServer(reflection.ServerOptions{Services: svr})
   643  	reflectv1alpha.RegisterServerReflectionServer(svr, impl)
   644  	testprotosgrpc.RegisterDummyServiceServer(svr, testService{})
   645  
   646  	go func() {
   647  		err := svr.Serve(captureConn)
   648  		testutil.Ok(t, err)
   649  	}()
   650  	defer svr.Stop()
   651  
   652  	var captureErrs captureErrors
   653  	cconn, err := grpc.Dial(
   654  		l.Addr().String(),
   655  		grpc.WithTransportCredentials(insecure.NewCredentials()),
   656  		grpc.WithStreamInterceptor(captureErrs.intercept),
   657  	)
   658  	if err != nil {
   659  		panic(fmt.Sprintf("Failed to create grpc client: %s", err.Error()))
   660  	}
   661  	defer func() {
   662  		err := cconn.Close()
   663  		testutil.Ok(t, err)
   664  	}()
   665  	client := NewClientAuto(context.Background(), cconn)
   666  	now := time.Now()
   667  	client.now = func() time.Time {
   668  		return now
   669  	}
   670  
   671  	svcs, err := client.ListServices()
   672  	testutil.Ok(t, err)
   673  	sort.Strings(svcs)
   674  	testutil.Eq(t, []string{
   675  		"grpc.reflection.v1alpha.ServerReflection",
   676  		"testprotos.DummyService",
   677  	}, svcs)
   678  
   679  	// It should have tried v1 first and failed then tried v1alpha.
   680  	actualLog := capture.names()
   681  	testutil.Eq(t, []string{
   682  		"/grpc.reflection.v1.ServerReflection/ServerReflectionInfo",
   683  		"/grpc.reflection.v1alpha.ServerReflection/ServerReflectionInfo",
   684  	}, actualLog)
   685  
   686  	// Make sure the error code observed by the client was unavailable and not unimplemented.
   687  	actualCodes := captureErrs.codes()
   688  	testutil.Eq(t, []codes.Code{codes.Unavailable}, actualCodes)
   689  }
   690  
   691  type captureListener struct {
   692  	net.Listener
   693  	mu   sync.Mutex
   694  	conn net.Conn
   695  }
   696  
   697  func (c *captureListener) Accept() (net.Conn, error) {
   698  	conn, err := c.Listener.Accept()
   699  	if err == nil {
   700  		c.mu.Lock()
   701  		c.conn = conn
   702  		c.mu.Unlock()
   703  	}
   704  	return conn, err
   705  }
   706  
   707  func (c *captureListener) latest() net.Conn {
   708  	c.mu.Lock()
   709  	defer c.mu.Unlock()
   710  	return c.conn
   711  }
   712  
   713  type captureErrors struct {
   714  	mu       sync.Mutex
   715  	observed []codes.Code
   716  }
   717  
   718  func (c *captureErrors) intercept(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
   719  	stream, err := streamer(ctx, desc, cc, method, opts...)
   720  	if err != nil {
   721  		c.observe(err)
   722  		return nil, err
   723  	}
   724  	return &captureErrorStream{ClientStream: stream, c: c}, nil
   725  }
   726  
   727  func (c *captureErrors) observe(err error) {
   728  	c.mu.Lock()
   729  	c.observed = append(c.observed, status.Code(err))
   730  	c.mu.Unlock()
   731  }
   732  
   733  func (c *captureErrors) codes() []codes.Code {
   734  	c.mu.Lock()
   735  	defer c.mu.Unlock()
   736  	ret := make([]codes.Code, len(c.observed))
   737  	copy(ret, c.observed)
   738  	return ret
   739  }
   740  
   741  type captureErrorStream struct {
   742  	grpc.ClientStream
   743  	c    *captureErrors
   744  	done int32
   745  }
   746  
   747  func (c *captureErrorStream) RecvMsg(m interface{}) error {
   748  	err := c.ClientStream.RecvMsg(m)
   749  	if err == nil || errors.Is(err, io.EOF) {
   750  		return nil
   751  	}
   752  	// Only record one error per RPC.
   753  	if atomic.CompareAndSwapInt32(&c.done, 0, 1) {
   754  		c.c.observe(err)
   755  	}
   756  	return err
   757  }
   758  
   759  func createFilesWithMissingDeps(t *testing.T) *files {
   760  	t.Helper()
   761  	var result files
   762  	empty, err := protodesc.NewFile(&descriptorpb.FileDescriptorProto{
   763  		Name:   proto.String("empty.proto"),
   764  		Syntax: proto.String("proto2"),
   765  	}, &result)
   766  	testutil.Ok(t, err)
   767  
   768  	// These will be missing, so we create them as placeholders, so
   769  	// the protobuf-go runtime can resolve imports for them and
   770  	// still build a protoreflect.FileDescriptor.
   771  	err = result.RegisterFile(&placeholder{path: "test/custom/options.proto", FileDescriptor: empty})
   772  	testutil.Ok(t, err)
   773  	err = result.RegisterFile(&placeholder{path: "test/unused.proto", FileDescriptor: empty})
   774  	testutil.Ok(t, err)
   775  
   776  	// register google/protobuf/descriptor.proto from the embedded descriptor in descriptorpb
   777  	err = result.RegisterFile((*descriptorpb.FileDescriptorProto)(nil).ProtoReflect().Descriptor().ParentFile())
   778  	testutil.Ok(t, err)
   779  
   780  	importedFile := &descriptorpb.FileDescriptorProto{
   781  		Name:             proto.String("test/imported.proto"),
   782  		Syntax:           proto.String("proto3"),
   783  		Package:          proto.String("test"),
   784  		Dependency:       []string{"google/protobuf/descriptor.proto", "test/unused.proto"},
   785  		PublicDependency: []int32{1}, // unused is public
   786  		MessageType: []*descriptorpb.DescriptorProto{
   787  			{
   788  				Name: proto.String("Message"),
   789  				Field: []*descriptorpb.FieldDescriptorProto{
   790  					{
   791  						Name:     proto.String("name"),
   792  						Number:   proto.Int32(1),
   793  						Label:    descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(),
   794  						Type:     descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(),
   795  						JsonName: proto.String("name"),
   796  					},
   797  					{
   798  						Name:     proto.String("tags"),
   799  						Number:   proto.Int32(2),
   800  						Label:    descriptorpb.FieldDescriptorProto_LABEL_REPEATED.Enum(),
   801  						Type:     descriptorpb.FieldDescriptorProto_TYPE_UINT64.Enum(),
   802  						JsonName: proto.String("tags"),
   803  					},
   804  				},
   805  				Extension: []*descriptorpb.FieldDescriptorProto{
   806  					{
   807  						Extendee: proto.String(".google.protobuf.MessageOptions"),
   808  						Name:     proto.String("message_option"),
   809  						Number:   proto.Int32(10101),
   810  						Label:    descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(),
   811  						Type:     descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(),
   812  					},
   813  				},
   814  			},
   815  		},
   816  		EnumType: []*descriptorpb.EnumDescriptorProto{
   817  			{
   818  				Name: proto.String("Enum"),
   819  				Value: []*descriptorpb.EnumValueDescriptorProto{
   820  					{
   821  						Name:   proto.String("VAL0"),
   822  						Number: proto.Int32(0),
   823  					},
   824  					{
   825  						Name:   proto.String("VAL1"),
   826  						Number: proto.Int32(1),
   827  					},
   828  				},
   829  			},
   830  		},
   831  		Extension: []*descriptorpb.FieldDescriptorProto{
   832  			{
   833  				Extendee: proto.String(".google.protobuf.FileOptions"),
   834  				Name:     proto.String("file_option"),
   835  				Number:   proto.Int32(10101),
   836  				Label:    descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(),
   837  				Type:     descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(),
   838  			},
   839  		},
   840  	}
   841  	importedFileDesc, err := protodesc.NewFile(importedFile, &result)
   842  	testutil.Ok(t, err)
   843  	err = result.Files.RegisterFile(importedFileDesc)
   844  	testutil.Ok(t, err)
   845  
   846  	topFile := &descriptorpb.FileDescriptorProto{
   847  		Name:       proto.String("foo/bar/this.proto"),
   848  		Syntax:     proto.String("proto3"),
   849  		Package:    proto.String("foo.bar"),
   850  		Dependency: []string{"test/imported.proto", "test/unused.proto", "test/custom/options.proto"},
   851  		MessageType: []*descriptorpb.DescriptorProto{
   852  			{
   853  				Name: proto.String("Foo"),
   854  				Field: []*descriptorpb.FieldDescriptorProto{
   855  					{
   856  						Name:     proto.String("msg"),
   857  						Number:   proto.Int32(1),
   858  						Label:    descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(),
   859  						Type:     descriptorpb.FieldDescriptorProto_TYPE_MESSAGE.Enum(),
   860  						TypeName: proto.String(".test.Message"),
   861  						JsonName: proto.String("msg"),
   862  					},
   863  					{
   864  						Name:     proto.String("en"),
   865  						Number:   proto.Int32(2),
   866  						Label:    descriptorpb.FieldDescriptorProto_LABEL_REPEATED.Enum(),
   867  						Type:     descriptorpb.FieldDescriptorProto_TYPE_ENUM.Enum(),
   868  						TypeName: proto.String(".test.Enum"),
   869  						JsonName: proto.String("en"),
   870  					},
   871  				},
   872  			},
   873  			{
   874  				Name: proto.String("Bar"),
   875  				Field: []*descriptorpb.FieldDescriptorProto{
   876  					{
   877  						Name:     proto.String("foos"),
   878  						Number:   proto.Int32(1),
   879  						Label:    descriptorpb.FieldDescriptorProto_LABEL_REPEATED.Enum(),
   880  						Type:     descriptorpb.FieldDescriptorProto_TYPE_MESSAGE.Enum(),
   881  						TypeName: proto.String(".foo.bar.Foo"),
   882  						JsonName: proto.String("foos"),
   883  					},
   884  				},
   885  			},
   886  		},
   887  	}
   888  	topFileDesc, err := protodesc.NewFile(topFile, &result)
   889  	testutil.Ok(t, err)
   890  	err = result.Files.RegisterFile(topFileDesc)
   891  	testutil.Ok(t, err)
   892  
   893  	return &result
   894  }
   895  
   896  type files struct {
   897  	protoregistry.Files
   898  }
   899  
   900  func (f *files) FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) {
   901  	d, err := f.FindDescriptorByName(field)
   902  	if err != nil {
   903  		return nil, err
   904  	}
   905  	fd, ok := d.(protoreflect.FieldDescriptor)
   906  	if !ok {
   907  		return nil, fmt.Errorf("%s is not a field descriptor but a %T", field, fd)
   908  	}
   909  	if !fd.IsExtension() {
   910  		return nil, fmt.Errorf("%s is a normal field, not an extension", field)
   911  	}
   912  	return asExtensionType(fd), nil
   913  }
   914  
   915  func (f *files) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) {
   916  	var found protoreflect.ExtensionType
   917  	f.RangeExtensionsByMessage(message, func(xt protoreflect.ExtensionType) bool {
   918  		if xt.TypeDescriptor().Number() == field {
   919  			found = xt
   920  			return false
   921  		}
   922  		return true
   923  	})
   924  	if found == nil {
   925  		return nil, protoregistry.NotFound
   926  	}
   927  	return found, nil
   928  }
   929  
   930  func (f *files) RangeExtensionsByMessage(message protoreflect.FullName, fn func(protoreflect.ExtensionType) bool) {
   931  	f.RangeFiles(func(file protoreflect.FileDescriptor) bool {
   932  		return rangeExtensionsByMessage(file, message, fn)
   933  	})
   934  }
   935  
   936  func rangeExtensionsByMessage(
   937  	container interface {
   938  		Messages() protoreflect.MessageDescriptors
   939  		Extensions() protoreflect.ExtensionDescriptors
   940  	},
   941  	message protoreflect.FullName,
   942  	fn func(protoreflect.ExtensionType) bool,
   943  ) bool {
   944  	for i := 0; i < container.Extensions().Len(); i++ {
   945  		ext := container.Extensions().Get(i)
   946  		if ext.ContainingMessage().FullName() == message {
   947  			if !fn(asExtensionType(ext)) {
   948  				return false
   949  			}
   950  		}
   951  	}
   952  	for i := 0; i < container.Messages().Len(); i++ {
   953  		if !rangeExtensionsByMessage(container.Messages().Get(i), message, fn) {
   954  			return false
   955  		}
   956  	}
   957  	return true
   958  }
   959  
   960  func asExtensionType(fd protoreflect.ExtensionDescriptor) protoreflect.ExtensionType {
   961  	xtd, ok := fd.(protoreflect.ExtensionTypeDescriptor)
   962  	if ok {
   963  		return xtd.Type()
   964  	}
   965  	return dynamicpb.NewExtensionType(fd)
   966  }
   967  
   968  type placeholder struct {
   969  	path string
   970  	protoreflect.FileDescriptor
   971  }
   972  
   973  func (p *placeholder) IsPlaceholder() bool {
   974  	return true
   975  }
   976  
   977  func (p *placeholder) Path() string {
   978  	return p.path
   979  }
   980  
   981  func (p *placeholder) Syntax() protoreflect.Syntax {
   982  	return 0
   983  }