google.golang.org/grpc@v1.72.2/internal/resolver/dns/dns_resolver_test.go (about)

     1  /*
     2   *
     3   * Copyright 2018 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package dns_test
    20  
    21  import (
    22  	"context"
    23  	"errors"
    24  	"fmt"
    25  	"net"
    26  	"strings"
    27  	"sync/atomic"
    28  	"testing"
    29  	"time"
    30  
    31  	"github.com/google/go-cmp/cmp"
    32  	"github.com/google/go-cmp/cmp/cmpopts"
    33  	"google.golang.org/grpc/balancer"
    34  	grpclbstate "google.golang.org/grpc/balancer/grpclb/state"
    35  	"google.golang.org/grpc/internal"
    36  	"google.golang.org/grpc/internal/envconfig"
    37  	"google.golang.org/grpc/internal/grpctest"
    38  	"google.golang.org/grpc/internal/resolver/dns"
    39  	dnsinternal "google.golang.org/grpc/internal/resolver/dns/internal"
    40  	"google.golang.org/grpc/internal/testutils"
    41  	"google.golang.org/grpc/resolver"
    42  	dnspublic "google.golang.org/grpc/resolver/dns"
    43  	"google.golang.org/grpc/serviceconfig"
    44  
    45  	_ "google.golang.org/grpc" // To initialize internal.ParseServiceConfig
    46  )
    47  
    48  const (
    49  	txtBytesLimit           = 255
    50  	defaultTestTimeout      = 10 * time.Second
    51  	defaultTestShortTimeout = 10 * time.Millisecond
    52  
    53  	colonDefaultPort = ":443"
    54  )
    55  
    56  type s struct {
    57  	grpctest.Tester
    58  }
    59  
    60  func Test(t *testing.T) {
    61  	grpctest.RunSubTests(t, s{})
    62  }
    63  
    64  // Override the default net.Resolver with a test resolver.
    65  func overrideNetResolver(t *testing.T, r *testNetResolver) {
    66  	origNetResolver := dnsinternal.NewNetResolver
    67  	dnsinternal.NewNetResolver = func(string) (dnsinternal.NetResolver, error) { return r, nil }
    68  	t.Cleanup(func() { dnsinternal.NewNetResolver = origNetResolver })
    69  }
    70  
    71  // Override the DNS minimum resolution interval used by the resolver.
    72  func overrideResolutionInterval(t *testing.T, d time.Duration) {
    73  	origMinResInterval := dns.MinResolutionInterval
    74  	dnspublic.SetMinResolutionInterval(d)
    75  	t.Cleanup(func() { dnspublic.SetMinResolutionInterval(origMinResInterval) })
    76  }
    77  
    78  // Override the timer used by the DNS resolver to fire after a duration of d.
    79  func overrideTimeAfterFunc(t *testing.T, d time.Duration) {
    80  	origTimeAfter := dnsinternal.TimeAfterFunc
    81  	dnsinternal.TimeAfterFunc = func(time.Duration) <-chan time.Time {
    82  		return time.After(d)
    83  	}
    84  	t.Cleanup(func() { dnsinternal.TimeAfterFunc = origTimeAfter })
    85  }
    86  
    87  // Override the timer used by the DNS resolver as follows:
    88  // - use the durChan to read the duration that the resolver wants to wait for
    89  // - use the timerChan to unblock the wait on the timer
    90  func overrideTimeAfterFuncWithChannel(t *testing.T) (durChan chan time.Duration, timeChan chan time.Time) {
    91  	origTimeAfter := dnsinternal.TimeAfterFunc
    92  	durChan = make(chan time.Duration, 1)
    93  	timeChan = make(chan time.Time)
    94  	dnsinternal.TimeAfterFunc = func(d time.Duration) <-chan time.Time {
    95  		select {
    96  		case durChan <- d:
    97  		default:
    98  		}
    99  		return timeChan
   100  	}
   101  	t.Cleanup(func() { dnsinternal.TimeAfterFunc = origTimeAfter })
   102  	return durChan, timeChan
   103  }
   104  
   105  // Override the current time used by the DNS resolver.
   106  func overrideTimeNowFunc(t *testing.T, now time.Time) {
   107  	origTimeNowFunc := dnsinternal.TimeNowFunc
   108  	dnsinternal.TimeNowFunc = func() time.Time { return now }
   109  	t.Cleanup(func() { dnsinternal.TimeNowFunc = origTimeNowFunc })
   110  }
   111  
   112  // Override the remaining wait time to allow re-resolution by DNS resolver.
   113  // Use the timeChan to read the time until resolver needs to wait for
   114  // and return 0 wait time.
   115  func overrideTimeUntilFuncWithChannel(t *testing.T) (timeChan chan time.Time) {
   116  	timeCh := make(chan time.Time, 1)
   117  	origTimeUntil := dnsinternal.TimeUntilFunc
   118  	dnsinternal.TimeUntilFunc = func(t time.Time) time.Duration {
   119  		timeCh <- t
   120  		return 0
   121  	}
   122  	t.Cleanup(func() { dnsinternal.TimeUntilFunc = origTimeUntil })
   123  	return timeCh
   124  }
   125  
   126  func enableSRVLookups(t *testing.T) {
   127  	origEnableSRVLookups := dns.EnableSRVLookups
   128  	dns.EnableSRVLookups = true
   129  	t.Cleanup(func() { dns.EnableSRVLookups = origEnableSRVLookups })
   130  }
   131  
   132  // Builds a DNS resolver for target and returns a couple of channels to read the
   133  // state and error pushed by the resolver respectively.
   134  func buildResolverWithTestClientConn(t *testing.T, target string) (resolver.Resolver, chan resolver.State, chan error) {
   135  	t.Helper()
   136  
   137  	b := resolver.Get("dns")
   138  	if b == nil {
   139  		t.Fatalf("Resolver for dns:/// scheme not registered")
   140  	}
   141  
   142  	stateCh := make(chan resolver.State, 1)
   143  	updateStateF := func(s resolver.State) error {
   144  		select {
   145  		case stateCh <- s:
   146  		default:
   147  		}
   148  		return nil
   149  	}
   150  
   151  	errCh := make(chan error, 1)
   152  	reportErrorF := func(err error) {
   153  		select {
   154  		case errCh <- err:
   155  		default:
   156  		}
   157  	}
   158  
   159  	tcc := &testutils.ResolverClientConn{Logger: t, UpdateStateF: updateStateF, ReportErrorF: reportErrorF}
   160  	r, err := b.Build(resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("dns:///%s", target))}, tcc, resolver.BuildOptions{})
   161  	if err != nil {
   162  		t.Fatalf("Failed to build DNS resolver for target %q: %v\n", target, err)
   163  	}
   164  	t.Cleanup(func() { r.Close() })
   165  
   166  	return r, stateCh, errCh
   167  }
   168  
   169  // Waits for a state update from the DNS resolver and verifies the following:
   170  // - wantAddrs matches the list of addresses in the update
   171  // - wantBalancerAddrs matches the list of grpclb addresses in the update
   172  // - wantSC matches the service config in the update
   173  func verifyUpdateFromResolver(ctx context.Context, t *testing.T, stateCh chan resolver.State, wantAddrs, wantBalancerAddrs []resolver.Address, wantSC string) {
   174  	t.Helper()
   175  
   176  	var state resolver.State
   177  	select {
   178  	case <-ctx.Done():
   179  		t.Fatal("Timeout when waiting for a state update from the resolver")
   180  	case state = <-stateCh:
   181  	}
   182  
   183  	if !cmp.Equal(state.Addresses, wantAddrs, cmpopts.EquateEmpty()) {
   184  		t.Fatalf("Got addresses: %+v, want: %+v", state.Addresses, wantAddrs)
   185  	}
   186  	if gs := grpclbstate.Get(state); gs == nil {
   187  		if len(wantBalancerAddrs) > 0 {
   188  			t.Fatalf("Got no grpclb addresses. Want %d", len(wantBalancerAddrs))
   189  		}
   190  	} else {
   191  		if !cmp.Equal(gs.BalancerAddresses, wantBalancerAddrs) {
   192  			t.Fatalf("Got grpclb addresses %+v, want %+v", gs.BalancerAddresses, wantBalancerAddrs)
   193  		}
   194  	}
   195  	if wantSC == "{}" {
   196  		if state.ServiceConfig != nil && state.ServiceConfig.Config != nil {
   197  			t.Fatalf("Got service config:\n%s \nWant service config: {}", cmp.Diff(nil, state.ServiceConfig.Config))
   198  		}
   199  
   200  	} else if wantSC != "" {
   201  		wantSCParsed := internal.ParseServiceConfig.(func(string) *serviceconfig.ParseResult)(wantSC)
   202  		if !internal.EqualServiceConfigForTesting(state.ServiceConfig.Config, wantSCParsed.Config) {
   203  			t.Fatalf("Got service config:\n%s \nWant service config:\n%s", cmp.Diff(nil, state.ServiceConfig.Config), cmp.Diff(nil, wantSCParsed.Config))
   204  		}
   205  	}
   206  }
   207  
   208  // This is the service config used by the fake net.Resolver in its TXT record.
   209  //   - it contains an array of 5 entries
   210  //   - the first three will be dropped by the DNS resolver as part of its
   211  //     canarying rule matching functionality:
   212  //   - the client language does not match in the first entry
   213  //   - the percentage is set to 0 in the second entry
   214  //   - the client host name does not match in the third entry
   215  //   - the fourth and fifth entries will match the canarying rules, and therefore
   216  //     the fourth entry will be used as it will be  the first matching entry.
   217  const txtRecordGood = `
   218  [
   219  	{
   220  		"clientLanguage": [
   221  			"CPP",
   222  			"JAVA"
   223  		],
   224  		"serviceConfig": {
   225  			"loadBalancingPolicy": "grpclb",
   226  			"methodConfig": [
   227  				{
   228  					"name": [
   229  						{
   230  							"service": "all"
   231  						}
   232  					],
   233  					"timeout": "1s"
   234  				}
   235  			]
   236  		}
   237  	},
   238  	{
   239  		"percentage": 0,
   240  		"serviceConfig": {
   241  			"loadBalancingPolicy": "grpclb",
   242  			"methodConfig": [
   243  				{
   244  					"name": [
   245  						{
   246  							"service": "all"
   247  						}
   248  					],
   249  					"timeout": "1s"
   250  				}
   251  			]
   252  		}
   253  	},
   254  	{
   255  		"clientHostName": [
   256  			"localhost"
   257  		],
   258  		"serviceConfig": {
   259  			"loadBalancingPolicy": "grpclb",
   260  			"methodConfig": [
   261  				{
   262  					"name": [
   263  						{
   264  							"service": "all"
   265  						}
   266  					],
   267  					"timeout": "1s"
   268  				}
   269  			]
   270  		}
   271  	},
   272  	{
   273  		"clientLanguage": [
   274  			"GO"
   275  		],
   276  		"percentage": 100,
   277  		"serviceConfig": {
   278  			"loadBalancingPolicy": "round_robin",
   279  			"methodConfig": [
   280  				{
   281  					"name": [
   282  						{
   283  							"service": "foo"
   284  						}
   285  					],
   286  					"waitForReady": true,
   287  					"timeout": "1s"
   288  				},
   289  				{
   290  					"name": [
   291  						{
   292  							"service": "bar"
   293  						}
   294  					],
   295  					"waitForReady": false
   296  				}
   297  			]
   298  		}
   299  	},
   300  	{
   301  		"serviceConfig": {
   302  			"loadBalancingPolicy": "round_robin",
   303  			"methodConfig": [
   304  				{
   305  					"name": [
   306  						{
   307  							"service": "foo",
   308  							"method": "bar"
   309  						}
   310  					],
   311  					"waitForReady": true
   312  				}
   313  			]
   314  		}
   315  	}
   316  ]`
   317  
   318  // This is the matched portion of the above TXT record entry.
   319  const scJSON = `
   320  {
   321  	"loadBalancingPolicy": "round_robin",
   322  	"methodConfig": [
   323  		{
   324  			"name": [
   325  				{
   326  					"service": "foo"
   327  				}
   328  			],
   329  			"waitForReady": true,
   330  			"timeout": "1s"
   331  		},
   332  		{
   333  			"name": [
   334  				{
   335  					"service": "bar"
   336  				}
   337  			],
   338  			"waitForReady": false
   339  		}
   340  	]
   341  }`
   342  
   343  // This service config contains three entries, but none of the match the DNS
   344  // resolver's canarying rules and hence the resulting service config pushed by
   345  // the DNS resolver will be an empty one.
   346  const txtRecordNonMatching = `
   347  [
   348  	{
   349  		"clientLanguage": [
   350  			"CPP",
   351  			"JAVA"
   352  		],
   353  		"serviceConfig": {
   354  			"loadBalancingPolicy": "grpclb",
   355  			"methodConfig": [
   356  				{
   357  					"name": [
   358  						{
   359  							"service": "all"
   360  						}
   361  					],
   362  					"timeout": "1s"
   363  				}
   364  			]
   365  		}
   366  	},
   367  	{
   368  		"percentage": 0,
   369  		"serviceConfig": {
   370  			"loadBalancingPolicy": "grpclb",
   371  			"methodConfig": [
   372  				{
   373  					"name": [
   374  						{
   375  							"service": "all"
   376  						}
   377  					],
   378  					"timeout": "1s"
   379  				}
   380  			]
   381  		}
   382  	},
   383  	{
   384  		"clientHostName": [
   385  			"localhost"
   386  		],
   387  		"serviceConfig": {
   388  			"loadBalancingPolicy": "grpclb",
   389  			"methodConfig": [
   390  				{
   391  					"name": [
   392  						{
   393  							"service": "all"
   394  						}
   395  					],
   396  					"timeout": "1s"
   397  				}
   398  			]
   399  		}
   400  	}
   401  ]`
   402  
   403  // Tests the scenario where a name resolves to a list of addresses, possibly
   404  // some grpclb addresses as well, and a service config. The test verifies that
   405  // the expected update is pushed to the channel.
   406  func (s) TestDNSResolver_Basic(t *testing.T) {
   407  	tests := []struct {
   408  		name              string
   409  		target            string
   410  		hostLookupTable   map[string][]string
   411  		srvLookupTable    map[string][]*net.SRV
   412  		txtLookupTable    map[string][]string
   413  		wantAddrs         []resolver.Address
   414  		wantBalancerAddrs []resolver.Address
   415  		wantSC            string
   416  	}{
   417  		{
   418  			name:   "default_port",
   419  			target: "foo.bar.com",
   420  			hostLookupTable: map[string][]string{
   421  				"foo.bar.com": {"1.2.3.4", "5.6.7.8"},
   422  			},
   423  			txtLookupTable: map[string][]string{
   424  				"_grpc_config.foo.bar.com": txtRecordServiceConfig(txtRecordGood),
   425  			},
   426  			wantAddrs:         []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}},
   427  			wantBalancerAddrs: nil,
   428  			wantSC:            scJSON,
   429  		},
   430  		{
   431  			name:   "specified_port",
   432  			target: "foo.bar.com:1234",
   433  			hostLookupTable: map[string][]string{
   434  				"foo.bar.com": {"1.2.3.4", "5.6.7.8"},
   435  			},
   436  			txtLookupTable: map[string][]string{
   437  				"_grpc_config.foo.bar.com": txtRecordServiceConfig(txtRecordGood),
   438  			},
   439  			wantAddrs:         []resolver.Address{{Addr: "1.2.3.4:1234"}, {Addr: "5.6.7.8:1234"}},
   440  			wantBalancerAddrs: nil,
   441  			wantSC:            scJSON,
   442  		},
   443  		{
   444  			name:   "ipv4_with_SRV_and_single_grpclb_address",
   445  			target: "srv.ipv4.single.fake",
   446  			hostLookupTable: map[string][]string{
   447  				"srv.ipv4.single.fake": {"2.4.6.8"},
   448  				"ipv4.single.fake":     {"1.2.3.4"},
   449  			},
   450  			srvLookupTable: map[string][]*net.SRV{
   451  				"_grpclb._tcp.srv.ipv4.single.fake": {&net.SRV{Target: "ipv4.single.fake", Port: 1234}},
   452  			},
   453  			txtLookupTable: map[string][]string{
   454  				"_grpc_config.srv.ipv4.single.fake": txtRecordServiceConfig(txtRecordGood),
   455  			},
   456  			wantAddrs:         []resolver.Address{{Addr: "2.4.6.8" + colonDefaultPort}},
   457  			wantBalancerAddrs: []resolver.Address{{Addr: "1.2.3.4:1234", ServerName: "ipv4.single.fake"}},
   458  			wantSC:            scJSON,
   459  		},
   460  		{
   461  			name:   "ipv4_with_SRV_and_multiple_grpclb_address",
   462  			target: "srv.ipv4.multi.fake",
   463  			hostLookupTable: map[string][]string{
   464  				"ipv4.multi.fake": {"1.2.3.4", "5.6.7.8", "9.10.11.12"},
   465  			},
   466  			srvLookupTable: map[string][]*net.SRV{
   467  				"_grpclb._tcp.srv.ipv4.multi.fake": {&net.SRV{Target: "ipv4.multi.fake", Port: 1234}},
   468  			},
   469  			txtLookupTable: map[string][]string{
   470  				"_grpc_config.srv.ipv4.multi.fake": txtRecordServiceConfig(txtRecordGood),
   471  			},
   472  			wantAddrs: nil,
   473  			wantBalancerAddrs: []resolver.Address{
   474  				{Addr: "1.2.3.4:1234", ServerName: "ipv4.multi.fake"},
   475  				{Addr: "5.6.7.8:1234", ServerName: "ipv4.multi.fake"},
   476  				{Addr: "9.10.11.12:1234", ServerName: "ipv4.multi.fake"},
   477  			},
   478  			wantSC: scJSON,
   479  		},
   480  		{
   481  			name:   "ipv6_with_SRV_and_single_grpclb_address",
   482  			target: "srv.ipv6.single.fake",
   483  			hostLookupTable: map[string][]string{
   484  				"srv.ipv6.single.fake": nil,
   485  				"ipv6.single.fake":     {"2607:f8b0:400a:801::1001"},
   486  			},
   487  			srvLookupTable: map[string][]*net.SRV{
   488  				"_grpclb._tcp.srv.ipv6.single.fake": {&net.SRV{Target: "ipv6.single.fake", Port: 1234}},
   489  			},
   490  			txtLookupTable: map[string][]string{
   491  				"_grpc_config.srv.ipv6.single.fake": txtRecordServiceConfig(txtRecordNonMatching),
   492  			},
   493  			wantAddrs:         nil,
   494  			wantBalancerAddrs: []resolver.Address{{Addr: "[2607:f8b0:400a:801::1001]:1234", ServerName: "ipv6.single.fake"}},
   495  			wantSC:            "{}",
   496  		},
   497  		{
   498  			name:   "ipv6_with_SRV_and_multiple_grpclb_address",
   499  			target: "srv.ipv6.multi.fake",
   500  			hostLookupTable: map[string][]string{
   501  				"srv.ipv6.multi.fake": nil,
   502  				"ipv6.multi.fake":     {"2607:f8b0:400a:801::1001", "2607:f8b0:400a:801::1002", "2607:f8b0:400a:801::1003"},
   503  			},
   504  			srvLookupTable: map[string][]*net.SRV{
   505  				"_grpclb._tcp.srv.ipv6.multi.fake": {&net.SRV{Target: "ipv6.multi.fake", Port: 1234}},
   506  			},
   507  			txtLookupTable: map[string][]string{
   508  				"_grpc_config.srv.ipv6.multi.fake": txtRecordServiceConfig(txtRecordNonMatching),
   509  			},
   510  			wantAddrs: nil,
   511  			wantBalancerAddrs: []resolver.Address{
   512  				{Addr: "[2607:f8b0:400a:801::1001]:1234", ServerName: "ipv6.multi.fake"},
   513  				{Addr: "[2607:f8b0:400a:801::1002]:1234", ServerName: "ipv6.multi.fake"},
   514  				{Addr: "[2607:f8b0:400a:801::1003]:1234", ServerName: "ipv6.multi.fake"},
   515  			},
   516  			wantSC: "{}",
   517  		},
   518  	}
   519  
   520  	for _, test := range tests {
   521  		t.Run(test.name, func(t *testing.T) {
   522  			overrideTimeAfterFunc(t, 2*defaultTestTimeout)
   523  			overrideNetResolver(t, &testNetResolver{
   524  				hostLookupTable: test.hostLookupTable,
   525  				srvLookupTable:  test.srvLookupTable,
   526  				txtLookupTable:  test.txtLookupTable,
   527  			})
   528  			enableSRVLookups(t)
   529  			_, stateCh, _ := buildResolverWithTestClientConn(t, test.target)
   530  
   531  			ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   532  			defer cancel()
   533  			verifyUpdateFromResolver(ctx, t, stateCh, test.wantAddrs, test.wantBalancerAddrs, test.wantSC)
   534  		})
   535  	}
   536  }
   537  
   538  // Tests the case where the channel returns an error for the update pushed by
   539  // the DNS resolver. Verifies that the DNS resolver backs off before trying to
   540  // resolve. Once the channel returns a nil error, the test verifies that the DNS
   541  // resolver does not backoff anymore.
   542  func (s) TestDNSResolver_ExponentialBackoff(t *testing.T) {
   543  	tests := []struct {
   544  		name            string
   545  		target          string
   546  		hostLookupTable map[string][]string
   547  		txtLookupTable  map[string][]string
   548  		wantAddrs       []resolver.Address
   549  		wantSC          string
   550  	}{
   551  		{
   552  			name:            "happy case default port",
   553  			target:          "foo.bar.com",
   554  			hostLookupTable: map[string][]string{"foo.bar.com": {"1.2.3.4", "5.6.7.8"}},
   555  			txtLookupTable: map[string][]string{
   556  				"_grpc_config.foo.bar.com": txtRecordServiceConfig(txtRecordGood),
   557  			},
   558  			wantAddrs: []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}},
   559  			wantSC:    scJSON,
   560  		},
   561  		{
   562  			name:            "happy case specified port",
   563  			target:          "foo.bar.com:1234",
   564  			hostLookupTable: map[string][]string{"foo.bar.com": {"1.2.3.4", "5.6.7.8"}},
   565  			txtLookupTable: map[string][]string{
   566  				"_grpc_config.foo.bar.com": txtRecordServiceConfig(txtRecordGood),
   567  			},
   568  			wantAddrs: []resolver.Address{{Addr: "1.2.3.4:1234"}, {Addr: "5.6.7.8:1234"}},
   569  			wantSC:    scJSON,
   570  		},
   571  		{
   572  			name:   "happy case another default port",
   573  			target: "srv.ipv4.single.fake",
   574  			hostLookupTable: map[string][]string{
   575  				"srv.ipv4.single.fake": {"2.4.6.8"},
   576  				"ipv4.single.fake":     {"1.2.3.4"},
   577  			},
   578  			txtLookupTable: map[string][]string{
   579  				"_grpc_config.srv.ipv4.single.fake": txtRecordServiceConfig(txtRecordGood),
   580  			},
   581  			wantAddrs: []resolver.Address{{Addr: "2.4.6.8" + colonDefaultPort}},
   582  			wantSC:    scJSON,
   583  		},
   584  	}
   585  	for _, test := range tests {
   586  		t.Run(test.name, func(t *testing.T) {
   587  			durChan, timeChan := overrideTimeAfterFuncWithChannel(t)
   588  			overrideNetResolver(t, &testNetResolver{
   589  				hostLookupTable: test.hostLookupTable,
   590  				txtLookupTable:  test.txtLookupTable,
   591  			})
   592  
   593  			// Set the test clientconn to return error back to the resolver when
   594  			// it pushes an update on the channel.
   595  			var returnNilErr atomic.Bool
   596  			updateStateF := func(s resolver.State) error {
   597  				if returnNilErr.Load() {
   598  					return nil
   599  				}
   600  				return balancer.ErrBadResolverState
   601  			}
   602  			tcc := &testutils.ResolverClientConn{Logger: t, UpdateStateF: updateStateF}
   603  
   604  			b := resolver.Get("dns")
   605  			if b == nil {
   606  				t.Fatalf("Resolver for dns:/// scheme not registered")
   607  			}
   608  			r, err := b.Build(resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("dns:///%s", test.target))}, tcc, resolver.BuildOptions{})
   609  			if err != nil {
   610  				t.Fatalf("Failed to build DNS resolver for target %q: %v\n", test.target, err)
   611  			}
   612  			defer r.Close()
   613  
   614  			// Expect the DNS resolver to backoff and attempt to re-resolve.
   615  			ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   616  			defer cancel()
   617  			const retries = 10
   618  			var prevDur time.Duration
   619  			for i := 0; i < retries; i++ {
   620  				select {
   621  				case <-ctx.Done():
   622  					t.Fatalf("(Iteration: %d): Timeout when waiting for DNS resolver to backoff", i)
   623  				case dur := <-durChan:
   624  					if dur <= prevDur {
   625  						t.Fatalf("(Iteration: %d): Unexpected decrease in amount of time to backoff", i)
   626  					}
   627  				}
   628  
   629  				if i == retries-1 {
   630  					// Update resolver.ClientConn to not return an error
   631  					// anymore before last resolution retry to ensure that
   632  					// last resolution attempt doesn't back off.
   633  					returnNilErr.Store(true)
   634  				}
   635  
   636  				// Unblock the DNS resolver's backoff by pushing the current time.
   637  				timeChan <- time.Now()
   638  			}
   639  
   640  			// Verify that the DNS resolver does not backoff anymore.
   641  			sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout)
   642  			defer sCancel()
   643  			select {
   644  			case <-durChan:
   645  				t.Fatal("Unexpected DNS resolver backoff")
   646  			case <-sCtx.Done():
   647  			}
   648  		})
   649  	}
   650  }
   651  
   652  // Tests the case where the DNS resolver is asked to re-resolve by invoking the
   653  // ResolveNow method.
   654  func (s) TestDNSResolver_ResolveNow(t *testing.T) {
   655  	const target = "foo.bar.com"
   656  
   657  	overrideResolutionInterval(t, 0)
   658  	overrideTimeAfterFunc(t, 0)
   659  	tr := &testNetResolver{
   660  		hostLookupTable: map[string][]string{
   661  			"foo.bar.com": {"1.2.3.4", "5.6.7.8"},
   662  		},
   663  		txtLookupTable: map[string][]string{
   664  			"_grpc_config.foo.bar.com": txtRecordServiceConfig(txtRecordGood),
   665  		},
   666  	}
   667  	overrideNetResolver(t, tr)
   668  
   669  	r, stateCh, _ := buildResolverWithTestClientConn(t, target)
   670  
   671  	// Verify that the first update pushed by the resolver matches expectations.
   672  	wantAddrs := []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}}
   673  	wantSC := scJSON
   674  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   675  	defer cancel()
   676  	verifyUpdateFromResolver(ctx, t, stateCh, wantAddrs, nil, wantSC)
   677  
   678  	// Update state in the fake net.Resolver to return only one address and a
   679  	// new service config.
   680  	tr.UpdateHostLookupTable(map[string][]string{target: {"1.2.3.4"}})
   681  	tr.UpdateTXTLookupTable(map[string][]string{
   682  		"_grpc_config.foo.bar.com": txtRecordServiceConfig(`[{"serviceConfig":{"loadBalancingPolicy": "grpclb"}}]`),
   683  	})
   684  
   685  	// Ask the resolver to re-resolve and verify that the new update matches
   686  	// expectations.
   687  	r.ResolveNow(resolver.ResolveNowOptions{})
   688  	wantAddrs = []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}}
   689  	wantSC = `{"loadBalancingPolicy": "grpclb"}`
   690  	verifyUpdateFromResolver(ctx, t, stateCh, wantAddrs, nil, wantSC)
   691  
   692  	// Update state in the fake resolver to return no addresses and the same
   693  	// service config as before.
   694  	tr.UpdateHostLookupTable(map[string][]string{target: nil})
   695  
   696  	// Ask the resolver to re-resolve and verify that the new update matches
   697  	// expectations.
   698  	r.ResolveNow(resolver.ResolveNowOptions{})
   699  	verifyUpdateFromResolver(ctx, t, stateCh, nil, nil, wantSC)
   700  }
   701  
   702  // Tests the case where the given name is an IP address and verifies that the
   703  // update pushed by the DNS resolver meets expectations.
   704  func (s) TestIPResolver(t *testing.T) {
   705  	tests := []struct {
   706  		name     string
   707  		target   string
   708  		wantAddr []resolver.Address
   709  	}{
   710  		{
   711  			name:     "localhost ipv4 default port",
   712  			target:   "127.0.0.1",
   713  			wantAddr: []resolver.Address{{Addr: "127.0.0.1:443"}},
   714  		},
   715  		{
   716  			name:     "localhost ipv4 non-default port",
   717  			target:   "127.0.0.1:12345",
   718  			wantAddr: []resolver.Address{{Addr: "127.0.0.1:12345"}},
   719  		},
   720  		{
   721  			name:     "localhost ipv6 default port no brackets",
   722  			target:   "::1",
   723  			wantAddr: []resolver.Address{{Addr: "[::1]:443"}},
   724  		},
   725  		{
   726  			name:     "localhost ipv6 default port with brackets",
   727  			target:   "[::1]",
   728  			wantAddr: []resolver.Address{{Addr: "[::1]:443"}},
   729  		},
   730  		{
   731  			name:     "localhost ipv6 non-default port",
   732  			target:   "[::1]:12345",
   733  			wantAddr: []resolver.Address{{Addr: "[::1]:12345"}},
   734  		},
   735  		{
   736  			name:     "ipv6 default port no brackets",
   737  			target:   "2001:db8:85a3::8a2e:370:7334",
   738  			wantAddr: []resolver.Address{{Addr: "[2001:db8:85a3::8a2e:370:7334]:443"}},
   739  		},
   740  		{
   741  			name:     "ipv6 default port with brackets",
   742  			target:   "[2001:db8:85a3::8a2e:370:7334]",
   743  			wantAddr: []resolver.Address{{Addr: "[2001:db8:85a3::8a2e:370:7334]:443"}},
   744  		},
   745  		{
   746  			name:     "ipv6 non-default port with brackets",
   747  			target:   "[2001:db8:85a3::8a2e:370:7334]:12345",
   748  			wantAddr: []resolver.Address{{Addr: "[2001:db8:85a3::8a2e:370:7334]:12345"}},
   749  		},
   750  		{
   751  			name:     "abbreviated ipv6 address",
   752  			target:   "[2001:db8::1]:http",
   753  			wantAddr: []resolver.Address{{Addr: "[2001:db8::1]:http"}},
   754  		},
   755  		{
   756  			name:     "ipv6 with zone and port",
   757  			target:   "[fe80::1%25eth0]:1234",
   758  			wantAddr: []resolver.Address{{Addr: "[fe80::1%eth0]:1234"}},
   759  		},
   760  		{
   761  			name:     "ipv6 with zone and default port",
   762  			target:   "fe80::1%25eth0",
   763  			wantAddr: []resolver.Address{{Addr: "[fe80::1%eth0]:443"}},
   764  		},
   765  	}
   766  
   767  	for _, test := range tests {
   768  		t.Run(test.name, func(t *testing.T) {
   769  			overrideResolutionInterval(t, 0)
   770  			overrideTimeAfterFunc(t, 2*defaultTestTimeout)
   771  			r, stateCh, _ := buildResolverWithTestClientConn(t, test.target)
   772  
   773  			ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   774  			defer cancel()
   775  			verifyUpdateFromResolver(ctx, t, stateCh, test.wantAddr, nil, "")
   776  
   777  			// Attempt to re-resolve should not result in a state update.
   778  			r.ResolveNow(resolver.ResolveNowOptions{})
   779  			sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout)
   780  			defer sCancel()
   781  			select {
   782  			case <-sCtx.Done():
   783  			case s := <-stateCh:
   784  				t.Fatalf("Unexpected state update from the resolver: %+v", s)
   785  			}
   786  		})
   787  	}
   788  }
   789  
   790  // Tests the DNS resolver builder with different target names.
   791  func (s) TestResolverBuild(t *testing.T) {
   792  	tests := []struct {
   793  		name    string
   794  		target  string
   795  		wantErr string
   796  	}{
   797  		{
   798  			name:   "valid url",
   799  			target: "www.google.com",
   800  		},
   801  		{
   802  			name:   "host port",
   803  			target: "foo.bar:12345",
   804  		},
   805  		{
   806  			name:   "ipv4 address with default port",
   807  			target: "127.0.0.1",
   808  		},
   809  		{
   810  			name:   "ipv6 address without brackets and default port",
   811  			target: "::",
   812  		},
   813  		{
   814  			name:   "ipv4 address with non-default port",
   815  			target: "127.0.0.1:12345",
   816  		},
   817  		{
   818  			name:   "localhost ipv6 with brackets",
   819  			target: "[::1]:80",
   820  		},
   821  		{
   822  			name:   "ipv6 address with brackets",
   823  			target: "[2001:db8:a0b:12f0::1]:21",
   824  		},
   825  		{
   826  			name:   "empty host with port",
   827  			target: ":80",
   828  		},
   829  		{
   830  			name:   "ipv6 address with zone",
   831  			target: "[fe80::1%25lo0]:80",
   832  		},
   833  		{
   834  			name:   "url with port",
   835  			target: "golang.org:http",
   836  		},
   837  		{
   838  			name:   "ipv6 address with non integer port",
   839  			target: "[2001:db8::1]:http",
   840  		},
   841  		{
   842  			name:    "address ends with colon",
   843  			target:  "[2001:db8::1]:",
   844  			wantErr: dnsinternal.ErrEndsWithColon.Error(),
   845  		},
   846  		{
   847  			name:    "address contains only a colon",
   848  			target:  ":",
   849  			wantErr: dnsinternal.ErrEndsWithColon.Error(),
   850  		},
   851  		{
   852  			name:    "empty address",
   853  			target:  "",
   854  			wantErr: dnsinternal.ErrMissingAddr.Error(),
   855  		},
   856  		{
   857  			name:    "invalid address",
   858  			target:  "[2001:db8:a0b:12f0::1",
   859  			wantErr: "invalid target address",
   860  		},
   861  	}
   862  
   863  	for _, test := range tests {
   864  		t.Run(test.name, func(t *testing.T) {
   865  			overrideTimeAfterFunc(t, 2*defaultTestTimeout)
   866  
   867  			b := resolver.Get("dns")
   868  			if b == nil {
   869  				t.Fatalf("Resolver for dns:/// scheme not registered")
   870  			}
   871  
   872  			tcc := &testutils.ResolverClientConn{Logger: t}
   873  			r, err := b.Build(resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("dns:///%s", test.target))}, tcc, resolver.BuildOptions{})
   874  			if err != nil {
   875  				if test.wantErr == "" {
   876  					t.Fatalf("DNS resolver build for target %q failed with error: %v", test.target, err)
   877  				}
   878  				if !strings.Contains(err.Error(), test.wantErr) {
   879  					t.Fatalf("DNS resolver build for target %q failed with error: %v, wantErr: %s", test.target, err, test.wantErr)
   880  				}
   881  				return
   882  			}
   883  			if err == nil && test.wantErr != "" {
   884  				t.Fatalf("DNS resolver build for target %q succeeded when expected to fail with error: %s", test.target, test.wantErr)
   885  			}
   886  			r.Close()
   887  		})
   888  	}
   889  }
   890  
   891  // Tests scenarios where fetching of service config is enabled or disabled, and
   892  // verifies that the expected update is pushed by the DNS resolver.
   893  func (s) TestDisableServiceConfig(t *testing.T) {
   894  	tests := []struct {
   895  		name                 string
   896  		target               string
   897  		hostLookupTable      map[string][]string
   898  		txtLookupTable       map[string][]string
   899  		disableServiceConfig bool
   900  		wantAddrs            []resolver.Address
   901  		wantSC               string
   902  	}{
   903  		{
   904  			name:            "false",
   905  			target:          "foo.bar.com",
   906  			hostLookupTable: map[string][]string{"foo.bar.com": {"1.2.3.4", "5.6.7.8"}},
   907  			txtLookupTable: map[string][]string{
   908  				"_grpc_config.foo.bar.com": txtRecordServiceConfig(txtRecordGood),
   909  			},
   910  			disableServiceConfig: false,
   911  			wantAddrs:            []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}},
   912  			wantSC:               scJSON,
   913  		},
   914  		{
   915  			name:            "true",
   916  			target:          "foo.bar.com",
   917  			hostLookupTable: map[string][]string{"foo.bar.com": {"1.2.3.4", "5.6.7.8"}},
   918  			txtLookupTable: map[string][]string{
   919  				"_grpc_config.foo.bar.com": txtRecordServiceConfig(txtRecordGood),
   920  			},
   921  			disableServiceConfig: true,
   922  			wantAddrs:            []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}},
   923  			wantSC:               "{}",
   924  		},
   925  	}
   926  
   927  	for _, test := range tests {
   928  		t.Run(test.name, func(t *testing.T) {
   929  			overrideTimeAfterFunc(t, 2*defaultTestTimeout)
   930  			overrideNetResolver(t, &testNetResolver{
   931  				hostLookupTable: test.hostLookupTable,
   932  				txtLookupTable:  test.txtLookupTable,
   933  			})
   934  
   935  			b := resolver.Get("dns")
   936  			if b == nil {
   937  				t.Fatalf("Resolver for dns:/// scheme not registered")
   938  			}
   939  
   940  			stateCh := make(chan resolver.State, 1)
   941  			updateStateF := func(s resolver.State) error {
   942  				stateCh <- s
   943  				return nil
   944  			}
   945  			tcc := &testutils.ResolverClientConn{Logger: t, UpdateStateF: updateStateF}
   946  			r, err := b.Build(resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("dns:///%s", test.target))}, tcc, resolver.BuildOptions{DisableServiceConfig: test.disableServiceConfig})
   947  			if err != nil {
   948  				t.Fatalf("Failed to build DNS resolver for target %q: %v\n", test.target, err)
   949  			}
   950  			defer r.Close()
   951  
   952  			ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   953  			defer cancel()
   954  			verifyUpdateFromResolver(ctx, t, stateCh, test.wantAddrs, nil, test.wantSC)
   955  		})
   956  	}
   957  }
   958  
   959  // Tests the case where a TXT lookup is expected to return an error. Verifies
   960  // that errors are ignored with the corresponding env var is set.
   961  func (s) TestTXTError(t *testing.T) {
   962  	for _, ignore := range []bool{false, true} {
   963  		t.Run(fmt.Sprintf("%v", ignore), func(t *testing.T) {
   964  			overrideTimeAfterFunc(t, 2*defaultTestTimeout)
   965  			overrideNetResolver(t, &testNetResolver{hostLookupTable: map[string][]string{"ipv4.single.fake": {"1.2.3.4"}}})
   966  
   967  			origTXTIgnore := envconfig.TXTErrIgnore
   968  			envconfig.TXTErrIgnore = ignore
   969  			defer func() { envconfig.TXTErrIgnore = origTXTIgnore }()
   970  
   971  			// There is no entry for "ipv4.single.fake" in the txtLookupTbl
   972  			// maintained by the fake net.Resolver. So, a TXT lookup for this
   973  			// name will return an error.
   974  			_, stateCh, _ := buildResolverWithTestClientConn(t, "ipv4.single.fake")
   975  
   976  			ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   977  			defer cancel()
   978  			var state resolver.State
   979  			select {
   980  			case <-ctx.Done():
   981  				t.Fatal("Timeout when waiting for a state update from the resolver")
   982  			case state = <-stateCh:
   983  			}
   984  
   985  			if ignore {
   986  				if state.ServiceConfig != nil {
   987  					t.Fatalf("Received non-nil service config: %+v; want nil", state.ServiceConfig)
   988  				}
   989  			} else {
   990  				if state.ServiceConfig == nil || state.ServiceConfig.Err == nil {
   991  					t.Fatalf("Received service config %+v; want non-nil error", state.ServiceConfig)
   992  				}
   993  			}
   994  		})
   995  	}
   996  }
   997  
   998  // Tests different cases for a user's dial target that specifies a non-empty
   999  // authority (or Host field of the URL).
  1000  func (s) TestCustomAuthority(t *testing.T) {
  1001  	tests := []struct {
  1002  		name          string
  1003  		authority     string
  1004  		wantAuthority string
  1005  		wantBuildErr  bool
  1006  	}{
  1007  		{
  1008  			name:          "authority with default DNS port",
  1009  			authority:     "4.3.2.1:53",
  1010  			wantAuthority: "4.3.2.1:53",
  1011  		},
  1012  		{
  1013  			name:          "authority with non-default DNS port",
  1014  			authority:     "4.3.2.1:123",
  1015  			wantAuthority: "4.3.2.1:123",
  1016  		},
  1017  		{
  1018  			name:          "authority with no port",
  1019  			authority:     "4.3.2.1",
  1020  			wantAuthority: "4.3.2.1:53",
  1021  		},
  1022  		{
  1023  			name:          "ipv6 authority with no port",
  1024  			authority:     "::1",
  1025  			wantAuthority: "[::1]:53",
  1026  		},
  1027  		{
  1028  			name:          "ipv6 authority with brackets and no port",
  1029  			authority:     "[::1]",
  1030  			wantAuthority: "[::1]:53",
  1031  		},
  1032  		{
  1033  			name:          "ipv6 authority with brackets and non-default DNS port",
  1034  			authority:     "[::1]:123",
  1035  			wantAuthority: "[::1]:123",
  1036  		},
  1037  		{
  1038  			name:          "host name with no port",
  1039  			authority:     "dnsserver.com",
  1040  			wantAuthority: "dnsserver.com:53",
  1041  		},
  1042  		{
  1043  			name:          "no host port and non-default port",
  1044  			authority:     ":123",
  1045  			wantAuthority: "localhost:123",
  1046  		},
  1047  		{
  1048  			name:          "only colon",
  1049  			authority:     ":",
  1050  			wantAuthority: "",
  1051  			wantBuildErr:  true,
  1052  		},
  1053  		{
  1054  			name:          "ipv6 name ending in colon",
  1055  			authority:     "[::1]:",
  1056  			wantAuthority: "",
  1057  			wantBuildErr:  true,
  1058  		},
  1059  		{
  1060  			name:          "host name ending in colon",
  1061  			authority:     "dnsserver.com:",
  1062  			wantAuthority: "",
  1063  			wantBuildErr:  true,
  1064  		},
  1065  	}
  1066  
  1067  	for _, test := range tests {
  1068  		t.Run(test.name, func(t *testing.T) {
  1069  			overrideTimeAfterFunc(t, 2*defaultTestTimeout)
  1070  
  1071  			// Override the address dialer to verify the authority being passed.
  1072  			origAddressDialer := dnsinternal.AddressDialer
  1073  			errChan := make(chan error, 1)
  1074  			dnsinternal.AddressDialer = func(authority string) func(ctx context.Context, network, address string) (net.Conn, error) {
  1075  				if authority != test.wantAuthority {
  1076  					errChan <- fmt.Errorf("wrong custom authority passed to resolver. target: %s got authority: %s want authority: %s", test.authority, authority, test.wantAuthority)
  1077  				} else {
  1078  					errChan <- nil
  1079  				}
  1080  				return func(ctx context.Context, network, address string) (net.Conn, error) {
  1081  					return nil, errors.New("no need to dial")
  1082  				}
  1083  			}
  1084  			defer func() { dnsinternal.AddressDialer = origAddressDialer }()
  1085  
  1086  			b := resolver.Get("dns")
  1087  			if b == nil {
  1088  				t.Fatalf("Resolver for dns:/// scheme not registered")
  1089  			}
  1090  
  1091  			tcc := &testutils.ResolverClientConn{Logger: t}
  1092  			endpoint := "foo.bar.com"
  1093  			target := resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("dns://%s/%s", test.authority, endpoint))}
  1094  			r, err := b.Build(target, tcc, resolver.BuildOptions{})
  1095  			if (err != nil) != test.wantBuildErr {
  1096  				t.Fatalf("DNS resolver build for target %+v returned error %v: wantErr: %v\n", target, err, test.wantBuildErr)
  1097  			}
  1098  			if err != nil {
  1099  				return
  1100  			}
  1101  			defer r.Close()
  1102  
  1103  			if err := <-errChan; err != nil {
  1104  				t.Fatal(err)
  1105  			}
  1106  		})
  1107  	}
  1108  }
  1109  
  1110  // TestRateLimitedResolve exercises the rate limit enforced on re-resolution
  1111  // requests. It sets the re-resolution rate to a small value and repeatedly
  1112  // calls ResolveNow() and ensures only the expected number of resolution
  1113  // requests are made.
  1114  func (s) TestRateLimitedResolve(t *testing.T) {
  1115  	const target = "foo.bar.com"
  1116  	_, timeChan := overrideTimeAfterFuncWithChannel(t)
  1117  	tr := &testNetResolver{
  1118  		lookupHostCh:    testutils.NewChannel(),
  1119  		hostLookupTable: map[string][]string{target: {"1.2.3.4", "5.6.7.8"}},
  1120  	}
  1121  	overrideNetResolver(t, tr)
  1122  
  1123  	r, stateCh, _ := buildResolverWithTestClientConn(t, target)
  1124  
  1125  	// Wait for the first resolution request to be done. This happens as part
  1126  	// of the first iteration of the for loop in watcher().
  1127  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  1128  	defer cancel()
  1129  	if _, err := tr.lookupHostCh.Receive(ctx); err != nil {
  1130  		t.Fatalf("Timed out waiting for lookup() call.")
  1131  	}
  1132  
  1133  	// Call Resolve Now 100 times, shouldn't continue onto next iteration of
  1134  	// watcher, thus shouldn't lookup again.
  1135  	for i := 0; i <= 100; i++ {
  1136  		r.ResolveNow(resolver.ResolveNowOptions{})
  1137  	}
  1138  
  1139  	continueCtx, continueCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout)
  1140  	defer continueCancel()
  1141  	if _, err := tr.lookupHostCh.Receive(continueCtx); err == nil {
  1142  		t.Fatalf("Should not have looked up again as DNS Min Res Rate timer has not gone off.")
  1143  	}
  1144  
  1145  	// Make the DNSMinResRate timer fire immediately, by sending the current
  1146  	// time on it. This will unblock the resolver which is currently blocked on
  1147  	// the DNS Min Res Rate timer going off, which will allow it to continue to
  1148  	// the next iteration of the watcher loop.
  1149  	select {
  1150  	case timeChan <- time.Now():
  1151  	case <-ctx.Done():
  1152  		t.Fatal("Timed out waiting for the DNS resolver to block on DNS Min Res Rate to elapse")
  1153  	}
  1154  
  1155  	// Now that DNS Min Res Rate timer has gone off, it should lookup again.
  1156  	if _, err := tr.lookupHostCh.Receive(ctx); err != nil {
  1157  		t.Fatalf("Timed out waiting for lookup() call.")
  1158  	}
  1159  
  1160  	// Resolve Now 1000 more times, shouldn't lookup again as DNS Min Res Rate
  1161  	// timer has not gone off.
  1162  	for i := 0; i < 1000; i++ {
  1163  		r.ResolveNow(resolver.ResolveNowOptions{})
  1164  	}
  1165  	continueCtx, continueCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout)
  1166  	defer continueCancel()
  1167  	if _, err := tr.lookupHostCh.Receive(continueCtx); err == nil {
  1168  		t.Fatalf("Should not have looked up again as DNS Min Res Rate timer has not gone off.")
  1169  	}
  1170  
  1171  	// Make the DNSMinResRate timer fire immediately again.
  1172  	select {
  1173  	case timeChan <- time.Now():
  1174  	case <-ctx.Done():
  1175  		t.Fatal("Timed out waiting for the DNS resolver to block on DNS Min Res Rate to elapse")
  1176  	}
  1177  
  1178  	// Now that DNS Min Res Rate timer has gone off, it should lookup again.
  1179  	if _, err := tr.lookupHostCh.Receive(ctx); err != nil {
  1180  		t.Fatalf("Timed out waiting for lookup() call.")
  1181  	}
  1182  
  1183  	wantAddrs := []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}}
  1184  	var state resolver.State
  1185  	select {
  1186  	case <-ctx.Done():
  1187  		t.Fatal("Timeout when waiting for a state update from the resolver")
  1188  	case state = <-stateCh:
  1189  	}
  1190  	if !cmp.Equal(state.Addresses, wantAddrs, cmpopts.EquateEmpty()) {
  1191  		t.Fatalf("Got addresses: %+v, want: %+v", state.Addresses, wantAddrs)
  1192  	}
  1193  }
  1194  
  1195  // Test verifies that when the DNS resolver gets an error from the underlying
  1196  // net.Resolver, it reports the error to the channel and backs off and retries.
  1197  func (s) TestReportError(t *testing.T) {
  1198  	durChan, timeChan := overrideTimeAfterFuncWithChannel(t)
  1199  	overrideNetResolver(t, &testNetResolver{})
  1200  
  1201  	const target = "notfoundaddress"
  1202  	_, _, errorCh := buildResolverWithTestClientConn(t, target)
  1203  
  1204  	// Should receive first error.
  1205  	ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  1206  	defer ctxCancel()
  1207  	select {
  1208  	case <-ctx.Done():
  1209  		t.Fatal("Timeout when waiting for an error from the resolver")
  1210  	case err := <-errorCh:
  1211  		if !strings.Contains(err.Error(), "hostLookup error") {
  1212  			t.Fatalf(`ReportError(err=%v) called; want err contains "hostLookupError"`, err)
  1213  		}
  1214  	}
  1215  
  1216  	// Expect the DNS resolver to backoff and attempt to re-resolve. Every time,
  1217  	// the DNS resolver will receive the same error from the net.Resolver and is
  1218  	// expected to push it to the channel.
  1219  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  1220  	defer cancel()
  1221  	const retries = 10
  1222  	var prevDur time.Duration
  1223  	for i := 0; i < retries; i++ {
  1224  		select {
  1225  		case <-ctx.Done():
  1226  			t.Fatalf("(Iteration: %d): Timeout when waiting for DNS resolver to backoff", i)
  1227  		case dur := <-durChan:
  1228  			if dur <= prevDur {
  1229  				t.Fatalf("(Iteration: %d): Unexpected decrease in amount of time to backoff", i)
  1230  			}
  1231  		}
  1232  
  1233  		// Unblock the DNS resolver's backoff by pushing the current time.
  1234  		timeChan <- time.Now()
  1235  
  1236  		select {
  1237  		case <-ctx.Done():
  1238  			t.Fatal("Timeout when waiting for an error from the resolver")
  1239  		case err := <-errorCh:
  1240  			if !strings.Contains(err.Error(), "hostLookup error") {
  1241  				t.Fatalf(`ReportError(err=%v) called; want err contains "hostLookupError"`, err)
  1242  			}
  1243  		}
  1244  	}
  1245  }
  1246  
  1247  // Override the default dns.ResolvingTimeout with a test duration.
  1248  func overrideResolveTimeoutDuration(t *testing.T, dur time.Duration) {
  1249  	t.Helper()
  1250  
  1251  	origDur := dns.ResolvingTimeout
  1252  	dnspublic.SetResolvingTimeout(dur)
  1253  
  1254  	t.Cleanup(func() { dnspublic.SetResolvingTimeout(origDur) })
  1255  }
  1256  
  1257  // Test verifies that the DNS resolver gets timeout error when net.Resolver
  1258  // takes too long to resolve a target.
  1259  func (s) TestResolveTimeout(t *testing.T) {
  1260  	// Set DNS resolving timeout duration to 7ms
  1261  	timeoutDur := 7 * time.Millisecond
  1262  	overrideResolveTimeoutDuration(t, timeoutDur)
  1263  
  1264  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  1265  	defer cancel()
  1266  
  1267  	// We are trying to resolve hostname which takes infinity time to resolve.
  1268  	const target = "infinity"
  1269  
  1270  	// Define a testNetResolver with lookupHostCh, an unbuffered channel,
  1271  	// so we can block the resolver until reaching timeout.
  1272  	tr := &testNetResolver{
  1273  		lookupHostCh:    testutils.NewChannelWithSize(0),
  1274  		hostLookupTable: map[string][]string{target: {"1.2.3.4"}},
  1275  	}
  1276  	overrideNetResolver(t, tr)
  1277  
  1278  	_, _, errCh := buildResolverWithTestClientConn(t, target)
  1279  	select {
  1280  	case <-ctx.Done():
  1281  		t.Fatal("Timeout when waiting for the DNS resolver to timeout")
  1282  	case err := <-errCh:
  1283  		if err == nil || !strings.Contains(err.Error(), "context deadline exceeded") {
  1284  			t.Fatalf(`Expected to see Timeout error; got: %v`, err)
  1285  		}
  1286  	}
  1287  }
  1288  
  1289  // Test verifies that changing [MinResolutionInterval] variable correctly effects
  1290  // the resolution behaviour
  1291  func (s) TestMinResolutionInterval(t *testing.T) {
  1292  	const target = "foo.bar.com"
  1293  
  1294  	overrideResolutionInterval(t, 1*time.Millisecond)
  1295  	tr := &testNetResolver{
  1296  		hostLookupTable: map[string][]string{
  1297  			"foo.bar.com": {"1.2.3.4", "5.6.7.8"},
  1298  		},
  1299  		txtLookupTable: map[string][]string{
  1300  			"_grpc_config.foo.bar.com": txtRecordServiceConfig(txtRecordGood),
  1301  		},
  1302  	}
  1303  	overrideNetResolver(t, tr)
  1304  
  1305  	r, stateCh, _ := buildResolverWithTestClientConn(t, target)
  1306  
  1307  	wantAddrs := []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}}
  1308  	wantSC := scJSON
  1309  
  1310  	for i := 0; i < 5; i++ {
  1311  		// set context timeout slightly higher than the min resolution interval to make sure resolutions
  1312  		// happen successfully
  1313  		ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
  1314  		defer cancel()
  1315  
  1316  		verifyUpdateFromResolver(ctx, t, stateCh, wantAddrs, nil, wantSC)
  1317  		r.ResolveNow(resolver.ResolveNowOptions{})
  1318  	}
  1319  }
  1320  
  1321  // TestMinResolutionInterval_NoExtraDelay verifies that there is no extra delay
  1322  // between two resolution requests apart from [MinResolutionInterval].
  1323  func (s) TestMinResolutionInterval_NoExtraDelay(t *testing.T) {
  1324  	tr := &testNetResolver{
  1325  		hostLookupTable: map[string][]string{
  1326  			"foo.bar.com": {"1.2.3.4", "5.6.7.8"},
  1327  		},
  1328  		txtLookupTable: map[string][]string{
  1329  			"_grpc_config.foo.bar.com": txtRecordServiceConfig(txtRecordGood),
  1330  		},
  1331  	}
  1332  	overrideNetResolver(t, tr)
  1333  	// Override time.Now() to return a zero value for time. This will allow us
  1334  	// to verify that the call to time.Until is made with the exact
  1335  	// [MinResolutionInterval] that we expect.
  1336  	overrideTimeNowFunc(t, time.Time{})
  1337  	// Override time.Until() to read the time passed to it
  1338  	// and return immediately without any delay
  1339  	timeCh := overrideTimeUntilFuncWithChannel(t)
  1340  
  1341  	r, stateCh, errorCh := buildResolverWithTestClientConn(t, "foo.bar.com")
  1342  
  1343  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  1344  	defer cancel()
  1345  
  1346  	// Ensure that the first resolution happens.
  1347  	select {
  1348  	case <-ctx.Done():
  1349  		t.Fatal("Timeout when waiting for DNS resolver")
  1350  	case err := <-errorCh:
  1351  		t.Fatalf("Unexpected error from resolver, %v", err)
  1352  	case <-stateCh:
  1353  	}
  1354  
  1355  	// Request re-resolution and verify that the resolver waits for
  1356  	// [MinResolutionInterval].
  1357  	r.ResolveNow(resolver.ResolveNowOptions{})
  1358  	select {
  1359  	case <-ctx.Done():
  1360  		t.Fatal("Timeout when waiting for DNS resolver")
  1361  	case gotTime := <-timeCh:
  1362  		wantTime := time.Time{}.Add(dns.MinResolutionInterval)
  1363  		if !gotTime.Equal(wantTime) {
  1364  			t.Fatalf("DNS resolver waits for %v time before re-resolution, want %v", gotTime, wantTime)
  1365  		}
  1366  	}
  1367  
  1368  	// Ensure that the re-resolution request actually happens.
  1369  	select {
  1370  	case <-ctx.Done():
  1371  		t.Fatal("Timeout when waiting for an error from the resolver")
  1372  	case err := <-errorCh:
  1373  		t.Fatalf("Unexpected error from resolver, %v", err)
  1374  	case <-stateCh:
  1375  	}
  1376  }