google.golang.org/grpc@v1.74.2/xds/internal/balancer/clusterimpl/balancer_test.go (about)

     1  /*
     2   *
     3   * Copyright 2020 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package clusterimpl
    20  
    21  import (
    22  	"context"
    23  	"encoding/json"
    24  	"errors"
    25  	"fmt"
    26  	"sort"
    27  	"strings"
    28  	"sync"
    29  	"testing"
    30  	"time"
    31  
    32  	"google.golang.org/grpc/balancer"
    33  	"google.golang.org/grpc/balancer/base"
    34  	"google.golang.org/grpc/balancer/roundrobin"
    35  	"google.golang.org/grpc/connectivity"
    36  	"google.golang.org/grpc/internal"
    37  	"google.golang.org/grpc/internal/balancer/stub"
    38  	"google.golang.org/grpc/internal/grpctest"
    39  	internalserviceconfig "google.golang.org/grpc/internal/serviceconfig"
    40  	"google.golang.org/grpc/internal/testutils"
    41  	"google.golang.org/grpc/internal/xds"
    42  	"google.golang.org/grpc/internal/xds/bootstrap"
    43  	"google.golang.org/grpc/resolver"
    44  	"google.golang.org/grpc/serviceconfig"
    45  	xdsinternal "google.golang.org/grpc/xds/internal"
    46  	"google.golang.org/grpc/xds/internal/clients"
    47  	"google.golang.org/grpc/xds/internal/testutils/fakeclient"
    48  	"google.golang.org/grpc/xds/internal/xdsclient"
    49  
    50  	v3orcapb "github.com/cncf/xds/go/xds/data/orca/v3"
    51  	"github.com/google/go-cmp/cmp"
    52  	"github.com/google/go-cmp/cmp/cmpopts"
    53  )
    54  
    55  const (
    56  	defaultTestTimeout      = 5 * time.Second
    57  	defaultShortTestTimeout = 100 * time.Microsecond
    58  
    59  	testClusterName = "test-cluster"
    60  	testServiceName = "test-eds-service"
    61  
    62  	testNamedMetricsKey1 = "test-named1"
    63  	testNamedMetricsKey2 = "test-named2"
    64  )
    65  
    66  var (
    67  	testBackendEndpoints = []resolver.Endpoint{{Addresses: []resolver.Address{{Addr: "1.1.1.1:1"}}}}
    68  	cmpOpts              = cmp.Options{cmpopts.EquateEmpty(), cmp.AllowUnexported(loadData{}, localityData{}, requestData{}, serverLoadData{}), sortDataSlice}
    69  	toleranceCmpOpt      = cmp.Options{cmpopts.EquateApprox(0, 1e-5), cmp.AllowUnexported(loadData{}, localityData{}, requestData{}, serverLoadData{})}
    70  )
    71  
    72  type s struct {
    73  	grpctest.Tester
    74  }
    75  
    76  func Test(t *testing.T) {
    77  	grpctest.RunSubTests(t, s{})
    78  }
    79  
    80  // testLoadReporter records load data pertaining to a single cluster.
    81  //
    82  // It implements loadReporter interface for the picker. Tests can use it to
    83  // override the loadStore in the picker to verify load reporting.
    84  type testLoadReporter struct {
    85  	cluster, service string
    86  
    87  	mu               sync.Mutex
    88  	drops            map[string]uint64
    89  	localityRPCCount map[clients.Locality]*rpcCountData
    90  }
    91  
    92  // CallStarted records a call started for the clients.Locality.
    93  func (lr *testLoadReporter) CallStarted(locality clients.Locality) {
    94  	lr.mu.Lock()
    95  	defer lr.mu.Unlock()
    96  	if _, ok := lr.localityRPCCount[locality]; !ok {
    97  		lr.localityRPCCount[locality] = &rpcCountData{}
    98  	}
    99  	lr.localityRPCCount[locality].inProgress++
   100  	lr.localityRPCCount[locality].issued++
   101  }
   102  
   103  // CallFinished records a call finished for the clients.Locality.
   104  func (lr *testLoadReporter) CallFinished(locality clients.Locality, err error) {
   105  	lr.mu.Lock()
   106  	defer lr.mu.Unlock()
   107  	if lr.localityRPCCount == nil {
   108  		return
   109  	}
   110  	lrc := lr.localityRPCCount[locality]
   111  	lrc.inProgress--
   112  	if err == nil {
   113  		lrc.succeeded++
   114  	} else {
   115  		lrc.errored++
   116  	}
   117  }
   118  
   119  // CallServerLoad records a server load for the clients.Locality.
   120  func (lr *testLoadReporter) CallServerLoad(locality clients.Locality, name string, val float64) {
   121  	lr.mu.Lock()
   122  	defer lr.mu.Unlock()
   123  	if lr.localityRPCCount == nil {
   124  		return
   125  	}
   126  	lrc, ok := lr.localityRPCCount[locality]
   127  	if !ok {
   128  		return
   129  	}
   130  	if lrc.serverLoads == nil {
   131  		lrc.serverLoads = make(map[string]*rpcLoadData)
   132  	}
   133  	if _, ok := lrc.serverLoads[name]; !ok {
   134  		lrc.serverLoads[name] = &rpcLoadData{}
   135  	}
   136  	rld := lrc.serverLoads[name]
   137  	rld.add(val)
   138  }
   139  
   140  // CallDropped records a call dropped for the category.
   141  func (lr *testLoadReporter) CallDropped(category string) {
   142  	lr.mu.Lock()
   143  	defer lr.mu.Unlock()
   144  	lr.drops[category]++
   145  }
   146  
   147  // stats returns and resets all loads reported for a cluster and service,
   148  // except inProgress rpc counts.
   149  //
   150  // It returns nil if the store doesn't contain any (new) data.
   151  func (lr *testLoadReporter) stats() *loadData {
   152  	lr.mu.Lock()
   153  	defer lr.mu.Unlock()
   154  
   155  	sd := newLoadData(lr.cluster, lr.service)
   156  	for category, val := range lr.drops {
   157  		if val == 0 {
   158  			continue
   159  		}
   160  		if category != "" {
   161  			// Skip drops without category. They are counted in total_drops, but
   162  			// not in per category. One example is drops by circuit breaking.
   163  			sd.drops[category] = val
   164  		}
   165  		sd.totalDrops += val
   166  		lr.drops[category] = 0 // clear drops for next report
   167  	}
   168  	for locality, countData := range lr.localityRPCCount {
   169  		if countData.succeeded == 0 && countData.errored == 0 && countData.inProgress == 0 && countData.issued == 0 {
   170  			continue
   171  		}
   172  
   173  		ld := localityData{
   174  			requestStats: requestData{
   175  				succeeded:  countData.succeeded,
   176  				errored:    countData.errored,
   177  				inProgress: countData.inProgress,
   178  				issued:     countData.issued,
   179  			},
   180  			loadStats: make(map[string]serverLoadData),
   181  		}
   182  		// clear localityRPCCount for next report
   183  		countData.succeeded = 0
   184  		countData.errored = 0
   185  		countData.inProgress = 0
   186  		countData.issued = 0
   187  		for key, rld := range countData.serverLoads {
   188  			s, c := rld.loadAndClear() // get and clear serverLoads for next report
   189  			if c == 0 {
   190  				continue
   191  			}
   192  			ld.loadStats[key] = serverLoadData{sum: s, count: c}
   193  		}
   194  		sd.localityStats[locality] = ld
   195  	}
   196  	if sd.totalDrops == 0 && len(sd.drops) == 0 && len(sd.localityStats) == 0 {
   197  		return nil
   198  	}
   199  	return sd
   200  }
   201  
   202  // loadData contains all load data reported to the LoadStore since the most recent
   203  // call to stats().
   204  type loadData struct {
   205  	// cluster is the name of the cluster this data is for.
   206  	cluster string
   207  	// service is the name of the EDS service this data is for.
   208  	service string
   209  	// totalDrops is the total number of dropped requests.
   210  	totalDrops uint64
   211  	// drops is the number of dropped requests per category.
   212  	drops map[string]uint64
   213  	// localityStats contains load reports per locality.
   214  	localityStats map[clients.Locality]localityData
   215  }
   216  
   217  // localityData contains load data for a single locality.
   218  type localityData struct {
   219  	// requestStats contains counts of requests made to the locality.
   220  	requestStats requestData
   221  	// loadStats contains server load data for requests made to the locality,
   222  	// indexed by the load type.
   223  	loadStats map[string]serverLoadData
   224  }
   225  
   226  // requestData contains request counts.
   227  type requestData struct {
   228  	// succeeded is the number of succeeded requests.
   229  	succeeded uint64
   230  	// errored is the number of requests which ran into errors.
   231  	errored uint64
   232  	// inProgress is the number of requests in flight.
   233  	inProgress uint64
   234  	// issued is the total number requests that were sent.
   235  	issued uint64
   236  }
   237  
   238  // serverLoadData contains server load data.
   239  type serverLoadData struct {
   240  	// count is the number of load reports.
   241  	count uint64
   242  	// sum is the total value of all load reports.
   243  	sum float64
   244  }
   245  
   246  func newLoadData(cluster, service string) *loadData {
   247  	return &loadData{
   248  		cluster:       cluster,
   249  		service:       service,
   250  		drops:         make(map[string]uint64),
   251  		localityStats: make(map[clients.Locality]localityData),
   252  	}
   253  }
   254  
   255  type rpcCountData struct {
   256  	succeeded   uint64
   257  	errored     uint64
   258  	inProgress  uint64
   259  	issued      uint64
   260  	serverLoads map[string]*rpcLoadData
   261  }
   262  
   263  type rpcLoadData struct {
   264  	sum   float64
   265  	count uint64
   266  }
   267  
   268  func (rld *rpcLoadData) add(v float64) {
   269  	rld.sum += v
   270  	rld.count++
   271  }
   272  
   273  func (rld *rpcLoadData) loadAndClear() (s float64, c uint64) {
   274  	s, rld.sum = rld.sum, 0
   275  	c, rld.count = rld.count, 0
   276  	return s, c
   277  }
   278  
   279  func init() {
   280  	NewRandomWRR = testutils.NewTestWRR
   281  }
   282  
   283  var sortDataSlice = cmp.Transformer("SortDataSlice", func(in []*loadData) []*loadData {
   284  	out := append([]*loadData(nil), in...) // Copy input to avoid mutating it
   285  	sort.Slice(out,
   286  		func(i, j int) bool {
   287  			if out[i].cluster < out[j].cluster {
   288  				return true
   289  			}
   290  			if out[i].cluster == out[j].cluster {
   291  				return out[i].service < out[j].service
   292  			}
   293  			return false
   294  		},
   295  	)
   296  	return out
   297  })
   298  
   299  func verifyLoadStoreData(wantStoreData, gotStoreData *loadData) error {
   300  	if diff := cmp.Diff(wantStoreData, gotStoreData, cmpOpts); diff != "" {
   301  		return fmt.Errorf("store.stats() returned unexpected diff (-want +got):\n%s", diff)
   302  	}
   303  	return nil
   304  }
   305  
   306  // TestDropByCategory verifies that the balancer correctly drops the picks, and
   307  // that the drops are reported.
   308  func (s) TestDropByCategory(t *testing.T) {
   309  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   310  	defer cancel()
   311  
   312  	defer xdsclient.ClearCounterForTesting(testClusterName, testServiceName)
   313  	xdsC := fakeclient.NewClient()
   314  
   315  	builder := balancer.Get(Name)
   316  	cc := testutils.NewBalancerClientConn(t)
   317  	b := builder.Build(cc, balancer.BuildOptions{})
   318  	defer b.Close()
   319  
   320  	const (
   321  		dropReason      = "test-dropping-category"
   322  		dropNumerator   = 1
   323  		dropDenominator = 2
   324  	)
   325  	testLRSServerConfig, err := bootstrap.ServerConfigForTesting(bootstrap.ServerConfigTestingOptions{
   326  		URI:          "trafficdirector.googleapis.com:443",
   327  		ChannelCreds: []bootstrap.ChannelCreds{{Type: "google_default"}},
   328  	})
   329  	if err != nil {
   330  		t.Fatalf("Failed to create LRS server config for testing: %v", err)
   331  	}
   332  	if err := b.UpdateClientConnState(balancer.ClientConnState{
   333  		ResolverState: xdsclient.SetClient(resolver.State{Endpoints: testBackendEndpoints}, xdsC),
   334  		BalancerConfig: &LBConfig{
   335  			Cluster:             testClusterName,
   336  			EDSServiceName:      testServiceName,
   337  			LoadReportingServer: testLRSServerConfig,
   338  			DropCategories: []DropConfig{{
   339  				Category:           dropReason,
   340  				RequestsPerMillion: million * dropNumerator / dropDenominator,
   341  			}},
   342  			ChildPolicy: &internalserviceconfig.BalancerConfig{
   343  				Name: roundrobin.Name,
   344  			},
   345  		},
   346  	}); err != nil {
   347  		t.Fatalf("unexpected error from UpdateClientConnState: %v", err)
   348  	}
   349  
   350  	got, err := xdsC.WaitForReportLoad(ctx)
   351  	if err != nil {
   352  		t.Fatalf("xdsClient.ReportLoad failed with error: %v", err)
   353  	}
   354  	if got.Server != testLRSServerConfig {
   355  		t.Fatalf("xdsClient.ReportLoad called with {%q}: want {%q}", got.Server, testLRSServerConfig)
   356  	}
   357  
   358  	sc1 := <-cc.NewSubConnCh
   359  	sc1.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Connecting})
   360  	// This should get the connecting picker.
   361  	if err := cc.WaitForPickerWithErr(ctx, balancer.ErrNoSubConnAvailable); err != nil {
   362  		t.Fatal(err.Error())
   363  	}
   364  
   365  	sc1.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Ready})
   366  	// Test pick with one backend.
   367  
   368  	testClusterLoadReporter := &testLoadReporter{cluster: testClusterName, service: testServiceName, drops: make(map[string]uint64), localityRPCCount: make(map[clients.Locality]*rpcCountData)}
   369  
   370  	const rpcCount = 24
   371  	if err := cc.WaitForPicker(ctx, func(p balancer.Picker) error {
   372  		// Override the loadStore in the picker with testClusterLoadReporter.
   373  		picker := p.(*picker)
   374  		originalLoadStore := picker.loadStore
   375  		picker.loadStore = testClusterLoadReporter
   376  		defer func() { picker.loadStore = originalLoadStore }()
   377  
   378  		for i := 0; i < rpcCount; i++ {
   379  			gotSCSt, err := p.Pick(balancer.PickInfo{})
   380  			// Even RPCs are dropped.
   381  			if i%2 == 0 {
   382  				if err == nil || !strings.Contains(err.Error(), "dropped") {
   383  					return fmt.Errorf("pick.Pick, got %v, %v, want error RPC dropped", gotSCSt, err)
   384  				}
   385  				continue
   386  			}
   387  			if err != nil || gotSCSt.SubConn != sc1 {
   388  				return fmt.Errorf("picker.Pick, got %v, %v, want SubConn=%v", gotSCSt, err, sc1)
   389  			}
   390  			if gotSCSt.Done == nil {
   391  				continue
   392  			}
   393  			// Fail 1/4th of the requests that are not dropped.
   394  			if i%8 == 1 {
   395  				gotSCSt.Done(balancer.DoneInfo{Err: fmt.Errorf("test error")})
   396  			} else {
   397  				gotSCSt.Done(balancer.DoneInfo{})
   398  			}
   399  		}
   400  		return nil
   401  	}); err != nil {
   402  		t.Fatal(err.Error())
   403  	}
   404  
   405  	// Dump load data from the store and compare with expected counts.
   406  	const dropCount = rpcCount * dropNumerator / dropDenominator
   407  	wantStatsData0 := &loadData{
   408  		cluster:    testClusterName,
   409  		service:    testServiceName,
   410  		totalDrops: dropCount,
   411  		drops:      map[string]uint64{dropReason: dropCount},
   412  		localityStats: map[clients.Locality]localityData{
   413  			{}: {requestStats: requestData{
   414  				succeeded: (rpcCount - dropCount) * 3 / 4,
   415  				errored:   (rpcCount - dropCount) / 4,
   416  				issued:    rpcCount - dropCount,
   417  			}},
   418  		},
   419  	}
   420  
   421  	gotStatsData0 := testClusterLoadReporter.stats()
   422  	if err := verifyLoadStoreData(wantStatsData0, gotStatsData0); err != nil {
   423  		t.Fatal(err)
   424  	}
   425  
   426  	// Send an update with new drop configs.
   427  	const (
   428  		dropReason2      = "test-dropping-category-2"
   429  		dropNumerator2   = 1
   430  		dropDenominator2 = 4
   431  	)
   432  	if err := b.UpdateClientConnState(balancer.ClientConnState{
   433  		ResolverState: xdsclient.SetClient(resolver.State{Endpoints: testBackendEndpoints}, xdsC),
   434  		BalancerConfig: &LBConfig{
   435  			Cluster:             testClusterName,
   436  			EDSServiceName:      testServiceName,
   437  			LoadReportingServer: testLRSServerConfig,
   438  			DropCategories: []DropConfig{{
   439  				Category:           dropReason2,
   440  				RequestsPerMillion: million * dropNumerator2 / dropDenominator2,
   441  			}},
   442  			ChildPolicy: &internalserviceconfig.BalancerConfig{
   443  				Name: roundrobin.Name,
   444  			},
   445  		},
   446  	}); err != nil {
   447  		t.Fatalf("unexpected error from UpdateClientConnState: %v", err)
   448  	}
   449  
   450  	if err := cc.WaitForPicker(ctx, func(p balancer.Picker) error {
   451  		// Override the loadStore in the picker with testClusterLoadReporter.
   452  		picker := p.(*picker)
   453  		originalLoadStore := picker.loadStore
   454  		picker.loadStore = testClusterLoadReporter
   455  		defer func() { picker.loadStore = originalLoadStore }()
   456  		for i := 0; i < rpcCount; i++ {
   457  			gotSCSt, err := p.Pick(balancer.PickInfo{})
   458  			// Even RPCs are dropped.
   459  			if i%4 == 0 {
   460  				if err == nil || !strings.Contains(err.Error(), "dropped") {
   461  					return fmt.Errorf("pick.Pick, got %v, %v, want error RPC dropped", gotSCSt, err)
   462  				}
   463  				continue
   464  			}
   465  			if err != nil || gotSCSt.SubConn != sc1 {
   466  				return fmt.Errorf("picker.Pick, got %v, %v, want SubConn=%v", gotSCSt, err, sc1)
   467  			}
   468  			if gotSCSt.Done != nil {
   469  				gotSCSt.Done(balancer.DoneInfo{})
   470  			}
   471  		}
   472  		return nil
   473  	}); err != nil {
   474  		t.Fatal(err.Error())
   475  	}
   476  
   477  	const dropCount2 = rpcCount * dropNumerator2 / dropDenominator2
   478  	wantStatsData1 := &loadData{
   479  		cluster:    testClusterName,
   480  		service:    testServiceName,
   481  		totalDrops: dropCount2,
   482  		drops:      map[string]uint64{dropReason2: dropCount2},
   483  		localityStats: map[clients.Locality]localityData{
   484  			{}: {requestStats: requestData{
   485  				succeeded: rpcCount - dropCount2,
   486  				issued:    rpcCount - dropCount2,
   487  			}},
   488  		},
   489  	}
   490  
   491  	gotStatsData1 := testClusterLoadReporter.stats()
   492  	if err := verifyLoadStoreData(wantStatsData1, gotStatsData1); err != nil {
   493  		t.Fatal(err)
   494  	}
   495  }
   496  
   497  // TestDropCircuitBreaking verifies that the balancer correctly drops the picks
   498  // due to circuit breaking, and that the drops are reported.
   499  func (s) TestDropCircuitBreaking(t *testing.T) {
   500  	defer xdsclient.ClearCounterForTesting(testClusterName, testServiceName)
   501  	xdsC := fakeclient.NewClient()
   502  
   503  	builder := balancer.Get(Name)
   504  	cc := testutils.NewBalancerClientConn(t)
   505  	b := builder.Build(cc, balancer.BuildOptions{})
   506  	defer b.Close()
   507  
   508  	var maxRequest uint32 = 50
   509  	testLRSServerConfig, err := bootstrap.ServerConfigForTesting(bootstrap.ServerConfigTestingOptions{
   510  		URI:          "trafficdirector.googleapis.com:443",
   511  		ChannelCreds: []bootstrap.ChannelCreds{{Type: "google_default"}},
   512  	})
   513  	if err != nil {
   514  		t.Fatalf("Failed to create LRS server config for testing: %v", err)
   515  	}
   516  	if err := b.UpdateClientConnState(balancer.ClientConnState{
   517  		ResolverState: xdsclient.SetClient(resolver.State{Endpoints: testBackendEndpoints}, xdsC),
   518  		BalancerConfig: &LBConfig{
   519  			Cluster:               testClusterName,
   520  			EDSServiceName:        testServiceName,
   521  			LoadReportingServer:   testLRSServerConfig,
   522  			MaxConcurrentRequests: &maxRequest,
   523  			ChildPolicy: &internalserviceconfig.BalancerConfig{
   524  				Name: roundrobin.Name,
   525  			},
   526  		},
   527  	}); err != nil {
   528  		t.Fatalf("unexpected error from UpdateClientConnState: %v", err)
   529  	}
   530  
   531  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   532  	defer cancel()
   533  
   534  	got, err := xdsC.WaitForReportLoad(ctx)
   535  	if err != nil {
   536  		t.Fatalf("xdsClient.ReportLoad failed with error: %v", err)
   537  	}
   538  	if got.Server != testLRSServerConfig {
   539  		t.Fatalf("xdsClient.ReportLoad called with {%q}: want {%q}", got.Server, testLRSServerConfig)
   540  	}
   541  
   542  	sc1 := <-cc.NewSubConnCh
   543  	sc1.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Connecting})
   544  	// This should get the connecting picker.
   545  	if err := cc.WaitForPickerWithErr(ctx, balancer.ErrNoSubConnAvailable); err != nil {
   546  		t.Fatal(err.Error())
   547  	}
   548  
   549  	sc1.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Ready})
   550  	// Test pick with one backend.
   551  	testClusterLoadReporter := &testLoadReporter{cluster: testClusterName, service: testServiceName, drops: make(map[string]uint64), localityRPCCount: make(map[clients.Locality]*rpcCountData)}
   552  	const rpcCount = 100
   553  	if err := cc.WaitForPicker(ctx, func(p balancer.Picker) error {
   554  		dones := []func(){}
   555  		// Override the loadStore in the picker with testClusterLoadReporter.
   556  		picker := p.(*picker)
   557  		originalLoadStore := picker.loadStore
   558  		picker.loadStore = testClusterLoadReporter
   559  		defer func() { picker.loadStore = originalLoadStore }()
   560  
   561  		for i := 0; i < rpcCount; i++ {
   562  			gotSCSt, err := p.Pick(balancer.PickInfo{})
   563  			if i < 50 && err != nil {
   564  				return fmt.Errorf("The first 50%% picks should be non-drops, got error %v", err)
   565  			} else if i > 50 && err == nil {
   566  				return fmt.Errorf("The second 50%% picks should be drops, got error <nil>")
   567  			}
   568  			dones = append(dones, func() {
   569  				if gotSCSt.Done != nil {
   570  					gotSCSt.Done(balancer.DoneInfo{})
   571  				}
   572  			})
   573  		}
   574  		for _, done := range dones {
   575  			done()
   576  		}
   577  
   578  		dones = []func(){}
   579  		// Pick without drops.
   580  		for i := 0; i < 50; i++ {
   581  			gotSCSt, err := p.Pick(balancer.PickInfo{})
   582  			if err != nil {
   583  				t.Errorf("The third 50%% picks should be non-drops, got error %v", err)
   584  			}
   585  			dones = append(dones, func() {
   586  				if gotSCSt.Done != nil {
   587  					// Fail these requests to test error counts in the load
   588  					// report.
   589  					gotSCSt.Done(balancer.DoneInfo{Err: fmt.Errorf("test error")})
   590  				}
   591  			})
   592  		}
   593  		for _, done := range dones {
   594  			done()
   595  		}
   596  
   597  		return nil
   598  	}); err != nil {
   599  		t.Fatal(err.Error())
   600  	}
   601  
   602  	// Dump load data from the store and compare with expected counts.
   603  	wantStatsData0 := &loadData{
   604  		cluster:    testClusterName,
   605  		service:    testServiceName,
   606  		totalDrops: uint64(maxRequest),
   607  		localityStats: map[clients.Locality]localityData{
   608  			{}: {requestStats: requestData{
   609  				succeeded: uint64(rpcCount - maxRequest),
   610  				errored:   50,
   611  				issued:    uint64(rpcCount - maxRequest + 50),
   612  			}},
   613  		},
   614  	}
   615  
   616  	gotStatsData0 := testClusterLoadReporter.stats()
   617  	if err := verifyLoadStoreData(wantStatsData0, gotStatsData0); err != nil {
   618  		t.Fatal(err)
   619  	}
   620  }
   621  
   622  // TestPickerUpdateAfterClose covers the case where a child policy sends a
   623  // picker update after the cluster_impl policy is closed. Because picker updates
   624  // are handled in the run() goroutine, which exits before Close() returns, we
   625  // expect the above picker update to be dropped.
   626  func (s) TestPickerUpdateAfterClose(t *testing.T) {
   627  	defer xdsclient.ClearCounterForTesting(testClusterName, testServiceName)
   628  	xdsC := fakeclient.NewClient()
   629  
   630  	builder := balancer.Get(Name)
   631  	cc := testutils.NewBalancerClientConn(t)
   632  	b := builder.Build(cc, balancer.BuildOptions{})
   633  
   634  	// Create a stub balancer which waits for the cluster_impl policy to be
   635  	// closed before sending a picker update (upon receipt of a subConn state
   636  	// change).
   637  	closeCh := make(chan struct{})
   638  	const childPolicyName = "stubBalancer-TestPickerUpdateAfterClose"
   639  	stub.Register(childPolicyName, stub.BalancerFuncs{
   640  		UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error {
   641  			// Create a subConn which will be used later on to test the race
   642  			// between StateListener() and Close().
   643  			sc, err := bd.ClientConn.NewSubConn(ccs.ResolverState.Addresses, balancer.NewSubConnOptions{
   644  				StateListener: func(balancer.SubConnState) {
   645  					go func() {
   646  						// Wait for Close() to be called on the parent policy before
   647  						// sending the picker update.
   648  						<-closeCh
   649  						bd.ClientConn.UpdateState(balancer.State{
   650  							Picker: base.NewErrPicker(errors.New("dummy error picker")),
   651  						})
   652  					}()
   653  				},
   654  			})
   655  			if err != nil {
   656  				return err
   657  			}
   658  			sc.Connect()
   659  			return nil
   660  		},
   661  	})
   662  
   663  	var maxRequest uint32 = 50
   664  	if err := b.UpdateClientConnState(balancer.ClientConnState{
   665  		ResolverState: xdsclient.SetClient(resolver.State{Endpoints: testBackendEndpoints}, xdsC),
   666  		BalancerConfig: &LBConfig{
   667  			Cluster:               testClusterName,
   668  			EDSServiceName:        testServiceName,
   669  			MaxConcurrentRequests: &maxRequest,
   670  			ChildPolicy: &internalserviceconfig.BalancerConfig{
   671  				Name: childPolicyName,
   672  			},
   673  		},
   674  	}); err != nil {
   675  		b.Close()
   676  		t.Fatalf("unexpected error from UpdateClientConnState: %v", err)
   677  	}
   678  
   679  	// Send a subConn state change to trigger a picker update. The stub balancer
   680  	// that we use as the child policy will not send a picker update until the
   681  	// parent policy is closed.
   682  	sc1 := <-cc.NewSubConnCh
   683  	sc1.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Connecting})
   684  	b.Close()
   685  	close(closeCh)
   686  
   687  	select {
   688  	case <-cc.NewPickerCh:
   689  		t.Fatalf("unexpected picker update after balancer is closed")
   690  	case <-time.After(defaultShortTestTimeout):
   691  	}
   692  }
   693  
   694  // TestClusterNameInAddressAttributes covers the case that cluster name is
   695  // attached to the subconn address attributes.
   696  func (s) TestClusterNameInAddressAttributes(t *testing.T) {
   697  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   698  	defer cancel()
   699  
   700  	defer xdsclient.ClearCounterForTesting(testClusterName, testServiceName)
   701  	xdsC := fakeclient.NewClient()
   702  
   703  	builder := balancer.Get(Name)
   704  	cc := testutils.NewBalancerClientConn(t)
   705  	b := builder.Build(cc, balancer.BuildOptions{})
   706  	defer b.Close()
   707  
   708  	if err := b.UpdateClientConnState(balancer.ClientConnState{
   709  		ResolverState: xdsclient.SetClient(resolver.State{Endpoints: testBackendEndpoints}, xdsC),
   710  		BalancerConfig: &LBConfig{
   711  			Cluster:        testClusterName,
   712  			EDSServiceName: testServiceName,
   713  			ChildPolicy: &internalserviceconfig.BalancerConfig{
   714  				Name: roundrobin.Name,
   715  			},
   716  		},
   717  	}); err != nil {
   718  		t.Fatalf("unexpected error from UpdateClientConnState: %v", err)
   719  	}
   720  
   721  	sc1 := <-cc.NewSubConnCh
   722  	sc1.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Connecting})
   723  	// This should get the connecting picker.
   724  	if err := cc.WaitForPickerWithErr(ctx, balancer.ErrNoSubConnAvailable); err != nil {
   725  		t.Fatal(err.Error())
   726  	}
   727  
   728  	addrs1 := <-cc.NewSubConnAddrsCh
   729  	if got, want := addrs1[0].Addr, testBackendEndpoints[0].Addresses[0].Addr; got != want {
   730  		t.Fatalf("sc is created with addr %v, want %v", got, want)
   731  	}
   732  	cn, ok := xds.GetXDSHandshakeClusterName(addrs1[0].Attributes)
   733  	if !ok || cn != testClusterName {
   734  		t.Fatalf("sc is created with addr with cluster name %v, %v, want cluster name %v", cn, ok, testClusterName)
   735  	}
   736  
   737  	sc1.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Ready})
   738  	// Test pick with one backend.
   739  	if err := cc.WaitForRoundRobinPicker(ctx, sc1); err != nil {
   740  		t.Fatal(err.Error())
   741  	}
   742  
   743  	const testClusterName2 = "test-cluster-2"
   744  	var addr2 = resolver.Address{Addr: "2.2.2.2"}
   745  	if err := b.UpdateClientConnState(balancer.ClientConnState{
   746  		ResolverState: xdsclient.SetClient(resolver.State{Endpoints: []resolver.Endpoint{{Addresses: []resolver.Address{addr2}}}}, xdsC),
   747  		BalancerConfig: &LBConfig{
   748  			Cluster:        testClusterName2,
   749  			EDSServiceName: testServiceName,
   750  			ChildPolicy: &internalserviceconfig.BalancerConfig{
   751  				Name: roundrobin.Name,
   752  			},
   753  		},
   754  	}); err != nil {
   755  		t.Fatalf("unexpected error from UpdateClientConnState: %v", err)
   756  	}
   757  
   758  	addrs2 := <-cc.NewSubConnAddrsCh
   759  	if got, want := addrs2[0].Addr, addr2.Addr; got != want {
   760  		t.Fatalf("sc is created with addr %v, want %v", got, want)
   761  	}
   762  	// New addresses should have the new cluster name.
   763  	cn2, ok := xds.GetXDSHandshakeClusterName(addrs2[0].Attributes)
   764  	if !ok || cn2 != testClusterName2 {
   765  		t.Fatalf("sc is created with addr with cluster name %v, %v, want cluster name %v", cn2, ok, testClusterName2)
   766  	}
   767  }
   768  
   769  // TestReResolution verifies that when a SubConn turns transient failure,
   770  // re-resolution is triggered.
   771  func (s) TestReResolution(t *testing.T) {
   772  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   773  	defer cancel()
   774  
   775  	defer xdsclient.ClearCounterForTesting(testClusterName, testServiceName)
   776  	xdsC := fakeclient.NewClient()
   777  
   778  	builder := balancer.Get(Name)
   779  	cc := testutils.NewBalancerClientConn(t)
   780  	b := builder.Build(cc, balancer.BuildOptions{})
   781  	defer b.Close()
   782  
   783  	if err := b.UpdateClientConnState(balancer.ClientConnState{
   784  		ResolverState: xdsclient.SetClient(resolver.State{Endpoints: testBackendEndpoints}, xdsC),
   785  		BalancerConfig: &LBConfig{
   786  			Cluster:        testClusterName,
   787  			EDSServiceName: testServiceName,
   788  			ChildPolicy: &internalserviceconfig.BalancerConfig{
   789  				Name: roundrobin.Name,
   790  			},
   791  		},
   792  	}); err != nil {
   793  		t.Fatalf("unexpected error from UpdateClientConnState: %v", err)
   794  	}
   795  
   796  	sc1 := <-cc.NewSubConnCh
   797  	sc1.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Connecting})
   798  	// This should get the connecting picker.
   799  	if err := cc.WaitForPickerWithErr(ctx, balancer.ErrNoSubConnAvailable); err != nil {
   800  		t.Fatal(err.Error())
   801  	}
   802  
   803  	sc1.UpdateState(balancer.SubConnState{
   804  		ConnectivityState: connectivity.TransientFailure,
   805  		ConnectionError:   errors.New("test error"),
   806  	})
   807  	// This should get the transient failure picker.
   808  	if err := cc.WaitForErrPicker(ctx); err != nil {
   809  		t.Fatal(err.Error())
   810  	}
   811  
   812  	// The transient failure should trigger a re-resolution.
   813  	select {
   814  	case <-cc.ResolveNowCh:
   815  	case <-time.After(defaultTestTimeout):
   816  		t.Fatalf("timeout waiting for ResolveNow()")
   817  	}
   818  
   819  	sc1.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Ready})
   820  	// Test pick with one backend.
   821  	if err := cc.WaitForRoundRobinPicker(ctx, sc1); err != nil {
   822  		t.Fatal(err.Error())
   823  	}
   824  
   825  	sc1.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.TransientFailure})
   826  	// This should get the transient failure picker.
   827  	if err := cc.WaitForErrPicker(ctx); err != nil {
   828  		t.Fatal(err.Error())
   829  	}
   830  
   831  	// The transient failure should trigger a re-resolution.
   832  	select {
   833  	case <-cc.ResolveNowCh:
   834  	case <-time.After(defaultTestTimeout):
   835  		t.Fatalf("timeout waiting for ResolveNow()")
   836  	}
   837  }
   838  
   839  func (s) TestLoadReporting(t *testing.T) {
   840  	var testLocality = clients.Locality{
   841  		Region:  "test-region",
   842  		Zone:    "test-zone",
   843  		SubZone: "test-sub-zone",
   844  	}
   845  
   846  	xdsC := fakeclient.NewClient()
   847  
   848  	builder := balancer.Get(Name)
   849  	cc := testutils.NewBalancerClientConn(t)
   850  	b := builder.Build(cc, balancer.BuildOptions{})
   851  	defer b.Close()
   852  
   853  	endpoints := make([]resolver.Endpoint, len(testBackendEndpoints))
   854  	for i, e := range testBackendEndpoints {
   855  		endpoints[i] = xdsinternal.SetLocalityIDInEndpoint(e, testLocality)
   856  		for j, a := range e.Addresses {
   857  			endpoints[i].Addresses[j] = xdsinternal.SetLocalityID(a, testLocality)
   858  		}
   859  	}
   860  	testLRSServerConfig, err := bootstrap.ServerConfigForTesting(bootstrap.ServerConfigTestingOptions{
   861  		URI:          "trafficdirector.googleapis.com:443",
   862  		ChannelCreds: []bootstrap.ChannelCreds{{Type: "google_default"}},
   863  	})
   864  	if err != nil {
   865  		t.Fatalf("Failed to create LRS server config for testing: %v", err)
   866  	}
   867  	if err := b.UpdateClientConnState(balancer.ClientConnState{
   868  		ResolverState: xdsclient.SetClient(resolver.State{Endpoints: endpoints}, xdsC),
   869  		BalancerConfig: &LBConfig{
   870  			Cluster:             testClusterName,
   871  			EDSServiceName:      testServiceName,
   872  			LoadReportingServer: testLRSServerConfig,
   873  			// Locality:                testLocality,
   874  			ChildPolicy: &internalserviceconfig.BalancerConfig{
   875  				Name: roundrobin.Name,
   876  			},
   877  		},
   878  	}); err != nil {
   879  		t.Fatalf("unexpected error from UpdateClientConnState: %v", err)
   880  	}
   881  
   882  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   883  	defer cancel()
   884  
   885  	got, err := xdsC.WaitForReportLoad(ctx)
   886  	if err != nil {
   887  		t.Fatalf("xdsClient.ReportLoad failed with error: %v", err)
   888  	}
   889  	if got.Server != testLRSServerConfig {
   890  		t.Fatalf("xdsClient.ReportLoad called with {%q}: want {%q}", got.Server, testLRSServerConfig)
   891  	}
   892  
   893  	sc1 := <-cc.NewSubConnCh
   894  	sc1.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Connecting})
   895  	// This should get the connecting picker.
   896  	if err := cc.WaitForPickerWithErr(ctx, balancer.ErrNoSubConnAvailable); err != nil {
   897  		t.Fatal(err.Error())
   898  	}
   899  
   900  	scs := balancer.SubConnState{ConnectivityState: connectivity.Ready}
   901  	sca := internal.SetConnectedAddress.(func(*balancer.SubConnState, resolver.Address))
   902  	sca(&scs, endpoints[0].Addresses[0])
   903  	sc1.UpdateState(scs)
   904  	// Test pick with one backend.
   905  	testClusterLoadReporter := &testLoadReporter{cluster: testClusterName, service: testServiceName, drops: make(map[string]uint64), localityRPCCount: make(map[clients.Locality]*rpcCountData)}
   906  	const successCount = 5
   907  	const errorCount = 5
   908  	if err := cc.WaitForPicker(ctx, func(p balancer.Picker) error {
   909  		// Override the loadStore in the picker with testClusterLoadReporter.
   910  		picker := p.(*picker)
   911  		originalLoadStore := picker.loadStore
   912  		picker.loadStore = testClusterLoadReporter
   913  		defer func() { picker.loadStore = originalLoadStore }()
   914  		for i := 0; i < successCount; i++ {
   915  			gotSCSt, err := p.Pick(balancer.PickInfo{})
   916  			if gotSCSt.SubConn != sc1 {
   917  				return fmt.Errorf("picker.Pick, got %v, %v, want SubConn=%v", gotSCSt, err, sc1)
   918  			}
   919  			lr := &v3orcapb.OrcaLoadReport{
   920  				NamedMetrics: map[string]float64{testNamedMetricsKey1: 3.14, testNamedMetricsKey2: 2.718},
   921  			}
   922  			gotSCSt.Done(balancer.DoneInfo{ServerLoad: lr})
   923  		}
   924  		for i := 0; i < errorCount; i++ {
   925  			gotSCSt, err := p.Pick(balancer.PickInfo{})
   926  			if gotSCSt.SubConn != sc1 {
   927  				return fmt.Errorf("picker.Pick, got %v, %v, want SubConn=%v", gotSCSt, err, sc1)
   928  			}
   929  			gotSCSt.Done(balancer.DoneInfo{Err: fmt.Errorf("error")})
   930  		}
   931  		return nil
   932  	}); err != nil {
   933  		t.Fatal(err.Error())
   934  	}
   935  
   936  	// Dump load data from the store and compare with expected counts.
   937  	sd := testClusterLoadReporter.stats()
   938  	if sd == nil {
   939  		t.Fatalf("loads for cluster %v not found in store", testClusterName)
   940  	}
   941  	if sd.cluster != testClusterName || sd.service != testServiceName {
   942  		t.Fatalf("got unexpected load for %q, %q, want %q, %q", sd.cluster, sd.service, testClusterName, testServiceName)
   943  	}
   944  	localityData, ok := sd.localityStats[testLocality]
   945  	if !ok {
   946  		t.Fatalf("loads for %v not found in store", testLocality)
   947  	}
   948  	reqStats := localityData.requestStats
   949  	if reqStats.succeeded != successCount {
   950  		t.Errorf("got succeeded %v, want %v", reqStats.succeeded, successCount)
   951  	}
   952  	if reqStats.errored != errorCount {
   953  		t.Errorf("got errord %v, want %v", reqStats.errored, errorCount)
   954  	}
   955  	if reqStats.inProgress != 0 {
   956  		t.Errorf("got inProgress %v, want %v", reqStats.inProgress, 0)
   957  	}
   958  	wantLoadStats := map[string]serverLoadData{
   959  		testNamedMetricsKey1: {count: 5, sum: 15.7},  // aggregation of 5 * 3.14 = 15.7
   960  		testNamedMetricsKey2: {count: 5, sum: 13.59}, // aggregation of 5 * 2.718 = 13.59
   961  	}
   962  	if diff := cmp.Diff(wantLoadStats, localityData.loadStats, toleranceCmpOpt); diff != "" {
   963  		t.Errorf("localityData.LoadStats returned unexpected diff (-want +got):\n%s", diff)
   964  	}
   965  	b.Close()
   966  	if err := xdsC.WaitForCancelReportLoad(ctx); err != nil {
   967  		t.Fatalf("unexpected error waiting form load report to be canceled: %v", err)
   968  	}
   969  }
   970  
   971  // TestUpdateLRSServer covers the cases
   972  // - the init config specifies "" as the LRS server
   973  // - config modifies LRS server to a different string
   974  // - config sets LRS server to nil to stop load reporting
   975  func (s) TestUpdateLRSServer(t *testing.T) {
   976  	var testLocality = clients.Locality{
   977  		Region:  "test-region",
   978  		Zone:    "test-zone",
   979  		SubZone: "test-sub-zone",
   980  	}
   981  
   982  	xdsC := fakeclient.NewClient()
   983  
   984  	builder := balancer.Get(Name)
   985  	cc := testutils.NewBalancerClientConn(t)
   986  	b := builder.Build(cc, balancer.BuildOptions{})
   987  	defer b.Close()
   988  
   989  	endpoints := make([]resolver.Endpoint, len(testBackendEndpoints))
   990  	for i, e := range testBackendEndpoints {
   991  		endpoints[i] = xdsinternal.SetLocalityIDInEndpoint(e, testLocality)
   992  	}
   993  	testLRSServerConfig, err := bootstrap.ServerConfigForTesting(bootstrap.ServerConfigTestingOptions{
   994  		URI:          "trafficdirector.googleapis.com:443",
   995  		ChannelCreds: []bootstrap.ChannelCreds{{Type: "google_default"}},
   996  	})
   997  	if err != nil {
   998  		t.Fatalf("Failed to create LRS server config for testing: %v", err)
   999  	}
  1000  	if err := b.UpdateClientConnState(balancer.ClientConnState{
  1001  		ResolverState: xdsclient.SetClient(resolver.State{Endpoints: endpoints}, xdsC),
  1002  		BalancerConfig: &LBConfig{
  1003  			Cluster:             testClusterName,
  1004  			EDSServiceName:      testServiceName,
  1005  			LoadReportingServer: testLRSServerConfig,
  1006  			ChildPolicy: &internalserviceconfig.BalancerConfig{
  1007  				Name: roundrobin.Name,
  1008  			},
  1009  		},
  1010  	}); err != nil {
  1011  		t.Fatalf("unexpected error from UpdateClientConnState: %v", err)
  1012  	}
  1013  
  1014  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  1015  	defer cancel()
  1016  
  1017  	got, err := xdsC.WaitForReportLoad(ctx)
  1018  	if err != nil {
  1019  		t.Fatalf("xdsClient.ReportLoad failed with error: %v", err)
  1020  	}
  1021  	if got.Server != testLRSServerConfig {
  1022  		t.Fatalf("xdsClient.ReportLoad called with {%q}: want {%q}", got.Server, testLRSServerConfig)
  1023  	}
  1024  
  1025  	testLRSServerConfig2, err := bootstrap.ServerConfigForTesting(bootstrap.ServerConfigTestingOptions{
  1026  		URI:          "trafficdirector-another.googleapis.com:443",
  1027  		ChannelCreds: []bootstrap.ChannelCreds{{Type: "google_default"}},
  1028  	})
  1029  	if err != nil {
  1030  		t.Fatalf("Failed to create LRS server config for testing: %v", err)
  1031  	}
  1032  
  1033  	// Update LRS server to a different name.
  1034  	if err := b.UpdateClientConnState(balancer.ClientConnState{
  1035  		ResolverState: xdsclient.SetClient(resolver.State{Endpoints: endpoints}, xdsC),
  1036  		BalancerConfig: &LBConfig{
  1037  			Cluster:             testClusterName,
  1038  			EDSServiceName:      testServiceName,
  1039  			LoadReportingServer: testLRSServerConfig2,
  1040  			ChildPolicy: &internalserviceconfig.BalancerConfig{
  1041  				Name: roundrobin.Name,
  1042  			},
  1043  		},
  1044  	}); err != nil {
  1045  		t.Fatalf("unexpected error from UpdateClientConnState: %v", err)
  1046  	}
  1047  	if err := xdsC.WaitForCancelReportLoad(ctx); err != nil {
  1048  		t.Fatalf("unexpected error waiting form load report to be canceled: %v", err)
  1049  	}
  1050  	got2, err2 := xdsC.WaitForReportLoad(ctx)
  1051  	if err2 != nil {
  1052  		t.Fatalf("xdsClient.ReportLoad failed with error: %v", err2)
  1053  	}
  1054  	if got2.Server != testLRSServerConfig2 {
  1055  		t.Fatalf("xdsClient.ReportLoad called with {%q}: want {%q}", got2.Server, testLRSServerConfig2)
  1056  	}
  1057  
  1058  	// Update LRS server to nil, to disable LRS.
  1059  	if err := b.UpdateClientConnState(balancer.ClientConnState{
  1060  		ResolverState: xdsclient.SetClient(resolver.State{Endpoints: endpoints}, xdsC),
  1061  		BalancerConfig: &LBConfig{
  1062  			Cluster:        testClusterName,
  1063  			EDSServiceName: testServiceName,
  1064  			ChildPolicy: &internalserviceconfig.BalancerConfig{
  1065  				Name: roundrobin.Name,
  1066  			},
  1067  		},
  1068  	}); err != nil {
  1069  		t.Fatalf("unexpected error from UpdateClientConnState: %v", err)
  1070  	}
  1071  	if err := xdsC.WaitForCancelReportLoad(ctx); err != nil {
  1072  		t.Fatalf("unexpected error waiting form load report to be canceled: %v", err)
  1073  	}
  1074  
  1075  	shortCtx, shortCancel := context.WithTimeout(context.Background(), defaultShortTestTimeout)
  1076  	defer shortCancel()
  1077  	if s, err := xdsC.WaitForReportLoad(shortCtx); err != context.DeadlineExceeded {
  1078  		t.Fatalf("unexpected load report to server: %q", s)
  1079  	}
  1080  }
  1081  
  1082  // Test verifies that child policies was updated on receipt of
  1083  // configuration update.
  1084  func (s) TestChildPolicyUpdatedOnConfigUpdate(t *testing.T) {
  1085  	xdsC := fakeclient.NewClient()
  1086  
  1087  	builder := balancer.Get(Name)
  1088  	cc := testutils.NewBalancerClientConn(t)
  1089  	b := builder.Build(cc, balancer.BuildOptions{})
  1090  	defer b.Close()
  1091  
  1092  	// Keep track of which child policy was updated
  1093  	updatedChildPolicy := ""
  1094  
  1095  	// Create stub balancers to track config updates
  1096  	const (
  1097  		childPolicyName1 = "stubBalancer1"
  1098  		childPolicyName2 = "stubBalancer2"
  1099  	)
  1100  
  1101  	stub.Register(childPolicyName1, stub.BalancerFuncs{
  1102  		UpdateClientConnState: func(_ *stub.BalancerData, _ balancer.ClientConnState) error {
  1103  			updatedChildPolicy = childPolicyName1
  1104  			return nil
  1105  		},
  1106  	})
  1107  
  1108  	stub.Register(childPolicyName2, stub.BalancerFuncs{
  1109  		UpdateClientConnState: func(_ *stub.BalancerData, _ balancer.ClientConnState) error {
  1110  			updatedChildPolicy = childPolicyName2
  1111  			return nil
  1112  		},
  1113  	})
  1114  
  1115  	// Initial config update with childPolicyName1
  1116  	if err := b.UpdateClientConnState(balancer.ClientConnState{
  1117  		ResolverState: xdsclient.SetClient(resolver.State{Endpoints: testBackendEndpoints}, xdsC),
  1118  		BalancerConfig: &LBConfig{
  1119  			Cluster: testClusterName,
  1120  			ChildPolicy: &internalserviceconfig.BalancerConfig{
  1121  				Name: childPolicyName1,
  1122  			},
  1123  		},
  1124  	}); err != nil {
  1125  		t.Fatalf("Error updating the config: %v", err)
  1126  	}
  1127  
  1128  	if updatedChildPolicy != childPolicyName1 {
  1129  		t.Fatal("Child policy 1 was not updated on initial configuration update.")
  1130  	}
  1131  
  1132  	// Second config update with childPolicyName2
  1133  	if err := b.UpdateClientConnState(balancer.ClientConnState{
  1134  		ResolverState: xdsclient.SetClient(resolver.State{Endpoints: testBackendEndpoints}, xdsC),
  1135  		BalancerConfig: &LBConfig{
  1136  			Cluster: testClusterName,
  1137  			ChildPolicy: &internalserviceconfig.BalancerConfig{
  1138  				Name: childPolicyName2,
  1139  			},
  1140  		},
  1141  	}); err != nil {
  1142  		t.Fatalf("Error updating the config: %v", err)
  1143  	}
  1144  
  1145  	if updatedChildPolicy != childPolicyName2 {
  1146  		t.Fatal("Child policy 2 was not updated after child policy name change.")
  1147  	}
  1148  }
  1149  
  1150  // Test verifies that config update fails if child policy config
  1151  // failed to parse.
  1152  func (s) TestFailedToParseChildPolicyConfig(t *testing.T) {
  1153  	xdsC := fakeclient.NewClient()
  1154  
  1155  	builder := balancer.Get(Name)
  1156  	cc := testutils.NewBalancerClientConn(t)
  1157  	b := builder.Build(cc, balancer.BuildOptions{})
  1158  	defer b.Close()
  1159  
  1160  	// Create a stub balancer which fails to ParseConfig.
  1161  	const parseConfigError = "failed to parse config"
  1162  	const childPolicyName = "stubBalancer-FailedToParseChildPolicyConfig"
  1163  	stub.Register(childPolicyName, stub.BalancerFuncs{
  1164  		ParseConfig: func(_ json.RawMessage) (serviceconfig.LoadBalancingConfig, error) {
  1165  			return nil, errors.New(parseConfigError)
  1166  		},
  1167  	})
  1168  
  1169  	err := b.UpdateClientConnState(balancer.ClientConnState{
  1170  		ResolverState: xdsclient.SetClient(resolver.State{Endpoints: testBackendEndpoints}, xdsC),
  1171  		BalancerConfig: &LBConfig{
  1172  			Cluster: testClusterName,
  1173  			ChildPolicy: &internalserviceconfig.BalancerConfig{
  1174  				Name: childPolicyName,
  1175  			},
  1176  		},
  1177  	})
  1178  
  1179  	if err == nil || !strings.Contains(err.Error(), parseConfigError) {
  1180  		t.Fatalf("Got error: %v, want error: %s", err, parseConfigError)
  1181  	}
  1182  }
  1183  
  1184  // Test verify that the case picker is updated synchronously on receipt of
  1185  // configuration update.
  1186  func (s) TestPickerUpdatedSynchronouslyOnConfigUpdate(t *testing.T) {
  1187  	// Override the pickerUpdateHook to be notified that picker was updated.
  1188  	pickerUpdated := make(chan struct{}, 1)
  1189  	origNewPickerUpdated := pickerUpdateHook
  1190  	pickerUpdateHook = func() {
  1191  		pickerUpdated <- struct{}{}
  1192  	}
  1193  	defer func() { pickerUpdateHook = origNewPickerUpdated }()
  1194  
  1195  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  1196  	defer cancel()
  1197  	// Override the clientConnUpdateHook to ensure client conn was updated.
  1198  	clientConnUpdateDone := make(chan struct{}, 1)
  1199  	origClientConnUpdateHook := clientConnUpdateHook
  1200  	clientConnUpdateHook = func() {
  1201  		// Verify that picker was updated before the completion of
  1202  		// client conn update.
  1203  		select {
  1204  		case <-pickerUpdated:
  1205  		case <-ctx.Done():
  1206  			t.Fatal("Client conn update completed before picker update.")
  1207  		}
  1208  		clientConnUpdateDone <- struct{}{}
  1209  	}
  1210  	defer func() { clientConnUpdateHook = origClientConnUpdateHook }()
  1211  
  1212  	defer xdsclient.ClearCounterForTesting(testClusterName, testServiceName)
  1213  	xdsC := fakeclient.NewClient()
  1214  
  1215  	builder := balancer.Get(Name)
  1216  	cc := testutils.NewBalancerClientConn(t)
  1217  	b := builder.Build(cc, balancer.BuildOptions{})
  1218  	defer b.Close()
  1219  
  1220  	// Create a stub balancer which waits for the cluster_impl policy to be
  1221  	// closed before sending a picker update (upon receipt of a resolver
  1222  	// update).
  1223  	stub.Register(t.Name(), stub.BalancerFuncs{
  1224  		UpdateClientConnState: func(bd *stub.BalancerData, _ balancer.ClientConnState) error {
  1225  			bd.ClientConn.UpdateState(balancer.State{
  1226  				Picker: base.NewErrPicker(errors.New("dummy error picker")),
  1227  			})
  1228  			return nil
  1229  		},
  1230  	})
  1231  
  1232  	if err := b.UpdateClientConnState(balancer.ClientConnState{
  1233  		ResolverState: xdsclient.SetClient(resolver.State{Endpoints: testBackendEndpoints}, xdsC),
  1234  		BalancerConfig: &LBConfig{
  1235  			Cluster:        testClusterName,
  1236  			EDSServiceName: testServiceName,
  1237  			ChildPolicy: &internalserviceconfig.BalancerConfig{
  1238  				Name: t.Name(),
  1239  			},
  1240  		},
  1241  	}); err != nil {
  1242  		t.Fatalf("Unexpected error from UpdateClientConnState: %v", err)
  1243  	}
  1244  
  1245  	select {
  1246  	case <-clientConnUpdateDone:
  1247  	case <-ctx.Done():
  1248  		t.Fatal("Timed out waiting for client conn update to be completed.")
  1249  	}
  1250  }