google.golang.org/grpc@v1.72.2/test/balancer_test.go (about) 1 /* 2 * 3 * Copyright 2018 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 package test 20 21 import ( 22 "context" 23 "errors" 24 "fmt" 25 "net" 26 "reflect" 27 "testing" 28 "time" 29 30 "github.com/google/go-cmp/cmp" 31 32 "google.golang.org/grpc" 33 "google.golang.org/grpc/attributes" 34 "google.golang.org/grpc/balancer" 35 "google.golang.org/grpc/balancer/pickfirst" 36 "google.golang.org/grpc/codes" 37 "google.golang.org/grpc/connectivity" 38 "google.golang.org/grpc/credentials" 39 "google.golang.org/grpc/credentials/insecure" 40 "google.golang.org/grpc/internal" 41 "google.golang.org/grpc/internal/balancer/stub" 42 "google.golang.org/grpc/internal/balancerload" 43 "google.golang.org/grpc/internal/grpcsync" 44 "google.golang.org/grpc/internal/grpcutil" 45 imetadata "google.golang.org/grpc/internal/metadata" 46 "google.golang.org/grpc/internal/stubserver" 47 "google.golang.org/grpc/internal/testutils" 48 "google.golang.org/grpc/metadata" 49 "google.golang.org/grpc/resolver" 50 "google.golang.org/grpc/resolver/manual" 51 "google.golang.org/grpc/status" 52 "google.golang.org/grpc/testdata" 53 54 testgrpc "google.golang.org/grpc/interop/grpc_testing" 55 testpb "google.golang.org/grpc/interop/grpc_testing" 56 ) 57 58 const testBalancerName = "testbalancer" 59 60 // testBalancer creates one subconn with the first address from resolved 61 // addresses. 62 // 63 // It's used to test whether options for NewSubConn are applied correctly. 64 type testBalancer struct { 65 cc balancer.ClientConn 66 sc balancer.SubConn 67 68 newSubConnOptions balancer.NewSubConnOptions 69 pickInfos []balancer.PickInfo 70 pickExtraMDs []metadata.MD 71 doneInfo []balancer.DoneInfo 72 } 73 74 func (b *testBalancer) Build(cc balancer.ClientConn, _ balancer.BuildOptions) balancer.Balancer { 75 b.cc = cc 76 return b 77 } 78 79 func (*testBalancer) Name() string { 80 return testBalancerName 81 } 82 83 func (*testBalancer) ResolverError(error) { 84 panic("not implemented") 85 } 86 87 func (b *testBalancer) UpdateClientConnState(state balancer.ClientConnState) error { 88 // Only create a subconn at the first time. 89 if b.sc == nil { 90 var err error 91 b.newSubConnOptions.StateListener = b.updateSubConnState 92 b.sc, err = b.cc.NewSubConn(state.ResolverState.Addresses, b.newSubConnOptions) 93 if err != nil { 94 logger.Errorf("testBalancer: failed to NewSubConn: %v", err) 95 return nil 96 } 97 b.cc.UpdateState(balancer.State{ConnectivityState: connectivity.Connecting, Picker: &picker{err: balancer.ErrNoSubConnAvailable, bal: b}}) 98 b.sc.Connect() 99 } 100 return nil 101 } 102 103 func (b *testBalancer) UpdateSubConnState(sc balancer.SubConn, s balancer.SubConnState) { 104 panic(fmt.Sprintf("UpdateSubConnState(%v, %+v) called unexpectedly", sc, s)) 105 } 106 107 func (b *testBalancer) updateSubConnState(s balancer.SubConnState) { 108 logger.Infof("testBalancer: updateSubConnState: %v", s) 109 110 switch s.ConnectivityState { 111 case connectivity.Ready: 112 b.cc.UpdateState(balancer.State{ConnectivityState: s.ConnectivityState, Picker: &picker{bal: b}}) 113 case connectivity.Idle: 114 b.cc.UpdateState(balancer.State{ConnectivityState: s.ConnectivityState, Picker: &picker{bal: b, idle: true}}) 115 case connectivity.Connecting: 116 b.cc.UpdateState(balancer.State{ConnectivityState: s.ConnectivityState, Picker: &picker{err: balancer.ErrNoSubConnAvailable, bal: b}}) 117 case connectivity.TransientFailure: 118 b.cc.UpdateState(balancer.State{ConnectivityState: s.ConnectivityState, Picker: &picker{err: balancer.ErrTransientFailure, bal: b}}) 119 } 120 } 121 122 func (b *testBalancer) Close() {} 123 124 func (b *testBalancer) ExitIdle() {} 125 126 type picker struct { 127 err error 128 bal *testBalancer 129 idle bool 130 } 131 132 func (p *picker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { 133 if p.err != nil { 134 return balancer.PickResult{}, p.err 135 } 136 if p.idle { 137 p.bal.sc.Connect() 138 return balancer.PickResult{}, balancer.ErrNoSubConnAvailable 139 } 140 extraMD, _ := grpcutil.ExtraMetadata(info.Ctx) 141 info.Ctx = nil // Do not validate context. 142 p.bal.pickInfos = append(p.bal.pickInfos, info) 143 p.bal.pickExtraMDs = append(p.bal.pickExtraMDs, extraMD) 144 return balancer.PickResult{SubConn: p.bal.sc, Done: func(d balancer.DoneInfo) { p.bal.doneInfo = append(p.bal.doneInfo, d) }}, nil 145 } 146 147 func (s) TestCredsBundleFromBalancer(t *testing.T) { 148 balancer.Register(&testBalancer{ 149 newSubConnOptions: balancer.NewSubConnOptions{ 150 CredsBundle: &testCredsBundle{}, 151 }, 152 }) 153 te := newTest(t, env{name: "creds-bundle", network: "tcp", balancer: ""}) 154 te.tapHandle = authHandle 155 te.customDialOptions = []grpc.DialOption{ 156 grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, testBalancerName)), 157 } 158 creds, err := credentials.NewServerTLSFromFile(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem")) 159 if err != nil { 160 t.Fatalf("Failed to generate credentials %v", err) 161 } 162 te.customServerOptions = []grpc.ServerOption{ 163 grpc.Creds(creds), 164 } 165 te.startServer(&testServer{}) 166 defer te.tearDown() 167 168 cc := te.clientConn() 169 tc := testgrpc.NewTestServiceClient(cc) 170 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 171 defer cancel() 172 if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil { 173 t.Fatalf("Test failed. Reason: %v", err) 174 } 175 } 176 177 func (s) TestPickExtraMetadata(t *testing.T) { 178 for _, e := range listTestEnv() { 179 testPickExtraMetadata(t, e) 180 } 181 } 182 183 func testPickExtraMetadata(t *testing.T, e env) { 184 te := newTest(t, e) 185 b := &testBalancer{} 186 balancer.Register(b) 187 const ( 188 testUserAgent = "test-user-agent" 189 testSubContentType = "proto" 190 ) 191 192 te.customDialOptions = []grpc.DialOption{ 193 grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, testBalancerName)), 194 grpc.WithUserAgent(testUserAgent), 195 } 196 te.startServer(&testServer{security: e.security}) 197 defer te.tearDown() 198 199 // Trigger the extra-metadata-adding code path. 200 defer func(old string) { internal.GRPCResolverSchemeExtraMetadata = old }(internal.GRPCResolverSchemeExtraMetadata) 201 internal.GRPCResolverSchemeExtraMetadata = "passthrough" 202 203 cc := te.clientConn() 204 tc := testgrpc.NewTestServiceClient(cc) 205 206 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 207 defer cancel() 208 if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err != nil { 209 t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %v", err, nil) 210 } 211 if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.CallContentSubtype(testSubContentType)); err != nil { 212 t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %v", err, nil) 213 } 214 215 want := []metadata.MD{ 216 // First RPC doesn't have sub-content-type. 217 {"content-type": []string{"application/grpc"}}, 218 // Second RPC has sub-content-type "proto". 219 {"content-type": []string{"application/grpc+proto"}}, 220 } 221 if diff := cmp.Diff(want, b.pickExtraMDs); diff != "" { 222 t.Fatalf("unexpected diff in metadata (-want, +got): %s", diff) 223 } 224 } 225 226 func (s) TestDoneInfo(t *testing.T) { 227 for _, e := range listTestEnv() { 228 testDoneInfo(t, e) 229 } 230 } 231 232 func testDoneInfo(t *testing.T, e env) { 233 te := newTest(t, e) 234 b := &testBalancer{} 235 balancer.Register(b) 236 te.customDialOptions = []grpc.DialOption{ 237 grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, testBalancerName)), 238 } 239 te.userAgent = failAppUA 240 te.startServer(&testServer{security: e.security}) 241 defer te.tearDown() 242 243 cc := te.clientConn() 244 tc := testgrpc.NewTestServiceClient(cc) 245 246 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 247 defer cancel() 248 wantErr := detailedError 249 if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); !testutils.StatusErrEqual(err, wantErr) { 250 t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %v", status.Convert(err).Proto(), status.Convert(wantErr).Proto()) 251 } 252 if _, err := tc.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil { 253 t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, <nil>", ctx, err) 254 } 255 256 if len(b.doneInfo) < 1 || !testutils.StatusErrEqual(b.doneInfo[0].Err, wantErr) { 257 t.Fatalf("b.doneInfo = %v; want b.doneInfo[0].Err = %v", b.doneInfo, wantErr) 258 } 259 if len(b.doneInfo) < 2 || !reflect.DeepEqual(b.doneInfo[1].Trailer, testTrailerMetadata) { 260 t.Fatalf("b.doneInfo = %v; want b.doneInfo[1].Trailer = %v", b.doneInfo, testTrailerMetadata) 261 } 262 if len(b.pickInfos) != len(b.doneInfo) { 263 t.Fatalf("Got %d picks, but %d doneInfo, want equal amount", len(b.pickInfos), len(b.doneInfo)) 264 } 265 // To test done() is always called, even if it's returned with a non-Ready 266 // SubConn. 267 // 268 // Stop server and at the same time send RPCs. There are chances that picker 269 // is not updated in time, causing a non-Ready SubConn to be returned. 270 finished := make(chan struct{}) 271 go func() { 272 for i := 0; i < 20; i++ { 273 tc.UnaryCall(ctx, &testpb.SimpleRequest{}) 274 } 275 close(finished) 276 }() 277 te.srv.Stop() 278 <-finished 279 if len(b.pickInfos) != len(b.doneInfo) { 280 t.Fatalf("Got %d picks, %d doneInfo, want equal amount", len(b.pickInfos), len(b.doneInfo)) 281 } 282 } 283 284 const loadMDKey = "X-Endpoint-Load-Metrics-Bin" 285 286 type testLoadParser struct{} 287 288 func (*testLoadParser) Parse(md metadata.MD) any { 289 vs := md.Get(loadMDKey) 290 if len(vs) == 0 { 291 return nil 292 } 293 return vs[0] 294 } 295 296 func init() { 297 balancerload.SetParser(&testLoadParser{}) 298 } 299 300 func (s) TestDoneLoads(t *testing.T) { 301 testDoneLoads(t) 302 } 303 304 func testDoneLoads(t *testing.T) { 305 b := &testBalancer{} 306 balancer.Register(b) 307 308 const testLoad = "test-load-,-should-be-orca" 309 310 ss := &stubserver.StubServer{ 311 EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) { 312 grpc.SetTrailer(ctx, metadata.Pairs(loadMDKey, testLoad)) 313 return &testpb.Empty{}, nil 314 }, 315 } 316 if err := ss.Start(nil, grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, testBalancerName))); err != nil { 317 t.Fatalf("error starting testing server: %v", err) 318 } 319 defer ss.Stop() 320 321 tc := testgrpc.NewTestServiceClient(ss.CC) 322 323 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 324 defer cancel() 325 if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil { 326 t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %v", err, nil) 327 } 328 329 piWant := []balancer.PickInfo{ 330 {FullMethodName: "/grpc.testing.TestService/EmptyCall"}, 331 } 332 if !reflect.DeepEqual(b.pickInfos, piWant) { 333 t.Fatalf("b.pickInfos = %v; want %v", b.pickInfos, piWant) 334 } 335 336 if len(b.doneInfo) < 1 { 337 t.Fatalf("b.doneInfo = %v, want length 1", b.doneInfo) 338 } 339 gotLoad, _ := b.doneInfo[0].ServerLoad.(string) 340 if gotLoad != testLoad { 341 t.Fatalf("b.doneInfo[0].ServerLoad = %v; want = %v", b.doneInfo[0].ServerLoad, testLoad) 342 } 343 } 344 345 type aiPicker struct { 346 result balancer.PickResult 347 err error 348 } 349 350 func (aip *aiPicker) Pick(_ balancer.PickInfo) (balancer.PickResult, error) { 351 return aip.result, aip.err 352 } 353 354 // attrTransportCreds is a transport credential implementation which stores 355 // Attributes from the ClientHandshakeInfo struct passed in the context locally 356 // for the test to inspect. 357 type attrTransportCreds struct { 358 credentials.TransportCredentials 359 attr *attributes.Attributes 360 } 361 362 func (ac *attrTransportCreds) ClientHandshake(ctx context.Context, _ string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { 363 ai := credentials.ClientHandshakeInfoFromContext(ctx) 364 ac.attr = ai.Attributes 365 return rawConn, nil, nil 366 } 367 func (ac *attrTransportCreds) Info() credentials.ProtocolInfo { 368 return credentials.ProtocolInfo{} 369 } 370 func (ac *attrTransportCreds) Clone() credentials.TransportCredentials { 371 return nil 372 } 373 374 // TestAddressAttributesInNewSubConn verifies that the Attributes passed from a 375 // balancer in the resolver.Address that is passes to NewSubConn reaches all the 376 // way to the ClientHandshake method of the credentials configured on the parent 377 // channel. 378 func (s) TestAddressAttributesInNewSubConn(t *testing.T) { 379 const ( 380 testAttrKey = "foo" 381 testAttrVal = "bar" 382 attrBalancerName = "attribute-balancer" 383 ) 384 385 // Register a stub balancer which adds attributes to the first address that 386 // it receives and then calls NewSubConn on it. 387 bf := stub.BalancerFuncs{ 388 UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error { 389 addrs := ccs.ResolverState.Addresses 390 if len(addrs) == 0 { 391 return nil 392 } 393 394 // Only use the first address. 395 attr := attributes.New(testAttrKey, testAttrVal) 396 addrs[0].Attributes = attr 397 var sc balancer.SubConn 398 sc, err := bd.ClientConn.NewSubConn([]resolver.Address{addrs[0]}, balancer.NewSubConnOptions{ 399 StateListener: func(state balancer.SubConnState) { 400 bd.ClientConn.UpdateState(balancer.State{ConnectivityState: state.ConnectivityState, Picker: &aiPicker{result: balancer.PickResult{SubConn: sc}, err: state.ConnectionError}}) 401 }, 402 }) 403 if err != nil { 404 return err 405 } 406 sc.Connect() 407 return nil 408 }, 409 } 410 stub.Register(attrBalancerName, bf) 411 t.Logf("Registered balancer %s...", attrBalancerName) 412 413 r := manual.NewBuilderWithScheme("whatever") 414 t.Logf("Registered manual resolver with scheme %s...", r.Scheme()) 415 416 lis, err := net.Listen("tcp", "localhost:0") 417 if err != nil { 418 t.Fatal(err) 419 } 420 stub := &stubserver.StubServer{ 421 Listener: lis, 422 EmptyCallF: func(_ context.Context, _ *testpb.Empty) (*testpb.Empty, error) { 423 return &testpb.Empty{}, nil 424 }, 425 S: grpc.NewServer(), 426 } 427 stubserver.StartTestService(t, stub) 428 defer stub.S.Stop() 429 t.Logf("Started gRPC server at %s...", lis.Addr().String()) 430 431 creds := &attrTransportCreds{} 432 dopts := []grpc.DialOption{ 433 grpc.WithTransportCredentials(creds), 434 grpc.WithResolvers(r), 435 grpc.WithDefaultServiceConfig(fmt.Sprintf(`{ "loadBalancingConfig": [{"%v": {}}] }`, attrBalancerName)), 436 } 437 cc, err := grpc.NewClient(r.Scheme()+":///test.server", dopts...) 438 if err != nil { 439 t.Fatal(err) 440 } 441 defer cc.Close() 442 tc := testgrpc.NewTestServiceClient(cc) 443 t.Log("Created a ClientConn...") 444 445 // The first RPC should fail because there's no address. 446 ctx, cancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) 447 defer cancel() 448 if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err == nil || status.Code(err) != codes.DeadlineExceeded { 449 t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err) 450 } 451 t.Log("Made an RPC which was expected to fail...") 452 453 state := resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String()}}} 454 r.UpdateState(state) 455 t.Logf("Pushing resolver state update: %v through the manual resolver", state) 456 457 // The second RPC should succeed. 458 ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout) 459 defer cancel() 460 if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil { 461 t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err) 462 } 463 t.Log("Made an RPC which succeeded...") 464 465 wantAttr := attributes.New(testAttrKey, testAttrVal) 466 if gotAttr := creds.attr; !cmp.Equal(gotAttr, wantAttr, cmp.AllowUnexported(attributes.Attributes{})) { 467 t.Fatalf("received attributes %v in creds, want %v", gotAttr, wantAttr) 468 } 469 } 470 471 // TestMetadataInAddressAttributes verifies that the metadata added to 472 // address.Attributes will be sent with the RPCs. 473 func (s) TestMetadataInAddressAttributes(t *testing.T) { 474 const ( 475 testMDKey = "test-md" 476 testMDValue = "test-md-value" 477 mdBalancerName = "metadata-balancer" 478 ) 479 480 // Register a stub balancer which adds metadata to the first address that it 481 // receives and then calls NewSubConn on it. 482 bf := stub.BalancerFuncs{ 483 UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error { 484 addrs := ccs.ResolverState.Addresses 485 if len(addrs) == 0 { 486 return nil 487 } 488 // Only use the first address. 489 var sc balancer.SubConn 490 sc, err := bd.ClientConn.NewSubConn([]resolver.Address{ 491 imetadata.Set(addrs[0], metadata.Pairs(testMDKey, testMDValue)), 492 }, balancer.NewSubConnOptions{ 493 StateListener: func(state balancer.SubConnState) { 494 bd.ClientConn.UpdateState(balancer.State{ConnectivityState: state.ConnectivityState, Picker: &aiPicker{result: balancer.PickResult{SubConn: sc}, err: state.ConnectionError}}) 495 }, 496 }) 497 if err != nil { 498 return err 499 } 500 sc.Connect() 501 return nil 502 }, 503 } 504 stub.Register(mdBalancerName, bf) 505 t.Logf("Registered balancer %s...", mdBalancerName) 506 507 testMDChan := make(chan []string, 1) 508 ss := &stubserver.StubServer{ 509 EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) { 510 md, ok := metadata.FromIncomingContext(ctx) 511 if ok { 512 select { 513 case testMDChan <- md[testMDKey]: 514 case <-ctx.Done(): 515 return nil, ctx.Err() 516 } 517 } 518 return &testpb.Empty{}, nil 519 }, 520 } 521 if err := ss.Start(nil, grpc.WithDefaultServiceConfig( 522 fmt.Sprintf(`{ "loadBalancingConfig": [{"%v": {}}] }`, mdBalancerName), 523 )); err != nil { 524 t.Fatalf("Error starting endpoint server: %v", err) 525 } 526 defer ss.Stop() 527 528 // The RPC should succeed with the expected md. 529 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 530 defer cancel() 531 if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil { 532 t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err) 533 } 534 t.Log("Made an RPC which succeeded...") 535 536 // The server should receive the test metadata. 537 md1 := <-testMDChan 538 if len(md1) == 0 || md1[0] != testMDValue { 539 t.Fatalf("got md: %v, want %v", md1, []string{testMDValue}) 540 } 541 } 542 543 // TestServersSwap creates two servers and verifies the client switches between 544 // them when the name resolver reports the first and then the second. 545 func (s) TestServersSwap(t *testing.T) { 546 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 547 defer cancel() 548 549 // Initialize servers 550 reg := func(username string) (addr string, cleanup func()) { 551 lis, err := net.Listen("tcp", "localhost:0") 552 if err != nil { 553 t.Fatalf("Error while listening. Err: %v", err) 554 } 555 556 stub := &stubserver.StubServer{ 557 Listener: lis, 558 UnaryCallF: func(_ context.Context, _ *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { 559 return &testpb.SimpleResponse{Username: username}, nil 560 }, 561 S: grpc.NewServer(), 562 } 563 stubserver.StartTestService(t, stub) 564 return lis.Addr().String(), stub.S.Stop 565 } 566 const one = "1" 567 addr1, cleanup := reg(one) 568 defer cleanup() 569 const two = "2" 570 addr2, cleanup := reg(two) 571 defer cleanup() 572 573 // Initialize client 574 r := manual.NewBuilderWithScheme("whatever") 575 r.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: addr1}}}) 576 cc, err := grpc.NewClient(r.Scheme()+":///", grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithResolvers(r)) 577 if err != nil { 578 t.Fatalf("Error creating client: %v", err) 579 } 580 defer cc.Close() 581 client := testgrpc.NewTestServiceClient(cc) 582 583 // Confirm we are connected to the first server 584 if res, err := client.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil || res.Username != one { 585 t.Fatalf("UnaryCall(_) = %v, %v; want {Username: %q}, nil", res, err, one) 586 } 587 588 // Update resolver to report only the second server 589 r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: addr2}}}) 590 591 // Loop until new RPCs talk to server two. 592 for i := 0; i < 2000; i++ { 593 if res, err := client.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil { 594 t.Fatalf("UnaryCall(_) = _, %v; want _, nil", err) 595 } else if res.Username == two { 596 break // pass 597 } 598 time.Sleep(5 * time.Millisecond) 599 } 600 } 601 602 func (s) TestWaitForReady(t *testing.T) { 603 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 604 defer cancel() 605 606 // Initialize server 607 lis, err := net.Listen("tcp", "localhost:0") 608 if err != nil { 609 t.Fatalf("Error while listening. Err: %v", err) 610 } 611 const one = "1" 612 stub := &stubserver.StubServer{ 613 Listener: lis, 614 UnaryCallF: func(_ context.Context, _ *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { 615 return &testpb.SimpleResponse{Username: one}, nil 616 }, 617 S: grpc.NewServer(), 618 } 619 stubserver.StartTestService(t, stub) 620 defer stub.S.Stop() 621 622 // Initialize client 623 r := manual.NewBuilderWithScheme("whatever") 624 625 cc, err := grpc.NewClient(r.Scheme()+":///", grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithResolvers(r)) 626 if err != nil { 627 t.Fatalf("Error creating client: %v", err) 628 } 629 defer cc.Close() 630 cc.Connect() 631 client := testgrpc.NewTestServiceClient(cc) 632 633 // Report an error so non-WFR RPCs will give up early. 634 r.CC().ReportError(errors.New("fake resolver error")) 635 636 // Ensure the client is not connected to anything and fails non-WFR RPCs. 637 if res, err := client.UnaryCall(ctx, &testpb.SimpleRequest{}); status.Code(err) != codes.Unavailable { 638 t.Fatalf("UnaryCall(_) = %v, %v; want _, Code()=%v", res, err, codes.Unavailable) 639 } 640 641 errChan := make(chan error, 1) 642 go func() { 643 if res, err := client.UnaryCall(ctx, &testpb.SimpleRequest{}, grpc.WaitForReady(true)); err != nil || res.Username != one { 644 errChan <- fmt.Errorf("UnaryCall(_) = %v, %v; want {Username: %q}, nil", res, err, one) 645 } 646 close(errChan) 647 }() 648 649 select { 650 case err := <-errChan: 651 t.Errorf("unexpected receive from errChan before addresses provided") 652 t.Fatal(err.Error()) 653 case <-time.After(5 * time.Millisecond): 654 } 655 656 // Resolve the server. The WFR RPC should unblock and use it. 657 r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String()}}}) 658 659 if err := <-errChan; err != nil { 660 t.Fatal(err.Error()) 661 } 662 } 663 664 // authorityOverrideTransportCreds returns the configured authority value in its 665 // Info() method. 666 type authorityOverrideTransportCreds struct { 667 credentials.TransportCredentials 668 authorityOverride string 669 } 670 671 func (ao *authorityOverrideTransportCreds) ClientHandshake(_ context.Context, _ string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { 672 return rawConn, nil, nil 673 } 674 func (ao *authorityOverrideTransportCreds) Info() credentials.ProtocolInfo { 675 return credentials.ProtocolInfo{ServerName: ao.authorityOverride} 676 } 677 func (ao *authorityOverrideTransportCreds) Clone() credentials.TransportCredentials { 678 return &authorityOverrideTransportCreds{authorityOverride: ao.authorityOverride} 679 } 680 681 // TestAuthorityInBuildOptions tests that the Authority field in 682 // balancer.BuildOptions is setup correctly from gRPC. 683 func (s) TestAuthorityInBuildOptions(t *testing.T) { 684 const dialTarget = "test.server" 685 686 tests := []struct { 687 name string 688 dopts []grpc.DialOption 689 wantAuthority string 690 }{ 691 { 692 name: "authority from dial target", 693 dopts: []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())}, 694 wantAuthority: dialTarget, 695 }, 696 { 697 name: "authority from dial option", 698 dopts: []grpc.DialOption{ 699 grpc.WithTransportCredentials(insecure.NewCredentials()), 700 grpc.WithAuthority("authority-override"), 701 }, 702 wantAuthority: "authority-override", 703 }, 704 { 705 name: "authority from transport creds", 706 dopts: []grpc.DialOption{grpc.WithTransportCredentials(&authorityOverrideTransportCreds{authorityOverride: "authority-override-from-transport-creds"})}, 707 wantAuthority: "authority-override-from-transport-creds", 708 }, 709 } 710 711 for _, test := range tests { 712 t.Run(test.name, func(t *testing.T) { 713 authorityCh := make(chan string, 1) 714 bf := stub.BalancerFuncs{ 715 UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error { 716 select { 717 case authorityCh <- bd.BuildOptions.Authority: 718 default: 719 } 720 721 addrs := ccs.ResolverState.Addresses 722 if len(addrs) == 0 { 723 return nil 724 } 725 726 // Only use the first address. 727 var sc balancer.SubConn 728 sc, err := bd.ClientConn.NewSubConn([]resolver.Address{addrs[0]}, balancer.NewSubConnOptions{ 729 StateListener: func(state balancer.SubConnState) { 730 bd.ClientConn.UpdateState(balancer.State{ConnectivityState: state.ConnectivityState, Picker: &aiPicker{result: balancer.PickResult{SubConn: sc}, err: state.ConnectionError}}) 731 }, 732 }) 733 if err != nil { 734 return err 735 } 736 sc.Connect() 737 return nil 738 }, 739 } 740 balancerName := "stub-balancer-" + test.name 741 stub.Register(balancerName, bf) 742 t.Logf("Registered balancer %s...", balancerName) 743 744 lis, err := testutils.LocalTCPListener() 745 if err != nil { 746 t.Fatal(err) 747 } 748 749 stub := &stubserver.StubServer{ 750 Listener: lis, 751 EmptyCallF: func(_ context.Context, _ *testpb.Empty) (*testpb.Empty, error) { 752 return &testpb.Empty{}, nil 753 }, 754 S: grpc.NewServer(), 755 } 756 stubserver.StartTestService(t, stub) 757 defer stub.S.Stop() 758 t.Logf("Started gRPC server at %s...", lis.Addr().String()) 759 760 r := manual.NewBuilderWithScheme("whatever") 761 t.Logf("Registered manual resolver with scheme %s...", r.Scheme()) 762 r.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String()}}}) 763 764 dopts := append([]grpc.DialOption{ 765 grpc.WithResolvers(r), 766 grpc.WithDefaultServiceConfig(fmt.Sprintf(`{ "loadBalancingConfig": [{"%v": {}}] }`, balancerName)), 767 }, test.dopts...) 768 cc, err := grpc.NewClient(r.Scheme()+":///"+dialTarget, dopts...) 769 if err != nil { 770 t.Fatal(err) 771 } 772 defer cc.Close() 773 tc := testgrpc.NewTestServiceClient(cc) 774 t.Log("Created a ClientConn...") 775 776 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 777 defer cancel() 778 if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil { 779 t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err) 780 } 781 t.Log("Made an RPC which succeeded...") 782 783 select { 784 case <-ctx.Done(): 785 t.Fatal("timeout when waiting for Authority in balancer.BuildOptions") 786 case gotAuthority := <-authorityCh: 787 if gotAuthority != test.wantAuthority { 788 t.Fatalf("Authority in balancer.BuildOptions is %s, want %s", gotAuthority, test.wantAuthority) 789 } 790 } 791 }) 792 } 793 } 794 795 // testCCWrapper wraps a balancer.ClientConn and intercepts UpdateState and 796 // returns a custom picker which injects arbitrary metadata on a per-call basis. 797 type testCCWrapper struct { 798 balancer.ClientConn 799 } 800 801 func (t *testCCWrapper) UpdateState(state balancer.State) { 802 state.Picker = &wrappedPicker{p: state.Picker} 803 t.ClientConn.UpdateState(state) 804 } 805 806 const ( 807 metadataHeaderInjectedByBalancer = "metadata-header-injected-by-balancer" 808 metadataHeaderInjectedByApplication = "metadata-header-injected-by-application" 809 metadataValueInjectedByBalancer = "metadata-value-injected-by-balancer" 810 metadataValueInjectedByApplication = "metadata-value-injected-by-application" 811 ) 812 813 // wrappedPicker wraps the picker returned by the pick_first 814 type wrappedPicker struct { 815 p balancer.Picker 816 } 817 818 func (wp *wrappedPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { 819 res, err := wp.p.Pick(info) 820 if err != nil { 821 return balancer.PickResult{}, err 822 } 823 824 if res.Metadata == nil { 825 res.Metadata = metadata.Pairs(metadataHeaderInjectedByBalancer, metadataValueInjectedByBalancer) 826 } else { 827 res.Metadata.Append(metadataHeaderInjectedByBalancer, metadataValueInjectedByBalancer) 828 } 829 return res, nil 830 } 831 832 // TestMetadataInPickResult tests the scenario where an LB policy inject 833 // arbitrary metadata on a per-call basis and verifies that the injected 834 // metadata makes it all the way to the server RPC handler. 835 func (s) TestMetadataInPickResult(t *testing.T) { 836 t.Log("Starting test backend...") 837 mdChan := make(chan metadata.MD, 1) 838 ss := &stubserver.StubServer{ 839 EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) { 840 md, _ := metadata.FromIncomingContext(ctx) 841 select { 842 case mdChan <- md: 843 case <-ctx.Done(): 844 return nil, ctx.Err() 845 } 846 return &testpb.Empty{}, nil 847 }, 848 } 849 if err := ss.StartServer(); err != nil { 850 t.Fatalf("Starting test backend: %v", err) 851 } 852 defer ss.Stop() 853 t.Logf("Started test backend at %q", ss.Address) 854 855 // Register a test balancer that contains a pick_first balancer and forwards 856 // all calls from the ClientConn to it. For state updates from the 857 // pick_first balancer, it creates a custom picker which injects arbitrary 858 // metadata on a per-call basis. 859 stub.Register(t.Name(), stub.BalancerFuncs{ 860 Init: func(bd *stub.BalancerData) { 861 cc := &testCCWrapper{ClientConn: bd.ClientConn} 862 bd.Data = balancer.Get(pickfirst.Name).Build(cc, bd.BuildOptions) 863 }, 864 Close: func(bd *stub.BalancerData) { 865 bd.Data.(balancer.Balancer).Close() 866 }, 867 UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error { 868 bal := bd.Data.(balancer.Balancer) 869 return bal.UpdateClientConnState(ccs) 870 }, 871 }) 872 873 t.Log("Creating ClientConn to test backend...") 874 r := manual.NewBuilderWithScheme("whatever") 875 r.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: ss.Address}}}) 876 dopts := []grpc.DialOption{ 877 grpc.WithTransportCredentials(insecure.NewCredentials()), 878 grpc.WithResolvers(r), 879 grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, t.Name())), 880 } 881 cc, err := grpc.NewClient(r.Scheme()+":///test.server", dopts...) 882 if err != nil { 883 t.Fatalf("grpc.NewClient(): %v", err) 884 } 885 defer cc.Close() 886 tc := testgrpc.NewTestServiceClient(cc) 887 888 t.Log("Making EmptyCall() RPC with custom metadata...") 889 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 890 defer cancel() 891 md := metadata.Pairs(metadataHeaderInjectedByApplication, metadataValueInjectedByApplication) 892 ctx = metadata.NewOutgoingContext(ctx, md) 893 if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil { 894 t.Fatalf("EmptyCall() RPC: %v", err) 895 } 896 t.Log("EmptyCall() RPC succeeded") 897 898 t.Log("Waiting for custom metadata to be received at the test backend...") 899 var gotMD metadata.MD 900 select { 901 case gotMD = <-mdChan: 902 case <-ctx.Done(): 903 t.Fatalf("Timed out waiting for custom metadata to be received at the test backend") 904 } 905 906 t.Log("Verifying custom metadata added by the client application is received at the test backend...") 907 wantMDVal := []string{metadataValueInjectedByApplication} 908 gotMDVal := gotMD.Get(metadataHeaderInjectedByApplication) 909 if !cmp.Equal(gotMDVal, wantMDVal) { 910 t.Fatalf("Mismatch in custom metadata received at test backend, got: %v, want %v", gotMDVal, wantMDVal) 911 } 912 913 t.Log("Verifying custom metadata added by the LB policy is received at the test backend...") 914 wantMDVal = []string{metadataValueInjectedByBalancer} 915 gotMDVal = gotMD.Get(metadataHeaderInjectedByBalancer) 916 if !cmp.Equal(gotMDVal, wantMDVal) { 917 t.Fatalf("Mismatch in custom metadata received at test backend, got: %v, want %v", gotMDVal, wantMDVal) 918 } 919 } 920 921 // TestSubConnShutdown confirms that the Shutdown method on subconns and 922 // RemoveSubConn method on ClientConn properly initiates subconn shutdown. 923 func (s) TestSubConnShutdown(t *testing.T) { 924 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 925 defer cancel() 926 927 testCases := []struct { 928 name string 929 shutdown func(cc balancer.ClientConn, sc balancer.SubConn) 930 }{{ 931 name: "ClientConn.RemoveSubConn", 932 shutdown: func(cc balancer.ClientConn, sc balancer.SubConn) { 933 cc.RemoveSubConn(sc) 934 }, 935 }, { 936 name: "SubConn.Shutdown", 937 shutdown: func(_ balancer.ClientConn, sc balancer.SubConn) { 938 sc.Shutdown() 939 }, 940 }} 941 942 for _, tc := range testCases { 943 t.Run(tc.name, func(t *testing.T) { 944 gotShutdown := grpcsync.NewEvent() 945 946 bf := stub.BalancerFuncs{ 947 UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error { 948 var sc balancer.SubConn 949 opts := balancer.NewSubConnOptions{ 950 StateListener: func(scs balancer.SubConnState) { 951 switch scs.ConnectivityState { 952 case connectivity.Connecting: 953 // Ignored. 954 case connectivity.Ready: 955 tc.shutdown(bd.ClientConn, sc) 956 case connectivity.Shutdown: 957 gotShutdown.Fire() 958 default: 959 t.Errorf("got unexpected state %q in listener", scs.ConnectivityState) 960 } 961 }, 962 } 963 sc, err := bd.ClientConn.NewSubConn(ccs.ResolverState.Addresses, opts) 964 if err != nil { 965 return err 966 } 967 sc.Connect() 968 // Report the state as READY to unblock ss.Start(), which waits for ready. 969 bd.ClientConn.UpdateState(balancer.State{ConnectivityState: connectivity.Ready}) 970 return nil 971 }, 972 } 973 974 testBalName := "shutdown-test-balancer-" + tc.name 975 stub.Register(testBalName, bf) 976 t.Logf("Registered balancer %s...", testBalName) 977 978 ss := &stubserver.StubServer{} 979 if err := ss.Start(nil, grpc.WithDefaultServiceConfig( 980 fmt.Sprintf(`{ "loadBalancingConfig": [{"%v": {}}] }`, testBalName), 981 )); err != nil { 982 t.Fatalf("Error starting endpoint server: %v", err) 983 } 984 defer ss.Stop() 985 986 select { 987 case <-gotShutdown.Done(): 988 // Success 989 case <-ctx.Done(): 990 t.Fatalf("Timed out waiting for gotShutdown to be fired.") 991 } 992 }) 993 } 994 } 995 996 type subConnStoringCCWrapper struct { 997 balancer.ClientConn 998 stateListener func(balancer.SubConnState) 999 scChan chan balancer.SubConn 1000 } 1001 1002 func (ccw *subConnStoringCCWrapper) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) { 1003 if ccw.stateListener != nil { 1004 origListener := opts.StateListener 1005 opts.StateListener = func(scs balancer.SubConnState) { 1006 ccw.stateListener(scs) 1007 origListener(scs) 1008 } 1009 } 1010 sc, err := ccw.ClientConn.NewSubConn(addrs, opts) 1011 ccw.scChan <- sc 1012 return sc, err 1013 } 1014 1015 // Test calls RegisterHealthListener on a SubConn to verify that expected health 1016 // updates are sent only to the most recently registered listener. 1017 func (s) TestSubConn_RegisterHealthListener(t *testing.T) { 1018 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1019 defer cancel() 1020 scChan := make(chan balancer.SubConn, 1) 1021 bf := stub.BalancerFuncs{ 1022 Init: func(bd *stub.BalancerData) { 1023 cc := bd.ClientConn 1024 ccw := &subConnStoringCCWrapper{ 1025 ClientConn: cc, 1026 scChan: scChan, 1027 } 1028 bd.Data = balancer.Get(pickfirst.Name).Build(ccw, bd.BuildOptions) 1029 }, 1030 Close: func(bd *stub.BalancerData) { 1031 bd.Data.(balancer.Balancer).Close() 1032 }, 1033 UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error { 1034 return bd.Data.(balancer.Balancer).UpdateClientConnState(ccs) 1035 }, 1036 ExitIdle: func(bd *stub.BalancerData) { 1037 bd.Data.(balancer.ExitIdler).ExitIdle() 1038 }, 1039 } 1040 1041 stub.Register(t.Name(), bf) 1042 svcCfg := fmt.Sprintf(`{ "loadBalancingConfig": [{%q: {}}] }`, t.Name()) 1043 backend := stubserver.StartTestService(t, nil) 1044 defer backend.Stop() 1045 opts := []grpc.DialOption{ 1046 grpc.WithTransportCredentials(insecure.NewCredentials()), 1047 grpc.WithDefaultServiceConfig(svcCfg), 1048 } 1049 cc, err := grpc.NewClient(backend.Address, opts...) 1050 if err != nil { 1051 t.Fatalf("grpc.NewClient(%q) failed: %v", backend.Address, err) 1052 1053 } 1054 defer cc.Close() 1055 1056 cc.Connect() 1057 1058 var sc balancer.SubConn 1059 select { 1060 case sc = <-scChan: 1061 case <-ctx.Done(): 1062 t.Fatal("Context timed out waiting for SubConn creation") 1063 } 1064 healthUpdateChan := make(chan balancer.SubConnState, 1) 1065 1066 // Register listener while Ready and verify it gets a health update. 1067 testutils.AwaitState(ctx, t, cc, connectivity.Ready) 1068 for i := 0; i < 2; i++ { 1069 sc.RegisterHealthListener(func(scs balancer.SubConnState) { 1070 healthUpdateChan <- scs 1071 }) 1072 select { 1073 case scs := <-healthUpdateChan: 1074 if scs.ConnectivityState != connectivity.Ready { 1075 t.Fatalf("Received health update = %v, want = %v", scs.ConnectivityState, connectivity.Ready) 1076 } 1077 case <-ctx.Done(): 1078 t.Fatalf("Context timed out waiting for health update") 1079 } 1080 1081 // No further updates are expected. 1082 select { 1083 case scs := <-healthUpdateChan: 1084 t.Fatalf("Received unexpected health update while channel is in state %v: %v", cc.GetState(), scs) 1085 case <-time.After(defaultTestShortTimeout): 1086 } 1087 } 1088 1089 // Make the SubConn enter IDLE and verify that health updates are recevied 1090 // on registering a new listener. 1091 backend.S.Stop() 1092 backend.S = nil 1093 testutils.AwaitState(ctx, t, cc, connectivity.Idle) 1094 if err := backend.StartServer(); err != nil { 1095 t.Fatalf("Error while restarting the backend server: %v", err) 1096 } 1097 cc.Connect() 1098 testutils.AwaitState(ctx, t, cc, connectivity.Ready) 1099 sc.RegisterHealthListener(func(scs balancer.SubConnState) { 1100 healthUpdateChan <- scs 1101 }) 1102 select { 1103 case scs := <-healthUpdateChan: 1104 if scs.ConnectivityState != connectivity.Ready { 1105 t.Fatalf("Received health update = %v, want = %v", scs.ConnectivityState, connectivity.Ready) 1106 } 1107 case <-ctx.Done(): 1108 t.Fatalf("Context timed out waiting for health update") 1109 } 1110 } 1111 1112 // Test calls RegisterHealthListener on a SubConn twice while handling the 1113 // connectivity update. The test verifies that only the latest listener 1114 // receives the health update. 1115 func (s) TestSubConn_RegisterHealthListener_RegisterTwice(t *testing.T) { 1116 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1117 defer cancel() 1118 scChan := make(chan balancer.SubConn, 1) 1119 readyUpdateResumeCh := make(chan struct{}) 1120 readyUpdateReceivedCh := make(chan struct{}) 1121 bf := stub.BalancerFuncs{ 1122 Init: func(bd *stub.BalancerData) { 1123 cc := bd.ClientConn 1124 ccw := &subConnStoringCCWrapper{ 1125 ClientConn: cc, 1126 scChan: scChan, 1127 stateListener: func(scs balancer.SubConnState) { 1128 if scs.ConnectivityState != connectivity.Ready { 1129 return 1130 } 1131 close(readyUpdateReceivedCh) 1132 select { 1133 case <-readyUpdateResumeCh: 1134 case <-ctx.Done(): 1135 t.Error("Context timed out waiting for update on ready channel") 1136 } 1137 }, 1138 } 1139 bd.Data = balancer.Get(pickfirst.Name).Build(ccw, bd.BuildOptions) 1140 }, 1141 Close: func(bd *stub.BalancerData) { 1142 bd.Data.(balancer.Balancer).Close() 1143 }, 1144 UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error { 1145 return bd.Data.(balancer.Balancer).UpdateClientConnState(ccs) 1146 }, 1147 } 1148 1149 stub.Register(t.Name(), bf) 1150 svcCfg := fmt.Sprintf(`{ "loadBalancingConfig": [{%q: {}}] }`, t.Name()) 1151 backend := stubserver.StartTestService(t, nil) 1152 defer backend.Stop() 1153 opts := []grpc.DialOption{ 1154 grpc.WithTransportCredentials(insecure.NewCredentials()), 1155 grpc.WithDefaultServiceConfig(svcCfg), 1156 } 1157 cc, err := grpc.NewClient(backend.Address, opts...) 1158 if err != nil { 1159 t.Fatalf("grpc.NewClient(%q) failed: %v", backend.Address, err) 1160 1161 } 1162 defer cc.Close() 1163 1164 cc.Connect() 1165 1166 var sc balancer.SubConn 1167 select { 1168 case sc = <-scChan: 1169 case <-ctx.Done(): 1170 t.Fatal("Context timed out waiting for SubConn creation") 1171 } 1172 1173 // Wait for the SubConn to enter READY. 1174 select { 1175 case <-readyUpdateReceivedCh: 1176 case <-ctx.Done(): 1177 t.Fatalf("Context timed out waiting for SubConn to enter READY") 1178 } 1179 1180 healthChan1 := make(chan balancer.SubConnState, 1) 1181 healthChan2 := make(chan balancer.SubConnState, 1) 1182 1183 sc.RegisterHealthListener(func(scs balancer.SubConnState) { 1184 healthChan1 <- scs 1185 }) 1186 sc.RegisterHealthListener(func(scs balancer.SubConnState) { 1187 healthChan2 <- scs 1188 }) 1189 close(readyUpdateResumeCh) 1190 1191 select { 1192 case scs := <-healthChan2: 1193 if scs.ConnectivityState != connectivity.Ready { 1194 t.Fatalf("Received health update = %v, want = %v", scs.ConnectivityState, connectivity.Ready) 1195 } 1196 case <-ctx.Done(): 1197 t.Fatalf("Context timed out waiting for health update") 1198 } 1199 1200 // No updates should be received on the first listener. 1201 select { 1202 case scs := <-healthChan1: 1203 t.Fatalf("Received unexpected health update on first listener: %v", scs) 1204 case <-time.After(defaultTestShortTimeout): 1205 } 1206 } 1207 1208 // Test calls RegisterHealthListener on a SubConn with a nil listener and 1209 // verifies that the listener registered before the nil listener doesn't receive 1210 // any further updates. 1211 func (s) TestSubConn_RegisterHealthListener_NilListener(t *testing.T) { 1212 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1213 defer cancel() 1214 scChan := make(chan balancer.SubConn, 1) 1215 readyUpdateResumeCh := make(chan struct{}) 1216 readyUpdateReceivedCh := make(chan struct{}) 1217 bf := stub.BalancerFuncs{ 1218 Init: func(bd *stub.BalancerData) { 1219 cc := bd.ClientConn 1220 ccw := &subConnStoringCCWrapper{ 1221 ClientConn: cc, 1222 scChan: scChan, 1223 stateListener: func(scs balancer.SubConnState) { 1224 if scs.ConnectivityState != connectivity.Ready { 1225 return 1226 } 1227 close(readyUpdateReceivedCh) 1228 select { 1229 case <-readyUpdateResumeCh: 1230 case <-ctx.Done(): 1231 t.Error("Context timed out waiting for update on ready channel") 1232 } 1233 }, 1234 } 1235 bd.Data = balancer.Get(pickfirst.Name).Build(ccw, bd.BuildOptions) 1236 }, 1237 Close: func(bd *stub.BalancerData) { 1238 bd.Data.(balancer.Balancer).Close() 1239 }, 1240 UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error { 1241 return bd.Data.(balancer.Balancer).UpdateClientConnState(ccs) 1242 }, 1243 } 1244 1245 stub.Register(t.Name(), bf) 1246 svcCfg := fmt.Sprintf(`{ "loadBalancingConfig": [{%q: {}}] }`, t.Name()) 1247 backend := stubserver.StartTestService(t, nil) 1248 defer backend.Stop() 1249 opts := []grpc.DialOption{ 1250 grpc.WithTransportCredentials(insecure.NewCredentials()), 1251 grpc.WithDefaultServiceConfig(svcCfg), 1252 } 1253 cc, err := grpc.NewClient(backend.Address, opts...) 1254 if err != nil { 1255 t.Fatalf("grpc.NewClient(%q) failed: %v", backend.Address, err) 1256 1257 } 1258 defer cc.Close() 1259 1260 cc.Connect() 1261 1262 var sc balancer.SubConn 1263 select { 1264 case sc = <-scChan: 1265 case <-ctx.Done(): 1266 t.Fatal("Context timed out waiting for SubConn creation") 1267 } 1268 1269 // Wait for the SubConn to enter READY. 1270 select { 1271 case <-readyUpdateReceivedCh: 1272 case <-ctx.Done(): 1273 t.Fatalf("Context timed out waiting for SubConn to enter READY") 1274 } 1275 1276 healthChan := make(chan balancer.SubConnState, 1) 1277 1278 sc.RegisterHealthListener(func(scs balancer.SubConnState) { 1279 healthChan <- scs 1280 }) 1281 1282 // Registering a nil listener should invalidate the previously registered 1283 // listener. 1284 sc.RegisterHealthListener(nil) 1285 close(readyUpdateResumeCh) 1286 1287 // No updates should be received on the listener. 1288 select { 1289 case scs := <-healthChan: 1290 t.Fatalf("Received unexpected health update on the listener: %v", scs) 1291 case <-time.After(defaultTestShortTimeout): 1292 } 1293 }