github.com/jhump/protoreflect@v1.16.0/grpcreflect/client_test.go (about) 1 package grpcreflect 2 3 import ( 4 "context" 5 "encoding/base64" 6 "errors" 7 "fmt" 8 "io" 9 "net" 10 "os" 11 "sort" 12 "sync" 13 "sync/atomic" 14 "testing" 15 "time" 16 17 "github.com/golang/protobuf/proto" 18 "google.golang.org/grpc" 19 "google.golang.org/grpc/codes" 20 "google.golang.org/grpc/credentials/insecure" 21 "google.golang.org/grpc/reflection" 22 reflectv1alpha "google.golang.org/grpc/reflection/grpc_reflection_v1alpha" 23 "google.golang.org/grpc/status" 24 "google.golang.org/protobuf/reflect/protodesc" 25 "google.golang.org/protobuf/reflect/protoreflect" 26 "google.golang.org/protobuf/reflect/protoregistry" 27 "google.golang.org/protobuf/types/descriptorpb" 28 "google.golang.org/protobuf/types/dynamicpb" 29 _ "google.golang.org/protobuf/types/known/apipb" 30 _ "google.golang.org/protobuf/types/known/emptypb" 31 _ "google.golang.org/protobuf/types/known/fieldmaskpb" 32 _ "google.golang.org/protobuf/types/known/sourcecontextpb" 33 _ "google.golang.org/protobuf/types/known/typepb" 34 _ "google.golang.org/protobuf/types/pluginpb" 35 36 "github.com/jhump/protoreflect/desc" 37 "github.com/jhump/protoreflect/internal" 38 testprotosgrpc "github.com/jhump/protoreflect/internal/testprotos/grpc" 39 "github.com/jhump/protoreflect/internal/testutil" 40 ) 41 42 var client *Client 43 44 func TestMain(m *testing.M) { 45 code := 1 46 defer func() { 47 p := recover() 48 if p != nil { 49 _, _ = fmt.Fprintf(os.Stderr, "PANIC: %v\n", p) 50 } 51 os.Exit(code) 52 }() 53 54 svr := grpc.NewServer() 55 testprotosgrpc.RegisterDummyServiceServer(svr, testService{}) 56 reflection.Register(svr) 57 l, err := net.Listen("tcp", "127.0.0.1:0") 58 if err != nil { 59 panic(fmt.Sprintf("Failed to open server socket: %s", err.Error())) 60 } 61 go func() { 62 _ = svr.Serve(l) 63 }() 64 defer svr.Stop() 65 66 // create grpc client 67 addr := l.Addr().String() 68 cconn, err := grpc.Dial(addr, grpc.WithTransportCredentials(insecure.NewCredentials())) 69 if err != nil { 70 panic(fmt.Sprintf("Failed to create grpc client: %s", err.Error())) 71 } 72 defer func() { 73 _ = cconn.Close() 74 }() 75 76 stub := reflectv1alpha.NewServerReflectionClient(cconn) 77 client = NewClientV1Alpha(context.Background(), stub) 78 79 code = m.Run() 80 } 81 82 func TestFileByFileName(t *testing.T) { 83 fd, err := client.FileByFilename("desc_test1.proto") 84 testutil.Ok(t, err) 85 // shallow check that the descriptor appears correct and complete 86 testutil.Eq(t, "desc_test1.proto", fd.GetName()) 87 testutil.Eq(t, "testprotos", fd.GetPackage()) 88 md := fd.GetMessageTypes()[0] 89 testutil.Eq(t, "TestMessage", md.GetName()) 90 md = md.GetNestedMessageTypes()[0] 91 testutil.Eq(t, "NestedMessage", md.GetName()) 92 md = md.GetNestedMessageTypes()[0] 93 testutil.Eq(t, "AnotherNestedMessage", md.GetName()) 94 md = md.GetNestedMessageTypes()[0] 95 testutil.Eq(t, "YetAnotherNestedMessage", md.GetName()) 96 ed := md.GetNestedEnumTypes()[0] 97 testutil.Eq(t, "DeeplyNestedEnum", ed.GetName()) 98 99 _, err = client.FileByFilename("does not exist") 100 testutil.Require(t, IsElementNotFoundError(err)) 101 } 102 103 func TestFileByFileNameForWellKnownProtos(t *testing.T) { 104 wellKnownProtos := map[string][]string{ 105 "google/protobuf/any.proto": {"google.protobuf.Any"}, 106 "google/protobuf/api.proto": {"google.protobuf.Api", "google.protobuf.Method", "google.protobuf.Mixin"}, 107 "google/protobuf/descriptor.proto": {"google.protobuf.FileDescriptorSet", "google.protobuf.DescriptorProto"}, 108 "google/protobuf/duration.proto": {"google.protobuf.Duration"}, 109 "google/protobuf/empty.proto": {"google.protobuf.Empty"}, 110 "google/protobuf/field_mask.proto": {"google.protobuf.FieldMask"}, 111 "google/protobuf/source_context.proto": {"google.protobuf.SourceContext"}, 112 "google/protobuf/struct.proto": {"google.protobuf.Struct", "google.protobuf.Value", "google.protobuf.NullValue"}, 113 "google/protobuf/timestamp.proto": {"google.protobuf.Timestamp"}, 114 "google/protobuf/type.proto": {"google.protobuf.Type", "google.protobuf.Field", "google.protobuf.Syntax"}, 115 "google/protobuf/wrappers.proto": {"google.protobuf.DoubleValue", "google.protobuf.Int32Value", "google.protobuf.StringValue"}, 116 "google/protobuf/compiler/plugin.proto": {"google.protobuf.compiler.CodeGeneratorRequest"}, 117 } 118 119 for file, types := range wellKnownProtos { 120 fd, err := client.FileByFilename(file) 121 testutil.Ok(t, err) 122 testutil.Eq(t, file, fd.GetName()) 123 for _, typ := range types { 124 d := fd.FindSymbol(typ) 125 testutil.Require(t, d != nil) 126 } 127 128 // also try loading via alternate name 129 file = internal.StdFileAliases[file] 130 if file == "" { 131 // not a file that has a known alternate, so nothing else to check... 132 continue 133 } 134 fd, err = client.FileByFilename(file) 135 testutil.Ok(t, err) 136 testutil.Eq(t, file, fd.GetName()) 137 for _, typ := range types { 138 d := fd.FindSymbol(typ) 139 testutil.Require(t, d != nil) 140 } 141 } 142 } 143 144 func TestFileContainingSymbol(t *testing.T) { 145 fd, err := client.FileContainingSymbol("TopLevel") 146 testutil.Ok(t, err) 147 // shallow check that the descriptor appears correct and complete 148 testutil.Eq(t, "nopkg/desc_test_nopkg_new.proto", fd.GetName()) 149 testutil.Eq(t, "", fd.GetPackage()) 150 md := fd.GetMessageTypes()[0] 151 testutil.Eq(t, "TopLevel", md.GetName()) 152 testutil.Eq(t, "i", md.GetFields()[0].GetName()) 153 testutil.Eq(t, "j", md.GetFields()[1].GetName()) 154 testutil.Eq(t, "k", md.GetFields()[2].GetName()) 155 testutil.Eq(t, "l", md.GetFields()[3].GetName()) 156 testutil.Eq(t, "m", md.GetFields()[4].GetName()) 157 testutil.Eq(t, "n", md.GetFields()[5].GetName()) 158 testutil.Eq(t, "o", md.GetFields()[6].GetName()) 159 testutil.Eq(t, "p", md.GetFields()[7].GetName()) 160 testutil.Eq(t, "q", md.GetFields()[8].GetName()) 161 testutil.Eq(t, "r", md.GetFields()[9].GetName()) 162 testutil.Eq(t, "s", md.GetFields()[10].GetName()) 163 testutil.Eq(t, "t", md.GetFields()[11].GetName()) 164 165 _, err = client.FileContainingSymbol("does not exist") 166 testutil.Require(t, IsElementNotFoundError(err)) 167 } 168 169 func TestFileContainingExtension(t *testing.T) { 170 fd, err := client.FileContainingExtension("TopLevel", 100) 171 testutil.Ok(t, err) 172 // shallow check that the descriptor appears correct and complete 173 testutil.Eq(t, "desc_test2.proto", fd.GetName()) 174 testutil.Eq(t, "testprotos", fd.GetPackage()) 175 testutil.Eq(t, 4, len(fd.GetMessageTypes())) 176 testutil.Eq(t, "Frobnitz", fd.GetMessageTypes()[0].GetName()) 177 testutil.Eq(t, "Whatchamacallit", fd.GetMessageTypes()[1].GetName()) 178 testutil.Eq(t, "Whatzit", fd.GetMessageTypes()[2].GetName()) 179 testutil.Eq(t, "GroupX", fd.GetMessageTypes()[3].GetName()) 180 181 testutil.Eq(t, "desc_test1.proto", fd.GetDependencies()[0].GetName()) 182 testutil.Eq(t, "pkg/desc_test_pkg.proto", fd.GetDependencies()[1].GetName()) 183 testutil.Eq(t, "nopkg/desc_test_nopkg.proto", fd.GetDependencies()[2].GetName()) 184 185 _, err = client.FileContainingExtension("does not exist", 100) 186 testutil.Require(t, IsElementNotFoundError(err)) 187 _, err = client.FileContainingExtension("TopLevel", -9) 188 testutil.Require(t, IsElementNotFoundError(err)) 189 } 190 191 func TestAllExtensionNumbersForType(t *testing.T) { 192 nums, err := client.AllExtensionNumbersForType("TopLevel") 193 testutil.Ok(t, err) 194 inums := make([]int, len(nums)) 195 for idx, v := range nums { 196 inums[idx] = int(v) 197 } 198 sort.Ints(inums) 199 testutil.Eq(t, []int{100, 104}, inums) 200 201 nums, err = client.AllExtensionNumbersForType("testprotos.AnotherTestMessage") 202 testutil.Ok(t, err) 203 testutil.Eq(t, 5, len(nums)) 204 inums = make([]int, len(nums)) 205 for idx, v := range nums { 206 inums[idx] = int(v) 207 } 208 sort.Ints(inums) 209 testutil.Eq(t, []int{100, 101, 102, 103, 200}, inums) 210 211 _, err = client.AllExtensionNumbersForType("does not exist") 212 testutil.Require(t, IsElementNotFoundError(err)) 213 } 214 215 func TestListServices(t *testing.T) { 216 s, err := client.ListServices() 217 testutil.Ok(t, err) 218 219 sort.Strings(s) 220 testutil.Eq(t, []string{ 221 "grpc.reflection.v1.ServerReflection", 222 "grpc.reflection.v1alpha.ServerReflection", 223 "testprotos.DummyService", 224 }, s) 225 } 226 227 func TestReset(t *testing.T) { 228 _, err := client.ListServices() 229 testutil.Ok(t, err) 230 231 // save the current stream 232 stream := client.stream 233 // intercept cancellation 234 cancel := client.cancel 235 var cancelled atomic.Bool 236 client.cancel = func() { 237 cancelled.Store(true) 238 cancel() 239 } 240 241 client.Reset() 242 testutil.Eq(t, true, cancelled.Load()) 243 testutil.Eq(t, nil, client.stream) 244 245 _, err = client.ListServices() 246 testutil.Ok(t, err) 247 248 // stream was re-created 249 testutil.Eq(t, true, client.stream != nil && client.stream != stream) 250 } 251 252 func TestRecover(t *testing.T) { 253 _, err := client.ListServices() 254 testutil.Ok(t, err) 255 256 // kill the stream 257 stream := client.stream 258 err = client.stream.CloseSend() 259 testutil.Ok(t, err) 260 261 // it should auto-recover and re-create stream 262 _, err = client.ListServices() 263 testutil.Ok(t, err) 264 testutil.Eq(t, true, client.stream != nil && client.stream != stream) 265 } 266 267 func TestMultipleFiles(t *testing.T) { 268 svr := grpc.NewServer() 269 reflectv1alpha.RegisterServerReflectionServer(svr, testReflectionServer{}) 270 271 l, err := net.Listen("tcp", "127.0.0.1:0") 272 testutil.Ok(t, err, "failed to listen") 273 ctx, cancel := context.WithCancel(context.Background()) 274 defer cancel() 275 go func() { 276 defer cancel() 277 if err := svr.Serve(l); err != nil { 278 t.Logf("serve returned error: %v", err) 279 } 280 }() 281 time.Sleep(100 * time.Millisecond) // give server a chance to start 282 testutil.Ok(t, ctx.Err(), "failed to start server") 283 defer func() { 284 svr.Stop() 285 }() 286 287 dialCtx, dialCancel := context.WithTimeout(ctx, 3*time.Second) 288 defer dialCancel() 289 cc, err := grpc.DialContext(dialCtx, l.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock()) 290 testutil.Ok(t, err, "failed ot dial %v", l.Addr().String()) 291 cl := reflectv1alpha.NewServerReflectionClient(cc) 292 293 client := NewClientV1Alpha(ctx, cl) 294 defer client.Reset() 295 svcs, err := client.ListServices() 296 testutil.Ok(t, err, "failed to list services") 297 for _, svc := range svcs { 298 fd, err := client.FileContainingSymbol(svc) 299 testutil.Ok(t, err, "failed to file for service %v", svc) 300 sd := fd.FindSymbol(svc) 301 _, ok := sd.(*desc.ServiceDescriptor) 302 testutil.Require(t, ok, "symbol for %s is not a service descriptor, instead is %T", svc, sd) 303 } 304 } 305 306 func TestAllowMissingFileDescriptors(t *testing.T) { 307 svr := grpc.NewServer() 308 files := createFilesWithMissingDeps(t) 309 reflectionSvc := reflection.NewServer(reflection.ServerOptions{ 310 DescriptorResolver: files, 311 ExtensionResolver: files, 312 }) 313 reflectv1alpha.RegisterServerReflectionServer(svr, reflectionSvc) 314 315 l, err := net.Listen("tcp", "127.0.0.1:0") 316 testutil.Ok(t, err, "failed to listen") 317 ctx, cancel := context.WithCancel(context.Background()) 318 defer cancel() 319 go func() { 320 defer cancel() 321 if err := svr.Serve(l); err != nil { 322 t.Logf("serve returned error: %v", err) 323 } 324 }() 325 time.Sleep(100 * time.Millisecond) // give server a chance to start 326 testutil.Ok(t, ctx.Err(), "failed to start server") 327 defer func() { 328 svr.Stop() 329 }() 330 331 dialCtx, dialCancel := context.WithTimeout(ctx, 3*time.Second) 332 defer dialCancel() 333 cc, err := grpc.DialContext(dialCtx, l.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock()) 334 testutil.Ok(t, err, "failed ot dial %v", l.Addr().String()) 335 cl := reflectv1alpha.NewServerReflectionClient(cc) 336 337 client := NewClientV1Alpha(ctx, cl) 338 defer client.Reset() 339 340 // First we try some things that should fail due to missing descriptors. 341 _, err = client.FileByFilename("foo/bar/this.proto") 342 testutil.Nok(t, err) 343 _, err = client.FileContainingSymbol("foo.bar.Bar") 344 testutil.Nok(t, err) 345 _, err = client.FileContainingExtension("google.protobuf.MessageOptions", 10101) 346 testutil.Nok(t, err) 347 348 client.AllowMissingFileDescriptors() 349 // Now the above queries should succeed. 350 file, err := client.FileByFilename("foo/bar/this.proto") 351 testutil.Ok(t, err) 352 testutil.Require(t, file != nil) 353 testutil.Eq(t, "foo/bar/this.proto", file.GetName()) 354 _, err = client.FileContainingSymbol("foo.bar.Bar") 355 testutil.Ok(t, err) 356 testutil.Require(t, file != nil) 357 testutil.Eq(t, "foo/bar/this.proto", file.GetName()) 358 _, err = client.FileContainingExtension("google.protobuf.MessageOptions", 10101) 359 testutil.Ok(t, err) 360 testutil.Require(t, file != nil) 361 testutil.Eq(t, "foo/bar/this.proto", file.GetName()) 362 } 363 364 func TestFileWithoutDeps(t *testing.T) { 365 fd := &descriptorpb.FileDescriptorProto{ 366 Dependency: []string{ 367 "foo/bar.proto", 368 "foo/public/bar.proto", // missing 369 "foo/weak/bar.proto", 370 "foo/baz.proto", // missing 371 "foo/public/baz.proto", 372 "foo/weak/baz.proto", // missing 373 "foo/fizz.proto", 374 "foo/public/fizz.proto", // missing 375 "foo/weak/fizz.proto", 376 "foo/buzz.proto", // missing 377 "foo/public/buzz.proto", 378 "foo/weak/buzz.proto", // missing 379 }, 380 PublicDependency: []int32{1, 4, 7, 10}, 381 WeakDependency: []int32{2, 5, 8, 11}, 382 } 383 fd = fileWithoutDeps(fd, []int{1, 3, 5, 7, 9, 11}) 384 testutil.Eq(t, 385 []string{ 386 "foo/bar.proto", 387 "foo/weak/bar.proto", 388 "foo/public/baz.proto", 389 "foo/fizz.proto", 390 "foo/weak/fizz.proto", 391 "foo/public/buzz.proto", 392 }, 393 fd.Dependency) 394 testutil.Eq(t, []int32{2, 5}, fd.PublicDependency) 395 testutil.Eq(t, []int32{1, 4}, fd.WeakDependency) 396 } 397 398 type testReflectionServer struct{} 399 400 func (t testReflectionServer) ServerReflectionInfo(server reflectv1alpha.ServerReflection_ServerReflectionInfoServer) error { 401 const svcA_file = "ChdzYW5kYm94L3NlcnZpY2VfQS5wcm90bxIHc2FuZGJveCIWCghSZXF1ZXN0QRIKCgJpZBgBIAEoBSIYCglSZXNwb25zZUESCwoDc3RyGAEgASgJMj0KCVNlcnZpY2VfQRIwCgdFeGVjdXRlEhEuc2FuZGJveC5SZXF1ZXN0QRoSLnNhbmRib3guUmVzcG9uc2VBYgZwcm90bzM=" 402 const svcB_file = "ChdzYW5kYm94L1NlcnZpY2VfQi5wcm90bxIHc2FuZGJveCIWCghSZXF1ZXN0QhIKCgJpZBgBIAEoBSIYCglSZXNwb25zZUISCwoDc3RyGAEgASgJMj0KCVNlcnZpY2VfQhIwCgdFeGVjdXRlEhEuc2FuZGJveC5SZXF1ZXN0QhoSLnNhbmRib3guUmVzcG9uc2VCYgZwcm90bzM=" 403 404 for { 405 req, err := server.Recv() 406 if err == io.EOF { 407 return nil 408 } else if err != nil { 409 return err 410 } 411 var resp reflectv1alpha.ServerReflectionResponse 412 resp.OriginalRequest = req 413 switch req := req.MessageRequest.(type) { 414 case *reflectv1alpha.ServerReflectionRequest_FileByFilename: 415 switch req.FileByFilename { 416 case "sandbox/service_A.proto": 417 resp.MessageResponse = msgResponseForFiles(svcA_file) 418 case "sandbox/service_B.proto": 419 resp.MessageResponse = msgResponseForFiles(svcB_file) 420 default: 421 resp.MessageResponse = &reflectv1alpha.ServerReflectionResponse_ErrorResponse{ 422 ErrorResponse: &reflectv1alpha.ErrorResponse{ 423 ErrorCode: int32(codes.NotFound), 424 ErrorMessage: "not found", 425 }, 426 } 427 } 428 case *reflectv1alpha.ServerReflectionRequest_FileContainingSymbol: 429 switch req.FileContainingSymbol { 430 case "sandbox.Service_A": 431 resp.MessageResponse = msgResponseForFiles(svcA_file) 432 case "sandbox.Service_B": 433 // HERE is where we return two files instead of one 434 resp.MessageResponse = msgResponseForFiles(svcA_file, svcB_file) 435 default: 436 resp.MessageResponse = &reflectv1alpha.ServerReflectionResponse_ErrorResponse{ 437 ErrorResponse: &reflectv1alpha.ErrorResponse{ 438 ErrorCode: int32(codes.NotFound), 439 ErrorMessage: "not found", 440 }, 441 } 442 } 443 case *reflectv1alpha.ServerReflectionRequest_ListServices: 444 resp.MessageResponse = &reflectv1alpha.ServerReflectionResponse_ListServicesResponse{ 445 ListServicesResponse: &reflectv1alpha.ListServiceResponse{ 446 Service: []*reflectv1alpha.ServiceResponse{ 447 {Name: "sandbox.Service_A"}, 448 {Name: "sandbox.Service_B"}, 449 }, 450 }, 451 } 452 default: 453 resp.MessageResponse = &reflectv1alpha.ServerReflectionResponse_ErrorResponse{ 454 ErrorResponse: &reflectv1alpha.ErrorResponse{ 455 ErrorCode: int32(codes.NotFound), 456 ErrorMessage: "not found", 457 }, 458 } 459 } 460 if err := server.Send(&resp); err != nil { 461 return err 462 } 463 } 464 } 465 466 func msgResponseForFiles(files ...string) *reflectv1alpha.ServerReflectionResponse_FileDescriptorResponse { 467 descs := make([][]byte, len(files)) 468 for i, f := range files { 469 b, err := base64.StdEncoding.DecodeString(f) 470 if err != nil { 471 panic(err) 472 } 473 descs[i] = b 474 } 475 return &reflectv1alpha.ServerReflectionResponse_FileDescriptorResponse{ 476 FileDescriptorResponse: &reflectv1alpha.FileDescriptorResponse{ 477 FileDescriptorProto: descs, 478 }, 479 } 480 } 481 482 func TestAutoVersion(t *testing.T) { 483 t.Run("v1", func(t *testing.T) { 484 testClientAuto(t, 485 func(s *grpc.Server) { 486 reflection.RegisterV1(s) 487 testprotosgrpc.RegisterDummyServiceServer(s, testService{}) 488 }, 489 []string{ 490 "grpc.reflection.v1.ServerReflection", 491 "testprotos.DummyService", 492 }, 493 []string{ 494 "/grpc.reflection.v1.ServerReflection/ServerReflectionInfo", 495 "/grpc.reflection.v1.ServerReflection/ServerReflectionInfo", 496 "/grpc.reflection.v1.ServerReflection/ServerReflectionInfo", 497 "/grpc.reflection.v1.ServerReflection/ServerReflectionInfo", 498 }) 499 }) 500 501 t.Run("v1alpha", func(t *testing.T) { 502 testClientAuto(t, 503 func(s *grpc.Server) { 504 impl := reflection.NewServer(reflection.ServerOptions{Services: s}) 505 reflectv1alpha.RegisterServerReflectionServer(s, impl) 506 testprotosgrpc.RegisterDummyServiceServer(s, testService{}) 507 }, 508 []string{ 509 "grpc.reflection.v1alpha.ServerReflection", 510 "testprotos.DummyService", 511 }, 512 []string{ 513 // first one fails, so falls back to v1alpha 514 "/grpc.reflection.v1.ServerReflection/ServerReflectionInfo", 515 "/grpc.reflection.v1alpha.ServerReflection/ServerReflectionInfo", 516 // next two use v1alpha 517 "/grpc.reflection.v1alpha.ServerReflection/ServerReflectionInfo", 518 "/grpc.reflection.v1alpha.ServerReflection/ServerReflectionInfo", 519 // final one retries v1 520 "/grpc.reflection.v1.ServerReflection/ServerReflectionInfo", 521 "/grpc.reflection.v1alpha.ServerReflection/ServerReflectionInfo", 522 }) 523 }) 524 525 t.Run("both", func(t *testing.T) { 526 testClientAuto(t, 527 func(s *grpc.Server) { 528 reflection.Register(s) 529 testprotosgrpc.RegisterDummyServiceServer(s, testService{}) 530 }, 531 []string{ 532 "grpc.reflection.v1.ServerReflection", 533 "grpc.reflection.v1alpha.ServerReflection", 534 "testprotos.DummyService", 535 }, 536 []string{ 537 // never uses v1alpha since v1 works 538 "/grpc.reflection.v1.ServerReflection/ServerReflectionInfo", 539 "/grpc.reflection.v1.ServerReflection/ServerReflectionInfo", 540 "/grpc.reflection.v1.ServerReflection/ServerReflectionInfo", 541 "/grpc.reflection.v1.ServerReflection/ServerReflectionInfo", 542 }) 543 }) 544 545 t.Run("fallback-on-unavailable", testClientAutoOnUnavailable) 546 } 547 548 func testClientAuto(t *testing.T, register func(*grpc.Server), expectedServices []string, expectedLog []string) { 549 var capture captureStreamNames 550 svr := grpc.NewServer(grpc.StreamInterceptor(capture.intercept), grpc.UnknownServiceHandler(capture.handleUnknown)) 551 register(svr) 552 l, err := net.Listen("tcp", "127.0.0.1:0") 553 if err != nil { 554 panic(fmt.Sprintf("Failed to open server socket: %s", err.Error())) 555 } 556 go func() { 557 err := svr.Serve(l) 558 testutil.Ok(t, err) 559 }() 560 defer svr.Stop() 561 562 cconn, err := grpc.Dial(l.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) 563 if err != nil { 564 panic(fmt.Sprintf("Failed to create grpc client: %s", err.Error())) 565 } 566 defer func() { 567 err := cconn.Close() 568 testutil.Ok(t, err) 569 }() 570 client := NewClientAuto(context.Background(), cconn) 571 now := time.Now() 572 client.now = func() time.Time { 573 return now 574 } 575 576 svcs, err := client.ListServices() 577 testutil.Ok(t, err) 578 sort.Strings(svcs) 579 testutil.Eq(t, expectedServices, svcs) 580 client.Reset() 581 582 _, err = client.FileContainingSymbol(svcs[0]) 583 testutil.Ok(t, err) 584 client.Reset() 585 586 // at the threshold, but not quite enough to retry 587 now = now.Add(time.Hour) 588 _, err = client.ListServices() 589 testutil.Ok(t, err) 590 client.Reset() 591 592 // 1 ns more, and we've crossed threshold and will retry 593 now = now.Add(1) 594 _, err = client.ListServices() 595 testutil.Ok(t, err) 596 client.Reset() 597 598 actualLog := capture.names() 599 testutil.Eq(t, expectedLog, actualLog) 600 } 601 602 type captureStreamNames struct { 603 mu sync.Mutex 604 log []string 605 } 606 607 func (c *captureStreamNames) names() []string { 608 c.mu.Lock() 609 defer c.mu.Unlock() 610 ret := make([]string, len(c.log)) 611 copy(ret, c.log) 612 return ret 613 } 614 615 func (c *captureStreamNames) intercept(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { 616 c.mu.Lock() 617 c.log = append(c.log, info.FullMethod) 618 c.mu.Unlock() 619 return handler(srv, ss) 620 } 621 622 func (c *captureStreamNames) handleUnknown(_ interface{}, _ grpc.ServerStream) error { 623 return status.Errorf(codes.Unimplemented, "WTF?") 624 } 625 626 func testClientAutoOnUnavailable(t *testing.T) { 627 l, err := net.Listen("tcp", "127.0.0.1:0") 628 if err != nil { 629 panic(fmt.Sprintf("Failed to open server socket: %s", err.Error())) 630 } 631 captureConn := &captureListener{Listener: l} 632 633 var capture captureStreamNames 634 svr := grpc.NewServer( 635 grpc.StreamInterceptor(capture.intercept), 636 grpc.UnknownServiceHandler(func(_ interface{}, _ grpc.ServerStream) error { 637 // On unknown method, forcibly close the net.Conn, without sending 638 // back any reply, which should result in an "unavailable" error. 639 return captureConn.latest().Close() 640 }), 641 ) 642 impl := reflection.NewServer(reflection.ServerOptions{Services: svr}) 643 reflectv1alpha.RegisterServerReflectionServer(svr, impl) 644 testprotosgrpc.RegisterDummyServiceServer(svr, testService{}) 645 646 go func() { 647 err := svr.Serve(captureConn) 648 testutil.Ok(t, err) 649 }() 650 defer svr.Stop() 651 652 var captureErrs captureErrors 653 cconn, err := grpc.Dial( 654 l.Addr().String(), 655 grpc.WithTransportCredentials(insecure.NewCredentials()), 656 grpc.WithStreamInterceptor(captureErrs.intercept), 657 ) 658 if err != nil { 659 panic(fmt.Sprintf("Failed to create grpc client: %s", err.Error())) 660 } 661 defer func() { 662 err := cconn.Close() 663 testutil.Ok(t, err) 664 }() 665 client := NewClientAuto(context.Background(), cconn) 666 now := time.Now() 667 client.now = func() time.Time { 668 return now 669 } 670 671 svcs, err := client.ListServices() 672 testutil.Ok(t, err) 673 sort.Strings(svcs) 674 testutil.Eq(t, []string{ 675 "grpc.reflection.v1alpha.ServerReflection", 676 "testprotos.DummyService", 677 }, svcs) 678 679 // It should have tried v1 first and failed then tried v1alpha. 680 actualLog := capture.names() 681 testutil.Eq(t, []string{ 682 "/grpc.reflection.v1.ServerReflection/ServerReflectionInfo", 683 "/grpc.reflection.v1alpha.ServerReflection/ServerReflectionInfo", 684 }, actualLog) 685 686 // Make sure the error code observed by the client was unavailable and not unimplemented. 687 actualCodes := captureErrs.codes() 688 testutil.Eq(t, []codes.Code{codes.Unavailable}, actualCodes) 689 } 690 691 type captureListener struct { 692 net.Listener 693 mu sync.Mutex 694 conn net.Conn 695 } 696 697 func (c *captureListener) Accept() (net.Conn, error) { 698 conn, err := c.Listener.Accept() 699 if err == nil { 700 c.mu.Lock() 701 c.conn = conn 702 c.mu.Unlock() 703 } 704 return conn, err 705 } 706 707 func (c *captureListener) latest() net.Conn { 708 c.mu.Lock() 709 defer c.mu.Unlock() 710 return c.conn 711 } 712 713 type captureErrors struct { 714 mu sync.Mutex 715 observed []codes.Code 716 } 717 718 func (c *captureErrors) intercept(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { 719 stream, err := streamer(ctx, desc, cc, method, opts...) 720 if err != nil { 721 c.observe(err) 722 return nil, err 723 } 724 return &captureErrorStream{ClientStream: stream, c: c}, nil 725 } 726 727 func (c *captureErrors) observe(err error) { 728 c.mu.Lock() 729 c.observed = append(c.observed, status.Code(err)) 730 c.mu.Unlock() 731 } 732 733 func (c *captureErrors) codes() []codes.Code { 734 c.mu.Lock() 735 defer c.mu.Unlock() 736 ret := make([]codes.Code, len(c.observed)) 737 copy(ret, c.observed) 738 return ret 739 } 740 741 type captureErrorStream struct { 742 grpc.ClientStream 743 c *captureErrors 744 done int32 745 } 746 747 func (c *captureErrorStream) RecvMsg(m interface{}) error { 748 err := c.ClientStream.RecvMsg(m) 749 if err == nil || errors.Is(err, io.EOF) { 750 return nil 751 } 752 // Only record one error per RPC. 753 if atomic.CompareAndSwapInt32(&c.done, 0, 1) { 754 c.c.observe(err) 755 } 756 return err 757 } 758 759 func createFilesWithMissingDeps(t *testing.T) *files { 760 t.Helper() 761 var result files 762 empty, err := protodesc.NewFile(&descriptorpb.FileDescriptorProto{ 763 Name: proto.String("empty.proto"), 764 Syntax: proto.String("proto2"), 765 }, &result) 766 testutil.Ok(t, err) 767 768 // These will be missing, so we create them as placeholders, so 769 // the protobuf-go runtime can resolve imports for them and 770 // still build a protoreflect.FileDescriptor. 771 err = result.RegisterFile(&placeholder{path: "test/custom/options.proto", FileDescriptor: empty}) 772 testutil.Ok(t, err) 773 err = result.RegisterFile(&placeholder{path: "test/unused.proto", FileDescriptor: empty}) 774 testutil.Ok(t, err) 775 776 // register google/protobuf/descriptor.proto from the embedded descriptor in descriptorpb 777 err = result.RegisterFile((*descriptorpb.FileDescriptorProto)(nil).ProtoReflect().Descriptor().ParentFile()) 778 testutil.Ok(t, err) 779 780 importedFile := &descriptorpb.FileDescriptorProto{ 781 Name: proto.String("test/imported.proto"), 782 Syntax: proto.String("proto3"), 783 Package: proto.String("test"), 784 Dependency: []string{"google/protobuf/descriptor.proto", "test/unused.proto"}, 785 PublicDependency: []int32{1}, // unused is public 786 MessageType: []*descriptorpb.DescriptorProto{ 787 { 788 Name: proto.String("Message"), 789 Field: []*descriptorpb.FieldDescriptorProto{ 790 { 791 Name: proto.String("name"), 792 Number: proto.Int32(1), 793 Label: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(), 794 Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(), 795 JsonName: proto.String("name"), 796 }, 797 { 798 Name: proto.String("tags"), 799 Number: proto.Int32(2), 800 Label: descriptorpb.FieldDescriptorProto_LABEL_REPEATED.Enum(), 801 Type: descriptorpb.FieldDescriptorProto_TYPE_UINT64.Enum(), 802 JsonName: proto.String("tags"), 803 }, 804 }, 805 Extension: []*descriptorpb.FieldDescriptorProto{ 806 { 807 Extendee: proto.String(".google.protobuf.MessageOptions"), 808 Name: proto.String("message_option"), 809 Number: proto.Int32(10101), 810 Label: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(), 811 Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(), 812 }, 813 }, 814 }, 815 }, 816 EnumType: []*descriptorpb.EnumDescriptorProto{ 817 { 818 Name: proto.String("Enum"), 819 Value: []*descriptorpb.EnumValueDescriptorProto{ 820 { 821 Name: proto.String("VAL0"), 822 Number: proto.Int32(0), 823 }, 824 { 825 Name: proto.String("VAL1"), 826 Number: proto.Int32(1), 827 }, 828 }, 829 }, 830 }, 831 Extension: []*descriptorpb.FieldDescriptorProto{ 832 { 833 Extendee: proto.String(".google.protobuf.FileOptions"), 834 Name: proto.String("file_option"), 835 Number: proto.Int32(10101), 836 Label: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(), 837 Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(), 838 }, 839 }, 840 } 841 importedFileDesc, err := protodesc.NewFile(importedFile, &result) 842 testutil.Ok(t, err) 843 err = result.Files.RegisterFile(importedFileDesc) 844 testutil.Ok(t, err) 845 846 topFile := &descriptorpb.FileDescriptorProto{ 847 Name: proto.String("foo/bar/this.proto"), 848 Syntax: proto.String("proto3"), 849 Package: proto.String("foo.bar"), 850 Dependency: []string{"test/imported.proto", "test/unused.proto", "test/custom/options.proto"}, 851 MessageType: []*descriptorpb.DescriptorProto{ 852 { 853 Name: proto.String("Foo"), 854 Field: []*descriptorpb.FieldDescriptorProto{ 855 { 856 Name: proto.String("msg"), 857 Number: proto.Int32(1), 858 Label: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(), 859 Type: descriptorpb.FieldDescriptorProto_TYPE_MESSAGE.Enum(), 860 TypeName: proto.String(".test.Message"), 861 JsonName: proto.String("msg"), 862 }, 863 { 864 Name: proto.String("en"), 865 Number: proto.Int32(2), 866 Label: descriptorpb.FieldDescriptorProto_LABEL_REPEATED.Enum(), 867 Type: descriptorpb.FieldDescriptorProto_TYPE_ENUM.Enum(), 868 TypeName: proto.String(".test.Enum"), 869 JsonName: proto.String("en"), 870 }, 871 }, 872 }, 873 { 874 Name: proto.String("Bar"), 875 Field: []*descriptorpb.FieldDescriptorProto{ 876 { 877 Name: proto.String("foos"), 878 Number: proto.Int32(1), 879 Label: descriptorpb.FieldDescriptorProto_LABEL_REPEATED.Enum(), 880 Type: descriptorpb.FieldDescriptorProto_TYPE_MESSAGE.Enum(), 881 TypeName: proto.String(".foo.bar.Foo"), 882 JsonName: proto.String("foos"), 883 }, 884 }, 885 }, 886 }, 887 } 888 topFileDesc, err := protodesc.NewFile(topFile, &result) 889 testutil.Ok(t, err) 890 err = result.Files.RegisterFile(topFileDesc) 891 testutil.Ok(t, err) 892 893 return &result 894 } 895 896 type files struct { 897 protoregistry.Files 898 } 899 900 func (f *files) FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) { 901 d, err := f.FindDescriptorByName(field) 902 if err != nil { 903 return nil, err 904 } 905 fd, ok := d.(protoreflect.FieldDescriptor) 906 if !ok { 907 return nil, fmt.Errorf("%s is not a field descriptor but a %T", field, fd) 908 } 909 if !fd.IsExtension() { 910 return nil, fmt.Errorf("%s is a normal field, not an extension", field) 911 } 912 return asExtensionType(fd), nil 913 } 914 915 func (f *files) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) { 916 var found protoreflect.ExtensionType 917 f.RangeExtensionsByMessage(message, func(xt protoreflect.ExtensionType) bool { 918 if xt.TypeDescriptor().Number() == field { 919 found = xt 920 return false 921 } 922 return true 923 }) 924 if found == nil { 925 return nil, protoregistry.NotFound 926 } 927 return found, nil 928 } 929 930 func (f *files) RangeExtensionsByMessage(message protoreflect.FullName, fn func(protoreflect.ExtensionType) bool) { 931 f.RangeFiles(func(file protoreflect.FileDescriptor) bool { 932 return rangeExtensionsByMessage(file, message, fn) 933 }) 934 } 935 936 func rangeExtensionsByMessage( 937 container interface { 938 Messages() protoreflect.MessageDescriptors 939 Extensions() protoreflect.ExtensionDescriptors 940 }, 941 message protoreflect.FullName, 942 fn func(protoreflect.ExtensionType) bool, 943 ) bool { 944 for i := 0; i < container.Extensions().Len(); i++ { 945 ext := container.Extensions().Get(i) 946 if ext.ContainingMessage().FullName() == message { 947 if !fn(asExtensionType(ext)) { 948 return false 949 } 950 } 951 } 952 for i := 0; i < container.Messages().Len(); i++ { 953 if !rangeExtensionsByMessage(container.Messages().Get(i), message, fn) { 954 return false 955 } 956 } 957 return true 958 } 959 960 func asExtensionType(fd protoreflect.ExtensionDescriptor) protoreflect.ExtensionType { 961 xtd, ok := fd.(protoreflect.ExtensionTypeDescriptor) 962 if ok { 963 return xtd.Type() 964 } 965 return dynamicpb.NewExtensionType(fd) 966 } 967 968 type placeholder struct { 969 path string 970 protoreflect.FileDescriptor 971 } 972 973 func (p *placeholder) IsPlaceholder() bool { 974 return true 975 } 976 977 func (p *placeholder) Path() string { 978 return p.path 979 } 980 981 func (p *placeholder) Syntax() protoreflect.Syntax { 982 return 0 983 }