github.com/letsencrypt/boulder@v0.20251208.0/grpc/internal/resolver/dns/dns_resolver_test.go (about) 1 /* 2 * 3 * Copyright 2018 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 package dns 20 21 import ( 22 "context" 23 "errors" 24 "fmt" 25 "net" 26 "os" 27 "slices" 28 "strings" 29 "sync" 30 "testing" 31 "time" 32 33 "google.golang.org/grpc/balancer" 34 "google.golang.org/grpc/resolver" 35 36 "github.com/letsencrypt/boulder/grpc/internal/leakcheck" 37 "github.com/letsencrypt/boulder/grpc/internal/testutils" 38 "github.com/letsencrypt/boulder/test" 39 ) 40 41 func TestMain(m *testing.M) { 42 // Set a non-zero duration only for tests which are actually testing that 43 // feature. 44 replaceDNSResRate(time.Duration(0)) // No need to clean up since we os.Exit 45 overrideDefaultResolver(false) // No need to clean up since we os.Exit 46 code := m.Run() 47 os.Exit(code) 48 } 49 50 const ( 51 defaultTestTimeout = 10 * time.Second 52 defaultTestShortTimeout = 10 * time.Millisecond 53 ) 54 55 type testClientConn struct { 56 resolver.ClientConn // For unimplemented functions 57 target string 58 m1 sync.Mutex 59 state resolver.State 60 updateStateCalls int 61 errChan chan error 62 updateStateErr error 63 } 64 65 func (t *testClientConn) UpdateState(s resolver.State) error { 66 t.m1.Lock() 67 defer t.m1.Unlock() 68 t.state = s 69 t.updateStateCalls++ 70 // This error determines whether DNS Resolver actually decides to exponentially backoff or not. 71 // This can be any error. 72 return t.updateStateErr 73 } 74 75 func (t *testClientConn) getState() (resolver.State, int) { 76 t.m1.Lock() 77 defer t.m1.Unlock() 78 return t.state, t.updateStateCalls 79 } 80 81 func (t *testClientConn) ReportError(err error) { 82 t.errChan <- err 83 } 84 85 type testResolver struct { 86 // A write to this channel is made when this resolver receives a resolution 87 // request. Tests can rely on reading from this channel to be notified about 88 // resolution requests instead of sleeping for a predefined period of time. 89 lookupHostCh *testutils.Channel 90 } 91 92 func (tr *testResolver) LookupHost(ctx context.Context, host string) ([]string, error) { 93 if tr.lookupHostCh != nil { 94 tr.lookupHostCh.Send(nil) 95 } 96 return hostLookup(host) 97 } 98 99 func (*testResolver) LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, error) { 100 return srvLookup(service, proto, name) 101 } 102 103 // overrideDefaultResolver overrides the defaultResolver used by the code with 104 // an instance of the testResolver. pushOnLookup controls whether the 105 // testResolver created here pushes lookupHost events on its channel. 106 func overrideDefaultResolver(pushOnLookup bool) func() { 107 oldResolver := defaultResolver 108 109 var lookupHostCh *testutils.Channel 110 if pushOnLookup { 111 lookupHostCh = testutils.NewChannel() 112 } 113 defaultResolver = &testResolver{lookupHostCh: lookupHostCh} 114 115 return func() { 116 defaultResolver = oldResolver 117 } 118 } 119 120 func replaceDNSResRate(d time.Duration) func() { 121 oldMinDNSResRate := minDNSResRate 122 minDNSResRate = d 123 124 return func() { 125 minDNSResRate = oldMinDNSResRate 126 } 127 } 128 129 var hostLookupTbl = struct { 130 sync.Mutex 131 tbl map[string][]string 132 }{ 133 tbl: map[string][]string{ 134 "ipv4.single.fake": {"2.4.6.8"}, 135 "ipv4.multi.fake": {"1.2.3.4", "5.6.7.8", "9.10.11.12"}, 136 "ipv6.single.fake": {"2607:f8b0:400a:801::1001"}, 137 "ipv6.multi.fake": {"2607:f8b0:400a:801::1001", "2607:f8b0:400a:801::1002", "2607:f8b0:400a:801::1003"}, 138 }, 139 } 140 141 func hostLookup(host string) ([]string, error) { 142 hostLookupTbl.Lock() 143 defer hostLookupTbl.Unlock() 144 if addrs, ok := hostLookupTbl.tbl[host]; ok { 145 return addrs, nil 146 } 147 return nil, &net.DNSError{ 148 Err: "hostLookup error", 149 Name: host, 150 Server: "fake", 151 IsTemporary: true, 152 } 153 } 154 155 var srvLookupTbl = struct { 156 sync.Mutex 157 tbl map[string][]*net.SRV 158 }{ 159 tbl: map[string][]*net.SRV{ 160 "_foo._tcp.ipv4.single.fake": {&net.SRV{Target: "ipv4.single.fake", Port: 1234}}, 161 "_foo._tcp.ipv4.multi.fake": {&net.SRV{Target: "ipv4.multi.fake", Port: 1234}}, 162 "_foo._tcp.ipv6.single.fake": {&net.SRV{Target: "ipv6.single.fake", Port: 1234}}, 163 "_foo._tcp.ipv6.multi.fake": {&net.SRV{Target: "ipv6.multi.fake", Port: 1234}}, 164 }, 165 } 166 167 func srvLookup(service, proto, name string) (string, []*net.SRV, error) { 168 cname := "_" + service + "._" + proto + "." + name 169 srvLookupTbl.Lock() 170 defer srvLookupTbl.Unlock() 171 if srvs, cnt := srvLookupTbl.tbl[cname]; cnt { 172 return cname, srvs, nil 173 } 174 return "", nil, &net.DNSError{ 175 Err: "srvLookup error", 176 Name: cname, 177 Server: "fake", 178 IsTemporary: true, 179 } 180 } 181 182 func TestResolve(t *testing.T) { 183 testDNSResolver(t) 184 testDNSResolveNow(t) 185 } 186 187 func testDNSResolver(t *testing.T) { 188 defer func(nt func(d time.Duration) *time.Timer) { 189 newTimer = nt 190 }(newTimer) 191 newTimer = func(_ time.Duration) *time.Timer { 192 // Will never fire on its own, will protect from triggering exponential backoff. 193 return time.NewTimer(time.Hour) 194 } 195 tests := []struct { 196 target string 197 addrWant []resolver.Address 198 }{ 199 { 200 "foo.ipv4.single.fake", 201 []resolver.Address{{Addr: "2.4.6.8:1234", ServerName: "ipv4.single.fake"}}, 202 }, 203 { 204 "foo.ipv4.multi.fake", 205 []resolver.Address{ 206 {Addr: "1.2.3.4:1234", ServerName: "ipv4.multi.fake"}, 207 {Addr: "5.6.7.8:1234", ServerName: "ipv4.multi.fake"}, 208 {Addr: "9.10.11.12:1234", ServerName: "ipv4.multi.fake"}, 209 }, 210 }, 211 { 212 "foo.ipv6.single.fake", 213 []resolver.Address{{Addr: "[2607:f8b0:400a:801::1001]:1234", ServerName: "ipv6.single.fake"}}, 214 }, 215 { 216 "foo.ipv6.multi.fake", 217 []resolver.Address{ 218 {Addr: "[2607:f8b0:400a:801::1001]:1234", ServerName: "ipv6.multi.fake"}, 219 {Addr: "[2607:f8b0:400a:801::1002]:1234", ServerName: "ipv6.multi.fake"}, 220 {Addr: "[2607:f8b0:400a:801::1003]:1234", ServerName: "ipv6.multi.fake"}, 221 }, 222 }, 223 } 224 225 for _, a := range tests { 226 b := NewDefaultSRVBuilder() 227 cc := &testClientConn{target: a.target} 228 r, err := b.Build(resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("scheme:///%s", a.target))}, cc, resolver.BuildOptions{}) 229 if err != nil { 230 t.Fatalf("%v\n", err) 231 } 232 var state resolver.State 233 var cnt int 234 for range 2000 { 235 state, cnt = cc.getState() 236 if cnt > 0 { 237 break 238 } 239 time.Sleep(time.Millisecond) 240 } 241 if cnt == 0 { 242 t.Fatalf("UpdateState not called after 2s; aborting") 243 } 244 245 if !slices.Equal(a.addrWant, state.Addresses) { 246 t.Errorf("Resolved addresses of target: %q = %+v, want %+v", a.target, state.Addresses, a.addrWant) 247 } 248 r.Close() 249 } 250 } 251 252 // DNS Resolver immediately starts polling on an error from grpc. This should continue until the ClientConn doesn't 253 // send back an error from updating the DNS Resolver's state. 254 func TestDNSResolverExponentialBackoff(t *testing.T) { 255 defer leakcheck.Check(t) 256 defer func(nt func(d time.Duration) *time.Timer) { 257 newTimer = nt 258 }(newTimer) 259 timerChan := testutils.NewChannel() 260 newTimer = func(d time.Duration) *time.Timer { 261 // Will never fire on its own, allows this test to call timer immediately. 262 t := time.NewTimer(time.Hour) 263 timerChan.Send(t) 264 return t 265 } 266 target := "foo.ipv4.single.fake" 267 wantAddr := []resolver.Address{{Addr: "2.4.6.8:1234", ServerName: "ipv4.single.fake"}} 268 269 b := NewDefaultSRVBuilder() 270 cc := &testClientConn{target: target} 271 // Cause ClientConn to return an error. 272 cc.updateStateErr = balancer.ErrBadResolverState 273 r, err := b.Build(resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("scheme:///%s", target))}, cc, resolver.BuildOptions{}) 274 if err != nil { 275 t.Fatalf("Error building resolver for target %v: %v", target, err) 276 } 277 defer r.Close() 278 var state resolver.State 279 var cnt int 280 for range 2000 { 281 state, cnt = cc.getState() 282 if cnt > 0 { 283 break 284 } 285 time.Sleep(time.Millisecond) 286 } 287 if cnt == 0 { 288 t.Fatalf("UpdateState not called after 2s; aborting") 289 } 290 if !slices.Equal(wantAddr, state.Addresses) { 291 t.Errorf("Resolved addresses of target: %q = %+v, want %+v", target, state.Addresses, target) 292 } 293 ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) 294 defer ctxCancel() 295 // Cause timer to go off 10 times, and see if it calls updateState() correctly. 296 for range 10 { 297 timer, err := timerChan.Receive(ctx) 298 if err != nil { 299 t.Fatalf("Error receiving timer from mock NewTimer call: %v", err) 300 } 301 timerPointer := timer.(*time.Timer) 302 timerPointer.Reset(0) 303 } 304 // Poll to see if DNS Resolver updated state the correct number of times, which allows time for the DNS Resolver to call 305 // ClientConn update state. 306 deadline := time.Now().Add(defaultTestTimeout) 307 for { 308 cc.m1.Lock() 309 got := cc.updateStateCalls 310 cc.m1.Unlock() 311 if got == 11 { 312 break 313 } 314 315 if time.Now().After(deadline) { 316 t.Fatalf("Exponential backoff is not working as expected - should update state 11 times instead of %d", got) 317 } 318 319 time.Sleep(time.Millisecond) 320 } 321 322 // Update resolver.ClientConn to not return an error anymore - this should stop it from backing off. 323 cc.updateStateErr = nil 324 timer, err := timerChan.Receive(ctx) 325 if err != nil { 326 t.Fatalf("Error receiving timer from mock NewTimer call: %v", err) 327 } 328 timerPointer := timer.(*time.Timer) 329 timerPointer.Reset(0) 330 // Poll to see if DNS Resolver updated state the correct number of times, which allows time for the DNS Resolver to call 331 // ClientConn update state the final time. The DNS Resolver should then stop polling. 332 deadline = time.Now().Add(defaultTestTimeout) 333 for { 334 cc.m1.Lock() 335 got := cc.updateStateCalls 336 cc.m1.Unlock() 337 if got == 12 { 338 break 339 } 340 341 if time.Now().After(deadline) { 342 t.Fatalf("Exponential backoff is not working as expected - should stop backing off at 12 total UpdateState calls instead of %d", got) 343 } 344 345 _, err := timerChan.ReceiveOrFail() 346 if err { 347 t.Fatalf("Should not poll again after Client Conn stops returning error.") 348 } 349 350 time.Sleep(time.Millisecond) 351 } 352 } 353 354 func mutateTbl(target string) func() { 355 hostLookupTbl.Lock() 356 oldHostTblEntry := hostLookupTbl.tbl[target] 357 358 // Remove the last address from the target's entry. 359 hostLookupTbl.tbl[target] = hostLookupTbl.tbl[target][:len(oldHostTblEntry)-1] 360 hostLookupTbl.Unlock() 361 362 return func() { 363 hostLookupTbl.Lock() 364 hostLookupTbl.tbl[target] = oldHostTblEntry 365 hostLookupTbl.Unlock() 366 } 367 } 368 369 func testDNSResolveNow(t *testing.T) { 370 defer leakcheck.Check(t) 371 defer func(nt func(d time.Duration) *time.Timer) { 372 newTimer = nt 373 }(newTimer) 374 newTimer = func(_ time.Duration) *time.Timer { 375 // Will never fire on its own, will protect from triggering exponential backoff. 376 return time.NewTimer(time.Hour) 377 } 378 tests := []struct { 379 target string 380 addrWant []resolver.Address 381 addrNext []resolver.Address 382 }{ 383 { 384 "foo.ipv4.multi.fake", 385 []resolver.Address{ 386 {Addr: "1.2.3.4:1234", ServerName: "ipv4.multi.fake"}, 387 {Addr: "5.6.7.8:1234", ServerName: "ipv4.multi.fake"}, 388 {Addr: "9.10.11.12:1234", ServerName: "ipv4.multi.fake"}, 389 }, 390 []resolver.Address{ 391 {Addr: "1.2.3.4:1234", ServerName: "ipv4.multi.fake"}, 392 {Addr: "5.6.7.8:1234", ServerName: "ipv4.multi.fake"}, 393 }, 394 }, 395 } 396 397 for _, a := range tests { 398 b := NewDefaultSRVBuilder() 399 cc := &testClientConn{target: a.target} 400 r, err := b.Build(resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("scheme:///%s", a.target))}, cc, resolver.BuildOptions{}) 401 if err != nil { 402 t.Fatalf("%v\n", err) 403 } 404 defer r.Close() 405 var state resolver.State 406 var cnt int 407 for range 2000 { 408 state, cnt = cc.getState() 409 if cnt > 0 { 410 break 411 } 412 time.Sleep(time.Millisecond) 413 } 414 if cnt == 0 { 415 t.Fatalf("UpdateState not called after 2s; aborting. state=%v", state) 416 } 417 if !slices.Equal(a.addrWant, state.Addresses) { 418 t.Errorf("Resolved addresses of target: %q = %+v, want %+v", a.target, state.Addresses, a.addrWant) 419 } 420 421 revertTbl := mutateTbl(strings.TrimPrefix(a.target, "foo.")) 422 r.ResolveNow(resolver.ResolveNowOptions{}) 423 for range 2000 { 424 state, cnt = cc.getState() 425 if cnt == 2 { 426 break 427 } 428 time.Sleep(time.Millisecond) 429 } 430 if cnt != 2 { 431 t.Fatalf("UpdateState not called after 2s; aborting. state=%v", state) 432 } 433 if !slices.Equal(a.addrNext, state.Addresses) { 434 t.Errorf("Resolved addresses of target: %q = %+v, want %+v", a.target, state.Addresses, a.addrNext) 435 } 436 revertTbl() 437 } 438 } 439 440 func TestDNSResolverRetry(t *testing.T) { 441 defer func(nt func(d time.Duration) *time.Timer) { 442 newTimer = nt 443 }(newTimer) 444 newTimer = func(d time.Duration) *time.Timer { 445 // Will never fire on its own, will protect from triggering exponential backoff. 446 return time.NewTimer(time.Hour) 447 } 448 b := NewDefaultSRVBuilder() 449 target := "foo.ipv4.single.fake" 450 cc := &testClientConn{target: target} 451 r, err := b.Build(resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("scheme:///%s", target))}, cc, resolver.BuildOptions{}) 452 if err != nil { 453 t.Fatalf("%v\n", err) 454 } 455 defer r.Close() 456 var state resolver.State 457 for range 2000 { 458 state, _ = cc.getState() 459 if len(state.Addresses) == 1 { 460 break 461 } 462 time.Sleep(time.Millisecond) 463 } 464 if len(state.Addresses) != 1 { 465 t.Fatalf("UpdateState not called with 1 address after 2s; aborting. state=%v", state) 466 } 467 want := []resolver.Address{{Addr: "2.4.6.8:1234", ServerName: "ipv4.single.fake"}} 468 if !slices.Equal(want, state.Addresses) { 469 t.Errorf("Resolved addresses of target: %q = %+v, want %+v", target, state.Addresses, want) 470 } 471 // mutate the host lookup table so the target has 0 address returned. 472 revertTbl := mutateTbl(strings.TrimPrefix(target, "foo.")) 473 // trigger a resolve that will get empty address list 474 r.ResolveNow(resolver.ResolveNowOptions{}) 475 for range 2000 { 476 state, _ = cc.getState() 477 if len(state.Addresses) == 0 { 478 break 479 } 480 time.Sleep(time.Millisecond) 481 } 482 if len(state.Addresses) != 0 { 483 t.Fatalf("UpdateState not called with 0 address after 2s; aborting. state=%v", state) 484 } 485 revertTbl() 486 // wait for the retry to happen in two seconds. 487 r.ResolveNow(resolver.ResolveNowOptions{}) 488 for range 2000 { 489 state, _ = cc.getState() 490 if len(state.Addresses) == 1 { 491 break 492 } 493 time.Sleep(time.Millisecond) 494 } 495 if !slices.Equal(want, state.Addresses) { 496 t.Errorf("Resolved addresses of target: %q = %+v, want %+v", target, state.Addresses, want) 497 } 498 } 499 500 func TestCustomAuthority(t *testing.T) { 501 defer leakcheck.Check(t) 502 defer func(nt func(d time.Duration) *time.Timer) { 503 newTimer = nt 504 }(newTimer) 505 newTimer = func(d time.Duration) *time.Timer { 506 // Will never fire on its own, will protect from triggering exponential backoff. 507 return time.NewTimer(time.Hour) 508 } 509 510 tests := []struct { 511 authority string 512 authorityWant string 513 expectError bool 514 }{ 515 { 516 "4.3.2.1:" + defaultDNSSvrPort, 517 "4.3.2.1:" + defaultDNSSvrPort, 518 false, 519 }, 520 { 521 "4.3.2.1:123", 522 "4.3.2.1:123", 523 false, 524 }, 525 { 526 "4.3.2.1", 527 "4.3.2.1:" + defaultDNSSvrPort, 528 false, 529 }, 530 { 531 "::1", 532 "[::1]:" + defaultDNSSvrPort, 533 false, 534 }, 535 { 536 "[::1]", 537 "[::1]:" + defaultDNSSvrPort, 538 false, 539 }, 540 { 541 "[::1]:123", 542 "[::1]:123", 543 false, 544 }, 545 { 546 "dnsserver.com", 547 "dnsserver.com:" + defaultDNSSvrPort, 548 false, 549 }, 550 { 551 ":123", 552 "localhost:123", 553 false, 554 }, 555 { 556 ":", 557 "", 558 true, 559 }, 560 { 561 "[::1]:", 562 "", 563 true, 564 }, 565 { 566 "dnsserver.com:", 567 "", 568 true, 569 }, 570 } 571 oldcustomAuthorityDialer := customAuthorityDialer 572 defer func() { 573 customAuthorityDialer = oldcustomAuthorityDialer 574 }() 575 576 for _, a := range tests { 577 errChan := make(chan error, 1) 578 customAuthorityDialer = func(authority string) func(ctx context.Context, network, address string) (net.Conn, error) { 579 if authority != a.authorityWant { 580 errChan <- fmt.Errorf("wrong custom authority passed to resolver. input: %s expected: %s actual: %s", a.authority, a.authorityWant, authority) 581 } else { 582 errChan <- nil 583 } 584 return func(ctx context.Context, network, address string) (net.Conn, error) { 585 return nil, errors.New("no need to dial") 586 } 587 } 588 589 mockEndpointTarget := "foo.bar.com" 590 b := NewDefaultSRVBuilder() 591 cc := &testClientConn{target: mockEndpointTarget, errChan: make(chan error, 1)} 592 target := resolver.Target{ 593 URL: *testutils.MustParseURL(fmt.Sprintf("scheme://%s/%s", a.authority, mockEndpointTarget)), 594 } 595 r, err := b.Build(target, cc, resolver.BuildOptions{}) 596 597 if err == nil { 598 r.Close() 599 600 err = <-errChan 601 if err != nil { 602 t.Error(err.Error()) 603 } 604 605 if a.expectError { 606 t.Errorf("custom authority should have caused an error: %s", a.authority) 607 } 608 } else if !a.expectError { 609 t.Errorf("unexpected error using custom authority %s: %s", a.authority, err) 610 } 611 } 612 } 613 614 // TestRateLimitedResolve exercises the rate limit enforced on re-resolution 615 // requests. It sets the re-resolution rate to a small value and repeatedly 616 // calls ResolveNow() and ensures only the expected number of resolution 617 // requests are made. 618 func TestRateLimitedResolve(t *testing.T) { 619 defer leakcheck.Check(t) 620 defer func(nt func(d time.Duration) *time.Timer) { 621 newTimer = nt 622 }(newTimer) 623 newTimer = func(d time.Duration) *time.Timer { 624 // Will never fire on its own, will protect from triggering exponential 625 // backoff. 626 return time.NewTimer(time.Hour) 627 } 628 defer func(nt func(d time.Duration) *time.Timer) { 629 newTimerDNSResRate = nt 630 }(newTimerDNSResRate) 631 632 timerChan := testutils.NewChannel() 633 newTimerDNSResRate = func(d time.Duration) *time.Timer { 634 // Will never fire on its own, allows this test to call timer 635 // immediately. 636 t := time.NewTimer(time.Hour) 637 timerChan.Send(t) 638 return t 639 } 640 641 // Create a new testResolver{} for this test because we want the exact count 642 // of the number of times the resolver was invoked. 643 nc := overrideDefaultResolver(true) 644 defer nc() 645 646 target := "foo.ipv4.single.fake" 647 b := NewDefaultSRVBuilder() 648 cc := &testClientConn{target: target} 649 650 r, err := b.Build(resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("scheme:///%s", target))}, cc, resolver.BuildOptions{}) 651 if err != nil { 652 t.Fatalf("resolver.Build() returned error: %v\n", err) 653 } 654 defer r.Close() 655 656 dnsR, ok := r.(*dnsResolver) 657 if !ok { 658 t.Fatalf("resolver.Build() returned unexpected type: %T\n", dnsR) 659 } 660 661 tr, ok := dnsR.resolver.(*testResolver) 662 if !ok { 663 t.Fatalf("delegate resolver returned unexpected type: %T\n", tr) 664 } 665 666 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 667 defer cancel() 668 669 // Wait for the first resolution request to be done. This happens as part 670 // of the first iteration of the for loop in watcher(). 671 if _, err := tr.lookupHostCh.Receive(ctx); err != nil { 672 t.Fatalf("Timed out waiting for lookup() call.") 673 } 674 675 // Call Resolve Now 100 times, shouldn't continue onto next iteration of 676 // watcher, thus shouldn't lookup again. 677 for range 100 { 678 r.ResolveNow(resolver.ResolveNowOptions{}) 679 } 680 681 continueCtx, continueCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) 682 defer continueCancel() 683 684 if _, err := tr.lookupHostCh.Receive(continueCtx); err == nil { 685 t.Fatalf("Should not have looked up again as DNS Min Res Rate timer has not gone off.") 686 } 687 688 // Make the DNSMinResRate timer fire immediately (by receiving it, then 689 // resetting to 0), this will unblock the resolver which is currently 690 // blocked on the DNS Min Res Rate timer going off, which will allow it to 691 // continue to the next iteration of the watcher loop. 692 timer, err := timerChan.Receive(ctx) 693 if err != nil { 694 t.Fatalf("Error receiving timer from mock NewTimer call: %v", err) 695 } 696 timerPointer := timer.(*time.Timer) 697 timerPointer.Reset(0) 698 699 // Now that DNS Min Res Rate timer has gone off, it should lookup again. 700 if _, err := tr.lookupHostCh.Receive(ctx); err != nil { 701 t.Fatalf("Timed out waiting for lookup() call.") 702 } 703 704 // Resolve Now 1000 more times, shouldn't lookup again as DNS Min Res Rate 705 // timer has not gone off. 706 for range 1000 { 707 r.ResolveNow(resolver.ResolveNowOptions{}) 708 } 709 710 if _, err = tr.lookupHostCh.Receive(continueCtx); err == nil { 711 t.Fatalf("Should not have looked up again as DNS Min Res Rate timer has not gone off.") 712 } 713 714 // Make the DNSMinResRate timer fire immediately again. 715 timer, err = timerChan.Receive(ctx) 716 if err != nil { 717 t.Fatalf("Error receiving timer from mock NewTimer call: %v", err) 718 } 719 timerPointer = timer.(*time.Timer) 720 timerPointer.Reset(0) 721 722 // Now that DNS Min Res Rate timer has gone off, it should lookup again. 723 if _, err = tr.lookupHostCh.Receive(ctx); err != nil { 724 t.Fatalf("Timed out waiting for lookup() call.") 725 } 726 727 wantAddrs := []resolver.Address{{Addr: "2.4.6.8:1234", ServerName: "ipv4.single.fake"}} 728 var state resolver.State 729 for { 730 var cnt int 731 state, cnt = cc.getState() 732 if cnt > 0 { 733 break 734 } 735 time.Sleep(time.Millisecond) 736 } 737 if !slices.Equal(state.Addresses, wantAddrs) { 738 t.Errorf("Resolved addresses of target: %q = %+v, want %+v", target, state.Addresses, wantAddrs) 739 } 740 } 741 742 // DNS Resolver immediately starts polling on an error. This will cause the re-resolution to return another error. 743 // Thus, test that it constantly sends errors to the grpc.ClientConn. 744 func TestReportError(t *testing.T) { 745 const target = "not.found" 746 defer func(nt func(d time.Duration) *time.Timer) { 747 newTimer = nt 748 }(newTimer) 749 timerChan := testutils.NewChannel() 750 newTimer = func(d time.Duration) *time.Timer { 751 // Will never fire on its own, allows this test to call timer immediately. 752 t := time.NewTimer(time.Hour) 753 timerChan.Send(t) 754 return t 755 } 756 cc := &testClientConn{target: target, errChan: make(chan error)} 757 totalTimesCalledError := 0 758 b := NewDefaultSRVBuilder() 759 r, err := b.Build(resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("scheme:///%s", target))}, cc, resolver.BuildOptions{}) 760 if err != nil { 761 t.Fatalf("Error building resolver for target %v: %v", target, err) 762 } 763 // Should receive first error. 764 err = <-cc.errChan 765 if !strings.Contains(err.Error(), "srvLookup error") { 766 t.Fatalf(`ReportError(err=%v) called; want err contains "srvLookupError"`, err) 767 } 768 totalTimesCalledError++ 769 ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) 770 defer ctxCancel() 771 timer, err := timerChan.Receive(ctx) 772 if err != nil { 773 t.Fatalf("Error receiving timer from mock NewTimer call: %v", err) 774 } 775 timerPointer := timer.(*time.Timer) 776 timerPointer.Reset(0) 777 defer r.Close() 778 779 // Cause timer to go off 10 times, and see if it matches DNS Resolver updating Error. 780 for range 10 { 781 // Should call ReportError(). 782 err = <-cc.errChan 783 if !strings.Contains(err.Error(), "srvLookup error") { 784 t.Fatalf(`ReportError(err=%v) called; want err contains "srvLookupError"`, err) 785 } 786 totalTimesCalledError++ 787 timer, err := timerChan.Receive(ctx) 788 if err != nil { 789 t.Fatalf("Error receiving timer from mock NewTimer call: %v", err) 790 } 791 timerPointer := timer.(*time.Timer) 792 timerPointer.Reset(0) 793 } 794 795 if totalTimesCalledError != 11 { 796 t.Errorf("ReportError() not called 11 times, instead called %d times.", totalTimesCalledError) 797 } 798 // Clean up final watcher iteration. 799 <-cc.errChan 800 _, err = timerChan.Receive(ctx) 801 if err != nil { 802 t.Fatalf("Error receiving timer from mock NewTimer call: %v", err) 803 } 804 } 805 806 func Test_parseServiceDomain(t *testing.T) { 807 tests := []struct { 808 target string 809 expectService string 810 expectDomain string 811 wantErr bool 812 }{ 813 // valid 814 {"foo.bar", "foo", "bar", false}, 815 {"foo.bar.baz", "foo", "bar.baz", false}, 816 {"foo.bar.baz.", "foo", "bar.baz.", false}, 817 818 // invalid 819 {"", "", "", true}, 820 {".", "", "", true}, 821 {"foo", "", "", true}, 822 {".foo", "", "", true}, 823 {"foo.", "", "", true}, 824 {".foo.bar.baz", "", "", true}, 825 {".foo.bar.baz.", "", "", true}, 826 } 827 for _, tt := range tests { 828 t.Run(tt.target, func(t *testing.T) { 829 gotService, gotDomain, err := parseServiceDomain(tt.target) 830 if tt.wantErr { 831 test.AssertError(t, err, "expect err got nil") 832 } else { 833 test.AssertNotError(t, err, "expect nil err") 834 test.AssertEquals(t, gotService, tt.expectService) 835 test.AssertEquals(t, gotDomain, tt.expectDomain) 836 } 837 }) 838 } 839 }