github.com/psiphon-labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/common/resolver/resolver_test.go (about) 1 /* 2 * Copyright (c) 2022, Psiphon Inc. 3 * All rights reserved. 4 * 5 * This program is free software: you can redistribute it and/or modify 6 * it under the terms of the GNU General Public License as published by 7 * the Free Software Foundation, either version 3 of the License, or 8 * (at your option) any later version. 9 * 10 * This program is distributed in the hope that it will be useful, 11 * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 * GNU General Public License for more details. 14 * 15 * You should have received a copy of the GNU General Public License 16 * along with this program. If not, see <http://www.gnu.org/licenses/>. 17 * 18 */ 19 20 package resolver 21 22 import ( 23 "context" 24 "fmt" 25 "net" 26 "reflect" 27 "sync/atomic" 28 "testing" 29 "time" 30 31 "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common" 32 "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors" 33 "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/parameters" 34 "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng" 35 "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/transforms" 36 "github.com/miekg/dns" 37 ) 38 39 func TestMakeResolveParameters(t *testing.T) { 40 err := runTestMakeResolveParameters() 41 if err != nil { 42 t.Fatalf(errors.Trace(err).Error()) 43 } 44 } 45 46 func TestResolver(t *testing.T) { 47 err := runTestResolver() 48 if err != nil { 49 t.Fatalf(errors.Trace(err).Error()) 50 } 51 } 52 53 func TestPublicDNSServers(t *testing.T) { 54 IPs, metrics, err := runTestPublicDNSServers() 55 if err != nil { 56 t.Fatalf(errors.Trace(err).Error()) 57 } 58 t.Logf("IPs: %v", IPs) 59 t.Logf("Metrics: %v", metrics) 60 } 61 62 func runTestMakeResolveParameters() error { 63 64 frontingProviderID := "frontingProvider" 65 alternateDNSServer := "172.16.0.1" 66 alternateDNSServerWithPort := net.JoinHostPort(alternateDNSServer, resolverDNSPort) 67 preferredAlternateDNSServer := "172.16.0.2" 68 preferredAlternateDNSServerWithPort := net.JoinHostPort(preferredAlternateDNSServer, resolverDNSPort) 69 transformName := "exampleTransform" 70 71 paramValues := map[string]interface{}{ 72 "DNSResolverAttemptsPerServer": 2, 73 "DNSResolverAttemptsPerPreferredServer": 1, 74 "DNSResolverPreresolvedIPAddressProbability": 1.0, 75 "DNSResolverPreresolvedIPAddressCIDRs": parameters.LabeledCIDRs{frontingProviderID: []string{exampleIPv4CIDR}}, 76 "DNSResolverAlternateServers": []string{alternateDNSServer}, 77 "DNSResolverPreferredAlternateServers": []string{preferredAlternateDNSServer}, 78 "DNSResolverPreferAlternateServerProbability": 1.0, 79 "DNSResolverProtocolTransformProbability": 1.0, 80 "DNSResolverProtocolTransformSpecs": transforms.Specs{transformName: exampleTransform}, 81 "DNSResolverProtocolTransformScopedSpecNames": transforms.ScopedSpecNames{preferredAlternateDNSServer: []string{transformName}}, 82 "DNSResolverIncludeEDNS0Probability": 1.0, 83 } 84 85 params, err := parameters.NewParameters(nil) 86 if err != nil { 87 return errors.Trace(err) 88 } 89 _, err = params.Set("", false, paramValues) 90 if err != nil { 91 return errors.Trace(err) 92 } 93 94 resolver := NewResolver(&NetworkConfig{}, "") 95 defer resolver.Stop() 96 97 resolverParams, err := resolver.MakeResolveParameters( 98 params.Get(), frontingProviderID) 99 if err != nil { 100 return errors.Trace(err) 101 } 102 103 // Test: PreresolvedIPAddress 104 105 CIDRContainsIP := func(CIDR, IP string) bool { 106 _, IPNet, _ := net.ParseCIDR(CIDR) 107 return IPNet.Contains(net.ParseIP(IP)) 108 } 109 110 if resolverParams.AttemptsPerServer != 2 || 111 resolverParams.AttemptsPerPreferredServer != 1 || 112 resolverParams.RequestTimeout != 5*time.Second || 113 resolverParams.AwaitTimeout != 10*time.Millisecond || 114 !CIDRContainsIP(exampleIPv4CIDR, resolverParams.PreresolvedIPAddress) || 115 resolverParams.AlternateDNSServer != "" || 116 resolverParams.PreferAlternateDNSServer != false || 117 resolverParams.ProtocolTransformName != "" || 118 resolverParams.ProtocolTransformSpec != nil || 119 resolverParams.IncludeEDNS0 != false { 120 return errors.Tracef("unexpected resolver parameters: %+v", resolverParams) 121 } 122 123 // Test: additional generateIPAddressFromCIDR cases 124 125 for i := 0; i < 10000; i++ { 126 for _, CIDR := range []string{exampleIPv4CIDR, exampleIPv6CIDR} { 127 IP, err := generateIPAddressFromCIDR(CIDR) 128 if err != nil { 129 return errors.Trace(err) 130 } 131 if !CIDRContainsIP(CIDR, IP.String()) || common.IsBogon(IP) { 132 return errors.Tracef( 133 "invalid generated IP address %v for CIDR %v", IP, CIDR) 134 } 135 } 136 } 137 138 // Test: Preferred/Transform/EDNS(0) 139 140 paramValues["DNSResolverPreresolvedIPAddressProbability"] = 0.0 141 142 _, err = params.Set("", false, paramValues) 143 if err != nil { 144 return errors.Trace(err) 145 } 146 147 resolverParams, err = resolver.MakeResolveParameters( 148 params.Get(), frontingProviderID) 149 if err != nil { 150 return errors.Trace(err) 151 } 152 153 if resolverParams.AttemptsPerServer != 2 || 154 resolverParams.AttemptsPerPreferredServer != 1 || 155 resolverParams.RequestTimeout != 5*time.Second || 156 resolverParams.AwaitTimeout != 10*time.Millisecond || 157 resolverParams.PreresolvedIPAddress != "" || 158 resolverParams.AlternateDNSServer != preferredAlternateDNSServerWithPort || 159 resolverParams.PreferAlternateDNSServer != true || 160 resolverParams.ProtocolTransformName != transformName || 161 resolverParams.ProtocolTransformSpec == nil || 162 resolverParams.IncludeEDNS0 != true { 163 return errors.Tracef("unexpected resolver parameters: %+v", resolverParams) 164 } 165 166 // Test: No Preferred/Transform/EDNS(0) 167 168 paramValues["DNSResolverPreferAlternateServerProbability"] = 0.0 169 paramValues["DNSResolverProtocolTransformProbability"] = 0.0 170 paramValues["DNSResolverIncludeEDNS0Probability"] = 0.0 171 172 _, err = params.Set("", false, paramValues) 173 if err != nil { 174 return errors.Trace(err) 175 } 176 177 resolverParams, err = resolver.MakeResolveParameters( 178 params.Get(), frontingProviderID) 179 if err != nil { 180 return errors.Trace(err) 181 } 182 183 if resolverParams.AttemptsPerServer != 2 || 184 resolverParams.AttemptsPerPreferredServer != 1 || 185 resolverParams.RequestTimeout != 5*time.Second || 186 resolverParams.AwaitTimeout != 10*time.Millisecond || 187 resolverParams.PreresolvedIPAddress != "" || 188 resolverParams.AlternateDNSServer != alternateDNSServerWithPort || 189 resolverParams.PreferAlternateDNSServer != false || 190 resolverParams.ProtocolTransformName != "" || 191 resolverParams.ProtocolTransformSpec != nil || 192 resolverParams.IncludeEDNS0 != false { 193 return errors.Tracef("unexpected resolver parameters: %+v", resolverParams) 194 } 195 196 return nil 197 } 198 199 func runTestResolver() error { 200 201 // noResponseServer will not respond to requests 202 noResponseServer, err := newTestDNSServer(false, false, false) 203 if err != nil { 204 return errors.Trace(err) 205 } 206 defer noResponseServer.stop() 207 208 // invalidIPServer will respond with an invalid IP 209 invalidIPServer, err := newTestDNSServer(true, false, false) 210 if err != nil { 211 return errors.Trace(err) 212 } 213 defer invalidIPServer.stop() 214 215 // okServer will respond to correct requests (expected domain) with the 216 // correct response (expected IPv4 or IPv6 address) 217 okServer, err := newTestDNSServer(true, true, false) 218 if err != nil { 219 return errors.Trace(err) 220 } 221 defer okServer.stop() 222 223 // alternateOkServer behaves like okServer; getRequestCount is used to 224 // confirm that the alternate server was indeed used 225 alternateOkServer, err := newTestDNSServer(true, true, false) 226 if err != nil { 227 return errors.Trace(err) 228 } 229 defer alternateOkServer.stop() 230 231 // transformOkServer behaves like okServer but only responds if the 232 // transform was applied; other servers do not respond if the transform 233 // is applied 234 transformOkServer, err := newTestDNSServer(true, true, true) 235 if err != nil { 236 return errors.Trace(err) 237 } 238 defer transformOkServer.stop() 239 240 servers := []string{noResponseServer.getAddr(), invalidIPServer.getAddr(), okServer.getAddr()} 241 242 networkConfig := &NetworkConfig{ 243 GetDNSServers: func() []string { return servers }, 244 LogWarning: func(err error) { fmt.Printf("LogWarning: %v\n", err) }, 245 } 246 247 networkID := "networkID-1" 248 249 resolver := NewResolver(networkConfig, networkID) 250 defer resolver.Stop() 251 252 params := &ResolveParameters{ 253 AttemptsPerServer: 1, 254 AttemptsPerPreferredServer: 1, 255 RequestTimeout: 250 * time.Millisecond, 256 AwaitTimeout: 250 * time.Millisecond, 257 IncludeEDNS0: true, 258 } 259 260 checkResult := func(IPs []net.IP) error { 261 var IPv4, IPv6 net.IP 262 for _, IP := range IPs { 263 if IP.To4() != nil { 264 IPv4 = IP 265 } else { 266 IPv6 = IP 267 } 268 } 269 if IPv4 == nil { 270 return errors.TraceNew("missing IPv4 response") 271 } 272 if IPv4.String() != exampleIPv4 { 273 return errors.TraceNew("unexpected IPv4 response") 274 } 275 if resolver.hasIPv6Route { 276 if IPv6 == nil { 277 return errors.TraceNew("missing IPv6 response") 278 } 279 if IPv6.String() != exampleIPv6 { 280 return errors.TraceNew("unexpected IPv6 response") 281 } 282 } 283 return nil 284 } 285 286 ctx, cancelFunc := context.WithTimeout(context.Background(), 10*time.Second) 287 defer cancelFunc() 288 289 // Test: should retry until okServer responds 290 291 IPs, err := resolver.ResolveIP(ctx, networkID, params, exampleDomain) 292 if err != nil { 293 return errors.Trace(err) 294 } 295 296 err = checkResult(IPs) 297 if err != nil { 298 return errors.Trace(err) 299 } 300 301 if resolver.metrics.resolves != 1 || 302 resolver.metrics.cacheHits != 0 || 303 resolver.metrics.requestsIPv4 != 3 || resolver.metrics.responsesIPv4 != 1 || 304 (resolver.hasIPv6Route && (resolver.metrics.requestsIPv6 != 3 || resolver.metrics.responsesIPv6 != 1)) { 305 return errors.Tracef("unexpected metrics: %+v", resolver.metrics) 306 } 307 308 // Test: cached response 309 310 beforeMetrics := resolver.metrics 311 312 IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain) 313 if err != nil { 314 return errors.Trace(err) 315 } 316 317 err = checkResult(IPs) 318 if err != nil { 319 return errors.Trace(err) 320 } 321 322 if resolver.metrics.resolves != beforeMetrics.resolves+1 || 323 resolver.metrics.cacheHits != beforeMetrics.cacheHits+1 || 324 resolver.metrics.requestsIPv4 != beforeMetrics.requestsIPv4 || 325 resolver.metrics.requestsIPv6 != beforeMetrics.requestsIPv6 { 326 return errors.Tracef("unexpected metrics: %+v", resolver.metrics) 327 } 328 329 // Test: PreresolvedIPAddress 330 331 beforeMetrics = resolver.metrics 332 333 params.PreresolvedIPAddress = exampleIPv4 334 335 IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain) 336 if err != nil { 337 return errors.Trace(err) 338 } 339 340 if len(IPs) != 1 || IPs[0].String() != exampleIPv4 { 341 return errors.TraceNew("unexpected preresolved response") 342 } 343 344 if resolver.metrics.resolves != beforeMetrics.resolves+1 || 345 resolver.metrics.cacheHits != beforeMetrics.cacheHits || 346 resolver.metrics.requestsIPv4 != beforeMetrics.requestsIPv4 || 347 resolver.metrics.requestsIPv6 != beforeMetrics.requestsIPv6 { 348 return errors.Tracef("unexpected metrics: %+v", resolver.metrics) 349 } 350 351 params.PreresolvedIPAddress = "" 352 353 // Test: change network ID, which must clear cache 354 355 beforeMetrics = resolver.metrics 356 357 networkID = "networkID-2" 358 359 IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain) 360 if err != nil { 361 return errors.Trace(err) 362 } 363 364 err = checkResult(IPs) 365 if err != nil { 366 return errors.Trace(err) 367 } 368 369 if resolver.metrics.resolves != beforeMetrics.resolves+1 || 370 resolver.metrics.cacheHits != beforeMetrics.cacheHits { 371 return errors.Tracef("unexpected metrics: %+v (%+v)", resolver.metrics, beforeMetrics) 372 } 373 374 // Test: PreferAlternateDNSServer 375 376 if alternateOkServer.getRequestCount() != 0 { 377 return errors.TraceNew("unexpected alternate server request count") 378 } 379 380 resolver.cache.Flush() 381 382 params.AlternateDNSServer = alternateOkServer.getAddr() 383 params.PreferAlternateDNSServer = true 384 385 IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain) 386 if err != nil { 387 return errors.Trace(err) 388 } 389 390 err = checkResult(IPs) 391 if err != nil { 392 return errors.Trace(err) 393 } 394 395 if alternateOkServer.getRequestCount() < 1 { 396 return errors.TraceNew("unexpected alternate server request count") 397 } 398 399 params.AlternateDNSServer = "" 400 params.PreferAlternateDNSServer = false 401 402 // Test: PreferAlternateDNSServer with failed attempt (exercise maxAttempts prefer case) 403 404 resolver.cache.Flush() 405 406 params.AlternateDNSServer = invalidIPServer.getAddr() 407 params.PreferAlternateDNSServer = true 408 409 IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain) 410 if err != nil { 411 return errors.Trace(err) 412 } 413 414 err = checkResult(IPs) 415 if err != nil { 416 return errors.Trace(err) 417 } 418 419 params.AlternateDNSServer = "" 420 params.PreferAlternateDNSServer = false 421 422 // Test: fall over to AlternateDNSServer when no system servers 423 424 beforeCount := alternateOkServer.getRequestCount() 425 426 previousGetDNSServers := networkConfig.GetDNSServers 427 428 networkConfig.GetDNSServers = func() []string { return nil } 429 430 // Force system servers update 431 networkID = "networkID-3" 432 433 resolver.cache.Flush() 434 435 params.AlternateDNSServer = alternateOkServer.getAddr() 436 params.PreferAlternateDNSServer = false 437 438 IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain) 439 if err != nil { 440 return errors.Trace(err) 441 } 442 443 err = checkResult(IPs) 444 if err != nil { 445 return errors.Trace(err) 446 } 447 448 if alternateOkServer.getRequestCount() <= beforeCount { 449 return errors.TraceNew("unexpected alterate server request count") 450 } 451 452 // Test: use default, standard resolver when no servers 453 454 resolver.cache.Flush() 455 456 params.AlternateDNSServer = "" 457 params.PreferAlternateDNSServer = false 458 459 if len(resolver.systemServers) != 0 { 460 return errors.TraceNew("unexpected server count") 461 } 462 463 IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain) 464 if err != nil { 465 return errors.Trace(err) 466 } 467 468 if len(IPs) == 0 { 469 return errors.TraceNew("unexpected response") 470 } 471 472 // Test: ResolveAddress 473 474 networkConfig.GetDNSServers = previousGetDNSServers 475 476 // Force system servers update 477 networkID = "networkID-4" 478 479 domainAddress := net.JoinHostPort(exampleDomain, "443") 480 481 address, err := resolver.ResolveAddress(ctx, networkID, params, domainAddress) 482 if err != nil { 483 return errors.Trace(err) 484 } 485 486 host, port, err := net.SplitHostPort(address) 487 if err != nil { 488 return errors.Trace(err) 489 } 490 491 IP := net.ParseIP(host) 492 493 if IP == nil || (host != exampleIPv4 && host != exampleIPv6) || port != "443" { 494 return errors.TraceNew("unexpected response") 495 } 496 497 // Test: protocol transform 498 499 if transformOkServer.getRequestCount() != 0 { 500 return errors.TraceNew("unexpected transform server request count") 501 } 502 503 resolver.cache.Flush() 504 505 params.AlternateDNSServer = transformOkServer.getAddr() 506 params.PreferAlternateDNSServer = true 507 508 seed, err := prng.NewSeed() 509 if err != nil { 510 return errors.Trace(err) 511 } 512 513 params.ProtocolTransformName = "exampleTransform" 514 params.ProtocolTransformSpec = exampleTransform 515 params.ProtocolTransformSeed = seed 516 517 IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain) 518 if err != nil { 519 return errors.Trace(err) 520 } 521 522 err = checkResult(IPs) 523 if err != nil { 524 return errors.Trace(err) 525 } 526 527 if transformOkServer.getRequestCount() < 1 { 528 return errors.TraceNew("unexpected transform server request count") 529 } 530 531 params.AlternateDNSServer = "" 532 params.PreferAlternateDNSServer = false 533 params.ProtocolTransformName = "" 534 params.ProtocolTransformSpec = nil 535 params.ProtocolTransformSeed = nil 536 537 // Test: EDNS(0) 538 539 resolver.cache.Flush() 540 541 params.IncludeEDNS0 = true 542 543 IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain) 544 if err != nil { 545 return errors.Trace(err) 546 } 547 548 err = checkResult(IPs) 549 if err != nil { 550 return errors.Trace(err) 551 } 552 553 params.IncludeEDNS0 = false 554 555 // Test: input IP address 556 557 beforeMetrics = resolver.metrics 558 559 resolver.cache.Flush() 560 561 IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleIPv4) 562 if err != nil { 563 return errors.Trace(err) 564 } 565 566 if len(IPs) != 1 || IPs[0].String() != exampleIPv4 { 567 return errors.TraceNew("unexpected IPv4 response") 568 } 569 570 if resolver.metrics.resolves != beforeMetrics.resolves { 571 return errors.Tracef("unexpected metrics: %+v", resolver.metrics) 572 } 573 574 // Test: DNS cache extension 575 576 resolver.cache.Flush() 577 578 networkConfig.CacheExtensionInitialTTL = (exampleTTLSeconds * 2) * time.Second 579 networkConfig.CacheExtensionVerifiedTTL = 2 * time.Hour 580 581 now := time.Now() 582 583 IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain) 584 if err != nil { 585 return errors.Trace(err) 586 } 587 588 entry, expiry, ok := resolver.cache.GetWithExpiration(exampleDomain) 589 if !ok || 590 !reflect.DeepEqual(entry, IPs) || 591 expiry.Before(now.Add(networkConfig.CacheExtensionInitialTTL)) || 592 expiry.After(now.Add(networkConfig.CacheExtensionVerifiedTTL)) { 593 return errors.TraceNew("unexpected CacheExtensionInitialTTL state") 594 } 595 596 resolver.VerifyCacheExtension(exampleDomain) 597 598 entry, expiry, ok = resolver.cache.GetWithExpiration(exampleDomain) 599 if !ok || 600 !reflect.DeepEqual(entry, IPs) || 601 expiry.Before(now.Add(networkConfig.CacheExtensionVerifiedTTL)) { 602 return errors.TraceNew("unexpected CacheExtensionInitialTTL state") 603 } 604 605 // Set cache flush condition, which should be ignored 606 networkID = "networkID-5" 607 608 resolver.updateNetworkState(networkID) 609 610 entry, expiry, ok = resolver.cache.GetWithExpiration(exampleDomain) 611 if !ok || 612 !reflect.DeepEqual(entry, IPs) || 613 expiry.Before(now.Add(networkConfig.CacheExtensionVerifiedTTL)) { 614 return errors.TraceNew("unexpected CacheExtensionInitialTTL state") 615 } 616 617 // Test: cancel context 618 619 resolver.cache.Flush() 620 621 cancelFunc() 622 623 IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain) 624 if err == nil { 625 return errors.TraceNew("unexpected success") 626 } 627 628 // Test: cancel context while resolving 629 630 // This test exercises the additional answers and await cases in 631 // ResolveIP. The test is timing dependent, and so imperfect, but this 632 // configuration can reproduce panics in those cases before bugs were 633 // fixed, where DNS responses need to be received just as the context is 634 // cancelled. 635 636 networkConfig.GetDNSServers = func() []string { return []string{okServer.getAddr()} } 637 networkID = "networkID-6" 638 639 for i := 0; i < 500; i++ { 640 resolver.cache.Flush() 641 642 ctx, cancelFunc := context.WithTimeout( 643 context.Background(), time.Duration((i%10+1)*20)*time.Microsecond) 644 defer cancelFunc() 645 646 _, _ = resolver.ResolveIP(ctx, networkID, params, exampleDomain) 647 } 648 649 return nil 650 } 651 652 func runTestPublicDNSServers() ([]net.IP, string, error) { 653 654 networkConfig := &NetworkConfig{ 655 GetDNSServers: getPublicDNSServers, 656 } 657 658 networkID := "networkID-1" 659 660 resolver := NewResolver(networkConfig, networkID) 661 defer resolver.Stop() 662 663 params := &ResolveParameters{ 664 AttemptsPerServer: 1, 665 RequestTimeout: 5 * time.Second, 666 AwaitTimeout: 1 * time.Second, 667 IncludeEDNS0: true, 668 } 669 670 IPs, err := resolver.ResolveIP( 671 context.Background(), networkID, params, exampleDomain) 672 if err != nil { 673 return nil, "", errors.Trace(err) 674 } 675 676 gotIPv4 := false 677 gotIPv6 := false 678 for _, IP := range IPs { 679 if IP.To4() != nil { 680 gotIPv4 = true 681 } else { 682 gotIPv6 = true 683 } 684 } 685 if !gotIPv4 { 686 return nil, "", errors.TraceNew("missing IPv4 response") 687 } 688 if !gotIPv6 && resolver.hasIPv6Route { 689 return nil, "", errors.TraceNew("missing IPv6 response") 690 } 691 692 return IPs, resolver.GetMetrics(), nil 693 } 694 695 func getPublicDNSServers() []string { 696 servers := []string{"1.1.1.1", "8.8.8.8", "9.9.9.9"} 697 shuffledServers := make([]string, len(servers)) 698 for i, j := range prng.Perm(len(servers)) { 699 shuffledServers[i] = servers[j] 700 } 701 return shuffledServers 702 } 703 704 const ( 705 exampleDomain = "example.com" 706 exampleIPv4 = "93.184.216.34" 707 exampleIPv4CIDR = "93.184.216.0/24" 708 exampleIPv6 = "2606:2800:220:1:248:1893:25c8:1946" 709 exampleIPv6CIDR = "2606:2800:220::/48" 710 exampleTTLSeconds = 60 711 ) 712 713 // Set the reserved Z flag 714 var exampleTransform = transforms.Spec{[2]string{"^([a-f0-9]{4})0100", "\\$\\{1\\}0140"}} 715 716 type testDNSServer struct { 717 respond bool 718 validResponse bool 719 expectTransform bool 720 addr string 721 requestCount int32 722 server *dns.Server 723 } 724 725 func newTestDNSServer(respond, validResponse, expectTransform bool) (*testDNSServer, error) { 726 727 udpAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") 728 if err != nil { 729 return nil, errors.Trace(err) 730 } 731 732 udpConn, err := net.ListenUDP("udp", udpAddr) 733 if err != nil { 734 return nil, errors.Trace(err) 735 } 736 737 s := &testDNSServer{ 738 respond: respond, 739 validResponse: validResponse, 740 expectTransform: expectTransform, 741 addr: udpConn.LocalAddr().String(), 742 } 743 744 server := &dns.Server{ 745 PacketConn: udpConn, 746 Handler: s, 747 } 748 749 s.server = server 750 751 go server.ActivateAndServe() 752 753 return s, nil 754 } 755 756 func (s *testDNSServer) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { 757 atomic.AddInt32(&s.requestCount, 1) 758 759 if !s.respond { 760 return 761 } 762 763 // Check the reserved Z flag 764 if s.expectTransform != r.MsgHdr.Zero { 765 return 766 } 767 768 if len(r.Question) != 1 || r.Question[0].Name != dns.Fqdn(exampleDomain) { 769 return 770 } 771 772 m := new(dns.Msg) 773 m.SetReply(r) 774 m.Answer = make([]dns.RR, 1) 775 if r.Question[0].Qtype == dns.TypeA { 776 IP := net.ParseIP(exampleIPv4) 777 if !s.validResponse { 778 IP = net.ParseIP("127.0.0.1") 779 } 780 m.Answer[0] = &dns.A{ 781 Hdr: dns.RR_Header{ 782 Name: r.Question[0].Name, 783 Rrtype: dns.TypeA, 784 Class: dns.ClassINET, 785 Ttl: exampleTTLSeconds}, 786 A: IP, 787 } 788 } else { 789 IP := net.ParseIP(exampleIPv6) 790 if !s.validResponse { 791 IP = net.ParseIP("::1") 792 } 793 m.Answer[0] = &dns.AAAA{ 794 Hdr: dns.RR_Header{ 795 Name: r.Question[0].Name, 796 Rrtype: dns.TypeAAAA, 797 Class: dns.ClassINET, 798 Ttl: exampleTTLSeconds}, 799 AAAA: IP, 800 } 801 } 802 803 w.WriteMsg(m) 804 } 805 806 func (s *testDNSServer) getAddr() string { 807 return s.addr 808 } 809 810 func (s *testDNSServer) getRequestCount() int { 811 return int(atomic.LoadInt32(&s.requestCount)) 812 } 813 814 func (s *testDNSServer) stop() { 815 s.server.PacketConn.Close() 816 s.server.Shutdown() 817 }