github.com/searKing/golang/go@v1.2.117/net/resolver/dns/dns_resolver_test.go (about) 1 // Copyright 2021 The searKing Author. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package dns 6 7 import ( 8 "context" 9 "errors" 10 "fmt" 11 "net" 12 "os" 13 "reflect" 14 "strings" 15 "sync" 16 "testing" 17 "time" 18 19 "github.com/searKing/golang/go/net/resolver" 20 testing_ "github.com/searKing/golang/go/testing" 21 "github.com/searKing/golang/go/testing/leakcheck" 22 ) 23 24 func TestMain(m *testing.M) { 25 // Set a non-zero duration only for tests which are actually testing that 26 // feature. 27 replaceDNSResRate(time.Duration(0)) // No need to clean up since we os.Exit 28 overrideDefaultResolver(false) // No need to clean up since we os.Exit 29 code := m.Run() 30 os.Exit(code) 31 } 32 33 const ( 34 defaultTestTimeout = 10 * time.Second 35 defaultTestShortTimeout = 10 * time.Millisecond 36 ) 37 38 type testClientConn struct { 39 resolver.ClientConn // For unimplemented functions 40 target string 41 m1 sync.Mutex 42 state resolver.State 43 updateStateCalls int 44 errChan chan error 45 updateStateErr error 46 } 47 48 func (t *testClientConn) UpdateState(s resolver.State) error { 49 t.m1.Lock() 50 defer t.m1.Unlock() 51 t.state = s 52 t.updateStateCalls++ 53 // This error determines whether DNS Resolver actually decides to exponentially backoff or not. 54 // This can be any error. 55 return t.updateStateErr 56 } 57 58 func (t *testClientConn) getState() (resolver.State, int) { 59 t.m1.Lock() 60 defer t.m1.Unlock() 61 return t.state, t.updateStateCalls 62 } 63 64 func scFromState(s resolver.State) string { 65 return "" 66 } 67 68 func (t *testClientConn) ReportError(err error) { 69 t.errChan <- err 70 } 71 72 type testResolver struct { 73 // A write to this channel is made when this resolver receives a resolution 74 // request. Tests can rely on reading from this channel to be notified about 75 // resolution requests instead of sleeping for a predefined period of time. 76 lookupHostCh *testing_.Channel 77 } 78 79 func (tr *testResolver) LookupHost(ctx context.Context, host string) ([]string, error) { 80 if tr.lookupHostCh != nil { 81 tr.lookupHostCh.Send(nil) 82 } 83 return hostLookup(host) 84 } 85 86 func (*testResolver) LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, error) { 87 return srvLookup(service, proto, name) 88 } 89 90 func (*testResolver) LookupTXT(ctx context.Context, host string) ([]string, error) { 91 return []string{host}, nil 92 } 93 94 // overrideDefaultResolver overrides the defaultResolver used by the code with 95 // an instance of the testResolver. pushOnLookup controls whether the 96 // testResolver created here pushes lookupHost events on its channel. 97 func overrideDefaultResolver(pushOnLookup bool) func() { 98 oldResolver := defaultResolver 99 100 var lookupHostCh *testing_.Channel 101 if pushOnLookup { 102 lookupHostCh = testing_.NewChannel() 103 } 104 defaultResolver = &testResolver{lookupHostCh: lookupHostCh} 105 106 return func() { 107 defaultResolver = oldResolver 108 } 109 } 110 111 func replaceDNSResRate(d time.Duration) func() { 112 oldMinDNSResRate := minDNSResRate 113 minDNSResRate = d 114 115 return func() { 116 minDNSResRate = oldMinDNSResRate 117 } 118 } 119 120 var hostLookupTbl = struct { 121 sync.Mutex 122 tbl map[string][]string 123 }{ 124 tbl: map[string][]string{ 125 "foo.bar.com": {"1.2.3.4", "5.6.7.8"}, 126 "ipv4.single.fake": {"1.2.3.4"}, 127 "srv.ipv4.single.fake": {"2.4.6.8"}, 128 "srv.ipv4.multi.fake": {}, 129 "srv.ipv6.single.fake": {}, 130 "srv.ipv6.multi.fake": {}, 131 "ipv4.multi.fake": {"1.2.3.4", "5.6.7.8", "9.10.11.12"}, 132 "ipv6.single.fake": {"2607:f8b0:400a:801::1001"}, 133 "ipv6.multi.fake": {"2607:f8b0:400a:801::1001", "2607:f8b0:400a:801::1002", "2607:f8b0:400a:801::1003"}, 134 }, 135 } 136 137 func hostLookup(host string) ([]string, error) { 138 hostLookupTbl.Lock() 139 defer hostLookupTbl.Unlock() 140 if addrs, ok := hostLookupTbl.tbl[host]; ok { 141 return addrs, nil 142 } 143 return nil, &net.DNSError{ 144 Err: "hostLookup error", 145 Name: host, 146 Server: "fake", 147 IsTemporary: true, 148 } 149 } 150 151 var srvLookupTbl = struct { 152 sync.Mutex 153 tbl map[string][]*net.SRV 154 }{ 155 tbl: map[string][]*net.SRV{ 156 "_grpclb._tcp.srv.ipv4.single.fake": {&net.SRV{Target: "ipv4.single.fake", Port: 1234}}, 157 "_grpclb._tcp.srv.ipv4.multi.fake": {&net.SRV{Target: "ipv4.multi.fake", Port: 1234}}, 158 "_grpclb._tcp.srv.ipv6.single.fake": {&net.SRV{Target: "ipv6.single.fake", Port: 1234}}, 159 "_grpclb._tcp.srv.ipv6.multi.fake": {&net.SRV{Target: "ipv6.multi.fake", Port: 1234}}, 160 }, 161 } 162 163 func srvLookup(service, proto, name string) (string, []*net.SRV, error) { 164 cname := "_" + service + "._" + proto + "." + name 165 srvLookupTbl.Lock() 166 defer srvLookupTbl.Unlock() 167 if srvs, cnt := srvLookupTbl.tbl[cname]; cnt { 168 return cname, srvs, nil 169 } 170 return "", nil, &net.DNSError{ 171 Err: "srvLookup error", 172 Name: cname, 173 Server: "fake", 174 IsTemporary: true, 175 } 176 } 177 178 func TestResolve(t *testing.T) { 179 testDNSResolver(t) 180 testDNSResolverWithSRV(t) 181 testDNSResolveNow(t) 182 testIPResolver(t) 183 } 184 185 func testDNSResolver(t *testing.T) { 186 defer leakcheck.Check(t) 187 defer func(nt func(d time.Duration) *time.Timer) { 188 newTimer = nt 189 }(newTimer) 190 newTimer = func(_ time.Duration) *time.Timer { 191 // Will never fire on its own, will protect from triggering exponential backoff. 192 return time.NewTimer(time.Hour) 193 } 194 tests := []struct { 195 target string 196 addrWant []resolver.Address 197 }{ 198 { 199 "foo.bar.com", 200 []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}}, 201 }, 202 { 203 "foo.bar.com:1234", 204 []resolver.Address{{Addr: "1.2.3.4:1234"}, {Addr: "5.6.7.8:1234"}}, 205 }, 206 { 207 "srv.ipv4.single.fake", 208 []resolver.Address{{Addr: "2.4.6.8" + colonDefaultPort}}, 209 }, 210 { 211 "srv.ipv4.multi.fake", 212 nil, 213 }, 214 { 215 "srv.ipv6.single.fake", 216 nil, 217 }, 218 { 219 "srv.ipv6.multi.fake", 220 nil, 221 }, 222 } 223 224 for _, a := range tests { 225 b := NewBuilder() 226 cc := &testClientConn{target: a.target} 227 r, err := b.Build(context.Background(), resolver.Target{Endpoint: a.target}, resolver.BuildWithClientConn(cc)) 228 if err != nil { 229 t.Fatalf("%v\n", err) 230 } 231 var state resolver.State 232 var cnt int 233 for i := 0; i < 2000; i++ { 234 state, cnt = cc.getState() 235 if cnt > 0 { 236 break 237 } 238 time.Sleep(time.Millisecond) 239 } 240 if cnt == 0 { 241 t.Fatalf("UpdateState not called after 2s; aborting") 242 } 243 if !reflect.DeepEqual(a.addrWant, state.Addresses) { 244 t.Errorf("Resolved addresses of target: %q = %+v, want %+v", a.target, state.Addresses, a.addrWant) 245 } 246 r.Close() 247 } 248 } 249 250 // DNS Resolver immediately starts polling on an error from grpc. This should continue until the ClientConn doesn't 251 // send back an error from updating the DNS Resolver's state. 252 func TestDNSResolverExponentialBackoff(t *testing.T) { 253 //defer leakcheck.Check(t) 254 defer func(nt func(d time.Duration) *time.Timer) { 255 newTimer = nt 256 }(newTimer) 257 timerChan := testing_.NewChannel() 258 newTimer = func(d time.Duration) *time.Timer { 259 // Will never fire on its own, allows this test to call timer immediately. 260 t := time.NewTimer(time.Hour) 261 timerChan.Send(t) 262 return t 263 } 264 tests := []struct { 265 name string 266 target string 267 addrWant []resolver.Address 268 }{ 269 { 270 "happy case default port", 271 "foo.bar.com", 272 []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}}, 273 }, 274 { 275 "happy case specified port", 276 "foo.bar.com:1234", 277 []resolver.Address{{Addr: "1.2.3.4:1234"}, {Addr: "5.6.7.8:1234"}}, 278 }, 279 { 280 "happy case another default port", 281 "srv.ipv4.single.fake", 282 []resolver.Address{{Addr: "2.4.6.8" + colonDefaultPort}}, 283 }, 284 } 285 for _, test := range tests { 286 t.Run(test.name, func(t *testing.T) { 287 func() { 288 ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) 289 defer ctxCancel() 290 err := timerChan.Clear(ctx) 291 if err != nil { 292 t.Fatalf("Error clear timer from mock NewTimer call: %v", err) 293 } 294 }() 295 b := NewBuilder() 296 cc := &testClientConn{target: test.target} 297 // Cause ClientConn to return an error. 298 cc.updateStateErr = resolver.ErrBadResolverState 299 r, err := b.Build(context.Background(), resolver.Target{Endpoint: test.target}, resolver.BuildWithClientConn(cc)) 300 if err != nil { 301 t.Fatalf("Error building resolver for target %v: %v", test.target, err) 302 } 303 var state resolver.State 304 var cnt int 305 for i := 0; i < 2000; i++ { 306 state, cnt = cc.getState() 307 if cnt > 0 { 308 break 309 } 310 time.Sleep(time.Millisecond) 311 } 312 if cnt == 0 { 313 t.Fatalf("UpdateState not called after 2s; aborting") 314 } 315 if !reflect.DeepEqual(test.addrWant, state.Addresses) { 316 t.Errorf("Resolved addresses of target: %q = %+v, want %+v", test.target, state.Addresses, test.addrWant) 317 } 318 ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) 319 defer ctxCancel() 320 // Cause timer to go off 10 times, and see if it calls updateState() correctly. 321 for i := 0; i < 10; i++ { 322 timer, err := timerChan.Receive(ctx) 323 if err != nil { 324 t.Fatalf("Error receiving timer from mock NewTimer call: %v", err) 325 } 326 timerPointer := timer.(*time.Timer) 327 timerPointer.Reset(0) 328 } 329 // Poll to see if DNS Resolver updated state the correct number of times, which allows time for the DNS Resolver to call 330 // ClientConn update state. 331 deadline := time.Now().Add(defaultTestTimeout) 332 for { 333 cc.m1.Lock() 334 got := cc.updateStateCalls 335 cc.m1.Unlock() 336 if got == 11 { 337 break 338 } 339 340 if time.Now().After(deadline) { 341 t.Fatalf("Exponential backoff is not working as expected - should update state 11 times instead of %d", got) 342 } 343 344 time.Sleep(time.Millisecond) 345 } 346 347 // Update resolver.ClientConn to not return an error anymore - this should stop it from backing off. 348 cc.updateStateErr = nil 349 timer, err := timerChan.Receive(ctx) 350 if err != nil { 351 t.Fatalf("Error receiving timer from mock NewTimer call: %v", err) 352 } 353 timerPointer := timer.(*time.Timer) 354 timerPointer.Reset(0) 355 // Poll to see if DNS Resolver updated state the correct number of times, which allows time for the DNS Resolver to call 356 // ClientConn update state the final time. The DNS Resolver should then stop polling. 357 deadline = time.Now().Add(defaultTestTimeout) 358 for { 359 cc.m1.Lock() 360 got := cc.updateStateCalls 361 cc.m1.Unlock() 362 if got == 12 { 363 break 364 } 365 366 if time.Now().After(deadline) { 367 t.Fatalf("Exponential backoff is not working as expected - should stop backing off at 12 total UpdateState calls instead of %d", got) 368 } 369 370 _, err := timerChan.ReceiveOrFail() 371 if err { 372 t.Fatalf("Should not poll again after Client Conn stops returning error.") 373 } 374 375 time.Sleep(time.Millisecond) 376 } 377 r.Close() 378 }) 379 } 380 } 381 382 func testDNSResolverWithSRV(t *testing.T) { 383 EnableSRVLookups = true 384 defer func() { 385 EnableSRVLookups = false 386 }() 387 defer leakcheck.Check(t) 388 defer func(nt func(d time.Duration) *time.Timer) { 389 newTimer = nt 390 }(newTimer) 391 newTimer = func(_ time.Duration) *time.Timer { 392 // Will never fire on its own, will protect from triggering exponential backoff. 393 return time.NewTimer(time.Hour) 394 } 395 tests := []struct { 396 target string 397 addrWant []resolver.Address 398 }{ 399 { 400 "foo.bar.com", 401 []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}}, 402 }, 403 { 404 "foo.bar.com:1234", 405 []resolver.Address{{Addr: "1.2.3.4:1234"}, {Addr: "5.6.7.8:1234"}}, 406 }, 407 { 408 "srv.ipv4.single.fake", 409 []resolver.Address{{Addr: "2.4.6.8" + colonDefaultPort}}, 410 }, 411 { 412 "srv.ipv4.multi.fake", 413 nil, 414 }, 415 { 416 "srv.ipv6.single.fake", 417 nil, 418 }, 419 { 420 "srv.ipv6.multi.fake", 421 nil, 422 }, 423 } 424 425 for _, a := range tests { 426 b := NewBuilder() 427 cc := &testClientConn{target: a.target} 428 r, err := b.Build(context.Background(), resolver.Target{Endpoint: a.target}, resolver.BuildWithClientConn(cc)) 429 if err != nil { 430 t.Fatalf("%v\n", err) 431 } 432 defer r.Close() 433 var state resolver.State 434 var cnt int 435 for i := 0; i < 2000; i++ { 436 state, cnt = cc.getState() 437 if cnt > 0 { 438 break 439 } 440 time.Sleep(time.Millisecond) 441 } 442 if cnt == 0 { 443 t.Fatalf("UpdateState not called after 2s; aborting") 444 } 445 if !reflect.DeepEqual(a.addrWant, state.Addresses) { 446 t.Errorf("Resolved addresses of target: %q = %+v, want %+v", a.target, state.Addresses, a.addrWant) 447 } 448 } 449 } 450 451 func testDNSResolveNow(t *testing.T) { 452 defer leakcheck.Check(t) 453 tests := []struct { 454 target string 455 addrWant []resolver.Address 456 }{ 457 { 458 "foo.bar.com", 459 []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}}, 460 }, 461 } 462 463 for _, a := range tests { 464 b := NewBuilder() 465 cc := &testClientConn{target: a.target} 466 r, err := b.Build(context.Background(), resolver.Target{Endpoint: a.target}, resolver.BuildWithClientConn(cc)) 467 if err != nil { 468 t.Fatalf("%v\n", err) 469 } 470 defer r.Close() 471 var state resolver.State 472 var cnt int 473 for i := 0; i < 2000; i++ { 474 state, cnt = cc.getState() 475 if cnt > 0 { 476 break 477 } 478 time.Sleep(time.Millisecond) 479 } 480 if cnt == 0 { 481 t.Fatalf("UpdateState not called after 2s; aborting. state=%v", state) 482 } 483 if !reflect.DeepEqual(a.addrWant, state.Addresses) { 484 t.Errorf("Resolved addresses of target: %q = %+v, want %+v", a.target, state.Addresses, a.addrWant) 485 } 486 487 r.ResolveNow(context.Background()) 488 for i := 0; i < 2000; i++ { 489 state, cnt = cc.getState() 490 if cnt == 2 { 491 break 492 } 493 time.Sleep(time.Millisecond) 494 } 495 if cnt != 2 { 496 t.Fatalf("UpdateState not called after 2s; aborting. state=%v", state) 497 } 498 } 499 } 500 501 const colonDefaultPort = ":" + defaultPort 502 503 func testIPResolver(t *testing.T) { 504 defer leakcheck.Check(t) 505 defer func(nt func(d time.Duration) *time.Timer) { 506 newTimer = nt 507 }(newTimer) 508 newTimer = func(_ time.Duration) *time.Timer { 509 // Will never fire on its own, will protect from triggering exponential backoff. 510 return time.NewTimer(time.Hour) 511 } 512 tests := []struct { 513 target string 514 want []resolver.Address 515 }{ 516 {"127.0.0.1", []resolver.Address{{Addr: "127.0.0.1" + colonDefaultPort}}}, 517 {"127.0.0.1:12345", []resolver.Address{{Addr: "127.0.0.1:12345"}}}, 518 {"::1", []resolver.Address{{Addr: "[::1]" + colonDefaultPort}}}, 519 {"[::1]:12345", []resolver.Address{{Addr: "[::1]:12345"}}}, 520 {"[::1]", []resolver.Address{{Addr: "[::1]:443"}}}, 521 {"2001:db8:85a3::8a2e:370:7334", []resolver.Address{{Addr: "[2001:db8:85a3::8a2e:370:7334]" + colonDefaultPort}}}, 522 {"[2001:db8:85a3::8a2e:370:7334]", []resolver.Address{{Addr: "[2001:db8:85a3::8a2e:370:7334]" + colonDefaultPort}}}, 523 {"[2001:db8:85a3::8a2e:370:7334]:12345", []resolver.Address{{Addr: "[2001:db8:85a3::8a2e:370:7334]:12345"}}}, 524 {"[2001:db8::1]:http", []resolver.Address{{Addr: "[2001:db8::1]:http"}}}, 525 } 526 527 for _, v := range tests { 528 b := NewBuilder() 529 cc := &testClientConn{target: v.target} 530 r, err := b.Build(context.Background(), resolver.Target{Endpoint: v.target}, resolver.BuildWithClientConn(cc)) 531 if err != nil { 532 t.Fatalf("%v\n", err) 533 } 534 var state resolver.State 535 var cnt int 536 for { 537 state, cnt = cc.getState() 538 if cnt > 0 { 539 break 540 } 541 time.Sleep(time.Millisecond) 542 } 543 if !reflect.DeepEqual(v.want, state.Addresses) { 544 t.Errorf("Resolved addresses of target: %q = %+v, want %+v", v.target, state.Addresses, v.want) 545 } 546 r.ResolveNow(context.Background()) 547 for i := 0; i < 50; i++ { 548 state, cnt = cc.getState() 549 if cnt > 1 { 550 t.Fatalf("Unexpected second call by resolver to UpdateState. state: %v", state) 551 } 552 time.Sleep(time.Millisecond) 553 } 554 r.Close() 555 } 556 } 557 558 func TestResolveFunc(t *testing.T) { 559 defer leakcheck.Check(t) 560 defer func(nt func(d time.Duration) *time.Timer) { 561 newTimer = nt 562 }(newTimer) 563 newTimer = func(d time.Duration) *time.Timer { 564 // Will never fire on its own, will protect from triggering exponential backoff. 565 return time.NewTimer(time.Hour) 566 } 567 tests := []struct { 568 addr string 569 want error 570 }{ 571 // TODO(yuxuanli): More false cases? 572 {"www.google.com", nil}, 573 {"foo.bar:12345", nil}, 574 {"127.0.0.1", nil}, 575 {"::", nil}, 576 {"127.0.0.1:12345", nil}, 577 {"[::1]:80", nil}, 578 {"[2001:db8:a0b:12f0::1]:21", nil}, 579 {":80", nil}, 580 {"127.0.0...1:12345", nil}, 581 {"[fe80::1%lo0]:80", nil}, 582 {"golang.org:http", nil}, 583 {"[2001:db8::1]:http", nil}, 584 {"[2001:db8::1]:", errEndsWithColon}, 585 {":", errEndsWithColon}, 586 {"", errMissingAddr}, 587 {"[2001:db8:a0b:12f0::1", fmt.Errorf("invalid target address [2001:db8:a0b:12f0::1, error info: address [2001:db8:a0b:12f0::1:443: missing ']' in address")}, 588 } 589 590 b := NewBuilder() 591 for _, v := range tests { 592 cc := &testClientConn{target: v.addr, errChan: make(chan error, 1)} 593 r, err := b.Build(context.Background(), resolver.Target{Endpoint: v.addr}, resolver.BuildWithClientConn(cc)) 594 if err == nil { 595 r.Close() 596 } 597 if !reflect.DeepEqual(err, v.want) { 598 t.Errorf("Build(%q, cc, _) = %v, want %v", v.addr, err, v.want) 599 } 600 } 601 } 602 603 func TestDNSResolverRetry(t *testing.T) { 604 b := NewBuilder() 605 target := "ipv4.single.fake" 606 cc := &testClientConn{target: target} 607 r, err := b.Build(context.Background(), resolver.Target{Endpoint: target}, resolver.BuildWithClientConn(cc)) 608 if err != nil { 609 t.Fatalf("%v\n", err) 610 } 611 defer r.Close() 612 var state resolver.State 613 for i := 0; i < 2000; i++ { 614 state, _ = cc.getState() 615 if len(state.Addresses) == 1 { 616 break 617 } 618 time.Sleep(time.Millisecond) 619 } 620 if len(state.Addresses) != 1 { 621 t.Fatalf("UpdateState not called with 1 address after 2s; aborting. state=%v", state) 622 } 623 want := []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}} 624 if !reflect.DeepEqual(want, state.Addresses) { 625 t.Errorf("Resolved addresses of target: %q = %+v, want %+v", target, state.Addresses, want) 626 } 627 // mutate the host lookup table so the target has 0 address returned. 628 revertTbl := mutateTbl(target) 629 // trigger a resolve that will get empty address list 630 r.ResolveNow(context.Background()) 631 for i := 0; i < 2000; i++ { 632 state, _ = cc.getState() 633 if len(state.Addresses) == 0 { 634 break 635 } 636 time.Sleep(time.Millisecond) 637 } 638 if len(state.Addresses) != 0 { 639 t.Fatalf("UpdateState not called with 0 address after 2s; aborting. state=%v", state) 640 } 641 revertTbl() 642 // wait for the retry to happen in two seconds. 643 r.ResolveNow(context.Background()) 644 for i := 0; i < 2000; i++ { 645 state, _ = cc.getState() 646 if len(state.Addresses) == 1 { 647 break 648 } 649 time.Sleep(time.Millisecond) 650 } 651 if !reflect.DeepEqual(want, state.Addresses) { 652 t.Errorf("Resolved addresses of target: %q = %+v, want %+v", target, state.Addresses, want) 653 } 654 } 655 656 func TestCustomAuthority(t *testing.T) { 657 defer leakcheck.Check(t) 658 defer func(nt func(d time.Duration) *time.Timer) { 659 newTimer = nt 660 }(newTimer) 661 newTimer = func(d time.Duration) *time.Timer { 662 // Will never fire on its own, will protect from triggering exponential backoff. 663 return time.NewTimer(time.Hour) 664 } 665 666 tests := []struct { 667 authority string 668 authorityWant string 669 expectError bool 670 }{ 671 { 672 "4.3.2.1:" + defaultDNSSvrPort, 673 "4.3.2.1:" + defaultDNSSvrPort, 674 false, 675 }, 676 { 677 "4.3.2.1:123", 678 "4.3.2.1:123", 679 false, 680 }, 681 { 682 "4.3.2.1", 683 "4.3.2.1:" + defaultDNSSvrPort, 684 false, 685 }, 686 { 687 "::1", 688 "[::1]:" + defaultDNSSvrPort, 689 false, 690 }, 691 { 692 "[::1]", 693 "[::1]:" + defaultDNSSvrPort, 694 false, 695 }, 696 { 697 "[::1]:123", 698 "[::1]:123", 699 false, 700 }, 701 { 702 "dnsserver.com", 703 "dnsserver.com:" + defaultDNSSvrPort, 704 false, 705 }, 706 { 707 ":123", 708 "localhost:123", 709 false, 710 }, 711 { 712 ":", 713 "", 714 true, 715 }, 716 { 717 "[::1]:", 718 "", 719 true, 720 }, 721 { 722 "dnsserver.com:", 723 "", 724 true, 725 }, 726 } 727 oldCustomAuthorityDialler := customAuthorityDialler 728 defer func() { 729 customAuthorityDialler = oldCustomAuthorityDialler 730 }() 731 732 for _, a := range tests { 733 errChan := make(chan error, 1) 734 customAuthorityDialler = func(authority string) func(ctx context.Context, network, address string) (net.Conn, error) { 735 if authority != a.authorityWant { 736 errChan <- fmt.Errorf("wrong custom authority passed to resolver. input: %s expected: %s actual: %s", a.authority, a.authorityWant, authority) 737 } else { 738 errChan <- nil 739 } 740 return func(ctx context.Context, network, address string) (net.Conn, error) { 741 return nil, errors.New("no need to dial") 742 } 743 } 744 745 b := NewBuilder() 746 cc := &testClientConn{target: "foo.bar.com", errChan: make(chan error, 1)} 747 r, err := b.Build(context.Background(), resolver.Target{Endpoint: "foo.bar.com", Authority: a.authority}, resolver.BuildWithClientConn(cc)) 748 749 if err == nil { 750 r.Close() 751 752 err = <-errChan 753 if err != nil { 754 t.Errorf(err.Error()) 755 } 756 757 if a.expectError { 758 t.Errorf("custom authority should have caused an error: %s", a.authority) 759 } 760 } else if !a.expectError { 761 t.Errorf("unexpected error using custom authority %s: %s", a.authority, err) 762 } 763 } 764 } 765 766 // TestRateLimitedResolve exercises the rate limit enforced on re-resolution 767 // requests. It sets the re-resolution rate to a small value and repeatedly 768 // calls ResolveNow() and ensures only the expected number of resolution 769 // requests are made. 770 771 func TestRateLimitedResolve(t *testing.T) { 772 defer leakcheck.Check(t) 773 defer func(nt func(d time.Duration) *time.Timer) { 774 newTimer = nt 775 }(newTimer) 776 newTimer = func(d time.Duration) *time.Timer { 777 // Will never fire on its own, will protect from triggering exponential 778 // backoff. 779 return time.NewTimer(time.Hour) 780 } 781 defer func(nt func(d time.Duration) *time.Timer) { 782 newTimer = nt 783 }(newTimer) 784 785 timerChan := testing_.NewChannel() 786 newTimer = func(d time.Duration) *time.Timer { 787 // Will never fire on its own, allows this test to call timer 788 // immediately. 789 t := time.NewTimer(time.Hour) 790 timerChan.Send(t) 791 return t 792 } 793 794 // Create a new testResolver{} for this test because we want the exact count 795 // of the number of times the resolver was invoked. 796 nc := overrideDefaultResolver(true) 797 defer nc() 798 799 target := "foo.bar.com" 800 b := NewBuilder() 801 cc := &testClientConn{target: target} 802 803 r, err := b.Build(context.Background(), resolver.Target{Endpoint: target}, resolver.BuildWithClientConn(cc)) 804 if err != nil { 805 t.Fatalf("resolver.Build() returned error: %v\n", err) 806 } 807 defer r.Close() 808 809 dnsR, ok := r.(*dnsResolver) 810 if !ok { 811 t.Fatalf("resolver.Build() returned unexpected type: %T\n", dnsR) 812 } 813 814 tr, ok := dnsR.resolver.(*testResolver) 815 if !ok { 816 t.Fatalf("delegate resolver returned unexpected type: %T\n", tr) 817 } 818 819 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 820 defer cancel() 821 822 // Wait for the first resolution request to be done. This happens as part 823 // of the first iteration of the for loop in watcher(). 824 if _, err := tr.lookupHostCh.Receive(ctx); err != nil { 825 t.Fatalf("Timed out waiting for lookup() call.") 826 } 827 828 // Call Resolve Now 100 times, shouldn't continue onto next iteration of 829 // watcher, thus shouldn't lookup again. 830 for i := 0; i <= 100; i++ { 831 r.ResolveNow(context.Background()) 832 } 833 834 continueCtx, continueCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) 835 defer continueCancel() 836 837 if _, err := tr.lookupHostCh.Receive(continueCtx); err == nil { 838 t.Fatalf("Should not have looked up again as DNS Min Res Rate timer has not gone off.") 839 } 840 841 // Make the DNSMinResRate timer fire immediately (by receiving it, then 842 // resetting to 0), this will unblock the resolver which is currently 843 // blocked on the DNS Min Res Rate timer going off, which will allow it to 844 // continue to the next iteration of the watcher loop. 845 timer, err := timerChan.Receive(ctx) 846 if err != nil { 847 t.Fatalf("Error receiving timer from mock NewTimer call: %v", err) 848 } 849 timerPointer := timer.(*time.Timer) 850 timerPointer.Reset(0) 851 852 // Now that DNS Min Res Rate timer has gone off, it should lookup again. 853 if _, err := tr.lookupHostCh.Receive(ctx); err != nil { 854 t.Fatalf("Timed out waiting for lookup() call.") 855 } 856 857 // Resolve Now 1000 more times, shouldn't lookup again as DNS Min Res Rate 858 // timer has not gone off. 859 for i := 0; i < 1000; i++ { 860 r.ResolveNow(context.Background()) 861 } 862 863 if _, err = tr.lookupHostCh.Receive(continueCtx); err == nil { 864 t.Fatalf("Should not have looked up again as DNS Min Res Rate timer has not gone off.") 865 } 866 867 // Make the DNSMinResRate timer fire immediately again. 868 timer, err = timerChan.Receive(ctx) 869 if err != nil { 870 t.Fatalf("Error receiving timer from mock NewTimer call: %v", err) 871 } 872 timerPointer = timer.(*time.Timer) 873 timerPointer.Reset(0) 874 875 // Now that DNS Min Res Rate timer has gone off, it should lookup again. 876 if _, err = tr.lookupHostCh.Receive(ctx); err != nil { 877 t.Fatalf("Timed out waiting for lookup() call.") 878 } 879 880 wantAddrs := []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}} 881 var state resolver.State 882 for { 883 var cnt int 884 state, cnt = cc.getState() 885 if cnt > 0 { 886 break 887 } 888 time.Sleep(time.Millisecond) 889 } 890 if !reflect.DeepEqual(state.Addresses, wantAddrs) { 891 t.Errorf("Resolved addresses of target: %q = %+v, want %+v", target, state.Addresses, wantAddrs) 892 } 893 } 894 895 // DNS Resolver immediately starts polling on an error. This will cause the re-resolution to return another error. 896 // Thus, test that it constantly sends errors to the grpc.ClientConn. 897 func TestReportError(t *testing.T) { 898 const target = "notfoundaddress" 899 defer func(nt func(d time.Duration) *time.Timer) { 900 newTimer = nt 901 }(newTimer) 902 timerChan := testing_.NewChannel() 903 newTimer = func(d time.Duration) *time.Timer { 904 // Will never fire on its own, allows this test to call timer immediately. 905 t := time.NewTimer(time.Hour) 906 timerChan.Send(t) 907 return t 908 } 909 cc := &testClientConn{target: target, errChan: make(chan error)} 910 totalTimesCalledError := 0 911 b := NewBuilder() 912 r, err := b.Build(context.Background(), resolver.Target{Endpoint: target}, resolver.BuildWithClientConn(cc)) 913 if err != nil { 914 t.Fatalf("Error building resolver for target %v: %v", target, err) 915 } 916 // Should receive first error. 917 err = <-cc.errChan 918 if !strings.Contains(err.Error(), "hostLookup error") { 919 t.Fatalf(`ReportError(err=%v) called; want err contains "hostLookupError"`, err) 920 } 921 totalTimesCalledError++ 922 ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) 923 defer ctxCancel() 924 timer, err := timerChan.Receive(ctx) 925 if err != nil { 926 t.Fatalf("Error receiving timer from mock NewTimer call: %v", err) 927 } 928 timerPointer := timer.(*time.Timer) 929 timerPointer.Reset(0) 930 defer r.Close() 931 932 // Cause timer to go off 10 times, and see if it matches DNS Resolver updating Error. 933 for i := 0; i < 10; i++ { 934 // Should call ReportError(). 935 err = <-cc.errChan 936 if !strings.Contains(err.Error(), "hostLookup error") { 937 t.Fatalf(`ReportError(err=%v) called; want err contains "hostLookupError"`, err) 938 } 939 totalTimesCalledError++ 940 timer, err := timerChan.Receive(ctx) 941 if err != nil { 942 t.Fatalf("Error receiving timer from mock NewTimer call: %v", err) 943 } 944 timerPointer := timer.(*time.Timer) 945 timerPointer.Reset(0) 946 } 947 948 if totalTimesCalledError != 11 { 949 t.Errorf("ReportError() not called 11 times, instead called %d times.", totalTimesCalledError) 950 } 951 // Clean up final watcher iteration. 952 <-cc.errChan 953 _, err = timerChan.Receive(ctx) 954 if err != nil { 955 t.Fatalf("Error receiving timer from mock NewTimer call: %v", err) 956 } 957 } 958 959 func mutateTbl(target string) func() { 960 hostLookupTbl.Lock() 961 oldHostTblEntry := hostLookupTbl.tbl[target] 962 hostLookupTbl.tbl[target] = hostLookupTbl.tbl[target][:len(oldHostTblEntry)-1] 963 hostLookupTbl.Unlock() 964 return func() { 965 hostLookupTbl.Lock() 966 hostLookupTbl.tbl[target] = oldHostTblEntry 967 hostLookupTbl.Unlock() 968 } 969 }