github.com/searKing/golang/go@v1.2.117/net/resolver/dns/dns_resolver_test.go (about)

     1  // Copyright 2021 The searKing Author. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package dns
     6  
     7  import (
     8  	"context"
     9  	"errors"
    10  	"fmt"
    11  	"net"
    12  	"os"
    13  	"reflect"
    14  	"strings"
    15  	"sync"
    16  	"testing"
    17  	"time"
    18  
    19  	"github.com/searKing/golang/go/net/resolver"
    20  	testing_ "github.com/searKing/golang/go/testing"
    21  	"github.com/searKing/golang/go/testing/leakcheck"
    22  )
    23  
    24  func TestMain(m *testing.M) {
    25  	// Set a non-zero duration only for tests which are actually testing that
    26  	// feature.
    27  	replaceDNSResRate(time.Duration(0)) // No need to clean up since we os.Exit
    28  	overrideDefaultResolver(false)      // No need to clean up since we os.Exit
    29  	code := m.Run()
    30  	os.Exit(code)
    31  }
    32  
    33  const (
    34  	defaultTestTimeout      = 10 * time.Second
    35  	defaultTestShortTimeout = 10 * time.Millisecond
    36  )
    37  
    38  type testClientConn struct {
    39  	resolver.ClientConn // For unimplemented functions
    40  	target              string
    41  	m1                  sync.Mutex
    42  	state               resolver.State
    43  	updateStateCalls    int
    44  	errChan             chan error
    45  	updateStateErr      error
    46  }
    47  
    48  func (t *testClientConn) UpdateState(s resolver.State) error {
    49  	t.m1.Lock()
    50  	defer t.m1.Unlock()
    51  	t.state = s
    52  	t.updateStateCalls++
    53  	// This error determines whether DNS Resolver actually decides to exponentially backoff or not.
    54  	// This can be any error.
    55  	return t.updateStateErr
    56  }
    57  
    58  func (t *testClientConn) getState() (resolver.State, int) {
    59  	t.m1.Lock()
    60  	defer t.m1.Unlock()
    61  	return t.state, t.updateStateCalls
    62  }
    63  
    64  func scFromState(s resolver.State) string {
    65  	return ""
    66  }
    67  
    68  func (t *testClientConn) ReportError(err error) {
    69  	t.errChan <- err
    70  }
    71  
    72  type testResolver struct {
    73  	// A write to this channel is made when this resolver receives a resolution
    74  	// request. Tests can rely on reading from this channel to be notified about
    75  	// resolution requests instead of sleeping for a predefined period of time.
    76  	lookupHostCh *testing_.Channel
    77  }
    78  
    79  func (tr *testResolver) LookupHost(ctx context.Context, host string) ([]string, error) {
    80  	if tr.lookupHostCh != nil {
    81  		tr.lookupHostCh.Send(nil)
    82  	}
    83  	return hostLookup(host)
    84  }
    85  
    86  func (*testResolver) LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, error) {
    87  	return srvLookup(service, proto, name)
    88  }
    89  
    90  func (*testResolver) LookupTXT(ctx context.Context, host string) ([]string, error) {
    91  	return []string{host}, nil
    92  }
    93  
    94  // overrideDefaultResolver overrides the defaultResolver used by the code with
    95  // an instance of the testResolver. pushOnLookup controls whether the
    96  // testResolver created here pushes lookupHost events on its channel.
    97  func overrideDefaultResolver(pushOnLookup bool) func() {
    98  	oldResolver := defaultResolver
    99  
   100  	var lookupHostCh *testing_.Channel
   101  	if pushOnLookup {
   102  		lookupHostCh = testing_.NewChannel()
   103  	}
   104  	defaultResolver = &testResolver{lookupHostCh: lookupHostCh}
   105  
   106  	return func() {
   107  		defaultResolver = oldResolver
   108  	}
   109  }
   110  
   111  func replaceDNSResRate(d time.Duration) func() {
   112  	oldMinDNSResRate := minDNSResRate
   113  	minDNSResRate = d
   114  
   115  	return func() {
   116  		minDNSResRate = oldMinDNSResRate
   117  	}
   118  }
   119  
   120  var hostLookupTbl = struct {
   121  	sync.Mutex
   122  	tbl map[string][]string
   123  }{
   124  	tbl: map[string][]string{
   125  		"foo.bar.com":          {"1.2.3.4", "5.6.7.8"},
   126  		"ipv4.single.fake":     {"1.2.3.4"},
   127  		"srv.ipv4.single.fake": {"2.4.6.8"},
   128  		"srv.ipv4.multi.fake":  {},
   129  		"srv.ipv6.single.fake": {},
   130  		"srv.ipv6.multi.fake":  {},
   131  		"ipv4.multi.fake":      {"1.2.3.4", "5.6.7.8", "9.10.11.12"},
   132  		"ipv6.single.fake":     {"2607:f8b0:400a:801::1001"},
   133  		"ipv6.multi.fake":      {"2607:f8b0:400a:801::1001", "2607:f8b0:400a:801::1002", "2607:f8b0:400a:801::1003"},
   134  	},
   135  }
   136  
   137  func hostLookup(host string) ([]string, error) {
   138  	hostLookupTbl.Lock()
   139  	defer hostLookupTbl.Unlock()
   140  	if addrs, ok := hostLookupTbl.tbl[host]; ok {
   141  		return addrs, nil
   142  	}
   143  	return nil, &net.DNSError{
   144  		Err:         "hostLookup error",
   145  		Name:        host,
   146  		Server:      "fake",
   147  		IsTemporary: true,
   148  	}
   149  }
   150  
   151  var srvLookupTbl = struct {
   152  	sync.Mutex
   153  	tbl map[string][]*net.SRV
   154  }{
   155  	tbl: map[string][]*net.SRV{
   156  		"_grpclb._tcp.srv.ipv4.single.fake": {&net.SRV{Target: "ipv4.single.fake", Port: 1234}},
   157  		"_grpclb._tcp.srv.ipv4.multi.fake":  {&net.SRV{Target: "ipv4.multi.fake", Port: 1234}},
   158  		"_grpclb._tcp.srv.ipv6.single.fake": {&net.SRV{Target: "ipv6.single.fake", Port: 1234}},
   159  		"_grpclb._tcp.srv.ipv6.multi.fake":  {&net.SRV{Target: "ipv6.multi.fake", Port: 1234}},
   160  	},
   161  }
   162  
   163  func srvLookup(service, proto, name string) (string, []*net.SRV, error) {
   164  	cname := "_" + service + "._" + proto + "." + name
   165  	srvLookupTbl.Lock()
   166  	defer srvLookupTbl.Unlock()
   167  	if srvs, cnt := srvLookupTbl.tbl[cname]; cnt {
   168  		return cname, srvs, nil
   169  	}
   170  	return "", nil, &net.DNSError{
   171  		Err:         "srvLookup error",
   172  		Name:        cname,
   173  		Server:      "fake",
   174  		IsTemporary: true,
   175  	}
   176  }
   177  
   178  func TestResolve(t *testing.T) {
   179  	testDNSResolver(t)
   180  	testDNSResolverWithSRV(t)
   181  	testDNSResolveNow(t)
   182  	testIPResolver(t)
   183  }
   184  
   185  func testDNSResolver(t *testing.T) {
   186  	defer leakcheck.Check(t)
   187  	defer func(nt func(d time.Duration) *time.Timer) {
   188  		newTimer = nt
   189  	}(newTimer)
   190  	newTimer = func(_ time.Duration) *time.Timer {
   191  		// Will never fire on its own, will protect from triggering exponential backoff.
   192  		return time.NewTimer(time.Hour)
   193  	}
   194  	tests := []struct {
   195  		target   string
   196  		addrWant []resolver.Address
   197  	}{
   198  		{
   199  			"foo.bar.com",
   200  			[]resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}},
   201  		},
   202  		{
   203  			"foo.bar.com:1234",
   204  			[]resolver.Address{{Addr: "1.2.3.4:1234"}, {Addr: "5.6.7.8:1234"}},
   205  		},
   206  		{
   207  			"srv.ipv4.single.fake",
   208  			[]resolver.Address{{Addr: "2.4.6.8" + colonDefaultPort}},
   209  		},
   210  		{
   211  			"srv.ipv4.multi.fake",
   212  			nil,
   213  		},
   214  		{
   215  			"srv.ipv6.single.fake",
   216  			nil,
   217  		},
   218  		{
   219  			"srv.ipv6.multi.fake",
   220  			nil,
   221  		},
   222  	}
   223  
   224  	for _, a := range tests {
   225  		b := NewBuilder()
   226  		cc := &testClientConn{target: a.target}
   227  		r, err := b.Build(context.Background(), resolver.Target{Endpoint: a.target}, resolver.BuildWithClientConn(cc))
   228  		if err != nil {
   229  			t.Fatalf("%v\n", err)
   230  		}
   231  		var state resolver.State
   232  		var cnt int
   233  		for i := 0; i < 2000; i++ {
   234  			state, cnt = cc.getState()
   235  			if cnt > 0 {
   236  				break
   237  			}
   238  			time.Sleep(time.Millisecond)
   239  		}
   240  		if cnt == 0 {
   241  			t.Fatalf("UpdateState not called after 2s; aborting")
   242  		}
   243  		if !reflect.DeepEqual(a.addrWant, state.Addresses) {
   244  			t.Errorf("Resolved addresses of target: %q = %+v, want %+v", a.target, state.Addresses, a.addrWant)
   245  		}
   246  		r.Close()
   247  	}
   248  }
   249  
   250  // DNS Resolver immediately starts polling on an error from grpc. This should continue until the ClientConn doesn't
   251  // send back an error from updating the DNS Resolver's state.
   252  func TestDNSResolverExponentialBackoff(t *testing.T) {
   253  	//defer leakcheck.Check(t)
   254  	defer func(nt func(d time.Duration) *time.Timer) {
   255  		newTimer = nt
   256  	}(newTimer)
   257  	timerChan := testing_.NewChannel()
   258  	newTimer = func(d time.Duration) *time.Timer {
   259  		// Will never fire on its own, allows this test to call timer immediately.
   260  		t := time.NewTimer(time.Hour)
   261  		timerChan.Send(t)
   262  		return t
   263  	}
   264  	tests := []struct {
   265  		name     string
   266  		target   string
   267  		addrWant []resolver.Address
   268  	}{
   269  		{
   270  			"happy case default port",
   271  			"foo.bar.com",
   272  			[]resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}},
   273  		},
   274  		{
   275  			"happy case specified port",
   276  			"foo.bar.com:1234",
   277  			[]resolver.Address{{Addr: "1.2.3.4:1234"}, {Addr: "5.6.7.8:1234"}},
   278  		},
   279  		{
   280  			"happy case another default port",
   281  			"srv.ipv4.single.fake",
   282  			[]resolver.Address{{Addr: "2.4.6.8" + colonDefaultPort}},
   283  		},
   284  	}
   285  	for _, test := range tests {
   286  		t.Run(test.name, func(t *testing.T) {
   287  			func() {
   288  				ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   289  				defer ctxCancel()
   290  				err := timerChan.Clear(ctx)
   291  				if err != nil {
   292  					t.Fatalf("Error clear timer from mock NewTimer call: %v", err)
   293  				}
   294  			}()
   295  			b := NewBuilder()
   296  			cc := &testClientConn{target: test.target}
   297  			// Cause ClientConn to return an error.
   298  			cc.updateStateErr = resolver.ErrBadResolverState
   299  			r, err := b.Build(context.Background(), resolver.Target{Endpoint: test.target}, resolver.BuildWithClientConn(cc))
   300  			if err != nil {
   301  				t.Fatalf("Error building resolver for target %v: %v", test.target, err)
   302  			}
   303  			var state resolver.State
   304  			var cnt int
   305  			for i := 0; i < 2000; i++ {
   306  				state, cnt = cc.getState()
   307  				if cnt > 0 {
   308  					break
   309  				}
   310  				time.Sleep(time.Millisecond)
   311  			}
   312  			if cnt == 0 {
   313  				t.Fatalf("UpdateState not called after 2s; aborting")
   314  			}
   315  			if !reflect.DeepEqual(test.addrWant, state.Addresses) {
   316  				t.Errorf("Resolved addresses of target: %q = %+v, want %+v", test.target, state.Addresses, test.addrWant)
   317  			}
   318  			ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   319  			defer ctxCancel()
   320  			// Cause timer to go off 10 times, and see if it calls updateState() correctly.
   321  			for i := 0; i < 10; i++ {
   322  				timer, err := timerChan.Receive(ctx)
   323  				if err != nil {
   324  					t.Fatalf("Error receiving timer from mock NewTimer call: %v", err)
   325  				}
   326  				timerPointer := timer.(*time.Timer)
   327  				timerPointer.Reset(0)
   328  			}
   329  			// Poll to see if DNS Resolver updated state the correct number of times, which allows time for the DNS Resolver to call
   330  			// ClientConn update state.
   331  			deadline := time.Now().Add(defaultTestTimeout)
   332  			for {
   333  				cc.m1.Lock()
   334  				got := cc.updateStateCalls
   335  				cc.m1.Unlock()
   336  				if got == 11 {
   337  					break
   338  				}
   339  
   340  				if time.Now().After(deadline) {
   341  					t.Fatalf("Exponential backoff is not working as expected - should update state 11 times instead of %d", got)
   342  				}
   343  
   344  				time.Sleep(time.Millisecond)
   345  			}
   346  
   347  			// Update resolver.ClientConn to not return an error anymore - this should stop it from backing off.
   348  			cc.updateStateErr = nil
   349  			timer, err := timerChan.Receive(ctx)
   350  			if err != nil {
   351  				t.Fatalf("Error receiving timer from mock NewTimer call: %v", err)
   352  			}
   353  			timerPointer := timer.(*time.Timer)
   354  			timerPointer.Reset(0)
   355  			// Poll to see if DNS Resolver updated state the correct number of times, which allows time for the DNS Resolver to call
   356  			// ClientConn update state the final time. The DNS Resolver should then stop polling.
   357  			deadline = time.Now().Add(defaultTestTimeout)
   358  			for {
   359  				cc.m1.Lock()
   360  				got := cc.updateStateCalls
   361  				cc.m1.Unlock()
   362  				if got == 12 {
   363  					break
   364  				}
   365  
   366  				if time.Now().After(deadline) {
   367  					t.Fatalf("Exponential backoff is not working as expected - should stop backing off at 12 total UpdateState calls instead of %d", got)
   368  				}
   369  
   370  				_, err := timerChan.ReceiveOrFail()
   371  				if err {
   372  					t.Fatalf("Should not poll again after Client Conn stops returning error.")
   373  				}
   374  
   375  				time.Sleep(time.Millisecond)
   376  			}
   377  			r.Close()
   378  		})
   379  	}
   380  }
   381  
   382  func testDNSResolverWithSRV(t *testing.T) {
   383  	EnableSRVLookups = true
   384  	defer func() {
   385  		EnableSRVLookups = false
   386  	}()
   387  	defer leakcheck.Check(t)
   388  	defer func(nt func(d time.Duration) *time.Timer) {
   389  		newTimer = nt
   390  	}(newTimer)
   391  	newTimer = func(_ time.Duration) *time.Timer {
   392  		// Will never fire on its own, will protect from triggering exponential backoff.
   393  		return time.NewTimer(time.Hour)
   394  	}
   395  	tests := []struct {
   396  		target   string
   397  		addrWant []resolver.Address
   398  	}{
   399  		{
   400  			"foo.bar.com",
   401  			[]resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}},
   402  		},
   403  		{
   404  			"foo.bar.com:1234",
   405  			[]resolver.Address{{Addr: "1.2.3.4:1234"}, {Addr: "5.6.7.8:1234"}},
   406  		},
   407  		{
   408  			"srv.ipv4.single.fake",
   409  			[]resolver.Address{{Addr: "2.4.6.8" + colonDefaultPort}},
   410  		},
   411  		{
   412  			"srv.ipv4.multi.fake",
   413  			nil,
   414  		},
   415  		{
   416  			"srv.ipv6.single.fake",
   417  			nil,
   418  		},
   419  		{
   420  			"srv.ipv6.multi.fake",
   421  			nil,
   422  		},
   423  	}
   424  
   425  	for _, a := range tests {
   426  		b := NewBuilder()
   427  		cc := &testClientConn{target: a.target}
   428  		r, err := b.Build(context.Background(), resolver.Target{Endpoint: a.target}, resolver.BuildWithClientConn(cc))
   429  		if err != nil {
   430  			t.Fatalf("%v\n", err)
   431  		}
   432  		defer r.Close()
   433  		var state resolver.State
   434  		var cnt int
   435  		for i := 0; i < 2000; i++ {
   436  			state, cnt = cc.getState()
   437  			if cnt > 0 {
   438  				break
   439  			}
   440  			time.Sleep(time.Millisecond)
   441  		}
   442  		if cnt == 0 {
   443  			t.Fatalf("UpdateState not called after 2s; aborting")
   444  		}
   445  		if !reflect.DeepEqual(a.addrWant, state.Addresses) {
   446  			t.Errorf("Resolved addresses of target: %q = %+v, want %+v", a.target, state.Addresses, a.addrWant)
   447  		}
   448  	}
   449  }
   450  
   451  func testDNSResolveNow(t *testing.T) {
   452  	defer leakcheck.Check(t)
   453  	tests := []struct {
   454  		target   string
   455  		addrWant []resolver.Address
   456  	}{
   457  		{
   458  			"foo.bar.com",
   459  			[]resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}},
   460  		},
   461  	}
   462  
   463  	for _, a := range tests {
   464  		b := NewBuilder()
   465  		cc := &testClientConn{target: a.target}
   466  		r, err := b.Build(context.Background(), resolver.Target{Endpoint: a.target}, resolver.BuildWithClientConn(cc))
   467  		if err != nil {
   468  			t.Fatalf("%v\n", err)
   469  		}
   470  		defer r.Close()
   471  		var state resolver.State
   472  		var cnt int
   473  		for i := 0; i < 2000; i++ {
   474  			state, cnt = cc.getState()
   475  			if cnt > 0 {
   476  				break
   477  			}
   478  			time.Sleep(time.Millisecond)
   479  		}
   480  		if cnt == 0 {
   481  			t.Fatalf("UpdateState not called after 2s; aborting.  state=%v", state)
   482  		}
   483  		if !reflect.DeepEqual(a.addrWant, state.Addresses) {
   484  			t.Errorf("Resolved addresses of target: %q = %+v, want %+v", a.target, state.Addresses, a.addrWant)
   485  		}
   486  
   487  		r.ResolveNow(context.Background())
   488  		for i := 0; i < 2000; i++ {
   489  			state, cnt = cc.getState()
   490  			if cnt == 2 {
   491  				break
   492  			}
   493  			time.Sleep(time.Millisecond)
   494  		}
   495  		if cnt != 2 {
   496  			t.Fatalf("UpdateState not called after 2s; aborting.  state=%v", state)
   497  		}
   498  	}
   499  }
   500  
   501  const colonDefaultPort = ":" + defaultPort
   502  
   503  func testIPResolver(t *testing.T) {
   504  	defer leakcheck.Check(t)
   505  	defer func(nt func(d time.Duration) *time.Timer) {
   506  		newTimer = nt
   507  	}(newTimer)
   508  	newTimer = func(_ time.Duration) *time.Timer {
   509  		// Will never fire on its own, will protect from triggering exponential backoff.
   510  		return time.NewTimer(time.Hour)
   511  	}
   512  	tests := []struct {
   513  		target string
   514  		want   []resolver.Address
   515  	}{
   516  		{"127.0.0.1", []resolver.Address{{Addr: "127.0.0.1" + colonDefaultPort}}},
   517  		{"127.0.0.1:12345", []resolver.Address{{Addr: "127.0.0.1:12345"}}},
   518  		{"::1", []resolver.Address{{Addr: "[::1]" + colonDefaultPort}}},
   519  		{"[::1]:12345", []resolver.Address{{Addr: "[::1]:12345"}}},
   520  		{"[::1]", []resolver.Address{{Addr: "[::1]:443"}}},
   521  		{"2001:db8:85a3::8a2e:370:7334", []resolver.Address{{Addr: "[2001:db8:85a3::8a2e:370:7334]" + colonDefaultPort}}},
   522  		{"[2001:db8:85a3::8a2e:370:7334]", []resolver.Address{{Addr: "[2001:db8:85a3::8a2e:370:7334]" + colonDefaultPort}}},
   523  		{"[2001:db8:85a3::8a2e:370:7334]:12345", []resolver.Address{{Addr: "[2001:db8:85a3::8a2e:370:7334]:12345"}}},
   524  		{"[2001:db8::1]:http", []resolver.Address{{Addr: "[2001:db8::1]:http"}}},
   525  	}
   526  
   527  	for _, v := range tests {
   528  		b := NewBuilder()
   529  		cc := &testClientConn{target: v.target}
   530  		r, err := b.Build(context.Background(), resolver.Target{Endpoint: v.target}, resolver.BuildWithClientConn(cc))
   531  		if err != nil {
   532  			t.Fatalf("%v\n", err)
   533  		}
   534  		var state resolver.State
   535  		var cnt int
   536  		for {
   537  			state, cnt = cc.getState()
   538  			if cnt > 0 {
   539  				break
   540  			}
   541  			time.Sleep(time.Millisecond)
   542  		}
   543  		if !reflect.DeepEqual(v.want, state.Addresses) {
   544  			t.Errorf("Resolved addresses of target: %q = %+v, want %+v", v.target, state.Addresses, v.want)
   545  		}
   546  		r.ResolveNow(context.Background())
   547  		for i := 0; i < 50; i++ {
   548  			state, cnt = cc.getState()
   549  			if cnt > 1 {
   550  				t.Fatalf("Unexpected second call by resolver to UpdateState.  state: %v", state)
   551  			}
   552  			time.Sleep(time.Millisecond)
   553  		}
   554  		r.Close()
   555  	}
   556  }
   557  
   558  func TestResolveFunc(t *testing.T) {
   559  	defer leakcheck.Check(t)
   560  	defer func(nt func(d time.Duration) *time.Timer) {
   561  		newTimer = nt
   562  	}(newTimer)
   563  	newTimer = func(d time.Duration) *time.Timer {
   564  		// Will never fire on its own, will protect from triggering exponential backoff.
   565  		return time.NewTimer(time.Hour)
   566  	}
   567  	tests := []struct {
   568  		addr string
   569  		want error
   570  	}{
   571  		// TODO(yuxuanli): More false cases?
   572  		{"www.google.com", nil},
   573  		{"foo.bar:12345", nil},
   574  		{"127.0.0.1", nil},
   575  		{"::", nil},
   576  		{"127.0.0.1:12345", nil},
   577  		{"[::1]:80", nil},
   578  		{"[2001:db8:a0b:12f0::1]:21", nil},
   579  		{":80", nil},
   580  		{"127.0.0...1:12345", nil},
   581  		{"[fe80::1%lo0]:80", nil},
   582  		{"golang.org:http", nil},
   583  		{"[2001:db8::1]:http", nil},
   584  		{"[2001:db8::1]:", errEndsWithColon},
   585  		{":", errEndsWithColon},
   586  		{"", errMissingAddr},
   587  		{"[2001:db8:a0b:12f0::1", fmt.Errorf("invalid target address [2001:db8:a0b:12f0::1, error info: address [2001:db8:a0b:12f0::1:443: missing ']' in address")},
   588  	}
   589  
   590  	b := NewBuilder()
   591  	for _, v := range tests {
   592  		cc := &testClientConn{target: v.addr, errChan: make(chan error, 1)}
   593  		r, err := b.Build(context.Background(), resolver.Target{Endpoint: v.addr}, resolver.BuildWithClientConn(cc))
   594  		if err == nil {
   595  			r.Close()
   596  		}
   597  		if !reflect.DeepEqual(err, v.want) {
   598  			t.Errorf("Build(%q, cc, _) = %v, want %v", v.addr, err, v.want)
   599  		}
   600  	}
   601  }
   602  
   603  func TestDNSResolverRetry(t *testing.T) {
   604  	b := NewBuilder()
   605  	target := "ipv4.single.fake"
   606  	cc := &testClientConn{target: target}
   607  	r, err := b.Build(context.Background(), resolver.Target{Endpoint: target}, resolver.BuildWithClientConn(cc))
   608  	if err != nil {
   609  		t.Fatalf("%v\n", err)
   610  	}
   611  	defer r.Close()
   612  	var state resolver.State
   613  	for i := 0; i < 2000; i++ {
   614  		state, _ = cc.getState()
   615  		if len(state.Addresses) == 1 {
   616  			break
   617  		}
   618  		time.Sleep(time.Millisecond)
   619  	}
   620  	if len(state.Addresses) != 1 {
   621  		t.Fatalf("UpdateState not called with 1 address after 2s; aborting.  state=%v", state)
   622  	}
   623  	want := []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}}
   624  	if !reflect.DeepEqual(want, state.Addresses) {
   625  		t.Errorf("Resolved addresses of target: %q = %+v, want %+v", target, state.Addresses, want)
   626  	}
   627  	// mutate the host lookup table so the target has 0 address returned.
   628  	revertTbl := mutateTbl(target)
   629  	// trigger a resolve that will get empty address list
   630  	r.ResolveNow(context.Background())
   631  	for i := 0; i < 2000; i++ {
   632  		state, _ = cc.getState()
   633  		if len(state.Addresses) == 0 {
   634  			break
   635  		}
   636  		time.Sleep(time.Millisecond)
   637  	}
   638  	if len(state.Addresses) != 0 {
   639  		t.Fatalf("UpdateState not called with 0 address after 2s; aborting.  state=%v", state)
   640  	}
   641  	revertTbl()
   642  	// wait for the retry to happen in two seconds.
   643  	r.ResolveNow(context.Background())
   644  	for i := 0; i < 2000; i++ {
   645  		state, _ = cc.getState()
   646  		if len(state.Addresses) == 1 {
   647  			break
   648  		}
   649  		time.Sleep(time.Millisecond)
   650  	}
   651  	if !reflect.DeepEqual(want, state.Addresses) {
   652  		t.Errorf("Resolved addresses of target: %q = %+v, want %+v", target, state.Addresses, want)
   653  	}
   654  }
   655  
   656  func TestCustomAuthority(t *testing.T) {
   657  	defer leakcheck.Check(t)
   658  	defer func(nt func(d time.Duration) *time.Timer) {
   659  		newTimer = nt
   660  	}(newTimer)
   661  	newTimer = func(d time.Duration) *time.Timer {
   662  		// Will never fire on its own, will protect from triggering exponential backoff.
   663  		return time.NewTimer(time.Hour)
   664  	}
   665  
   666  	tests := []struct {
   667  		authority     string
   668  		authorityWant string
   669  		expectError   bool
   670  	}{
   671  		{
   672  			"4.3.2.1:" + defaultDNSSvrPort,
   673  			"4.3.2.1:" + defaultDNSSvrPort,
   674  			false,
   675  		},
   676  		{
   677  			"4.3.2.1:123",
   678  			"4.3.2.1:123",
   679  			false,
   680  		},
   681  		{
   682  			"4.3.2.1",
   683  			"4.3.2.1:" + defaultDNSSvrPort,
   684  			false,
   685  		},
   686  		{
   687  			"::1",
   688  			"[::1]:" + defaultDNSSvrPort,
   689  			false,
   690  		},
   691  		{
   692  			"[::1]",
   693  			"[::1]:" + defaultDNSSvrPort,
   694  			false,
   695  		},
   696  		{
   697  			"[::1]:123",
   698  			"[::1]:123",
   699  			false,
   700  		},
   701  		{
   702  			"dnsserver.com",
   703  			"dnsserver.com:" + defaultDNSSvrPort,
   704  			false,
   705  		},
   706  		{
   707  			":123",
   708  			"localhost:123",
   709  			false,
   710  		},
   711  		{
   712  			":",
   713  			"",
   714  			true,
   715  		},
   716  		{
   717  			"[::1]:",
   718  			"",
   719  			true,
   720  		},
   721  		{
   722  			"dnsserver.com:",
   723  			"",
   724  			true,
   725  		},
   726  	}
   727  	oldCustomAuthorityDialler := customAuthorityDialler
   728  	defer func() {
   729  		customAuthorityDialler = oldCustomAuthorityDialler
   730  	}()
   731  
   732  	for _, a := range tests {
   733  		errChan := make(chan error, 1)
   734  		customAuthorityDialler = func(authority string) func(ctx context.Context, network, address string) (net.Conn, error) {
   735  			if authority != a.authorityWant {
   736  				errChan <- fmt.Errorf("wrong custom authority passed to resolver. input: %s expected: %s actual: %s", a.authority, a.authorityWant, authority)
   737  			} else {
   738  				errChan <- nil
   739  			}
   740  			return func(ctx context.Context, network, address string) (net.Conn, error) {
   741  				return nil, errors.New("no need to dial")
   742  			}
   743  		}
   744  
   745  		b := NewBuilder()
   746  		cc := &testClientConn{target: "foo.bar.com", errChan: make(chan error, 1)}
   747  		r, err := b.Build(context.Background(), resolver.Target{Endpoint: "foo.bar.com", Authority: a.authority}, resolver.BuildWithClientConn(cc))
   748  
   749  		if err == nil {
   750  			r.Close()
   751  
   752  			err = <-errChan
   753  			if err != nil {
   754  				t.Errorf(err.Error())
   755  			}
   756  
   757  			if a.expectError {
   758  				t.Errorf("custom authority should have caused an error: %s", a.authority)
   759  			}
   760  		} else if !a.expectError {
   761  			t.Errorf("unexpected error using custom authority %s: %s", a.authority, err)
   762  		}
   763  	}
   764  }
   765  
   766  // TestRateLimitedResolve exercises the rate limit enforced on re-resolution
   767  // requests. It sets the re-resolution rate to a small value and repeatedly
   768  // calls ResolveNow() and ensures only the expected number of resolution
   769  // requests are made.
   770  
   771  func TestRateLimitedResolve(t *testing.T) {
   772  	defer leakcheck.Check(t)
   773  	defer func(nt func(d time.Duration) *time.Timer) {
   774  		newTimer = nt
   775  	}(newTimer)
   776  	newTimer = func(d time.Duration) *time.Timer {
   777  		// Will never fire on its own, will protect from triggering exponential
   778  		// backoff.
   779  		return time.NewTimer(time.Hour)
   780  	}
   781  	defer func(nt func(d time.Duration) *time.Timer) {
   782  		newTimer = nt
   783  	}(newTimer)
   784  
   785  	timerChan := testing_.NewChannel()
   786  	newTimer = func(d time.Duration) *time.Timer {
   787  		// Will never fire on its own, allows this test to call timer
   788  		// immediately.
   789  		t := time.NewTimer(time.Hour)
   790  		timerChan.Send(t)
   791  		return t
   792  	}
   793  
   794  	// Create a new testResolver{} for this test because we want the exact count
   795  	// of the number of times the resolver was invoked.
   796  	nc := overrideDefaultResolver(true)
   797  	defer nc()
   798  
   799  	target := "foo.bar.com"
   800  	b := NewBuilder()
   801  	cc := &testClientConn{target: target}
   802  
   803  	r, err := b.Build(context.Background(), resolver.Target{Endpoint: target}, resolver.BuildWithClientConn(cc))
   804  	if err != nil {
   805  		t.Fatalf("resolver.Build() returned error: %v\n", err)
   806  	}
   807  	defer r.Close()
   808  
   809  	dnsR, ok := r.(*dnsResolver)
   810  	if !ok {
   811  		t.Fatalf("resolver.Build() returned unexpected type: %T\n", dnsR)
   812  	}
   813  
   814  	tr, ok := dnsR.resolver.(*testResolver)
   815  	if !ok {
   816  		t.Fatalf("delegate resolver returned unexpected type: %T\n", tr)
   817  	}
   818  
   819  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   820  	defer cancel()
   821  
   822  	// Wait for the first resolution request to be done. This happens as part
   823  	// of the first iteration of the for loop in watcher().
   824  	if _, err := tr.lookupHostCh.Receive(ctx); err != nil {
   825  		t.Fatalf("Timed out waiting for lookup() call.")
   826  	}
   827  
   828  	// Call Resolve Now 100 times, shouldn't continue onto next iteration of
   829  	// watcher, thus shouldn't lookup again.
   830  	for i := 0; i <= 100; i++ {
   831  		r.ResolveNow(context.Background())
   832  	}
   833  
   834  	continueCtx, continueCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout)
   835  	defer continueCancel()
   836  
   837  	if _, err := tr.lookupHostCh.Receive(continueCtx); err == nil {
   838  		t.Fatalf("Should not have looked up again as DNS Min Res Rate timer has not gone off.")
   839  	}
   840  
   841  	// Make the DNSMinResRate timer fire immediately (by receiving it, then
   842  	// resetting to 0), this will unblock the resolver which is currently
   843  	// blocked on the DNS Min Res Rate timer going off, which will allow it to
   844  	// continue to the next iteration of the watcher loop.
   845  	timer, err := timerChan.Receive(ctx)
   846  	if err != nil {
   847  		t.Fatalf("Error receiving timer from mock NewTimer call: %v", err)
   848  	}
   849  	timerPointer := timer.(*time.Timer)
   850  	timerPointer.Reset(0)
   851  
   852  	// Now that DNS Min Res Rate timer has gone off, it should lookup again.
   853  	if _, err := tr.lookupHostCh.Receive(ctx); err != nil {
   854  		t.Fatalf("Timed out waiting for lookup() call.")
   855  	}
   856  
   857  	// Resolve Now 1000 more times, shouldn't lookup again as DNS Min Res Rate
   858  	// timer has not gone off.
   859  	for i := 0; i < 1000; i++ {
   860  		r.ResolveNow(context.Background())
   861  	}
   862  
   863  	if _, err = tr.lookupHostCh.Receive(continueCtx); err == nil {
   864  		t.Fatalf("Should not have looked up again as DNS Min Res Rate timer has not gone off.")
   865  	}
   866  
   867  	// Make the DNSMinResRate timer fire immediately again.
   868  	timer, err = timerChan.Receive(ctx)
   869  	if err != nil {
   870  		t.Fatalf("Error receiving timer from mock NewTimer call: %v", err)
   871  	}
   872  	timerPointer = timer.(*time.Timer)
   873  	timerPointer.Reset(0)
   874  
   875  	// Now that DNS Min Res Rate timer has gone off, it should lookup again.
   876  	if _, err = tr.lookupHostCh.Receive(ctx); err != nil {
   877  		t.Fatalf("Timed out waiting for lookup() call.")
   878  	}
   879  
   880  	wantAddrs := []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}}
   881  	var state resolver.State
   882  	for {
   883  		var cnt int
   884  		state, cnt = cc.getState()
   885  		if cnt > 0 {
   886  			break
   887  		}
   888  		time.Sleep(time.Millisecond)
   889  	}
   890  	if !reflect.DeepEqual(state.Addresses, wantAddrs) {
   891  		t.Errorf("Resolved addresses of target: %q = %+v, want %+v", target, state.Addresses, wantAddrs)
   892  	}
   893  }
   894  
   895  // DNS Resolver immediately starts polling on an error. This will cause the re-resolution to return another error.
   896  // Thus, test that it constantly sends errors to the grpc.ClientConn.
   897  func TestReportError(t *testing.T) {
   898  	const target = "notfoundaddress"
   899  	defer func(nt func(d time.Duration) *time.Timer) {
   900  		newTimer = nt
   901  	}(newTimer)
   902  	timerChan := testing_.NewChannel()
   903  	newTimer = func(d time.Duration) *time.Timer {
   904  		// Will never fire on its own, allows this test to call timer immediately.
   905  		t := time.NewTimer(time.Hour)
   906  		timerChan.Send(t)
   907  		return t
   908  	}
   909  	cc := &testClientConn{target: target, errChan: make(chan error)}
   910  	totalTimesCalledError := 0
   911  	b := NewBuilder()
   912  	r, err := b.Build(context.Background(), resolver.Target{Endpoint: target}, resolver.BuildWithClientConn(cc))
   913  	if err != nil {
   914  		t.Fatalf("Error building resolver for target %v: %v", target, err)
   915  	}
   916  	// Should receive first error.
   917  	err = <-cc.errChan
   918  	if !strings.Contains(err.Error(), "hostLookup error") {
   919  		t.Fatalf(`ReportError(err=%v) called; want err contains "hostLookupError"`, err)
   920  	}
   921  	totalTimesCalledError++
   922  	ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   923  	defer ctxCancel()
   924  	timer, err := timerChan.Receive(ctx)
   925  	if err != nil {
   926  		t.Fatalf("Error receiving timer from mock NewTimer call: %v", err)
   927  	}
   928  	timerPointer := timer.(*time.Timer)
   929  	timerPointer.Reset(0)
   930  	defer r.Close()
   931  
   932  	// Cause timer to go off 10 times, and see if it matches DNS Resolver updating Error.
   933  	for i := 0; i < 10; i++ {
   934  		// Should call ReportError().
   935  		err = <-cc.errChan
   936  		if !strings.Contains(err.Error(), "hostLookup error") {
   937  			t.Fatalf(`ReportError(err=%v) called; want err contains "hostLookupError"`, err)
   938  		}
   939  		totalTimesCalledError++
   940  		timer, err := timerChan.Receive(ctx)
   941  		if err != nil {
   942  			t.Fatalf("Error receiving timer from mock NewTimer call: %v", err)
   943  		}
   944  		timerPointer := timer.(*time.Timer)
   945  		timerPointer.Reset(0)
   946  	}
   947  
   948  	if totalTimesCalledError != 11 {
   949  		t.Errorf("ReportError() not called 11 times, instead called %d times.", totalTimesCalledError)
   950  	}
   951  	// Clean up final watcher iteration.
   952  	<-cc.errChan
   953  	_, err = timerChan.Receive(ctx)
   954  	if err != nil {
   955  		t.Fatalf("Error receiving timer from mock NewTimer call: %v", err)
   956  	}
   957  }
   958  
   959  func mutateTbl(target string) func() {
   960  	hostLookupTbl.Lock()
   961  	oldHostTblEntry := hostLookupTbl.tbl[target]
   962  	hostLookupTbl.tbl[target] = hostLookupTbl.tbl[target][:len(oldHostTblEntry)-1]
   963  	hostLookupTbl.Unlock()
   964  	return func() {
   965  		hostLookupTbl.Lock()
   966  		hostLookupTbl.tbl[target] = oldHostTblEntry
   967  		hostLookupTbl.Unlock()
   968  	}
   969  }