github.com/Big-big-orange/protoreflect@v0.0.0-20240408141420-285cedfdf6a4/grpcreflect/client_test.go (about) 1 package grpcreflect 2 3 import ( 4 "context" 5 "encoding/base64" 6 "errors" 7 "fmt" 8 "io" 9 "net" 10 "os" 11 "sort" 12 "sync" 13 "sync/atomic" 14 "testing" 15 "time" 16 17 "google.golang.org/grpc" 18 "google.golang.org/grpc/codes" 19 "google.golang.org/grpc/credentials/insecure" 20 "google.golang.org/grpc/reflection" 21 reflectv1alpha "google.golang.org/grpc/reflection/grpc_reflection_v1alpha" 22 "google.golang.org/grpc/status" 23 _ "google.golang.org/protobuf/types/known/apipb" 24 _ "google.golang.org/protobuf/types/known/emptypb" 25 _ "google.golang.org/protobuf/types/known/fieldmaskpb" 26 _ "google.golang.org/protobuf/types/known/sourcecontextpb" 27 _ "google.golang.org/protobuf/types/known/typepb" 28 _ "google.golang.org/protobuf/types/pluginpb" 29 30 "github.com/Big-big-orange/protoreflect/desc" 31 "github.com/Big-big-orange/protoreflect/internal" 32 testprotosgrpc "github.com/Big-big-orange/protoreflect/internal/testprotos/grpc" 33 "github.com/Big-big-orange/protoreflect/internal/testutil" 34 ) 35 36 var client *Client 37 38 func TestMain(m *testing.M) { 39 code := 1 40 defer func() { 41 p := recover() 42 if p != nil { 43 _, _ = fmt.Fprintf(os.Stderr, "PANIC: %v\n", p) 44 } 45 os.Exit(code) 46 }() 47 48 svr := grpc.NewServer() 49 testprotosgrpc.RegisterDummyServiceServer(svr, testService{}) 50 reflection.Register(svr) 51 l, err := net.Listen("tcp", "127.0.0.1:0") 52 if err != nil { 53 panic(fmt.Sprintf("Failed to open server socket: %s", err.Error())) 54 } 55 go func() { 56 _ = svr.Serve(l) 57 }() 58 defer svr.Stop() 59 60 // create grpc client 61 addr := l.Addr().String() 62 cconn, err := grpc.Dial(addr, grpc.WithTransportCredentials(insecure.NewCredentials())) 63 if err != nil { 64 panic(fmt.Sprintf("Failed to create grpc client: %s", err.Error())) 65 } 66 defer func() { 67 _ = cconn.Close() 68 }() 69 70 stub := reflectv1alpha.NewServerReflectionClient(cconn) 71 client = NewClientV1Alpha(context.Background(), stub) 72 73 code = m.Run() 74 } 75 76 func TestFileByFileName(t *testing.T) { 77 fd, err := client.FileByFilename("desc_test1.proto") 78 testutil.Ok(t, err) 79 // shallow check that the descriptor appears correct and complete 80 testutil.Eq(t, "desc_test1.proto", fd.GetName()) 81 testutil.Eq(t, "testprotos", fd.GetPackage()) 82 md := fd.GetMessageTypes()[0] 83 testutil.Eq(t, "TestMessage", md.GetName()) 84 md = md.GetNestedMessageTypes()[0] 85 testutil.Eq(t, "NestedMessage", md.GetName()) 86 md = md.GetNestedMessageTypes()[0] 87 testutil.Eq(t, "AnotherNestedMessage", md.GetName()) 88 md = md.GetNestedMessageTypes()[0] 89 testutil.Eq(t, "YetAnotherNestedMessage", md.GetName()) 90 ed := md.GetNestedEnumTypes()[0] 91 testutil.Eq(t, "DeeplyNestedEnum", ed.GetName()) 92 93 _, err = client.FileByFilename("does not exist") 94 testutil.Require(t, IsElementNotFoundError(err)) 95 } 96 97 func TestFileByFileNameForWellKnownProtos(t *testing.T) { 98 wellKnownProtos := map[string][]string{ 99 "google/protobuf/any.proto": {"google.protobuf.Any"}, 100 "google/protobuf/api.proto": {"google.protobuf.Api", "google.protobuf.Method", "google.protobuf.Mixin"}, 101 "google/protobuf/descriptor.proto": {"google.protobuf.FileDescriptorSet", "google.protobuf.DescriptorProto"}, 102 "google/protobuf/duration.proto": {"google.protobuf.Duration"}, 103 "google/protobuf/empty.proto": {"google.protobuf.Empty"}, 104 "google/protobuf/field_mask.proto": {"google.protobuf.FieldMask"}, 105 "google/protobuf/source_context.proto": {"google.protobuf.SourceContext"}, 106 "google/protobuf/struct.proto": {"google.protobuf.Struct", "google.protobuf.Value", "google.protobuf.NullValue"}, 107 "google/protobuf/timestamp.proto": {"google.protobuf.Timestamp"}, 108 "google/protobuf/type.proto": {"google.protobuf.Type", "google.protobuf.Field", "google.protobuf.Syntax"}, 109 "google/protobuf/wrappers.proto": {"google.protobuf.DoubleValue", "google.protobuf.Int32Value", "google.protobuf.StringValue"}, 110 "google/protobuf/compiler/plugin.proto": {"google.protobuf.compiler.CodeGeneratorRequest"}, 111 } 112 113 for file, types := range wellKnownProtos { 114 fd, err := client.FileByFilename(file) 115 testutil.Ok(t, err) 116 testutil.Eq(t, file, fd.GetName()) 117 for _, typ := range types { 118 d := fd.FindSymbol(typ) 119 testutil.Require(t, d != nil) 120 } 121 122 // also try loading via alternate name 123 file = internal.StdFileAliases[file] 124 if file == "" { 125 // not a file that has a known alternate, so nothing else to check... 126 continue 127 } 128 fd, err = client.FileByFilename(file) 129 testutil.Ok(t, err) 130 testutil.Eq(t, file, fd.GetName()) 131 for _, typ := range types { 132 d := fd.FindSymbol(typ) 133 testutil.Require(t, d != nil) 134 } 135 } 136 } 137 138 func TestFileContainingSymbol(t *testing.T) { 139 fd, err := client.FileContainingSymbol("TopLevel") 140 testutil.Ok(t, err) 141 // shallow check that the descriptor appears correct and complete 142 testutil.Eq(t, "nopkg/desc_test_nopkg_new.proto", fd.GetName()) 143 testutil.Eq(t, "", fd.GetPackage()) 144 md := fd.GetMessageTypes()[0] 145 testutil.Eq(t, "TopLevel", md.GetName()) 146 testutil.Eq(t, "i", md.GetFields()[0].GetName()) 147 testutil.Eq(t, "j", md.GetFields()[1].GetName()) 148 testutil.Eq(t, "k", md.GetFields()[2].GetName()) 149 testutil.Eq(t, "l", md.GetFields()[3].GetName()) 150 testutil.Eq(t, "m", md.GetFields()[4].GetName()) 151 testutil.Eq(t, "n", md.GetFields()[5].GetName()) 152 testutil.Eq(t, "o", md.GetFields()[6].GetName()) 153 testutil.Eq(t, "p", md.GetFields()[7].GetName()) 154 testutil.Eq(t, "q", md.GetFields()[8].GetName()) 155 testutil.Eq(t, "r", md.GetFields()[9].GetName()) 156 testutil.Eq(t, "s", md.GetFields()[10].GetName()) 157 testutil.Eq(t, "t", md.GetFields()[11].GetName()) 158 159 _, err = client.FileContainingSymbol("does not exist") 160 testutil.Require(t, IsElementNotFoundError(err)) 161 } 162 163 func TestFileContainingExtension(t *testing.T) { 164 fd, err := client.FileContainingExtension("TopLevel", 100) 165 testutil.Ok(t, err) 166 // shallow check that the descriptor appears correct and complete 167 testutil.Eq(t, "desc_test2.proto", fd.GetName()) 168 testutil.Eq(t, "testprotos", fd.GetPackage()) 169 testutil.Eq(t, 4, len(fd.GetMessageTypes())) 170 testutil.Eq(t, "Frobnitz", fd.GetMessageTypes()[0].GetName()) 171 testutil.Eq(t, "Whatchamacallit", fd.GetMessageTypes()[1].GetName()) 172 testutil.Eq(t, "Whatzit", fd.GetMessageTypes()[2].GetName()) 173 testutil.Eq(t, "GroupX", fd.GetMessageTypes()[3].GetName()) 174 175 testutil.Eq(t, "desc_test1.proto", fd.GetDependencies()[0].GetName()) 176 testutil.Eq(t, "pkg/desc_test_pkg.proto", fd.GetDependencies()[1].GetName()) 177 testutil.Eq(t, "nopkg/desc_test_nopkg.proto", fd.GetDependencies()[2].GetName()) 178 179 _, err = client.FileContainingExtension("does not exist", 100) 180 testutil.Require(t, IsElementNotFoundError(err)) 181 _, err = client.FileContainingExtension("TopLevel", -9) 182 testutil.Require(t, IsElementNotFoundError(err)) 183 } 184 185 func TestAllExtensionNumbersForType(t *testing.T) { 186 nums, err := client.AllExtensionNumbersForType("TopLevel") 187 testutil.Ok(t, err) 188 inums := make([]int, len(nums)) 189 for idx, v := range nums { 190 inums[idx] = int(v) 191 } 192 sort.Ints(inums) 193 testutil.Eq(t, []int{100, 104}, inums) 194 195 nums, err = client.AllExtensionNumbersForType("testprotos.AnotherTestMessage") 196 testutil.Ok(t, err) 197 testutil.Eq(t, 5, len(nums)) 198 inums = make([]int, len(nums)) 199 for idx, v := range nums { 200 inums[idx] = int(v) 201 } 202 sort.Ints(inums) 203 testutil.Eq(t, []int{100, 101, 102, 103, 200}, inums) 204 205 _, err = client.AllExtensionNumbersForType("does not exist") 206 testutil.Require(t, IsElementNotFoundError(err)) 207 } 208 209 func TestListServices(t *testing.T) { 210 s, err := client.ListServices() 211 testutil.Ok(t, err) 212 213 sort.Strings(s) 214 testutil.Eq(t, []string{ 215 "grpc.reflection.v1.ServerReflection", 216 "grpc.reflection.v1alpha.ServerReflection", 217 "testprotos.DummyService", 218 }, s) 219 } 220 221 func TestReset(t *testing.T) { 222 _, err := client.ListServices() 223 testutil.Ok(t, err) 224 225 // save the current stream 226 stream := client.stream 227 // intercept cancellation 228 cancel := client.cancel 229 var cancelled int32 230 client.cancel = func() { 231 atomic.StoreInt32(&cancelled, 1) 232 cancel() 233 } 234 235 client.Reset() 236 testutil.Eq(t, int32(1), atomic.LoadInt32(&cancelled)) 237 testutil.Eq(t, nil, client.stream) 238 239 _, err = client.ListServices() 240 testutil.Ok(t, err) 241 242 // stream was re-created 243 testutil.Eq(t, true, client.stream != nil && client.stream != stream) 244 } 245 246 func TestRecover(t *testing.T) { 247 _, err := client.ListServices() 248 testutil.Ok(t, err) 249 250 // kill the stream 251 stream := client.stream 252 err = client.stream.CloseSend() 253 testutil.Ok(t, err) 254 255 // it should auto-recover and re-create stream 256 _, err = client.ListServices() 257 testutil.Ok(t, err) 258 testutil.Eq(t, true, client.stream != nil && client.stream != stream) 259 } 260 261 func TestMultipleFiles(t *testing.T) { 262 svr := grpc.NewServer() 263 reflectv1alpha.RegisterServerReflectionServer(svr, testReflectionServer{}) 264 265 l, err := net.Listen("tcp", "127.0.0.1:0") 266 testutil.Ok(t, err, "failed to listen") 267 ctx, cancel := context.WithCancel(context.Background()) 268 defer cancel() 269 go func() { 270 defer cancel() 271 if err := svr.Serve(l); err != nil { 272 t.Logf("serve returned error: %v", err) 273 } 274 }() 275 time.Sleep(100 * time.Millisecond) // give server a chance to start 276 testutil.Ok(t, ctx.Err(), "failed to start server") 277 defer func() { 278 svr.Stop() 279 }() 280 281 dialCtx, dialCancel := context.WithTimeout(ctx, 3*time.Second) 282 defer dialCancel() 283 cc, err := grpc.DialContext(dialCtx, l.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock()) 284 testutil.Ok(t, err, "failed ot dial %v", l.Addr().String()) 285 cl := reflectv1alpha.NewServerReflectionClient(cc) 286 287 client := NewClientV1Alpha(ctx, cl) 288 defer client.Reset() 289 svcs, err := client.ListServices() 290 testutil.Ok(t, err, "failed to list services") 291 for _, svc := range svcs { 292 fd, err := client.FileContainingSymbol(svc) 293 testutil.Ok(t, err, "failed to file for service %v", svc) 294 sd := fd.FindSymbol(svc) 295 _, ok := sd.(*desc.ServiceDescriptor) 296 testutil.Require(t, ok, "symbol for %s is not a service descriptor, instead is %T", svc, sd) 297 } 298 } 299 300 type testReflectionServer struct{} 301 302 func (t testReflectionServer) ServerReflectionInfo(server reflectv1alpha.ServerReflection_ServerReflectionInfoServer) error { 303 const svcA_file = "ChdzYW5kYm94L3NlcnZpY2VfQS5wcm90bxIHc2FuZGJveCIWCghSZXF1ZXN0QRIKCgJpZBgBIAEoBSIYCglSZXNwb25zZUESCwoDc3RyGAEgASgJMj0KCVNlcnZpY2VfQRIwCgdFeGVjdXRlEhEuc2FuZGJveC5SZXF1ZXN0QRoSLnNhbmRib3guUmVzcG9uc2VBYgZwcm90bzM=" 304 const svcB_file = "ChdzYW5kYm94L1NlcnZpY2VfQi5wcm90bxIHc2FuZGJveCIWCghSZXF1ZXN0QhIKCgJpZBgBIAEoBSIYCglSZXNwb25zZUISCwoDc3RyGAEgASgJMj0KCVNlcnZpY2VfQhIwCgdFeGVjdXRlEhEuc2FuZGJveC5SZXF1ZXN0QhoSLnNhbmRib3guUmVzcG9uc2VCYgZwcm90bzM=" 305 306 for { 307 req, err := server.Recv() 308 if err == io.EOF { 309 return nil 310 } else if err != nil { 311 return err 312 } 313 var resp reflectv1alpha.ServerReflectionResponse 314 resp.OriginalRequest = req 315 switch req := req.MessageRequest.(type) { 316 case *reflectv1alpha.ServerReflectionRequest_FileByFilename: 317 switch req.FileByFilename { 318 case "sandbox/service_A.proto": 319 resp.MessageResponse = msgResponseForFiles(svcA_file) 320 case "sandbox/service_B.proto": 321 resp.MessageResponse = msgResponseForFiles(svcB_file) 322 default: 323 resp.MessageResponse = &reflectv1alpha.ServerReflectionResponse_ErrorResponse{ 324 ErrorResponse: &reflectv1alpha.ErrorResponse{ 325 ErrorCode: int32(codes.NotFound), 326 ErrorMessage: "not found", 327 }, 328 } 329 } 330 case *reflectv1alpha.ServerReflectionRequest_FileContainingSymbol: 331 switch req.FileContainingSymbol { 332 case "sandbox.Service_A": 333 resp.MessageResponse = msgResponseForFiles(svcA_file) 334 case "sandbox.Service_B": 335 // HERE is where we return two files instead of one 336 resp.MessageResponse = msgResponseForFiles(svcA_file, svcB_file) 337 default: 338 resp.MessageResponse = &reflectv1alpha.ServerReflectionResponse_ErrorResponse{ 339 ErrorResponse: &reflectv1alpha.ErrorResponse{ 340 ErrorCode: int32(codes.NotFound), 341 ErrorMessage: "not found", 342 }, 343 } 344 } 345 case *reflectv1alpha.ServerReflectionRequest_ListServices: 346 resp.MessageResponse = &reflectv1alpha.ServerReflectionResponse_ListServicesResponse{ 347 ListServicesResponse: &reflectv1alpha.ListServiceResponse{ 348 Service: []*reflectv1alpha.ServiceResponse{ 349 {Name: "sandbox.Service_A"}, 350 {Name: "sandbox.Service_B"}, 351 }, 352 }, 353 } 354 default: 355 resp.MessageResponse = &reflectv1alpha.ServerReflectionResponse_ErrorResponse{ 356 ErrorResponse: &reflectv1alpha.ErrorResponse{ 357 ErrorCode: int32(codes.NotFound), 358 ErrorMessage: "not found", 359 }, 360 } 361 } 362 if err := server.Send(&resp); err != nil { 363 return err 364 } 365 } 366 } 367 368 func msgResponseForFiles(files ...string) *reflectv1alpha.ServerReflectionResponse_FileDescriptorResponse { 369 descs := make([][]byte, len(files)) 370 for i, f := range files { 371 b, err := base64.StdEncoding.DecodeString(f) 372 if err != nil { 373 panic(err) 374 } 375 descs[i] = b 376 } 377 return &reflectv1alpha.ServerReflectionResponse_FileDescriptorResponse{ 378 FileDescriptorResponse: &reflectv1alpha.FileDescriptorResponse{ 379 FileDescriptorProto: descs, 380 }, 381 } 382 } 383 384 func TestAutoVersion(t *testing.T) { 385 t.Run("v1", func(t *testing.T) { 386 testClientAuto(t, 387 func(s *grpc.Server) { 388 reflection.RegisterV1(s) 389 testprotosgrpc.RegisterDummyServiceServer(s, testService{}) 390 }, 391 []string{ 392 "grpc.reflection.v1.ServerReflection", 393 "testprotos.DummyService", 394 }, 395 []string{ 396 "/grpc.reflection.v1.ServerReflection/ServerReflectionInfo", 397 "/grpc.reflection.v1.ServerReflection/ServerReflectionInfo", 398 "/grpc.reflection.v1.ServerReflection/ServerReflectionInfo", 399 "/grpc.reflection.v1.ServerReflection/ServerReflectionInfo", 400 }) 401 }) 402 403 t.Run("v1alpha", func(t *testing.T) { 404 testClientAuto(t, 405 func(s *grpc.Server) { 406 impl := reflection.NewServer(reflection.ServerOptions{Services: s}) 407 reflectv1alpha.RegisterServerReflectionServer(s, impl) 408 testprotosgrpc.RegisterDummyServiceServer(s, testService{}) 409 }, 410 []string{ 411 "grpc.reflection.v1alpha.ServerReflection", 412 "testprotos.DummyService", 413 }, 414 []string{ 415 // first one fails, so falls back to v1alpha 416 "/grpc.reflection.v1.ServerReflection/ServerReflectionInfo", 417 "/grpc.reflection.v1alpha.ServerReflection/ServerReflectionInfo", 418 // next two use v1alpha 419 "/grpc.reflection.v1alpha.ServerReflection/ServerReflectionInfo", 420 "/grpc.reflection.v1alpha.ServerReflection/ServerReflectionInfo", 421 // final one retries v1 422 "/grpc.reflection.v1.ServerReflection/ServerReflectionInfo", 423 "/grpc.reflection.v1alpha.ServerReflection/ServerReflectionInfo", 424 }) 425 }) 426 427 t.Run("both", func(t *testing.T) { 428 testClientAuto(t, 429 func(s *grpc.Server) { 430 reflection.Register(s) 431 testprotosgrpc.RegisterDummyServiceServer(s, testService{}) 432 }, 433 []string{ 434 "grpc.reflection.v1.ServerReflection", 435 "grpc.reflection.v1alpha.ServerReflection", 436 "testprotos.DummyService", 437 }, 438 []string{ 439 // never uses v1alpha since v1 works 440 "/grpc.reflection.v1.ServerReflection/ServerReflectionInfo", 441 "/grpc.reflection.v1.ServerReflection/ServerReflectionInfo", 442 "/grpc.reflection.v1.ServerReflection/ServerReflectionInfo", 443 "/grpc.reflection.v1.ServerReflection/ServerReflectionInfo", 444 }) 445 }) 446 447 t.Run("fallback-on-unavailable", testClientAutoOnUnavailable) 448 } 449 450 func testClientAuto(t *testing.T, register func(*grpc.Server), expectedServices []string, expectedLog []string) { 451 var capture captureStreamNames 452 svr := grpc.NewServer(grpc.StreamInterceptor(capture.intercept), grpc.UnknownServiceHandler(capture.handleUnknown)) 453 register(svr) 454 l, err := net.Listen("tcp", "127.0.0.1:0") 455 if err != nil { 456 panic(fmt.Sprintf("Failed to open server socket: %s", err.Error())) 457 } 458 go func() { 459 err := svr.Serve(l) 460 testutil.Ok(t, err) 461 }() 462 defer svr.Stop() 463 464 cconn, err := grpc.Dial(l.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) 465 if err != nil { 466 panic(fmt.Sprintf("Failed to create grpc client: %s", err.Error())) 467 } 468 defer func() { 469 err := cconn.Close() 470 testutil.Ok(t, err) 471 }() 472 client := NewClientAuto(context.Background(), cconn) 473 now := time.Now() 474 client.now = func() time.Time { 475 return now 476 } 477 478 svcs, err := client.ListServices() 479 testutil.Ok(t, err) 480 sort.Strings(svcs) 481 testutil.Eq(t, expectedServices, svcs) 482 client.Reset() 483 484 _, err = client.FileContainingSymbol(svcs[0]) 485 testutil.Ok(t, err) 486 client.Reset() 487 488 // at the threshold, but not quite enough to retry 489 now = now.Add(time.Hour) 490 _, err = client.ListServices() 491 testutil.Ok(t, err) 492 client.Reset() 493 494 // 1 ns more, and we've crossed threshold and will retry 495 now = now.Add(1) 496 _, err = client.ListServices() 497 testutil.Ok(t, err) 498 client.Reset() 499 500 actualLog := capture.names() 501 testutil.Eq(t, expectedLog, actualLog) 502 } 503 504 type captureStreamNames struct { 505 mu sync.Mutex 506 log []string 507 } 508 509 func (c *captureStreamNames) names() []string { 510 c.mu.Lock() 511 defer c.mu.Unlock() 512 ret := make([]string, len(c.log)) 513 copy(ret, c.log) 514 return ret 515 } 516 517 func (c *captureStreamNames) intercept(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { 518 c.mu.Lock() 519 c.log = append(c.log, info.FullMethod) 520 c.mu.Unlock() 521 return handler(srv, ss) 522 } 523 524 func (c *captureStreamNames) handleUnknown(_ interface{}, _ grpc.ServerStream) error { 525 return status.Errorf(codes.Unimplemented, "WTF?") 526 } 527 528 func testClientAutoOnUnavailable(t *testing.T) { 529 l, err := net.Listen("tcp", "127.0.0.1:0") 530 if err != nil { 531 panic(fmt.Sprintf("Failed to open server socket: %s", err.Error())) 532 } 533 captureConn := &captureListener{Listener: l} 534 535 var capture captureStreamNames 536 svr := grpc.NewServer( 537 grpc.StreamInterceptor(capture.intercept), 538 grpc.UnknownServiceHandler(func(_ interface{}, _ grpc.ServerStream) error { 539 // On unknown method, forcibly close the net.Conn, without sending 540 // back any reply, which should result in an "unavailable" error. 541 return captureConn.latest().Close() 542 }), 543 ) 544 impl := reflection.NewServer(reflection.ServerOptions{Services: svr}) 545 reflectv1alpha.RegisterServerReflectionServer(svr, impl) 546 testprotosgrpc.RegisterDummyServiceServer(svr, testService{}) 547 548 go func() { 549 err := svr.Serve(captureConn) 550 testutil.Ok(t, err) 551 }() 552 defer svr.Stop() 553 554 var captureErrs captureErrors 555 cconn, err := grpc.Dial( 556 l.Addr().String(), 557 grpc.WithTransportCredentials(insecure.NewCredentials()), 558 grpc.WithStreamInterceptor(captureErrs.intercept), 559 ) 560 if err != nil { 561 panic(fmt.Sprintf("Failed to create grpc client: %s", err.Error())) 562 } 563 defer func() { 564 err := cconn.Close() 565 testutil.Ok(t, err) 566 }() 567 client := NewClientAuto(context.Background(), cconn) 568 now := time.Now() 569 client.now = func() time.Time { 570 return now 571 } 572 573 svcs, err := client.ListServices() 574 testutil.Ok(t, err) 575 sort.Strings(svcs) 576 testutil.Eq(t, []string{ 577 "grpc.reflection.v1alpha.ServerReflection", 578 "testprotos.DummyService", 579 }, svcs) 580 581 // It should have tried v1 first and failed then tried v1alpha. 582 actualLog := capture.names() 583 testutil.Eq(t, []string{ 584 "/grpc.reflection.v1.ServerReflection/ServerReflectionInfo", 585 "/grpc.reflection.v1alpha.ServerReflection/ServerReflectionInfo", 586 }, actualLog) 587 588 // Make sure the error code observed by the client was unavailable and not unimplemented. 589 actualCodes := captureErrs.codes() 590 testutil.Eq(t, []codes.Code{codes.Unavailable}, actualCodes) 591 } 592 593 type captureListener struct { 594 net.Listener 595 mu sync.Mutex 596 conn net.Conn 597 } 598 599 func (c *captureListener) Accept() (net.Conn, error) { 600 conn, err := c.Listener.Accept() 601 if err == nil { 602 c.mu.Lock() 603 c.conn = conn 604 c.mu.Unlock() 605 } 606 return conn, err 607 } 608 609 func (c *captureListener) latest() net.Conn { 610 c.mu.Lock() 611 defer c.mu.Unlock() 612 return c.conn 613 } 614 615 type captureErrors struct { 616 mu sync.Mutex 617 observed []codes.Code 618 } 619 620 func (c *captureErrors) intercept(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { 621 stream, err := streamer(ctx, desc, cc, method, opts...) 622 if err != nil { 623 c.observe(err) 624 return nil, err 625 } 626 return &captureErrorStream{ClientStream: stream, c: c}, nil 627 } 628 629 func (c *captureErrors) observe(err error) { 630 c.mu.Lock() 631 c.observed = append(c.observed, status.Code(err)) 632 c.mu.Unlock() 633 } 634 635 func (c *captureErrors) codes() []codes.Code { 636 c.mu.Lock() 637 defer c.mu.Unlock() 638 ret := make([]codes.Code, len(c.observed)) 639 copy(ret, c.observed) 640 return ret 641 } 642 643 type captureErrorStream struct { 644 grpc.ClientStream 645 c *captureErrors 646 done int32 647 } 648 649 func (c *captureErrorStream) RecvMsg(m interface{}) error { 650 err := c.ClientStream.RecvMsg(m) 651 if err == nil || errors.Is(err, io.EOF) { 652 return nil 653 } 654 // Only record one error per RPC. 655 if atomic.CompareAndSwapInt32(&c.done, 0, 1) { 656 c.c.observe(err) 657 } 658 return err 659 }