google.golang.org/grpc@v1.72.2/test/balancer_test.go (about)

     1  /*
     2   *
     3   * Copyright 2018 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package test
    20  
    21  import (
    22  	"context"
    23  	"errors"
    24  	"fmt"
    25  	"net"
    26  	"reflect"
    27  	"testing"
    28  	"time"
    29  
    30  	"github.com/google/go-cmp/cmp"
    31  
    32  	"google.golang.org/grpc"
    33  	"google.golang.org/grpc/attributes"
    34  	"google.golang.org/grpc/balancer"
    35  	"google.golang.org/grpc/balancer/pickfirst"
    36  	"google.golang.org/grpc/codes"
    37  	"google.golang.org/grpc/connectivity"
    38  	"google.golang.org/grpc/credentials"
    39  	"google.golang.org/grpc/credentials/insecure"
    40  	"google.golang.org/grpc/internal"
    41  	"google.golang.org/grpc/internal/balancer/stub"
    42  	"google.golang.org/grpc/internal/balancerload"
    43  	"google.golang.org/grpc/internal/grpcsync"
    44  	"google.golang.org/grpc/internal/grpcutil"
    45  	imetadata "google.golang.org/grpc/internal/metadata"
    46  	"google.golang.org/grpc/internal/stubserver"
    47  	"google.golang.org/grpc/internal/testutils"
    48  	"google.golang.org/grpc/metadata"
    49  	"google.golang.org/grpc/resolver"
    50  	"google.golang.org/grpc/resolver/manual"
    51  	"google.golang.org/grpc/status"
    52  	"google.golang.org/grpc/testdata"
    53  
    54  	testgrpc "google.golang.org/grpc/interop/grpc_testing"
    55  	testpb "google.golang.org/grpc/interop/grpc_testing"
    56  )
    57  
    58  const testBalancerName = "testbalancer"
    59  
    60  // testBalancer creates one subconn with the first address from resolved
    61  // addresses.
    62  //
    63  // It's used to test whether options for NewSubConn are applied correctly.
    64  type testBalancer struct {
    65  	cc balancer.ClientConn
    66  	sc balancer.SubConn
    67  
    68  	newSubConnOptions balancer.NewSubConnOptions
    69  	pickInfos         []balancer.PickInfo
    70  	pickExtraMDs      []metadata.MD
    71  	doneInfo          []balancer.DoneInfo
    72  }
    73  
    74  func (b *testBalancer) Build(cc balancer.ClientConn, _ balancer.BuildOptions) balancer.Balancer {
    75  	b.cc = cc
    76  	return b
    77  }
    78  
    79  func (*testBalancer) Name() string {
    80  	return testBalancerName
    81  }
    82  
    83  func (*testBalancer) ResolverError(error) {
    84  	panic("not implemented")
    85  }
    86  
    87  func (b *testBalancer) UpdateClientConnState(state balancer.ClientConnState) error {
    88  	// Only create a subconn at the first time.
    89  	if b.sc == nil {
    90  		var err error
    91  		b.newSubConnOptions.StateListener = b.updateSubConnState
    92  		b.sc, err = b.cc.NewSubConn(state.ResolverState.Addresses, b.newSubConnOptions)
    93  		if err != nil {
    94  			logger.Errorf("testBalancer: failed to NewSubConn: %v", err)
    95  			return nil
    96  		}
    97  		b.cc.UpdateState(balancer.State{ConnectivityState: connectivity.Connecting, Picker: &picker{err: balancer.ErrNoSubConnAvailable, bal: b}})
    98  		b.sc.Connect()
    99  	}
   100  	return nil
   101  }
   102  
   103  func (b *testBalancer) UpdateSubConnState(sc balancer.SubConn, s balancer.SubConnState) {
   104  	panic(fmt.Sprintf("UpdateSubConnState(%v, %+v) called unexpectedly", sc, s))
   105  }
   106  
   107  func (b *testBalancer) updateSubConnState(s balancer.SubConnState) {
   108  	logger.Infof("testBalancer: updateSubConnState: %v", s)
   109  
   110  	switch s.ConnectivityState {
   111  	case connectivity.Ready:
   112  		b.cc.UpdateState(balancer.State{ConnectivityState: s.ConnectivityState, Picker: &picker{bal: b}})
   113  	case connectivity.Idle:
   114  		b.cc.UpdateState(balancer.State{ConnectivityState: s.ConnectivityState, Picker: &picker{bal: b, idle: true}})
   115  	case connectivity.Connecting:
   116  		b.cc.UpdateState(balancer.State{ConnectivityState: s.ConnectivityState, Picker: &picker{err: balancer.ErrNoSubConnAvailable, bal: b}})
   117  	case connectivity.TransientFailure:
   118  		b.cc.UpdateState(balancer.State{ConnectivityState: s.ConnectivityState, Picker: &picker{err: balancer.ErrTransientFailure, bal: b}})
   119  	}
   120  }
   121  
   122  func (b *testBalancer) Close() {}
   123  
   124  func (b *testBalancer) ExitIdle() {}
   125  
   126  type picker struct {
   127  	err  error
   128  	bal  *testBalancer
   129  	idle bool
   130  }
   131  
   132  func (p *picker) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
   133  	if p.err != nil {
   134  		return balancer.PickResult{}, p.err
   135  	}
   136  	if p.idle {
   137  		p.bal.sc.Connect()
   138  		return balancer.PickResult{}, balancer.ErrNoSubConnAvailable
   139  	}
   140  	extraMD, _ := grpcutil.ExtraMetadata(info.Ctx)
   141  	info.Ctx = nil // Do not validate context.
   142  	p.bal.pickInfos = append(p.bal.pickInfos, info)
   143  	p.bal.pickExtraMDs = append(p.bal.pickExtraMDs, extraMD)
   144  	return balancer.PickResult{SubConn: p.bal.sc, Done: func(d balancer.DoneInfo) { p.bal.doneInfo = append(p.bal.doneInfo, d) }}, nil
   145  }
   146  
   147  func (s) TestCredsBundleFromBalancer(t *testing.T) {
   148  	balancer.Register(&testBalancer{
   149  		newSubConnOptions: balancer.NewSubConnOptions{
   150  			CredsBundle: &testCredsBundle{},
   151  		},
   152  	})
   153  	te := newTest(t, env{name: "creds-bundle", network: "tcp", balancer: ""})
   154  	te.tapHandle = authHandle
   155  	te.customDialOptions = []grpc.DialOption{
   156  		grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, testBalancerName)),
   157  	}
   158  	creds, err := credentials.NewServerTLSFromFile(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
   159  	if err != nil {
   160  		t.Fatalf("Failed to generate credentials %v", err)
   161  	}
   162  	te.customServerOptions = []grpc.ServerOption{
   163  		grpc.Creds(creds),
   164  	}
   165  	te.startServer(&testServer{})
   166  	defer te.tearDown()
   167  
   168  	cc := te.clientConn()
   169  	tc := testgrpc.NewTestServiceClient(cc)
   170  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   171  	defer cancel()
   172  	if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
   173  		t.Fatalf("Test failed. Reason: %v", err)
   174  	}
   175  }
   176  
   177  func (s) TestPickExtraMetadata(t *testing.T) {
   178  	for _, e := range listTestEnv() {
   179  		testPickExtraMetadata(t, e)
   180  	}
   181  }
   182  
   183  func testPickExtraMetadata(t *testing.T, e env) {
   184  	te := newTest(t, e)
   185  	b := &testBalancer{}
   186  	balancer.Register(b)
   187  	const (
   188  		testUserAgent      = "test-user-agent"
   189  		testSubContentType = "proto"
   190  	)
   191  
   192  	te.customDialOptions = []grpc.DialOption{
   193  		grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, testBalancerName)),
   194  		grpc.WithUserAgent(testUserAgent),
   195  	}
   196  	te.startServer(&testServer{security: e.security})
   197  	defer te.tearDown()
   198  
   199  	// Trigger the extra-metadata-adding code path.
   200  	defer func(old string) { internal.GRPCResolverSchemeExtraMetadata = old }(internal.GRPCResolverSchemeExtraMetadata)
   201  	internal.GRPCResolverSchemeExtraMetadata = "passthrough"
   202  
   203  	cc := te.clientConn()
   204  	tc := testgrpc.NewTestServiceClient(cc)
   205  
   206  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   207  	defer cancel()
   208  	if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err != nil {
   209  		t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %v", err, nil)
   210  	}
   211  	if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.CallContentSubtype(testSubContentType)); err != nil {
   212  		t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %v", err, nil)
   213  	}
   214  
   215  	want := []metadata.MD{
   216  		// First RPC doesn't have sub-content-type.
   217  		{"content-type": []string{"application/grpc"}},
   218  		// Second RPC has sub-content-type "proto".
   219  		{"content-type": []string{"application/grpc+proto"}},
   220  	}
   221  	if diff := cmp.Diff(want, b.pickExtraMDs); diff != "" {
   222  		t.Fatalf("unexpected diff in metadata (-want, +got): %s", diff)
   223  	}
   224  }
   225  
   226  func (s) TestDoneInfo(t *testing.T) {
   227  	for _, e := range listTestEnv() {
   228  		testDoneInfo(t, e)
   229  	}
   230  }
   231  
   232  func testDoneInfo(t *testing.T, e env) {
   233  	te := newTest(t, e)
   234  	b := &testBalancer{}
   235  	balancer.Register(b)
   236  	te.customDialOptions = []grpc.DialOption{
   237  		grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, testBalancerName)),
   238  	}
   239  	te.userAgent = failAppUA
   240  	te.startServer(&testServer{security: e.security})
   241  	defer te.tearDown()
   242  
   243  	cc := te.clientConn()
   244  	tc := testgrpc.NewTestServiceClient(cc)
   245  
   246  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   247  	defer cancel()
   248  	wantErr := detailedError
   249  	if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); !testutils.StatusErrEqual(err, wantErr) {
   250  		t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %v", status.Convert(err).Proto(), status.Convert(wantErr).Proto())
   251  	}
   252  	if _, err := tc.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil {
   253  		t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, <nil>", ctx, err)
   254  	}
   255  
   256  	if len(b.doneInfo) < 1 || !testutils.StatusErrEqual(b.doneInfo[0].Err, wantErr) {
   257  		t.Fatalf("b.doneInfo = %v; want b.doneInfo[0].Err = %v", b.doneInfo, wantErr)
   258  	}
   259  	if len(b.doneInfo) < 2 || !reflect.DeepEqual(b.doneInfo[1].Trailer, testTrailerMetadata) {
   260  		t.Fatalf("b.doneInfo = %v; want b.doneInfo[1].Trailer = %v", b.doneInfo, testTrailerMetadata)
   261  	}
   262  	if len(b.pickInfos) != len(b.doneInfo) {
   263  		t.Fatalf("Got %d picks, but %d doneInfo, want equal amount", len(b.pickInfos), len(b.doneInfo))
   264  	}
   265  	// To test done() is always called, even if it's returned with a non-Ready
   266  	// SubConn.
   267  	//
   268  	// Stop server and at the same time send RPCs. There are chances that picker
   269  	// is not updated in time, causing a non-Ready SubConn to be returned.
   270  	finished := make(chan struct{})
   271  	go func() {
   272  		for i := 0; i < 20; i++ {
   273  			tc.UnaryCall(ctx, &testpb.SimpleRequest{})
   274  		}
   275  		close(finished)
   276  	}()
   277  	te.srv.Stop()
   278  	<-finished
   279  	if len(b.pickInfos) != len(b.doneInfo) {
   280  		t.Fatalf("Got %d picks, %d doneInfo, want equal amount", len(b.pickInfos), len(b.doneInfo))
   281  	}
   282  }
   283  
   284  const loadMDKey = "X-Endpoint-Load-Metrics-Bin"
   285  
   286  type testLoadParser struct{}
   287  
   288  func (*testLoadParser) Parse(md metadata.MD) any {
   289  	vs := md.Get(loadMDKey)
   290  	if len(vs) == 0 {
   291  		return nil
   292  	}
   293  	return vs[0]
   294  }
   295  
   296  func init() {
   297  	balancerload.SetParser(&testLoadParser{})
   298  }
   299  
   300  func (s) TestDoneLoads(t *testing.T) {
   301  	testDoneLoads(t)
   302  }
   303  
   304  func testDoneLoads(t *testing.T) {
   305  	b := &testBalancer{}
   306  	balancer.Register(b)
   307  
   308  	const testLoad = "test-load-,-should-be-orca"
   309  
   310  	ss := &stubserver.StubServer{
   311  		EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
   312  			grpc.SetTrailer(ctx, metadata.Pairs(loadMDKey, testLoad))
   313  			return &testpb.Empty{}, nil
   314  		},
   315  	}
   316  	if err := ss.Start(nil, grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, testBalancerName))); err != nil {
   317  		t.Fatalf("error starting testing server: %v", err)
   318  	}
   319  	defer ss.Stop()
   320  
   321  	tc := testgrpc.NewTestServiceClient(ss.CC)
   322  
   323  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   324  	defer cancel()
   325  	if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
   326  		t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %v", err, nil)
   327  	}
   328  
   329  	piWant := []balancer.PickInfo{
   330  		{FullMethodName: "/grpc.testing.TestService/EmptyCall"},
   331  	}
   332  	if !reflect.DeepEqual(b.pickInfos, piWant) {
   333  		t.Fatalf("b.pickInfos = %v; want %v", b.pickInfos, piWant)
   334  	}
   335  
   336  	if len(b.doneInfo) < 1 {
   337  		t.Fatalf("b.doneInfo = %v, want length 1", b.doneInfo)
   338  	}
   339  	gotLoad, _ := b.doneInfo[0].ServerLoad.(string)
   340  	if gotLoad != testLoad {
   341  		t.Fatalf("b.doneInfo[0].ServerLoad = %v; want = %v", b.doneInfo[0].ServerLoad, testLoad)
   342  	}
   343  }
   344  
   345  type aiPicker struct {
   346  	result balancer.PickResult
   347  	err    error
   348  }
   349  
   350  func (aip *aiPicker) Pick(_ balancer.PickInfo) (balancer.PickResult, error) {
   351  	return aip.result, aip.err
   352  }
   353  
   354  // attrTransportCreds is a transport credential implementation which stores
   355  // Attributes from the ClientHandshakeInfo struct passed in the context locally
   356  // for the test to inspect.
   357  type attrTransportCreds struct {
   358  	credentials.TransportCredentials
   359  	attr *attributes.Attributes
   360  }
   361  
   362  func (ac *attrTransportCreds) ClientHandshake(ctx context.Context, _ string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
   363  	ai := credentials.ClientHandshakeInfoFromContext(ctx)
   364  	ac.attr = ai.Attributes
   365  	return rawConn, nil, nil
   366  }
   367  func (ac *attrTransportCreds) Info() credentials.ProtocolInfo {
   368  	return credentials.ProtocolInfo{}
   369  }
   370  func (ac *attrTransportCreds) Clone() credentials.TransportCredentials {
   371  	return nil
   372  }
   373  
   374  // TestAddressAttributesInNewSubConn verifies that the Attributes passed from a
   375  // balancer in the resolver.Address that is passes to NewSubConn reaches all the
   376  // way to the ClientHandshake method of the credentials configured on the parent
   377  // channel.
   378  func (s) TestAddressAttributesInNewSubConn(t *testing.T) {
   379  	const (
   380  		testAttrKey      = "foo"
   381  		testAttrVal      = "bar"
   382  		attrBalancerName = "attribute-balancer"
   383  	)
   384  
   385  	// Register a stub balancer which adds attributes to the first address that
   386  	// it receives and then calls NewSubConn on it.
   387  	bf := stub.BalancerFuncs{
   388  		UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error {
   389  			addrs := ccs.ResolverState.Addresses
   390  			if len(addrs) == 0 {
   391  				return nil
   392  			}
   393  
   394  			// Only use the first address.
   395  			attr := attributes.New(testAttrKey, testAttrVal)
   396  			addrs[0].Attributes = attr
   397  			var sc balancer.SubConn
   398  			sc, err := bd.ClientConn.NewSubConn([]resolver.Address{addrs[0]}, balancer.NewSubConnOptions{
   399  				StateListener: func(state balancer.SubConnState) {
   400  					bd.ClientConn.UpdateState(balancer.State{ConnectivityState: state.ConnectivityState, Picker: &aiPicker{result: balancer.PickResult{SubConn: sc}, err: state.ConnectionError}})
   401  				},
   402  			})
   403  			if err != nil {
   404  				return err
   405  			}
   406  			sc.Connect()
   407  			return nil
   408  		},
   409  	}
   410  	stub.Register(attrBalancerName, bf)
   411  	t.Logf("Registered balancer %s...", attrBalancerName)
   412  
   413  	r := manual.NewBuilderWithScheme("whatever")
   414  	t.Logf("Registered manual resolver with scheme %s...", r.Scheme())
   415  
   416  	lis, err := net.Listen("tcp", "localhost:0")
   417  	if err != nil {
   418  		t.Fatal(err)
   419  	}
   420  	stub := &stubserver.StubServer{
   421  		Listener: lis,
   422  		EmptyCallF: func(_ context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
   423  			return &testpb.Empty{}, nil
   424  		},
   425  		S: grpc.NewServer(),
   426  	}
   427  	stubserver.StartTestService(t, stub)
   428  	defer stub.S.Stop()
   429  	t.Logf("Started gRPC server at %s...", lis.Addr().String())
   430  
   431  	creds := &attrTransportCreds{}
   432  	dopts := []grpc.DialOption{
   433  		grpc.WithTransportCredentials(creds),
   434  		grpc.WithResolvers(r),
   435  		grpc.WithDefaultServiceConfig(fmt.Sprintf(`{ "loadBalancingConfig": [{"%v": {}}] }`, attrBalancerName)),
   436  	}
   437  	cc, err := grpc.NewClient(r.Scheme()+":///test.server", dopts...)
   438  	if err != nil {
   439  		t.Fatal(err)
   440  	}
   441  	defer cc.Close()
   442  	tc := testgrpc.NewTestServiceClient(cc)
   443  	t.Log("Created a ClientConn...")
   444  
   445  	// The first RPC should fail because there's no address.
   446  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestShortTimeout)
   447  	defer cancel()
   448  	if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err == nil || status.Code(err) != codes.DeadlineExceeded {
   449  		t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
   450  	}
   451  	t.Log("Made an RPC which was expected to fail...")
   452  
   453  	state := resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String()}}}
   454  	r.UpdateState(state)
   455  	t.Logf("Pushing resolver state update: %v through the manual resolver", state)
   456  
   457  	// The second RPC should succeed.
   458  	ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout)
   459  	defer cancel()
   460  	if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
   461  		t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
   462  	}
   463  	t.Log("Made an RPC which succeeded...")
   464  
   465  	wantAttr := attributes.New(testAttrKey, testAttrVal)
   466  	if gotAttr := creds.attr; !cmp.Equal(gotAttr, wantAttr, cmp.AllowUnexported(attributes.Attributes{})) {
   467  		t.Fatalf("received attributes %v in creds, want %v", gotAttr, wantAttr)
   468  	}
   469  }
   470  
   471  // TestMetadataInAddressAttributes verifies that the metadata added to
   472  // address.Attributes will be sent with the RPCs.
   473  func (s) TestMetadataInAddressAttributes(t *testing.T) {
   474  	const (
   475  		testMDKey      = "test-md"
   476  		testMDValue    = "test-md-value"
   477  		mdBalancerName = "metadata-balancer"
   478  	)
   479  
   480  	// Register a stub balancer which adds metadata to the first address that it
   481  	// receives and then calls NewSubConn on it.
   482  	bf := stub.BalancerFuncs{
   483  		UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error {
   484  			addrs := ccs.ResolverState.Addresses
   485  			if len(addrs) == 0 {
   486  				return nil
   487  			}
   488  			// Only use the first address.
   489  			var sc balancer.SubConn
   490  			sc, err := bd.ClientConn.NewSubConn([]resolver.Address{
   491  				imetadata.Set(addrs[0], metadata.Pairs(testMDKey, testMDValue)),
   492  			}, balancer.NewSubConnOptions{
   493  				StateListener: func(state balancer.SubConnState) {
   494  					bd.ClientConn.UpdateState(balancer.State{ConnectivityState: state.ConnectivityState, Picker: &aiPicker{result: balancer.PickResult{SubConn: sc}, err: state.ConnectionError}})
   495  				},
   496  			})
   497  			if err != nil {
   498  				return err
   499  			}
   500  			sc.Connect()
   501  			return nil
   502  		},
   503  	}
   504  	stub.Register(mdBalancerName, bf)
   505  	t.Logf("Registered balancer %s...", mdBalancerName)
   506  
   507  	testMDChan := make(chan []string, 1)
   508  	ss := &stubserver.StubServer{
   509  		EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
   510  			md, ok := metadata.FromIncomingContext(ctx)
   511  			if ok {
   512  				select {
   513  				case testMDChan <- md[testMDKey]:
   514  				case <-ctx.Done():
   515  					return nil, ctx.Err()
   516  				}
   517  			}
   518  			return &testpb.Empty{}, nil
   519  		},
   520  	}
   521  	if err := ss.Start(nil, grpc.WithDefaultServiceConfig(
   522  		fmt.Sprintf(`{ "loadBalancingConfig": [{"%v": {}}] }`, mdBalancerName),
   523  	)); err != nil {
   524  		t.Fatalf("Error starting endpoint server: %v", err)
   525  	}
   526  	defer ss.Stop()
   527  
   528  	// The RPC should succeed with the expected md.
   529  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   530  	defer cancel()
   531  	if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
   532  		t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
   533  	}
   534  	t.Log("Made an RPC which succeeded...")
   535  
   536  	// The server should receive the test metadata.
   537  	md1 := <-testMDChan
   538  	if len(md1) == 0 || md1[0] != testMDValue {
   539  		t.Fatalf("got md: %v, want %v", md1, []string{testMDValue})
   540  	}
   541  }
   542  
   543  // TestServersSwap creates two servers and verifies the client switches between
   544  // them when the name resolver reports the first and then the second.
   545  func (s) TestServersSwap(t *testing.T) {
   546  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   547  	defer cancel()
   548  
   549  	// Initialize servers
   550  	reg := func(username string) (addr string, cleanup func()) {
   551  		lis, err := net.Listen("tcp", "localhost:0")
   552  		if err != nil {
   553  			t.Fatalf("Error while listening. Err: %v", err)
   554  		}
   555  
   556  		stub := &stubserver.StubServer{
   557  			Listener: lis,
   558  			UnaryCallF: func(_ context.Context, _ *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
   559  				return &testpb.SimpleResponse{Username: username}, nil
   560  			},
   561  			S: grpc.NewServer(),
   562  		}
   563  		stubserver.StartTestService(t, stub)
   564  		return lis.Addr().String(), stub.S.Stop
   565  	}
   566  	const one = "1"
   567  	addr1, cleanup := reg(one)
   568  	defer cleanup()
   569  	const two = "2"
   570  	addr2, cleanup := reg(two)
   571  	defer cleanup()
   572  
   573  	// Initialize client
   574  	r := manual.NewBuilderWithScheme("whatever")
   575  	r.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: addr1}}})
   576  	cc, err := grpc.NewClient(r.Scheme()+":///", grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithResolvers(r))
   577  	if err != nil {
   578  		t.Fatalf("Error creating client: %v", err)
   579  	}
   580  	defer cc.Close()
   581  	client := testgrpc.NewTestServiceClient(cc)
   582  
   583  	// Confirm we are connected to the first server
   584  	if res, err := client.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil || res.Username != one {
   585  		t.Fatalf("UnaryCall(_) = %v, %v; want {Username: %q}, nil", res, err, one)
   586  	}
   587  
   588  	// Update resolver to report only the second server
   589  	r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: addr2}}})
   590  
   591  	// Loop until new RPCs talk to server two.
   592  	for i := 0; i < 2000; i++ {
   593  		if res, err := client.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil {
   594  			t.Fatalf("UnaryCall(_) = _, %v; want _, nil", err)
   595  		} else if res.Username == two {
   596  			break // pass
   597  		}
   598  		time.Sleep(5 * time.Millisecond)
   599  	}
   600  }
   601  
   602  func (s) TestWaitForReady(t *testing.T) {
   603  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   604  	defer cancel()
   605  
   606  	// Initialize server
   607  	lis, err := net.Listen("tcp", "localhost:0")
   608  	if err != nil {
   609  		t.Fatalf("Error while listening. Err: %v", err)
   610  	}
   611  	const one = "1"
   612  	stub := &stubserver.StubServer{
   613  		Listener: lis,
   614  		UnaryCallF: func(_ context.Context, _ *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
   615  			return &testpb.SimpleResponse{Username: one}, nil
   616  		},
   617  		S: grpc.NewServer(),
   618  	}
   619  	stubserver.StartTestService(t, stub)
   620  	defer stub.S.Stop()
   621  
   622  	// Initialize client
   623  	r := manual.NewBuilderWithScheme("whatever")
   624  
   625  	cc, err := grpc.NewClient(r.Scheme()+":///", grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithResolvers(r))
   626  	if err != nil {
   627  		t.Fatalf("Error creating client: %v", err)
   628  	}
   629  	defer cc.Close()
   630  	cc.Connect()
   631  	client := testgrpc.NewTestServiceClient(cc)
   632  
   633  	// Report an error so non-WFR RPCs will give up early.
   634  	r.CC().ReportError(errors.New("fake resolver error"))
   635  
   636  	// Ensure the client is not connected to anything and fails non-WFR RPCs.
   637  	if res, err := client.UnaryCall(ctx, &testpb.SimpleRequest{}); status.Code(err) != codes.Unavailable {
   638  		t.Fatalf("UnaryCall(_) = %v, %v; want _, Code()=%v", res, err, codes.Unavailable)
   639  	}
   640  
   641  	errChan := make(chan error, 1)
   642  	go func() {
   643  		if res, err := client.UnaryCall(ctx, &testpb.SimpleRequest{}, grpc.WaitForReady(true)); err != nil || res.Username != one {
   644  			errChan <- fmt.Errorf("UnaryCall(_) = %v, %v; want {Username: %q}, nil", res, err, one)
   645  		}
   646  		close(errChan)
   647  	}()
   648  
   649  	select {
   650  	case err := <-errChan:
   651  		t.Errorf("unexpected receive from errChan before addresses provided")
   652  		t.Fatal(err.Error())
   653  	case <-time.After(5 * time.Millisecond):
   654  	}
   655  
   656  	// Resolve the server.  The WFR RPC should unblock and use it.
   657  	r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String()}}})
   658  
   659  	if err := <-errChan; err != nil {
   660  		t.Fatal(err.Error())
   661  	}
   662  }
   663  
   664  // authorityOverrideTransportCreds returns the configured authority value in its
   665  // Info() method.
   666  type authorityOverrideTransportCreds struct {
   667  	credentials.TransportCredentials
   668  	authorityOverride string
   669  }
   670  
   671  func (ao *authorityOverrideTransportCreds) ClientHandshake(_ context.Context, _ string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
   672  	return rawConn, nil, nil
   673  }
   674  func (ao *authorityOverrideTransportCreds) Info() credentials.ProtocolInfo {
   675  	return credentials.ProtocolInfo{ServerName: ao.authorityOverride}
   676  }
   677  func (ao *authorityOverrideTransportCreds) Clone() credentials.TransportCredentials {
   678  	return &authorityOverrideTransportCreds{authorityOverride: ao.authorityOverride}
   679  }
   680  
   681  // TestAuthorityInBuildOptions tests that the Authority field in
   682  // balancer.BuildOptions is setup correctly from gRPC.
   683  func (s) TestAuthorityInBuildOptions(t *testing.T) {
   684  	const dialTarget = "test.server"
   685  
   686  	tests := []struct {
   687  		name          string
   688  		dopts         []grpc.DialOption
   689  		wantAuthority string
   690  	}{
   691  		{
   692  			name:          "authority from dial target",
   693  			dopts:         []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())},
   694  			wantAuthority: dialTarget,
   695  		},
   696  		{
   697  			name: "authority from dial option",
   698  			dopts: []grpc.DialOption{
   699  				grpc.WithTransportCredentials(insecure.NewCredentials()),
   700  				grpc.WithAuthority("authority-override"),
   701  			},
   702  			wantAuthority: "authority-override",
   703  		},
   704  		{
   705  			name:          "authority from transport creds",
   706  			dopts:         []grpc.DialOption{grpc.WithTransportCredentials(&authorityOverrideTransportCreds{authorityOverride: "authority-override-from-transport-creds"})},
   707  			wantAuthority: "authority-override-from-transport-creds",
   708  		},
   709  	}
   710  
   711  	for _, test := range tests {
   712  		t.Run(test.name, func(t *testing.T) {
   713  			authorityCh := make(chan string, 1)
   714  			bf := stub.BalancerFuncs{
   715  				UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error {
   716  					select {
   717  					case authorityCh <- bd.BuildOptions.Authority:
   718  					default:
   719  					}
   720  
   721  					addrs := ccs.ResolverState.Addresses
   722  					if len(addrs) == 0 {
   723  						return nil
   724  					}
   725  
   726  					// Only use the first address.
   727  					var sc balancer.SubConn
   728  					sc, err := bd.ClientConn.NewSubConn([]resolver.Address{addrs[0]}, balancer.NewSubConnOptions{
   729  						StateListener: func(state balancer.SubConnState) {
   730  							bd.ClientConn.UpdateState(balancer.State{ConnectivityState: state.ConnectivityState, Picker: &aiPicker{result: balancer.PickResult{SubConn: sc}, err: state.ConnectionError}})
   731  						},
   732  					})
   733  					if err != nil {
   734  						return err
   735  					}
   736  					sc.Connect()
   737  					return nil
   738  				},
   739  			}
   740  			balancerName := "stub-balancer-" + test.name
   741  			stub.Register(balancerName, bf)
   742  			t.Logf("Registered balancer %s...", balancerName)
   743  
   744  			lis, err := testutils.LocalTCPListener()
   745  			if err != nil {
   746  				t.Fatal(err)
   747  			}
   748  
   749  			stub := &stubserver.StubServer{
   750  				Listener: lis,
   751  				EmptyCallF: func(_ context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
   752  					return &testpb.Empty{}, nil
   753  				},
   754  				S: grpc.NewServer(),
   755  			}
   756  			stubserver.StartTestService(t, stub)
   757  			defer stub.S.Stop()
   758  			t.Logf("Started gRPC server at %s...", lis.Addr().String())
   759  
   760  			r := manual.NewBuilderWithScheme("whatever")
   761  			t.Logf("Registered manual resolver with scheme %s...", r.Scheme())
   762  			r.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String()}}})
   763  
   764  			dopts := append([]grpc.DialOption{
   765  				grpc.WithResolvers(r),
   766  				grpc.WithDefaultServiceConfig(fmt.Sprintf(`{ "loadBalancingConfig": [{"%v": {}}] }`, balancerName)),
   767  			}, test.dopts...)
   768  			cc, err := grpc.NewClient(r.Scheme()+":///"+dialTarget, dopts...)
   769  			if err != nil {
   770  				t.Fatal(err)
   771  			}
   772  			defer cc.Close()
   773  			tc := testgrpc.NewTestServiceClient(cc)
   774  			t.Log("Created a ClientConn...")
   775  
   776  			ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   777  			defer cancel()
   778  			if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
   779  				t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
   780  			}
   781  			t.Log("Made an RPC which succeeded...")
   782  
   783  			select {
   784  			case <-ctx.Done():
   785  				t.Fatal("timeout when waiting for Authority in balancer.BuildOptions")
   786  			case gotAuthority := <-authorityCh:
   787  				if gotAuthority != test.wantAuthority {
   788  					t.Fatalf("Authority in balancer.BuildOptions is %s, want %s", gotAuthority, test.wantAuthority)
   789  				}
   790  			}
   791  		})
   792  	}
   793  }
   794  
   795  // testCCWrapper wraps a balancer.ClientConn and intercepts UpdateState and
   796  // returns a custom picker which injects arbitrary metadata on a per-call basis.
   797  type testCCWrapper struct {
   798  	balancer.ClientConn
   799  }
   800  
   801  func (t *testCCWrapper) UpdateState(state balancer.State) {
   802  	state.Picker = &wrappedPicker{p: state.Picker}
   803  	t.ClientConn.UpdateState(state)
   804  }
   805  
   806  const (
   807  	metadataHeaderInjectedByBalancer    = "metadata-header-injected-by-balancer"
   808  	metadataHeaderInjectedByApplication = "metadata-header-injected-by-application"
   809  	metadataValueInjectedByBalancer     = "metadata-value-injected-by-balancer"
   810  	metadataValueInjectedByApplication  = "metadata-value-injected-by-application"
   811  )
   812  
   813  // wrappedPicker wraps the picker returned by the pick_first
   814  type wrappedPicker struct {
   815  	p balancer.Picker
   816  }
   817  
   818  func (wp *wrappedPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
   819  	res, err := wp.p.Pick(info)
   820  	if err != nil {
   821  		return balancer.PickResult{}, err
   822  	}
   823  
   824  	if res.Metadata == nil {
   825  		res.Metadata = metadata.Pairs(metadataHeaderInjectedByBalancer, metadataValueInjectedByBalancer)
   826  	} else {
   827  		res.Metadata.Append(metadataHeaderInjectedByBalancer, metadataValueInjectedByBalancer)
   828  	}
   829  	return res, nil
   830  }
   831  
   832  // TestMetadataInPickResult tests the scenario where an LB policy inject
   833  // arbitrary metadata on a per-call basis and verifies that the injected
   834  // metadata makes it all the way to the server RPC handler.
   835  func (s) TestMetadataInPickResult(t *testing.T) {
   836  	t.Log("Starting test backend...")
   837  	mdChan := make(chan metadata.MD, 1)
   838  	ss := &stubserver.StubServer{
   839  		EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
   840  			md, _ := metadata.FromIncomingContext(ctx)
   841  			select {
   842  			case mdChan <- md:
   843  			case <-ctx.Done():
   844  				return nil, ctx.Err()
   845  			}
   846  			return &testpb.Empty{}, nil
   847  		},
   848  	}
   849  	if err := ss.StartServer(); err != nil {
   850  		t.Fatalf("Starting test backend: %v", err)
   851  	}
   852  	defer ss.Stop()
   853  	t.Logf("Started test backend at %q", ss.Address)
   854  
   855  	// Register a test balancer that contains a pick_first balancer and forwards
   856  	// all calls from the ClientConn to it. For state updates from the
   857  	// pick_first balancer, it creates a custom picker which injects arbitrary
   858  	// metadata on a per-call basis.
   859  	stub.Register(t.Name(), stub.BalancerFuncs{
   860  		Init: func(bd *stub.BalancerData) {
   861  			cc := &testCCWrapper{ClientConn: bd.ClientConn}
   862  			bd.Data = balancer.Get(pickfirst.Name).Build(cc, bd.BuildOptions)
   863  		},
   864  		Close: func(bd *stub.BalancerData) {
   865  			bd.Data.(balancer.Balancer).Close()
   866  		},
   867  		UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error {
   868  			bal := bd.Data.(balancer.Balancer)
   869  			return bal.UpdateClientConnState(ccs)
   870  		},
   871  	})
   872  
   873  	t.Log("Creating ClientConn to test backend...")
   874  	r := manual.NewBuilderWithScheme("whatever")
   875  	r.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: ss.Address}}})
   876  	dopts := []grpc.DialOption{
   877  		grpc.WithTransportCredentials(insecure.NewCredentials()),
   878  		grpc.WithResolvers(r),
   879  		grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, t.Name())),
   880  	}
   881  	cc, err := grpc.NewClient(r.Scheme()+":///test.server", dopts...)
   882  	if err != nil {
   883  		t.Fatalf("grpc.NewClient(): %v", err)
   884  	}
   885  	defer cc.Close()
   886  	tc := testgrpc.NewTestServiceClient(cc)
   887  
   888  	t.Log("Making EmptyCall() RPC with custom metadata...")
   889  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   890  	defer cancel()
   891  	md := metadata.Pairs(metadataHeaderInjectedByApplication, metadataValueInjectedByApplication)
   892  	ctx = metadata.NewOutgoingContext(ctx, md)
   893  	if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
   894  		t.Fatalf("EmptyCall() RPC: %v", err)
   895  	}
   896  	t.Log("EmptyCall() RPC succeeded")
   897  
   898  	t.Log("Waiting for custom metadata to be received at the test backend...")
   899  	var gotMD metadata.MD
   900  	select {
   901  	case gotMD = <-mdChan:
   902  	case <-ctx.Done():
   903  		t.Fatalf("Timed out waiting for custom metadata to be received at the test backend")
   904  	}
   905  
   906  	t.Log("Verifying custom metadata added by the client application is received at the test backend...")
   907  	wantMDVal := []string{metadataValueInjectedByApplication}
   908  	gotMDVal := gotMD.Get(metadataHeaderInjectedByApplication)
   909  	if !cmp.Equal(gotMDVal, wantMDVal) {
   910  		t.Fatalf("Mismatch in custom metadata received at test backend, got: %v, want %v", gotMDVal, wantMDVal)
   911  	}
   912  
   913  	t.Log("Verifying custom metadata added by the LB policy is received at the test backend...")
   914  	wantMDVal = []string{metadataValueInjectedByBalancer}
   915  	gotMDVal = gotMD.Get(metadataHeaderInjectedByBalancer)
   916  	if !cmp.Equal(gotMDVal, wantMDVal) {
   917  		t.Fatalf("Mismatch in custom metadata received at test backend, got: %v, want %v", gotMDVal, wantMDVal)
   918  	}
   919  }
   920  
   921  // TestSubConnShutdown confirms that the Shutdown method on subconns and
   922  // RemoveSubConn method on ClientConn properly initiates subconn shutdown.
   923  func (s) TestSubConnShutdown(t *testing.T) {
   924  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   925  	defer cancel()
   926  
   927  	testCases := []struct {
   928  		name     string
   929  		shutdown func(cc balancer.ClientConn, sc balancer.SubConn)
   930  	}{{
   931  		name: "ClientConn.RemoveSubConn",
   932  		shutdown: func(cc balancer.ClientConn, sc balancer.SubConn) {
   933  			cc.RemoveSubConn(sc)
   934  		},
   935  	}, {
   936  		name: "SubConn.Shutdown",
   937  		shutdown: func(_ balancer.ClientConn, sc balancer.SubConn) {
   938  			sc.Shutdown()
   939  		},
   940  	}}
   941  
   942  	for _, tc := range testCases {
   943  		t.Run(tc.name, func(t *testing.T) {
   944  			gotShutdown := grpcsync.NewEvent()
   945  
   946  			bf := stub.BalancerFuncs{
   947  				UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error {
   948  					var sc balancer.SubConn
   949  					opts := balancer.NewSubConnOptions{
   950  						StateListener: func(scs balancer.SubConnState) {
   951  							switch scs.ConnectivityState {
   952  							case connectivity.Connecting:
   953  								// Ignored.
   954  							case connectivity.Ready:
   955  								tc.shutdown(bd.ClientConn, sc)
   956  							case connectivity.Shutdown:
   957  								gotShutdown.Fire()
   958  							default:
   959  								t.Errorf("got unexpected state %q in listener", scs.ConnectivityState)
   960  							}
   961  						},
   962  					}
   963  					sc, err := bd.ClientConn.NewSubConn(ccs.ResolverState.Addresses, opts)
   964  					if err != nil {
   965  						return err
   966  					}
   967  					sc.Connect()
   968  					// Report the state as READY to unblock ss.Start(), which waits for ready.
   969  					bd.ClientConn.UpdateState(balancer.State{ConnectivityState: connectivity.Ready})
   970  					return nil
   971  				},
   972  			}
   973  
   974  			testBalName := "shutdown-test-balancer-" + tc.name
   975  			stub.Register(testBalName, bf)
   976  			t.Logf("Registered balancer %s...", testBalName)
   977  
   978  			ss := &stubserver.StubServer{}
   979  			if err := ss.Start(nil, grpc.WithDefaultServiceConfig(
   980  				fmt.Sprintf(`{ "loadBalancingConfig": [{"%v": {}}] }`, testBalName),
   981  			)); err != nil {
   982  				t.Fatalf("Error starting endpoint server: %v", err)
   983  			}
   984  			defer ss.Stop()
   985  
   986  			select {
   987  			case <-gotShutdown.Done():
   988  				// Success
   989  			case <-ctx.Done():
   990  				t.Fatalf("Timed out waiting for gotShutdown to be fired.")
   991  			}
   992  		})
   993  	}
   994  }
   995  
   996  type subConnStoringCCWrapper struct {
   997  	balancer.ClientConn
   998  	stateListener func(balancer.SubConnState)
   999  	scChan        chan balancer.SubConn
  1000  }
  1001  
  1002  func (ccw *subConnStoringCCWrapper) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) {
  1003  	if ccw.stateListener != nil {
  1004  		origListener := opts.StateListener
  1005  		opts.StateListener = func(scs balancer.SubConnState) {
  1006  			ccw.stateListener(scs)
  1007  			origListener(scs)
  1008  		}
  1009  	}
  1010  	sc, err := ccw.ClientConn.NewSubConn(addrs, opts)
  1011  	ccw.scChan <- sc
  1012  	return sc, err
  1013  }
  1014  
  1015  // Test calls RegisterHealthListener on a SubConn to verify that expected health
  1016  // updates are sent only to the most recently registered listener.
  1017  func (s) TestSubConn_RegisterHealthListener(t *testing.T) {
  1018  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  1019  	defer cancel()
  1020  	scChan := make(chan balancer.SubConn, 1)
  1021  	bf := stub.BalancerFuncs{
  1022  		Init: func(bd *stub.BalancerData) {
  1023  			cc := bd.ClientConn
  1024  			ccw := &subConnStoringCCWrapper{
  1025  				ClientConn: cc,
  1026  				scChan:     scChan,
  1027  			}
  1028  			bd.Data = balancer.Get(pickfirst.Name).Build(ccw, bd.BuildOptions)
  1029  		},
  1030  		Close: func(bd *stub.BalancerData) {
  1031  			bd.Data.(balancer.Balancer).Close()
  1032  		},
  1033  		UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error {
  1034  			return bd.Data.(balancer.Balancer).UpdateClientConnState(ccs)
  1035  		},
  1036  		ExitIdle: func(bd *stub.BalancerData) {
  1037  			bd.Data.(balancer.ExitIdler).ExitIdle()
  1038  		},
  1039  	}
  1040  
  1041  	stub.Register(t.Name(), bf)
  1042  	svcCfg := fmt.Sprintf(`{ "loadBalancingConfig": [{%q: {}}] }`, t.Name())
  1043  	backend := stubserver.StartTestService(t, nil)
  1044  	defer backend.Stop()
  1045  	opts := []grpc.DialOption{
  1046  		grpc.WithTransportCredentials(insecure.NewCredentials()),
  1047  		grpc.WithDefaultServiceConfig(svcCfg),
  1048  	}
  1049  	cc, err := grpc.NewClient(backend.Address, opts...)
  1050  	if err != nil {
  1051  		t.Fatalf("grpc.NewClient(%q) failed: %v", backend.Address, err)
  1052  
  1053  	}
  1054  	defer cc.Close()
  1055  
  1056  	cc.Connect()
  1057  
  1058  	var sc balancer.SubConn
  1059  	select {
  1060  	case sc = <-scChan:
  1061  	case <-ctx.Done():
  1062  		t.Fatal("Context timed out waiting for SubConn creation")
  1063  	}
  1064  	healthUpdateChan := make(chan balancer.SubConnState, 1)
  1065  
  1066  	// Register listener while Ready and verify it gets a health update.
  1067  	testutils.AwaitState(ctx, t, cc, connectivity.Ready)
  1068  	for i := 0; i < 2; i++ {
  1069  		sc.RegisterHealthListener(func(scs balancer.SubConnState) {
  1070  			healthUpdateChan <- scs
  1071  		})
  1072  		select {
  1073  		case scs := <-healthUpdateChan:
  1074  			if scs.ConnectivityState != connectivity.Ready {
  1075  				t.Fatalf("Received health update = %v, want = %v", scs.ConnectivityState, connectivity.Ready)
  1076  			}
  1077  		case <-ctx.Done():
  1078  			t.Fatalf("Context timed out waiting for health update")
  1079  		}
  1080  
  1081  		// No further updates are expected.
  1082  		select {
  1083  		case scs := <-healthUpdateChan:
  1084  			t.Fatalf("Received unexpected health update while channel is in state %v: %v", cc.GetState(), scs)
  1085  		case <-time.After(defaultTestShortTimeout):
  1086  		}
  1087  	}
  1088  
  1089  	// Make the SubConn enter IDLE and verify that health updates are recevied
  1090  	// on registering a new listener.
  1091  	backend.S.Stop()
  1092  	backend.S = nil
  1093  	testutils.AwaitState(ctx, t, cc, connectivity.Idle)
  1094  	if err := backend.StartServer(); err != nil {
  1095  		t.Fatalf("Error while restarting the backend server: %v", err)
  1096  	}
  1097  	cc.Connect()
  1098  	testutils.AwaitState(ctx, t, cc, connectivity.Ready)
  1099  	sc.RegisterHealthListener(func(scs balancer.SubConnState) {
  1100  		healthUpdateChan <- scs
  1101  	})
  1102  	select {
  1103  	case scs := <-healthUpdateChan:
  1104  		if scs.ConnectivityState != connectivity.Ready {
  1105  			t.Fatalf("Received health update = %v, want = %v", scs.ConnectivityState, connectivity.Ready)
  1106  		}
  1107  	case <-ctx.Done():
  1108  		t.Fatalf("Context timed out waiting for health update")
  1109  	}
  1110  }
  1111  
  1112  // Test calls RegisterHealthListener on a SubConn twice while handling the
  1113  // connectivity update. The test verifies that only the latest listener
  1114  // receives the health update.
  1115  func (s) TestSubConn_RegisterHealthListener_RegisterTwice(t *testing.T) {
  1116  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  1117  	defer cancel()
  1118  	scChan := make(chan balancer.SubConn, 1)
  1119  	readyUpdateResumeCh := make(chan struct{})
  1120  	readyUpdateReceivedCh := make(chan struct{})
  1121  	bf := stub.BalancerFuncs{
  1122  		Init: func(bd *stub.BalancerData) {
  1123  			cc := bd.ClientConn
  1124  			ccw := &subConnStoringCCWrapper{
  1125  				ClientConn: cc,
  1126  				scChan:     scChan,
  1127  				stateListener: func(scs balancer.SubConnState) {
  1128  					if scs.ConnectivityState != connectivity.Ready {
  1129  						return
  1130  					}
  1131  					close(readyUpdateReceivedCh)
  1132  					select {
  1133  					case <-readyUpdateResumeCh:
  1134  					case <-ctx.Done():
  1135  						t.Error("Context timed out waiting for update on ready channel")
  1136  					}
  1137  				},
  1138  			}
  1139  			bd.Data = balancer.Get(pickfirst.Name).Build(ccw, bd.BuildOptions)
  1140  		},
  1141  		Close: func(bd *stub.BalancerData) {
  1142  			bd.Data.(balancer.Balancer).Close()
  1143  		},
  1144  		UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error {
  1145  			return bd.Data.(balancer.Balancer).UpdateClientConnState(ccs)
  1146  		},
  1147  	}
  1148  
  1149  	stub.Register(t.Name(), bf)
  1150  	svcCfg := fmt.Sprintf(`{ "loadBalancingConfig": [{%q: {}}] }`, t.Name())
  1151  	backend := stubserver.StartTestService(t, nil)
  1152  	defer backend.Stop()
  1153  	opts := []grpc.DialOption{
  1154  		grpc.WithTransportCredentials(insecure.NewCredentials()),
  1155  		grpc.WithDefaultServiceConfig(svcCfg),
  1156  	}
  1157  	cc, err := grpc.NewClient(backend.Address, opts...)
  1158  	if err != nil {
  1159  		t.Fatalf("grpc.NewClient(%q) failed: %v", backend.Address, err)
  1160  
  1161  	}
  1162  	defer cc.Close()
  1163  
  1164  	cc.Connect()
  1165  
  1166  	var sc balancer.SubConn
  1167  	select {
  1168  	case sc = <-scChan:
  1169  	case <-ctx.Done():
  1170  		t.Fatal("Context timed out waiting for SubConn creation")
  1171  	}
  1172  
  1173  	// Wait for the SubConn to enter READY.
  1174  	select {
  1175  	case <-readyUpdateReceivedCh:
  1176  	case <-ctx.Done():
  1177  		t.Fatalf("Context timed out waiting for SubConn to enter READY")
  1178  	}
  1179  
  1180  	healthChan1 := make(chan balancer.SubConnState, 1)
  1181  	healthChan2 := make(chan balancer.SubConnState, 1)
  1182  
  1183  	sc.RegisterHealthListener(func(scs balancer.SubConnState) {
  1184  		healthChan1 <- scs
  1185  	})
  1186  	sc.RegisterHealthListener(func(scs balancer.SubConnState) {
  1187  		healthChan2 <- scs
  1188  	})
  1189  	close(readyUpdateResumeCh)
  1190  
  1191  	select {
  1192  	case scs := <-healthChan2:
  1193  		if scs.ConnectivityState != connectivity.Ready {
  1194  			t.Fatalf("Received health update = %v, want = %v", scs.ConnectivityState, connectivity.Ready)
  1195  		}
  1196  	case <-ctx.Done():
  1197  		t.Fatalf("Context timed out waiting for health update")
  1198  	}
  1199  
  1200  	// No updates should be received on the first listener.
  1201  	select {
  1202  	case scs := <-healthChan1:
  1203  		t.Fatalf("Received unexpected health update on first listener: %v", scs)
  1204  	case <-time.After(defaultTestShortTimeout):
  1205  	}
  1206  }
  1207  
  1208  // Test calls RegisterHealthListener on a SubConn with a nil listener and
  1209  // verifies that the listener registered before the nil listener doesn't receive
  1210  // any further updates.
  1211  func (s) TestSubConn_RegisterHealthListener_NilListener(t *testing.T) {
  1212  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  1213  	defer cancel()
  1214  	scChan := make(chan balancer.SubConn, 1)
  1215  	readyUpdateResumeCh := make(chan struct{})
  1216  	readyUpdateReceivedCh := make(chan struct{})
  1217  	bf := stub.BalancerFuncs{
  1218  		Init: func(bd *stub.BalancerData) {
  1219  			cc := bd.ClientConn
  1220  			ccw := &subConnStoringCCWrapper{
  1221  				ClientConn: cc,
  1222  				scChan:     scChan,
  1223  				stateListener: func(scs balancer.SubConnState) {
  1224  					if scs.ConnectivityState != connectivity.Ready {
  1225  						return
  1226  					}
  1227  					close(readyUpdateReceivedCh)
  1228  					select {
  1229  					case <-readyUpdateResumeCh:
  1230  					case <-ctx.Done():
  1231  						t.Error("Context timed out waiting for update on ready channel")
  1232  					}
  1233  				},
  1234  			}
  1235  			bd.Data = balancer.Get(pickfirst.Name).Build(ccw, bd.BuildOptions)
  1236  		},
  1237  		Close: func(bd *stub.BalancerData) {
  1238  			bd.Data.(balancer.Balancer).Close()
  1239  		},
  1240  		UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error {
  1241  			return bd.Data.(balancer.Balancer).UpdateClientConnState(ccs)
  1242  		},
  1243  	}
  1244  
  1245  	stub.Register(t.Name(), bf)
  1246  	svcCfg := fmt.Sprintf(`{ "loadBalancingConfig": [{%q: {}}] }`, t.Name())
  1247  	backend := stubserver.StartTestService(t, nil)
  1248  	defer backend.Stop()
  1249  	opts := []grpc.DialOption{
  1250  		grpc.WithTransportCredentials(insecure.NewCredentials()),
  1251  		grpc.WithDefaultServiceConfig(svcCfg),
  1252  	}
  1253  	cc, err := grpc.NewClient(backend.Address, opts...)
  1254  	if err != nil {
  1255  		t.Fatalf("grpc.NewClient(%q) failed: %v", backend.Address, err)
  1256  
  1257  	}
  1258  	defer cc.Close()
  1259  
  1260  	cc.Connect()
  1261  
  1262  	var sc balancer.SubConn
  1263  	select {
  1264  	case sc = <-scChan:
  1265  	case <-ctx.Done():
  1266  		t.Fatal("Context timed out waiting for SubConn creation")
  1267  	}
  1268  
  1269  	// Wait for the SubConn to enter READY.
  1270  	select {
  1271  	case <-readyUpdateReceivedCh:
  1272  	case <-ctx.Done():
  1273  		t.Fatalf("Context timed out waiting for SubConn to enter READY")
  1274  	}
  1275  
  1276  	healthChan := make(chan balancer.SubConnState, 1)
  1277  
  1278  	sc.RegisterHealthListener(func(scs balancer.SubConnState) {
  1279  		healthChan <- scs
  1280  	})
  1281  
  1282  	// Registering a nil listener should invalidate the previously registered
  1283  	// listener.
  1284  	sc.RegisterHealthListener(nil)
  1285  	close(readyUpdateResumeCh)
  1286  
  1287  	// No updates should be received on the listener.
  1288  	select {
  1289  	case scs := <-healthChan:
  1290  		t.Fatalf("Received unexpected health update on the listener: %v", scs)
  1291  	case <-time.After(defaultTestShortTimeout):
  1292  	}
  1293  }