github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/grpc/test/balancer_test.go (about)

     1  /*
     2   *
     3   * Copyright 2018 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package test
    20  
    21  import (
    22  	"context"
    23  	"errors"
    24  	"fmt"
    25  	"net"
    26  	"reflect"
    27  	"testing"
    28  	"time"
    29  
    30  	"github.com/google/go-cmp/cmp"
    31  
    32  	grpc "github.com/hxx258456/ccgo/grpc"
    33  	"github.com/hxx258456/ccgo/grpc/attributes"
    34  	"github.com/hxx258456/ccgo/grpc/balancer"
    35  	"github.com/hxx258456/ccgo/grpc/balancer/roundrobin"
    36  	"github.com/hxx258456/ccgo/grpc/codes"
    37  	"github.com/hxx258456/ccgo/grpc/connectivity"
    38  	"github.com/hxx258456/ccgo/grpc/credentials"
    39  	"github.com/hxx258456/ccgo/grpc/credentials/insecure"
    40  	"github.com/hxx258456/ccgo/grpc/internal/balancer/stub"
    41  	"github.com/hxx258456/ccgo/grpc/internal/balancerload"
    42  	"github.com/hxx258456/ccgo/grpc/internal/grpcutil"
    43  	imetadata "github.com/hxx258456/ccgo/grpc/internal/metadata"
    44  	"github.com/hxx258456/ccgo/grpc/internal/stubserver"
    45  	"github.com/hxx258456/ccgo/grpc/internal/testutils"
    46  	"github.com/hxx258456/ccgo/grpc/metadata"
    47  	"github.com/hxx258456/ccgo/grpc/resolver"
    48  	"github.com/hxx258456/ccgo/grpc/resolver/manual"
    49  	"github.com/hxx258456/ccgo/grpc/status"
    50  	testpb "github.com/hxx258456/ccgo/grpc/test/grpc_testing"
    51  	"github.com/hxx258456/ccgo/grpc/testdata"
    52  )
    53  
    54  const testBalancerName = "testbalancer"
    55  
    56  // testBalancer creates one subconn with the first address from resolved
    57  // addresses.
    58  //
    59  // It's used to test whether options for NewSubConn are applied correctly.
    60  type testBalancer struct {
    61  	cc balancer.ClientConn
    62  	sc balancer.SubConn
    63  
    64  	newSubConnOptions balancer.NewSubConnOptions
    65  	pickInfos         []balancer.PickInfo
    66  	pickExtraMDs      []metadata.MD
    67  	doneInfo          []balancer.DoneInfo
    68  }
    69  
    70  func (b *testBalancer) Build(cc balancer.ClientConn, opt balancer.BuildOptions) balancer.Balancer {
    71  	b.cc = cc
    72  	return b
    73  }
    74  
    75  func (*testBalancer) Name() string {
    76  	return testBalancerName
    77  }
    78  
    79  func (*testBalancer) ResolverError(err error) {
    80  	panic("not implemented")
    81  }
    82  
    83  func (b *testBalancer) UpdateClientConnState(state balancer.ClientConnState) error {
    84  	// Only create a subconn at the first time.
    85  	if b.sc == nil {
    86  		var err error
    87  		b.sc, err = b.cc.NewSubConn(state.ResolverState.Addresses, b.newSubConnOptions)
    88  		if err != nil {
    89  			logger.Errorf("testBalancer: failed to NewSubConn: %v", err)
    90  			return nil
    91  		}
    92  		b.cc.UpdateState(balancer.State{ConnectivityState: connectivity.Connecting, Picker: &picker{err: balancer.ErrNoSubConnAvailable, bal: b}})
    93  		b.sc.Connect()
    94  	}
    95  	return nil
    96  }
    97  
    98  func (b *testBalancer) UpdateSubConnState(sc balancer.SubConn, s balancer.SubConnState) {
    99  	logger.Infof("testBalancer: UpdateSubConnState: %p, %v", sc, s)
   100  	if b.sc != sc {
   101  		logger.Infof("testBalancer: ignored state change because sc is not recognized")
   102  		return
   103  	}
   104  	if s.ConnectivityState == connectivity.Shutdown {
   105  		b.sc = nil
   106  		return
   107  	}
   108  
   109  	switch s.ConnectivityState {
   110  	case connectivity.Ready:
   111  		b.cc.UpdateState(balancer.State{ConnectivityState: s.ConnectivityState, Picker: &picker{sc: sc, bal: b}})
   112  	case connectivity.Idle:
   113  		b.cc.UpdateState(balancer.State{ConnectivityState: s.ConnectivityState, Picker: &picker{sc: sc, bal: b, idle: true}})
   114  	case connectivity.Connecting:
   115  		b.cc.UpdateState(balancer.State{ConnectivityState: s.ConnectivityState, Picker: &picker{err: balancer.ErrNoSubConnAvailable, bal: b}})
   116  	case connectivity.TransientFailure:
   117  		b.cc.UpdateState(balancer.State{ConnectivityState: s.ConnectivityState, Picker: &picker{err: balancer.ErrTransientFailure, bal: b}})
   118  	}
   119  }
   120  
   121  func (b *testBalancer) Close() {}
   122  
   123  func (b *testBalancer) ExitIdle() {}
   124  
   125  type picker struct {
   126  	err  error
   127  	sc   balancer.SubConn
   128  	bal  *testBalancer
   129  	idle bool
   130  }
   131  
   132  func (p *picker) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
   133  	if p.err != nil {
   134  		return balancer.PickResult{}, p.err
   135  	}
   136  	if p.idle {
   137  		p.sc.Connect()
   138  		return balancer.PickResult{}, balancer.ErrNoSubConnAvailable
   139  	}
   140  	extraMD, _ := grpcutil.ExtraMetadata(info.Ctx)
   141  	info.Ctx = nil // Do not validate context.
   142  	p.bal.pickInfos = append(p.bal.pickInfos, info)
   143  	p.bal.pickExtraMDs = append(p.bal.pickExtraMDs, extraMD)
   144  	return balancer.PickResult{SubConn: p.sc, Done: func(d balancer.DoneInfo) { p.bal.doneInfo = append(p.bal.doneInfo, d) }}, nil
   145  }
   146  
   147  func (s) TestCredsBundleFromBalancer(t *testing.T) {
   148  	balancer.Register(&testBalancer{
   149  		newSubConnOptions: balancer.NewSubConnOptions{
   150  			CredsBundle: &testCredsBundle{},
   151  		},
   152  	})
   153  	te := newTest(t, env{name: "creds-bundle", network: "tcp", balancer: ""})
   154  	te.tapHandle = authHandle
   155  	te.customDialOptions = []grpc.DialOption{
   156  		grpc.WithBalancerName(testBalancerName),
   157  	}
   158  	creds, err := credentials.NewServerTLSFromFile(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
   159  	if err != nil {
   160  		t.Fatalf("Failed to generate credentials %v", err)
   161  	}
   162  	te.customServerOptions = []grpc.ServerOption{
   163  		grpc.Creds(creds),
   164  	}
   165  	te.startServer(&testServer{})
   166  	defer te.tearDown()
   167  
   168  	cc := te.clientConn()
   169  	tc := testpb.NewTestServiceClient(cc)
   170  	if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
   171  		t.Fatalf("Test failed. Reason: %v", err)
   172  	}
   173  }
   174  
   175  func (s) TestPickExtraMetadata(t *testing.T) {
   176  	for _, e := range listTestEnv() {
   177  		testPickExtraMetadata(t, e)
   178  	}
   179  }
   180  
   181  func testPickExtraMetadata(t *testing.T, e env) {
   182  	te := newTest(t, e)
   183  	b := &testBalancer{}
   184  	balancer.Register(b)
   185  	const (
   186  		testUserAgent      = "test-user-agent"
   187  		testSubContentType = "proto"
   188  	)
   189  
   190  	te.customDialOptions = []grpc.DialOption{
   191  		grpc.WithBalancerName(testBalancerName),
   192  		grpc.WithUserAgent(testUserAgent),
   193  	}
   194  	te.startServer(&testServer{security: e.security})
   195  	defer te.tearDown()
   196  
   197  	// Set resolver to xds to trigger the extra metadata code path.
   198  	r := manual.NewBuilderWithScheme("xds")
   199  	resolver.Register(r)
   200  	defer func() {
   201  		resolver.UnregisterForTesting("xds")
   202  	}()
   203  	r.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: te.srvAddr}}})
   204  	te.resolverScheme = "xds"
   205  	cc := te.clientConn()
   206  	tc := testpb.NewTestServiceClient(cc)
   207  
   208  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   209  	defer cancel()
   210  	if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err != nil {
   211  		t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %v", err, nil)
   212  	}
   213  	if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.CallContentSubtype(testSubContentType)); err != nil {
   214  		t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %v", err, nil)
   215  	}
   216  
   217  	want := []metadata.MD{
   218  		// First RPC doesn't have sub-content-type.
   219  		{"content-type": []string{"application/grpc"}},
   220  		// Second RPC has sub-content-type "proto".
   221  		{"content-type": []string{"application/grpc+proto"}},
   222  	}
   223  	if diff := cmp.Diff(want, b.pickExtraMDs); diff != "" {
   224  		t.Fatalf("unexpected diff in metadata (-want, +got): %s", diff)
   225  	}
   226  }
   227  
   228  func (s) TestDoneInfo(t *testing.T) {
   229  	for _, e := range listTestEnv() {
   230  		testDoneInfo(t, e)
   231  	}
   232  }
   233  
   234  func testDoneInfo(t *testing.T, e env) {
   235  	te := newTest(t, e)
   236  	b := &testBalancer{}
   237  	balancer.Register(b)
   238  	te.customDialOptions = []grpc.DialOption{
   239  		grpc.WithBalancerName(testBalancerName),
   240  	}
   241  	te.userAgent = failAppUA
   242  	te.startServer(&testServer{security: e.security})
   243  	defer te.tearDown()
   244  
   245  	cc := te.clientConn()
   246  	tc := testpb.NewTestServiceClient(cc)
   247  
   248  	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   249  	defer cancel()
   250  	wantErr := detailedError
   251  	if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); !testutils.StatusErrEqual(err, wantErr) {
   252  		t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %v", err, wantErr)
   253  	}
   254  	if _, err := tc.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil {
   255  		t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, <nil>", ctx, err)
   256  	}
   257  
   258  	if len(b.doneInfo) < 1 || !testutils.StatusErrEqual(b.doneInfo[0].Err, wantErr) {
   259  		t.Fatalf("b.doneInfo = %v; want b.doneInfo[0].Err = %v", b.doneInfo, wantErr)
   260  	}
   261  	if len(b.doneInfo) < 2 || !reflect.DeepEqual(b.doneInfo[1].Trailer, testTrailerMetadata) {
   262  		t.Fatalf("b.doneInfo = %v; want b.doneInfo[1].Trailer = %v", b.doneInfo, testTrailerMetadata)
   263  	}
   264  	if len(b.pickInfos) != len(b.doneInfo) {
   265  		t.Fatalf("Got %d picks, but %d doneInfo, want equal amount", len(b.pickInfos), len(b.doneInfo))
   266  	}
   267  	// To test done() is always called, even if it's returned with a non-Ready
   268  	// SubConn.
   269  	//
   270  	// Stop server and at the same time send RPCs. There are chances that picker
   271  	// is not updated in time, causing a non-Ready SubConn to be returned.
   272  	finished := make(chan struct{})
   273  	go func() {
   274  		for i := 0; i < 20; i++ {
   275  			tc.UnaryCall(ctx, &testpb.SimpleRequest{})
   276  		}
   277  		close(finished)
   278  	}()
   279  	te.srv.Stop()
   280  	<-finished
   281  	if len(b.pickInfos) != len(b.doneInfo) {
   282  		t.Fatalf("Got %d picks, %d doneInfo, want equal amount", len(b.pickInfos), len(b.doneInfo))
   283  	}
   284  }
   285  
   286  const loadMDKey = "X-Endpoint-Load-Metrics-Bin"
   287  
   288  type testLoadParser struct{}
   289  
   290  func (*testLoadParser) Parse(md metadata.MD) interface{} {
   291  	vs := md.Get(loadMDKey)
   292  	if len(vs) == 0 {
   293  		return nil
   294  	}
   295  	return vs[0]
   296  }
   297  
   298  func init() {
   299  	balancerload.SetParser(&testLoadParser{})
   300  }
   301  
   302  func (s) TestDoneLoads(t *testing.T) {
   303  	testDoneLoads(t)
   304  }
   305  
   306  func testDoneLoads(t *testing.T) {
   307  	b := &testBalancer{}
   308  	balancer.Register(b)
   309  
   310  	const testLoad = "test-load-,-should-be-orca"
   311  
   312  	ss := &stubserver.StubServer{
   313  		EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
   314  			grpc.SetTrailer(ctx, metadata.Pairs(loadMDKey, testLoad))
   315  			return &testpb.Empty{}, nil
   316  		},
   317  	}
   318  	if err := ss.Start(nil, grpc.WithBalancerName(testBalancerName)); err != nil {
   319  		t.Fatalf("error starting testing server: %v", err)
   320  	}
   321  	defer ss.Stop()
   322  
   323  	tc := testpb.NewTestServiceClient(ss.CC)
   324  
   325  	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   326  	defer cancel()
   327  	if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
   328  		t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %v", err, nil)
   329  	}
   330  
   331  	piWant := []balancer.PickInfo{
   332  		{FullMethodName: "/grpc.testing.TestService/EmptyCall"},
   333  	}
   334  	if !reflect.DeepEqual(b.pickInfos, piWant) {
   335  		t.Fatalf("b.pickInfos = %v; want %v", b.pickInfos, piWant)
   336  	}
   337  
   338  	if len(b.doneInfo) < 1 {
   339  		t.Fatalf("b.doneInfo = %v, want length 1", b.doneInfo)
   340  	}
   341  	gotLoad, _ := b.doneInfo[0].ServerLoad.(string)
   342  	if gotLoad != testLoad {
   343  		t.Fatalf("b.doneInfo[0].ServerLoad = %v; want = %v", b.doneInfo[0].ServerLoad, testLoad)
   344  	}
   345  }
   346  
   347  const testBalancerKeepAddressesName = "testbalancer-keepingaddresses"
   348  
   349  // testBalancerKeepAddresses keeps the addresses in the builder instead of
   350  // creating SubConns.
   351  //
   352  // It's used to test the addresses balancer gets are correct.
   353  type testBalancerKeepAddresses struct {
   354  	addrsChan chan []resolver.Address
   355  }
   356  
   357  func newTestBalancerKeepAddresses() *testBalancerKeepAddresses {
   358  	return &testBalancerKeepAddresses{
   359  		addrsChan: make(chan []resolver.Address, 10),
   360  	}
   361  }
   362  
   363  func (testBalancerKeepAddresses) ResolverError(err error) {
   364  	panic("not implemented")
   365  }
   366  
   367  func (b *testBalancerKeepAddresses) Build(cc balancer.ClientConn, opt balancer.BuildOptions) balancer.Balancer {
   368  	return b
   369  }
   370  
   371  func (*testBalancerKeepAddresses) Name() string {
   372  	return testBalancerKeepAddressesName
   373  }
   374  
   375  func (b *testBalancerKeepAddresses) UpdateClientConnState(state balancer.ClientConnState) error {
   376  	b.addrsChan <- state.ResolverState.Addresses
   377  	return nil
   378  }
   379  
   380  func (testBalancerKeepAddresses) UpdateSubConnState(sc balancer.SubConn, s balancer.SubConnState) {
   381  	panic("not used")
   382  }
   383  
   384  func (testBalancerKeepAddresses) Close() {}
   385  
   386  func (testBalancerKeepAddresses) ExitIdle() {}
   387  
   388  // Make sure that non-grpclb balancers don't get grpclb addresses even if name
   389  // resolver sends them
   390  func (s) TestNonGRPCLBBalancerGetsNoGRPCLBAddress(t *testing.T) {
   391  	r := manual.NewBuilderWithScheme("whatever")
   392  
   393  	b := newTestBalancerKeepAddresses()
   394  	balancer.Register(b)
   395  
   396  	cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithResolvers(r),
   397  		grpc.WithBalancerName(b.Name()))
   398  	if err != nil {
   399  		t.Fatalf("failed to dial: %v", err)
   400  	}
   401  	defer cc.Close()
   402  
   403  	grpclbAddresses := []resolver.Address{{
   404  		Addr:       "grpc.lb.com",
   405  		Type:       resolver.GRPCLB,
   406  		ServerName: "grpc.lb.com",
   407  	}}
   408  
   409  	nonGRPCLBAddresses := []resolver.Address{{
   410  		Addr: "localhost",
   411  		Type: resolver.Backend,
   412  	}}
   413  
   414  	r.UpdateState(resolver.State{
   415  		Addresses: nonGRPCLBAddresses,
   416  	})
   417  	if got := <-b.addrsChan; !reflect.DeepEqual(got, nonGRPCLBAddresses) {
   418  		t.Fatalf("With only backend addresses, balancer got addresses %v, want %v", got, nonGRPCLBAddresses)
   419  	}
   420  
   421  	r.UpdateState(resolver.State{
   422  		Addresses: grpclbAddresses,
   423  	})
   424  	if got := <-b.addrsChan; len(got) != 0 {
   425  		t.Fatalf("With only grpclb addresses, balancer got addresses %v, want empty", got)
   426  	}
   427  
   428  	r.UpdateState(resolver.State{
   429  		Addresses: append(grpclbAddresses, nonGRPCLBAddresses...),
   430  	})
   431  	if got := <-b.addrsChan; !reflect.DeepEqual(got, nonGRPCLBAddresses) {
   432  		t.Fatalf("With both backend and grpclb addresses, balancer got addresses %v, want %v", got, nonGRPCLBAddresses)
   433  	}
   434  }
   435  
   436  type aiPicker struct {
   437  	result balancer.PickResult
   438  	err    error
   439  }
   440  
   441  func (aip *aiPicker) Pick(_ balancer.PickInfo) (balancer.PickResult, error) {
   442  	return aip.result, aip.err
   443  }
   444  
   445  // attrTransportCreds is a transport credential implementation which stores
   446  // Attributes from the ClientHandshakeInfo struct passed in the context locally
   447  // for the test to inspect.
   448  type attrTransportCreds struct {
   449  	credentials.TransportCredentials
   450  	attr *attributes.Attributes
   451  }
   452  
   453  func (ac *attrTransportCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
   454  	ai := credentials.ClientHandshakeInfoFromContext(ctx)
   455  	ac.attr = ai.Attributes
   456  	return rawConn, nil, nil
   457  }
   458  func (ac *attrTransportCreds) Info() credentials.ProtocolInfo {
   459  	return credentials.ProtocolInfo{}
   460  }
   461  func (ac *attrTransportCreds) Clone() credentials.TransportCredentials {
   462  	return nil
   463  }
   464  
   465  // TestAddressAttributesInNewSubConn verifies that the Attributes passed from a
   466  // balancer in the resolver.Address that is passes to NewSubConn reaches all the
   467  // way to the ClientHandshake method of the credentials configured on the parent
   468  // channel.
   469  func (s) TestAddressAttributesInNewSubConn(t *testing.T) {
   470  	const (
   471  		testAttrKey      = "foo"
   472  		testAttrVal      = "bar"
   473  		attrBalancerName = "attribute-balancer"
   474  	)
   475  
   476  	// Register a stub balancer which adds attributes to the first address that
   477  	// it receives and then calls NewSubConn on it.
   478  	bf := stub.BalancerFuncs{
   479  		UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error {
   480  			addrs := ccs.ResolverState.Addresses
   481  			if len(addrs) == 0 {
   482  				return nil
   483  			}
   484  
   485  			// Only use the first address.
   486  			attr := attributes.New(testAttrKey, testAttrVal)
   487  			addrs[0].Attributes = attr
   488  			sc, err := bd.ClientConn.NewSubConn([]resolver.Address{addrs[0]}, balancer.NewSubConnOptions{})
   489  			if err != nil {
   490  				return err
   491  			}
   492  			sc.Connect()
   493  			return nil
   494  		},
   495  		UpdateSubConnState: func(bd *stub.BalancerData, sc balancer.SubConn, state balancer.SubConnState) {
   496  			bd.ClientConn.UpdateState(balancer.State{ConnectivityState: state.ConnectivityState, Picker: &aiPicker{result: balancer.PickResult{SubConn: sc}, err: state.ConnectionError}})
   497  		},
   498  	}
   499  	stub.Register(attrBalancerName, bf)
   500  	t.Logf("Registered balancer %s...", attrBalancerName)
   501  
   502  	r := manual.NewBuilderWithScheme("whatever")
   503  	t.Logf("Registered manual resolver with scheme %s...", r.Scheme())
   504  
   505  	lis, err := net.Listen("tcp", "localhost:0")
   506  	if err != nil {
   507  		t.Fatal(err)
   508  	}
   509  
   510  	s := grpc.NewServer()
   511  	testpb.RegisterTestServiceServer(s, &testServer{})
   512  	go s.Serve(lis)
   513  	defer s.Stop()
   514  	t.Logf("Started gRPC server at %s...", lis.Addr().String())
   515  
   516  	creds := &attrTransportCreds{}
   517  	dopts := []grpc.DialOption{
   518  		grpc.WithTransportCredentials(creds),
   519  		grpc.WithResolvers(r),
   520  		grpc.WithDefaultServiceConfig(fmt.Sprintf(`{ "loadBalancingConfig": [{"%v": {}}] }`, attrBalancerName)),
   521  	}
   522  	cc, err := grpc.Dial(r.Scheme()+":///test.server", dopts...)
   523  	if err != nil {
   524  		t.Fatal(err)
   525  	}
   526  	defer cc.Close()
   527  	tc := testpb.NewTestServiceClient(cc)
   528  	t.Log("Created a ClientConn...")
   529  
   530  	// The first RPC should fail because there's no address.
   531  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   532  	defer cancel()
   533  	if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err == nil || status.Code(err) != codes.DeadlineExceeded {
   534  		t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
   535  	}
   536  	t.Log("Made an RPC which was expected to fail...")
   537  
   538  	state := resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String()}}}
   539  	r.UpdateState(state)
   540  	t.Logf("Pushing resolver state update: %v through the manual resolver", state)
   541  
   542  	// The second RPC should succeed.
   543  	ctx, cancel = context.WithTimeout(context.Background(), time.Second)
   544  	defer cancel()
   545  	if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
   546  		t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
   547  	}
   548  	t.Log("Made an RPC which succeeded...")
   549  
   550  	wantAttr := attributes.New(testAttrKey, testAttrVal)
   551  	if gotAttr := creds.attr; !cmp.Equal(gotAttr, wantAttr, cmp.AllowUnexported(attributes.Attributes{})) {
   552  		t.Fatalf("received attributes %v in creds, want %v", gotAttr, wantAttr)
   553  	}
   554  }
   555  
   556  // TestMetadataInAddressAttributes verifies that the metadata added to
   557  // address.Attributes will be sent with the RPCs.
   558  func (s) TestMetadataInAddressAttributes(t *testing.T) {
   559  	const (
   560  		testMDKey      = "test-md"
   561  		testMDValue    = "test-md-value"
   562  		mdBalancerName = "metadata-balancer"
   563  	)
   564  
   565  	// Register a stub balancer which adds metadata to the first address that it
   566  	// receives and then calls NewSubConn on it.
   567  	bf := stub.BalancerFuncs{
   568  		UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error {
   569  			addrs := ccs.ResolverState.Addresses
   570  			if len(addrs) == 0 {
   571  				return nil
   572  			}
   573  			// Only use the first address.
   574  			sc, err := bd.ClientConn.NewSubConn([]resolver.Address{
   575  				imetadata.Set(addrs[0], metadata.Pairs(testMDKey, testMDValue)),
   576  			}, balancer.NewSubConnOptions{})
   577  			if err != nil {
   578  				return err
   579  			}
   580  			sc.Connect()
   581  			return nil
   582  		},
   583  		UpdateSubConnState: func(bd *stub.BalancerData, sc balancer.SubConn, state balancer.SubConnState) {
   584  			bd.ClientConn.UpdateState(balancer.State{ConnectivityState: state.ConnectivityState, Picker: &aiPicker{result: balancer.PickResult{SubConn: sc}, err: state.ConnectionError}})
   585  		},
   586  	}
   587  	stub.Register(mdBalancerName, bf)
   588  	t.Logf("Registered balancer %s...", mdBalancerName)
   589  
   590  	testMDChan := make(chan []string, 1)
   591  	ss := &stubserver.StubServer{
   592  		EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
   593  			md, ok := metadata.FromIncomingContext(ctx)
   594  			if ok {
   595  				select {
   596  				case testMDChan <- md[testMDKey]:
   597  				case <-ctx.Done():
   598  					return nil, ctx.Err()
   599  				}
   600  			}
   601  			return &testpb.Empty{}, nil
   602  		},
   603  	}
   604  	if err := ss.Start(nil, grpc.WithDefaultServiceConfig(
   605  		fmt.Sprintf(`{ "loadBalancingConfig": [{"%v": {}}] }`, mdBalancerName),
   606  	)); err != nil {
   607  		t.Fatalf("Error starting endpoint server: %v", err)
   608  	}
   609  	defer ss.Stop()
   610  
   611  	// The RPC should succeed with the expected md.
   612  	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
   613  	defer cancel()
   614  	if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
   615  		t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
   616  	}
   617  	t.Log("Made an RPC which succeeded...")
   618  
   619  	// The server should receive the test metadata.
   620  	md1 := <-testMDChan
   621  	if len(md1) == 0 || md1[0] != testMDValue {
   622  		t.Fatalf("got md: %v, want %v", md1, []string{testMDValue})
   623  	}
   624  }
   625  
   626  // TestServersSwap creates two servers and verifies the client switches between
   627  // them when the name resolver reports the first and then the second.
   628  func (s) TestServersSwap(t *testing.T) {
   629  	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   630  	defer cancel()
   631  
   632  	// Initialize servers
   633  	reg := func(username string) (addr string, cleanup func()) {
   634  		lis, err := net.Listen("tcp", "localhost:0")
   635  		if err != nil {
   636  			t.Fatalf("Error while listening. Err: %v", err)
   637  		}
   638  		s := grpc.NewServer()
   639  		ts := &funcServer{
   640  			unaryCall: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
   641  				return &testpb.SimpleResponse{Username: username}, nil
   642  			},
   643  		}
   644  		testpb.RegisterTestServiceServer(s, ts)
   645  		go s.Serve(lis)
   646  		return lis.Addr().String(), s.Stop
   647  	}
   648  	const one = "1"
   649  	addr1, cleanup := reg(one)
   650  	defer cleanup()
   651  	const two = "2"
   652  	addr2, cleanup := reg(two)
   653  	defer cleanup()
   654  
   655  	// Initialize client
   656  	r := manual.NewBuilderWithScheme("whatever")
   657  	r.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: addr1}}})
   658  	cc, err := grpc.DialContext(ctx, r.Scheme()+":///", grpc.WithInsecure(), grpc.WithResolvers(r))
   659  	if err != nil {
   660  		t.Fatalf("Error creating client: %v", err)
   661  	}
   662  	defer cc.Close()
   663  	client := testpb.NewTestServiceClient(cc)
   664  
   665  	// Confirm we are connected to the first server
   666  	if res, err := client.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil || res.Username != one {
   667  		t.Fatalf("UnaryCall(_) = %v, %v; want {Username: %q}, nil", res, err, one)
   668  	}
   669  
   670  	// Update resolver to report only the second server
   671  	r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: addr2}}})
   672  
   673  	// Loop until new RPCs talk to server two.
   674  	for i := 0; i < 2000; i++ {
   675  		if res, err := client.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil {
   676  			t.Fatalf("UnaryCall(_) = _, %v; want _, nil", err)
   677  		} else if res.Username == two {
   678  			break // pass
   679  		}
   680  		time.Sleep(5 * time.Millisecond)
   681  	}
   682  }
   683  
   684  // TestEmptyAddrs verifies client behavior when a working connection is
   685  // removed.  In pick first and round-robin, both will continue using the old
   686  // connections.
   687  func (s) TestEmptyAddrs(t *testing.T) {
   688  	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   689  	defer cancel()
   690  
   691  	// Initialize server
   692  	lis, err := net.Listen("tcp", "localhost:0")
   693  	if err != nil {
   694  		t.Fatalf("Error while listening. Err: %v", err)
   695  	}
   696  	s := grpc.NewServer()
   697  	defer s.Stop()
   698  	const one = "1"
   699  	ts := &funcServer{
   700  		unaryCall: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
   701  			return &testpb.SimpleResponse{Username: one}, nil
   702  		},
   703  	}
   704  	testpb.RegisterTestServiceServer(s, ts)
   705  	go s.Serve(lis)
   706  
   707  	// Initialize pickfirst client
   708  	pfr := manual.NewBuilderWithScheme("whatever")
   709  
   710  	pfr.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String()}}})
   711  
   712  	pfcc, err := grpc.DialContext(ctx, pfr.Scheme()+":///", grpc.WithInsecure(), grpc.WithResolvers(pfr))
   713  	if err != nil {
   714  		t.Fatalf("Error creating client: %v", err)
   715  	}
   716  	defer pfcc.Close()
   717  	pfclient := testpb.NewTestServiceClient(pfcc)
   718  
   719  	// Confirm we are connected to the server
   720  	if res, err := pfclient.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil || res.Username != one {
   721  		t.Fatalf("UnaryCall(_) = %v, %v; want {Username: %q}, nil", res, err, one)
   722  	}
   723  
   724  	// Remove all addresses.
   725  	pfr.UpdateState(resolver.State{})
   726  
   727  	// Initialize roundrobin client
   728  	rrr := manual.NewBuilderWithScheme("whatever")
   729  
   730  	rrr.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String()}}})
   731  
   732  	rrcc, err := grpc.DialContext(ctx, rrr.Scheme()+":///", grpc.WithInsecure(), grpc.WithResolvers(rrr),
   733  		grpc.WithDefaultServiceConfig(fmt.Sprintf(`{ "loadBalancingConfig": [{"%v": {}}] }`, roundrobin.Name)))
   734  	if err != nil {
   735  		t.Fatalf("Error creating client: %v", err)
   736  	}
   737  	defer rrcc.Close()
   738  	rrclient := testpb.NewTestServiceClient(rrcc)
   739  
   740  	// Confirm we are connected to the server
   741  	if res, err := rrclient.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil || res.Username != one {
   742  		t.Fatalf("UnaryCall(_) = %v, %v; want {Username: %q}, nil", res, err, one)
   743  	}
   744  
   745  	// Remove all addresses.
   746  	rrr.UpdateState(resolver.State{})
   747  
   748  	// Confirm several new RPCs succeed on pick first.
   749  	for i := 0; i < 10; i++ {
   750  		if _, err := pfclient.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil {
   751  			t.Fatalf("UnaryCall(_) = _, %v; want _, nil", err)
   752  		}
   753  		time.Sleep(5 * time.Millisecond)
   754  	}
   755  
   756  	// Confirm several new RPCs succeed on round robin.
   757  	for i := 0; i < 10; i++ {
   758  		if _, err := pfclient.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil {
   759  			t.Fatalf("UnaryCall(_) = _, %v; want _, nil", err)
   760  		}
   761  		time.Sleep(5 * time.Millisecond)
   762  	}
   763  }
   764  
   765  func (s) TestWaitForReady(t *testing.T) {
   766  	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   767  	defer cancel()
   768  
   769  	// Initialize server
   770  	lis, err := net.Listen("tcp", "localhost:0")
   771  	if err != nil {
   772  		t.Fatalf("Error while listening. Err: %v", err)
   773  	}
   774  	s := grpc.NewServer()
   775  	defer s.Stop()
   776  	const one = "1"
   777  	ts := &funcServer{
   778  		unaryCall: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
   779  			return &testpb.SimpleResponse{Username: one}, nil
   780  		},
   781  	}
   782  	testpb.RegisterTestServiceServer(s, ts)
   783  	go s.Serve(lis)
   784  
   785  	// Initialize client
   786  	r := manual.NewBuilderWithScheme("whatever")
   787  
   788  	cc, err := grpc.DialContext(ctx, r.Scheme()+":///", grpc.WithInsecure(), grpc.WithResolvers(r))
   789  	if err != nil {
   790  		t.Fatalf("Error creating client: %v", err)
   791  	}
   792  	defer cc.Close()
   793  	client := testpb.NewTestServiceClient(cc)
   794  
   795  	// Report an error so non-WFR RPCs will give up early.
   796  	r.CC.ReportError(errors.New("fake resolver error"))
   797  
   798  	// Ensure the client is not connected to anything and fails non-WFR RPCs.
   799  	if res, err := client.UnaryCall(ctx, &testpb.SimpleRequest{}); status.Code(err) != codes.Unavailable {
   800  		t.Fatalf("UnaryCall(_) = %v, %v; want _, Code()=%v", res, err, codes.Unavailable)
   801  	}
   802  
   803  	errChan := make(chan error, 1)
   804  	go func() {
   805  		if res, err := client.UnaryCall(ctx, &testpb.SimpleRequest{}, grpc.WaitForReady(true)); err != nil || res.Username != one {
   806  			errChan <- fmt.Errorf("UnaryCall(_) = %v, %v; want {Username: %q}, nil", res, err, one)
   807  		}
   808  		close(errChan)
   809  	}()
   810  
   811  	select {
   812  	case err := <-errChan:
   813  		t.Errorf("unexpected receive from errChan before addresses provided")
   814  		t.Fatal(err.Error())
   815  	case <-time.After(5 * time.Millisecond):
   816  	}
   817  
   818  	// Resolve the server.  The WFR RPC should unblock and use it.
   819  	r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String()}}})
   820  
   821  	if err := <-errChan; err != nil {
   822  		t.Fatal(err.Error())
   823  	}
   824  }
   825  
   826  // authorityOverrideTransportCreds returns the configured authority value in its
   827  // Info() method.
   828  type authorityOverrideTransportCreds struct {
   829  	credentials.TransportCredentials
   830  	authorityOverride string
   831  }
   832  
   833  func (ao *authorityOverrideTransportCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
   834  	return rawConn, nil, nil
   835  }
   836  func (ao *authorityOverrideTransportCreds) Info() credentials.ProtocolInfo {
   837  	return credentials.ProtocolInfo{ServerName: ao.authorityOverride}
   838  }
   839  func (ao *authorityOverrideTransportCreds) Clone() credentials.TransportCredentials {
   840  	return &authorityOverrideTransportCreds{authorityOverride: ao.authorityOverride}
   841  }
   842  
   843  // TestAuthorityInBuildOptions tests that the Authority field in
   844  // balancer.BuildOptions is setup correctly from gRPC.
   845  func (s) TestAuthorityInBuildOptions(t *testing.T) {
   846  	const dialTarget = "test.server"
   847  
   848  	tests := []struct {
   849  		name          string
   850  		dopts         []grpc.DialOption
   851  		wantAuthority string
   852  	}{
   853  		{
   854  			name:          "authority from dial target",
   855  			dopts:         []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())},
   856  			wantAuthority: dialTarget,
   857  		},
   858  		{
   859  			name: "authority from dial option",
   860  			dopts: []grpc.DialOption{
   861  				grpc.WithTransportCredentials(insecure.NewCredentials()),
   862  				grpc.WithAuthority("authority-override"),
   863  			},
   864  			wantAuthority: "authority-override",
   865  		},
   866  		{
   867  			name:          "authority from transport creds",
   868  			dopts:         []grpc.DialOption{grpc.WithTransportCredentials(&authorityOverrideTransportCreds{authorityOverride: "authority-override-from-transport-creds"})},
   869  			wantAuthority: "authority-override-from-transport-creds",
   870  		},
   871  	}
   872  
   873  	for _, test := range tests {
   874  		t.Run(test.name, func(t *testing.T) {
   875  			authorityCh := make(chan string, 1)
   876  			bf := stub.BalancerFuncs{
   877  				UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error {
   878  					select {
   879  					case authorityCh <- bd.BuildOptions.Authority:
   880  					default:
   881  					}
   882  
   883  					addrs := ccs.ResolverState.Addresses
   884  					if len(addrs) == 0 {
   885  						return nil
   886  					}
   887  
   888  					// Only use the first address.
   889  					sc, err := bd.ClientConn.NewSubConn([]resolver.Address{addrs[0]}, balancer.NewSubConnOptions{})
   890  					if err != nil {
   891  						return err
   892  					}
   893  					sc.Connect()
   894  					return nil
   895  				},
   896  				UpdateSubConnState: func(bd *stub.BalancerData, sc balancer.SubConn, state balancer.SubConnState) {
   897  					bd.ClientConn.UpdateState(balancer.State{ConnectivityState: state.ConnectivityState, Picker: &aiPicker{result: balancer.PickResult{SubConn: sc}, err: state.ConnectionError}})
   898  				},
   899  			}
   900  			balancerName := "stub-balancer-" + test.name
   901  			stub.Register(balancerName, bf)
   902  			t.Logf("Registered balancer %s...", balancerName)
   903  
   904  			lis, err := testutils.LocalTCPListener()
   905  			if err != nil {
   906  				t.Fatal(err)
   907  			}
   908  
   909  			s := grpc.NewServer()
   910  			testpb.RegisterTestServiceServer(s, &testServer{})
   911  			go s.Serve(lis)
   912  			defer s.Stop()
   913  			t.Logf("Started gRPC server at %s...", lis.Addr().String())
   914  
   915  			r := manual.NewBuilderWithScheme("whatever")
   916  			t.Logf("Registered manual resolver with scheme %s...", r.Scheme())
   917  			r.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String()}}})
   918  
   919  			dopts := append([]grpc.DialOption{
   920  				grpc.WithResolvers(r),
   921  				grpc.WithDefaultServiceConfig(fmt.Sprintf(`{ "loadBalancingConfig": [{"%v": {}}] }`, balancerName)),
   922  			}, test.dopts...)
   923  			cc, err := grpc.Dial(r.Scheme()+":///"+dialTarget, dopts...)
   924  			if err != nil {
   925  				t.Fatal(err)
   926  			}
   927  			defer cc.Close()
   928  			tc := testpb.NewTestServiceClient(cc)
   929  			t.Log("Created a ClientConn...")
   930  
   931  			ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   932  			defer cancel()
   933  			if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
   934  				t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
   935  			}
   936  			t.Log("Made an RPC which succeeded...")
   937  
   938  			select {
   939  			case <-ctx.Done():
   940  				t.Fatal("timeout when waiting for Authority in balancer.BuildOptions")
   941  			case gotAuthority := <-authorityCh:
   942  				if gotAuthority != test.wantAuthority {
   943  					t.Fatalf("Authority in balancer.BuildOptions is %s, want %s", gotAuthority, test.wantAuthority)
   944  				}
   945  			}
   946  		})
   947  	}
   948  }