github.com/Psiphon-Labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/common/resolver/resolver_test.go (about)

     1  /*
     2   * Copyright (c) 2022, Psiphon Inc.
     3   * All rights reserved.
     4   *
     5   * This program is free software: you can redistribute it and/or modify
     6   * it under the terms of the GNU General Public License as published by
     7   * the Free Software Foundation, either version 3 of the License, or
     8   * (at your option) any later version.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package resolver
    21  
    22  import (
    23  	"context"
    24  	"fmt"
    25  	"net"
    26  	"reflect"
    27  	"sync/atomic"
    28  	"testing"
    29  	"time"
    30  
    31  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
    32  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
    33  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/parameters"
    34  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
    35  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/transforms"
    36  	"github.com/miekg/dns"
    37  )
    38  
    39  func TestMakeResolveParameters(t *testing.T) {
    40  	err := runTestMakeResolveParameters()
    41  	if err != nil {
    42  		t.Fatalf(errors.Trace(err).Error())
    43  	}
    44  }
    45  
    46  func TestResolver(t *testing.T) {
    47  	err := runTestResolver()
    48  	if err != nil {
    49  		t.Fatalf(errors.Trace(err).Error())
    50  	}
    51  }
    52  
    53  func TestPublicDNSServers(t *testing.T) {
    54  	IPs, metrics, err := runTestPublicDNSServers()
    55  	if err != nil {
    56  		t.Fatalf(errors.Trace(err).Error())
    57  	}
    58  	t.Logf("IPs: %v", IPs)
    59  	t.Logf("Metrics: %v", metrics)
    60  }
    61  
    62  func runTestMakeResolveParameters() error {
    63  
    64  	frontingProviderID := "frontingProvider"
    65  	alternateDNSServer := "172.16.0.1"
    66  	alternateDNSServerWithPort := net.JoinHostPort(alternateDNSServer, resolverDNSPort)
    67  	preferredAlternateDNSServer := "172.16.0.2"
    68  	preferredAlternateDNSServerWithPort := net.JoinHostPort(preferredAlternateDNSServer, resolverDNSPort)
    69  	transformName := "exampleTransform"
    70  
    71  	paramValues := map[string]interface{}{
    72  		"DNSResolverAttemptsPerServer":                2,
    73  		"DNSResolverAttemptsPerPreferredServer":       1,
    74  		"DNSResolverPreresolvedIPAddressProbability":  1.0,
    75  		"DNSResolverPreresolvedIPAddressCIDRs":        parameters.LabeledCIDRs{frontingProviderID: []string{exampleIPv4CIDR}},
    76  		"DNSResolverAlternateServers":                 []string{alternateDNSServer},
    77  		"DNSResolverPreferredAlternateServers":        []string{preferredAlternateDNSServer},
    78  		"DNSResolverPreferAlternateServerProbability": 1.0,
    79  		"DNSResolverProtocolTransformProbability":     1.0,
    80  		"DNSResolverProtocolTransformSpecs":           transforms.Specs{transformName: exampleTransform},
    81  		"DNSResolverProtocolTransformScopedSpecNames": transforms.ScopedSpecNames{preferredAlternateDNSServer: []string{transformName}},
    82  		"DNSResolverIncludeEDNS0Probability":          1.0,
    83  	}
    84  
    85  	params, err := parameters.NewParameters(nil)
    86  	if err != nil {
    87  		return errors.Trace(err)
    88  	}
    89  	_, err = params.Set("", false, paramValues)
    90  	if err != nil {
    91  		return errors.Trace(err)
    92  	}
    93  
    94  	resolver := NewResolver(&NetworkConfig{}, "")
    95  	defer resolver.Stop()
    96  
    97  	resolverParams, err := resolver.MakeResolveParameters(
    98  		params.Get(), frontingProviderID)
    99  	if err != nil {
   100  		return errors.Trace(err)
   101  	}
   102  
   103  	// Test: PreresolvedIPAddress
   104  
   105  	CIDRContainsIP := func(CIDR, IP string) bool {
   106  		_, IPNet, _ := net.ParseCIDR(CIDR)
   107  		return IPNet.Contains(net.ParseIP(IP))
   108  	}
   109  
   110  	if resolverParams.AttemptsPerServer != 2 ||
   111  		resolverParams.AttemptsPerPreferredServer != 1 ||
   112  		resolverParams.RequestTimeout != 5*time.Second ||
   113  		resolverParams.AwaitTimeout != 10*time.Millisecond ||
   114  		!CIDRContainsIP(exampleIPv4CIDR, resolverParams.PreresolvedIPAddress) ||
   115  		resolverParams.AlternateDNSServer != "" ||
   116  		resolverParams.PreferAlternateDNSServer != false ||
   117  		resolverParams.ProtocolTransformName != "" ||
   118  		resolverParams.ProtocolTransformSpec != nil ||
   119  		resolverParams.IncludeEDNS0 != false {
   120  		return errors.Tracef("unexpected resolver parameters: %+v", resolverParams)
   121  	}
   122  
   123  	// Test: additional generateIPAddressFromCIDR cases
   124  
   125  	for i := 0; i < 10000; i++ {
   126  		for _, CIDR := range []string{exampleIPv4CIDR, exampleIPv6CIDR} {
   127  			IP, err := generateIPAddressFromCIDR(CIDR)
   128  			if err != nil {
   129  				return errors.Trace(err)
   130  			}
   131  			if !CIDRContainsIP(CIDR, IP.String()) || common.IsBogon(IP) {
   132  				return errors.Tracef(
   133  					"invalid generated IP address %v for CIDR %v", IP, CIDR)
   134  			}
   135  		}
   136  	}
   137  
   138  	// Test: Preferred/Transform/EDNS(0)
   139  
   140  	paramValues["DNSResolverPreresolvedIPAddressProbability"] = 0.0
   141  
   142  	_, err = params.Set("", false, paramValues)
   143  	if err != nil {
   144  		return errors.Trace(err)
   145  	}
   146  
   147  	resolverParams, err = resolver.MakeResolveParameters(
   148  		params.Get(), frontingProviderID)
   149  	if err != nil {
   150  		return errors.Trace(err)
   151  	}
   152  
   153  	if resolverParams.AttemptsPerServer != 2 ||
   154  		resolverParams.AttemptsPerPreferredServer != 1 ||
   155  		resolverParams.RequestTimeout != 5*time.Second ||
   156  		resolverParams.AwaitTimeout != 10*time.Millisecond ||
   157  		resolverParams.PreresolvedIPAddress != "" ||
   158  		resolverParams.AlternateDNSServer != preferredAlternateDNSServerWithPort ||
   159  		resolverParams.PreferAlternateDNSServer != true ||
   160  		resolverParams.ProtocolTransformName != transformName ||
   161  		resolverParams.ProtocolTransformSpec == nil ||
   162  		resolverParams.IncludeEDNS0 != true {
   163  		return errors.Tracef("unexpected resolver parameters: %+v", resolverParams)
   164  	}
   165  
   166  	// Test: No Preferred/Transform/EDNS(0)
   167  
   168  	paramValues["DNSResolverPreferAlternateServerProbability"] = 0.0
   169  	paramValues["DNSResolverProtocolTransformProbability"] = 0.0
   170  	paramValues["DNSResolverIncludeEDNS0Probability"] = 0.0
   171  
   172  	_, err = params.Set("", false, paramValues)
   173  	if err != nil {
   174  		return errors.Trace(err)
   175  	}
   176  
   177  	resolverParams, err = resolver.MakeResolveParameters(
   178  		params.Get(), frontingProviderID)
   179  	if err != nil {
   180  		return errors.Trace(err)
   181  	}
   182  
   183  	if resolverParams.AttemptsPerServer != 2 ||
   184  		resolverParams.AttemptsPerPreferredServer != 1 ||
   185  		resolverParams.RequestTimeout != 5*time.Second ||
   186  		resolverParams.AwaitTimeout != 10*time.Millisecond ||
   187  		resolverParams.PreresolvedIPAddress != "" ||
   188  		resolverParams.AlternateDNSServer != alternateDNSServerWithPort ||
   189  		resolverParams.PreferAlternateDNSServer != false ||
   190  		resolverParams.ProtocolTransformName != "" ||
   191  		resolverParams.ProtocolTransformSpec != nil ||
   192  		resolverParams.IncludeEDNS0 != false {
   193  		return errors.Tracef("unexpected resolver parameters: %+v", resolverParams)
   194  	}
   195  
   196  	return nil
   197  }
   198  
   199  func runTestResolver() error {
   200  
   201  	// noResponseServer will not respond to requests
   202  	noResponseServer, err := newTestDNSServer(false, false, false)
   203  	if err != nil {
   204  		return errors.Trace(err)
   205  	}
   206  	defer noResponseServer.stop()
   207  
   208  	// invalidIPServer will respond with an invalid IP
   209  	invalidIPServer, err := newTestDNSServer(true, false, false)
   210  	if err != nil {
   211  		return errors.Trace(err)
   212  	}
   213  	defer invalidIPServer.stop()
   214  
   215  	// okServer will respond to correct requests (expected domain) with the
   216  	// correct response (expected IPv4 or IPv6 address)
   217  	okServer, err := newTestDNSServer(true, true, false)
   218  	if err != nil {
   219  		return errors.Trace(err)
   220  	}
   221  	defer okServer.stop()
   222  
   223  	// alternateOkServer behaves like okServer; getRequestCount is used to
   224  	// confirm that the alternate server was indeed used
   225  	alternateOkServer, err := newTestDNSServer(true, true, false)
   226  	if err != nil {
   227  		return errors.Trace(err)
   228  	}
   229  	defer alternateOkServer.stop()
   230  
   231  	// transformOkServer behaves like okServer but only responds if the
   232  	// transform was applied; other servers do not respond if the transform
   233  	// is applied
   234  	transformOkServer, err := newTestDNSServer(true, true, true)
   235  	if err != nil {
   236  		return errors.Trace(err)
   237  	}
   238  	defer transformOkServer.stop()
   239  
   240  	servers := []string{noResponseServer.getAddr(), invalidIPServer.getAddr(), okServer.getAddr()}
   241  
   242  	networkConfig := &NetworkConfig{
   243  		GetDNSServers: func() []string { return servers },
   244  		LogWarning:    func(err error) { fmt.Printf("LogWarning: %v\n", err) },
   245  	}
   246  
   247  	networkID := "networkID-1"
   248  
   249  	resolver := NewResolver(networkConfig, networkID)
   250  	defer resolver.Stop()
   251  
   252  	params := &ResolveParameters{
   253  		AttemptsPerServer:          1,
   254  		AttemptsPerPreferredServer: 1,
   255  		RequestTimeout:             250 * time.Millisecond,
   256  		AwaitTimeout:               250 * time.Millisecond,
   257  		IncludeEDNS0:               true,
   258  	}
   259  
   260  	checkResult := func(IPs []net.IP) error {
   261  		var IPv4, IPv6 net.IP
   262  		for _, IP := range IPs {
   263  			if IP.To4() != nil {
   264  				IPv4 = IP
   265  			} else {
   266  				IPv6 = IP
   267  			}
   268  		}
   269  		if IPv4 == nil {
   270  			return errors.TraceNew("missing IPv4 response")
   271  		}
   272  		if IPv4.String() != exampleIPv4 {
   273  			return errors.TraceNew("unexpected IPv4 response")
   274  		}
   275  		if resolver.hasIPv6Route {
   276  			if IPv6 == nil {
   277  				return errors.TraceNew("missing IPv6 response")
   278  			}
   279  			if IPv6.String() != exampleIPv6 {
   280  				return errors.TraceNew("unexpected IPv6 response")
   281  			}
   282  		}
   283  		return nil
   284  	}
   285  
   286  	ctx, cancelFunc := context.WithTimeout(context.Background(), 10*time.Second)
   287  	defer cancelFunc()
   288  
   289  	// Test: should retry until okServer responds
   290  
   291  	IPs, err := resolver.ResolveIP(ctx, networkID, params, exampleDomain)
   292  	if err != nil {
   293  		return errors.Trace(err)
   294  	}
   295  
   296  	err = checkResult(IPs)
   297  	if err != nil {
   298  		return errors.Trace(err)
   299  	}
   300  
   301  	if resolver.metrics.resolves != 1 ||
   302  		resolver.metrics.cacheHits != 0 ||
   303  		resolver.metrics.requestsIPv4 != 3 || resolver.metrics.responsesIPv4 != 1 ||
   304  		(resolver.hasIPv6Route && (resolver.metrics.requestsIPv6 != 3 || resolver.metrics.responsesIPv6 != 1)) {
   305  		return errors.Tracef("unexpected metrics: %+v", resolver.metrics)
   306  	}
   307  
   308  	// Test: cached response
   309  
   310  	beforeMetrics := resolver.metrics
   311  
   312  	IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
   313  	if err != nil {
   314  		return errors.Trace(err)
   315  	}
   316  
   317  	err = checkResult(IPs)
   318  	if err != nil {
   319  		return errors.Trace(err)
   320  	}
   321  
   322  	if resolver.metrics.resolves != beforeMetrics.resolves+1 ||
   323  		resolver.metrics.cacheHits != beforeMetrics.cacheHits+1 ||
   324  		resolver.metrics.requestsIPv4 != beforeMetrics.requestsIPv4 ||
   325  		resolver.metrics.requestsIPv6 != beforeMetrics.requestsIPv6 {
   326  		return errors.Tracef("unexpected metrics: %+v", resolver.metrics)
   327  	}
   328  
   329  	// Test: PreresolvedIPAddress
   330  
   331  	beforeMetrics = resolver.metrics
   332  
   333  	params.PreresolvedIPAddress = exampleIPv4
   334  
   335  	IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
   336  	if err != nil {
   337  		return errors.Trace(err)
   338  	}
   339  
   340  	if len(IPs) != 1 || IPs[0].String() != exampleIPv4 {
   341  		return errors.TraceNew("unexpected preresolved response")
   342  	}
   343  
   344  	if resolver.metrics.resolves != beforeMetrics.resolves+1 ||
   345  		resolver.metrics.cacheHits != beforeMetrics.cacheHits ||
   346  		resolver.metrics.requestsIPv4 != beforeMetrics.requestsIPv4 ||
   347  		resolver.metrics.requestsIPv6 != beforeMetrics.requestsIPv6 {
   348  		return errors.Tracef("unexpected metrics: %+v", resolver.metrics)
   349  	}
   350  
   351  	params.PreresolvedIPAddress = ""
   352  
   353  	// Test: change network ID, which must clear cache
   354  
   355  	beforeMetrics = resolver.metrics
   356  
   357  	networkID = "networkID-2"
   358  
   359  	IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
   360  	if err != nil {
   361  		return errors.Trace(err)
   362  	}
   363  
   364  	err = checkResult(IPs)
   365  	if err != nil {
   366  		return errors.Trace(err)
   367  	}
   368  
   369  	if resolver.metrics.resolves != beforeMetrics.resolves+1 ||
   370  		resolver.metrics.cacheHits != beforeMetrics.cacheHits {
   371  		return errors.Tracef("unexpected metrics: %+v (%+v)", resolver.metrics, beforeMetrics)
   372  	}
   373  
   374  	// Test: PreferAlternateDNSServer
   375  
   376  	if alternateOkServer.getRequestCount() != 0 {
   377  		return errors.TraceNew("unexpected alternate server request count")
   378  	}
   379  
   380  	resolver.cache.Flush()
   381  
   382  	params.AlternateDNSServer = alternateOkServer.getAddr()
   383  	params.PreferAlternateDNSServer = true
   384  
   385  	IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
   386  	if err != nil {
   387  		return errors.Trace(err)
   388  	}
   389  
   390  	err = checkResult(IPs)
   391  	if err != nil {
   392  		return errors.Trace(err)
   393  	}
   394  
   395  	if alternateOkServer.getRequestCount() < 1 {
   396  		return errors.TraceNew("unexpected alternate server request count")
   397  	}
   398  
   399  	params.AlternateDNSServer = ""
   400  	params.PreferAlternateDNSServer = false
   401  
   402  	// Test: PreferAlternateDNSServer with failed attempt (exercise maxAttempts prefer case)
   403  
   404  	resolver.cache.Flush()
   405  
   406  	params.AlternateDNSServer = invalidIPServer.getAddr()
   407  	params.PreferAlternateDNSServer = true
   408  
   409  	IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
   410  	if err != nil {
   411  		return errors.Trace(err)
   412  	}
   413  
   414  	err = checkResult(IPs)
   415  	if err != nil {
   416  		return errors.Trace(err)
   417  	}
   418  
   419  	params.AlternateDNSServer = ""
   420  	params.PreferAlternateDNSServer = false
   421  
   422  	// Test: fall over to AlternateDNSServer when no system servers
   423  
   424  	beforeCount := alternateOkServer.getRequestCount()
   425  
   426  	previousGetDNSServers := networkConfig.GetDNSServers
   427  
   428  	networkConfig.GetDNSServers = func() []string { return nil }
   429  
   430  	// Force system servers update
   431  	networkID = "networkID-3"
   432  
   433  	resolver.cache.Flush()
   434  
   435  	params.AlternateDNSServer = alternateOkServer.getAddr()
   436  	params.PreferAlternateDNSServer = false
   437  
   438  	IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
   439  	if err != nil {
   440  		return errors.Trace(err)
   441  	}
   442  
   443  	err = checkResult(IPs)
   444  	if err != nil {
   445  		return errors.Trace(err)
   446  	}
   447  
   448  	if alternateOkServer.getRequestCount() <= beforeCount {
   449  		return errors.TraceNew("unexpected alterate server request count")
   450  	}
   451  
   452  	// Test: use default, standard resolver when no servers
   453  
   454  	resolver.cache.Flush()
   455  
   456  	params.AlternateDNSServer = ""
   457  	params.PreferAlternateDNSServer = false
   458  
   459  	if len(resolver.systemServers) != 0 {
   460  		return errors.TraceNew("unexpected server count")
   461  	}
   462  
   463  	IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
   464  	if err != nil {
   465  		return errors.Trace(err)
   466  	}
   467  
   468  	if len(IPs) == 0 {
   469  		return errors.TraceNew("unexpected response")
   470  	}
   471  
   472  	// Test: ResolveAddress
   473  
   474  	networkConfig.GetDNSServers = previousGetDNSServers
   475  
   476  	// Force system servers update
   477  	networkID = "networkID-4"
   478  
   479  	domainAddress := net.JoinHostPort(exampleDomain, "443")
   480  
   481  	address, err := resolver.ResolveAddress(ctx, networkID, params, domainAddress)
   482  	if err != nil {
   483  		return errors.Trace(err)
   484  	}
   485  
   486  	host, port, err := net.SplitHostPort(address)
   487  	if err != nil {
   488  		return errors.Trace(err)
   489  	}
   490  
   491  	IP := net.ParseIP(host)
   492  
   493  	if IP == nil || (host != exampleIPv4 && host != exampleIPv6) || port != "443" {
   494  		return errors.TraceNew("unexpected response")
   495  	}
   496  
   497  	// Test: protocol transform
   498  
   499  	if transformOkServer.getRequestCount() != 0 {
   500  		return errors.TraceNew("unexpected transform server request count")
   501  	}
   502  
   503  	resolver.cache.Flush()
   504  
   505  	params.AlternateDNSServer = transformOkServer.getAddr()
   506  	params.PreferAlternateDNSServer = true
   507  
   508  	seed, err := prng.NewSeed()
   509  	if err != nil {
   510  		return errors.Trace(err)
   511  	}
   512  
   513  	params.ProtocolTransformName = "exampleTransform"
   514  	params.ProtocolTransformSpec = exampleTransform
   515  	params.ProtocolTransformSeed = seed
   516  
   517  	IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
   518  	if err != nil {
   519  		return errors.Trace(err)
   520  	}
   521  
   522  	err = checkResult(IPs)
   523  	if err != nil {
   524  		return errors.Trace(err)
   525  	}
   526  
   527  	if transformOkServer.getRequestCount() < 1 {
   528  		return errors.TraceNew("unexpected transform server request count")
   529  	}
   530  
   531  	params.AlternateDNSServer = ""
   532  	params.PreferAlternateDNSServer = false
   533  	params.ProtocolTransformName = ""
   534  	params.ProtocolTransformSpec = nil
   535  	params.ProtocolTransformSeed = nil
   536  
   537  	// Test: EDNS(0)
   538  
   539  	resolver.cache.Flush()
   540  
   541  	params.IncludeEDNS0 = true
   542  
   543  	IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
   544  	if err != nil {
   545  		return errors.Trace(err)
   546  	}
   547  
   548  	err = checkResult(IPs)
   549  	if err != nil {
   550  		return errors.Trace(err)
   551  	}
   552  
   553  	params.IncludeEDNS0 = false
   554  
   555  	// Test: input IP address
   556  
   557  	beforeMetrics = resolver.metrics
   558  
   559  	resolver.cache.Flush()
   560  
   561  	IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleIPv4)
   562  	if err != nil {
   563  		return errors.Trace(err)
   564  	}
   565  
   566  	if len(IPs) != 1 || IPs[0].String() != exampleIPv4 {
   567  		return errors.TraceNew("unexpected IPv4 response")
   568  	}
   569  
   570  	if resolver.metrics.resolves != beforeMetrics.resolves {
   571  		return errors.Tracef("unexpected metrics: %+v", resolver.metrics)
   572  	}
   573  
   574  	// Test: DNS cache extension
   575  
   576  	resolver.cache.Flush()
   577  
   578  	networkConfig.CacheExtensionInitialTTL = (exampleTTLSeconds * 2) * time.Second
   579  	networkConfig.CacheExtensionVerifiedTTL = 2 * time.Hour
   580  
   581  	now := time.Now()
   582  
   583  	IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
   584  	if err != nil {
   585  		return errors.Trace(err)
   586  	}
   587  
   588  	entry, expiry, ok := resolver.cache.GetWithExpiration(exampleDomain)
   589  	if !ok ||
   590  		!reflect.DeepEqual(entry, IPs) ||
   591  		expiry.Before(now.Add(networkConfig.CacheExtensionInitialTTL)) ||
   592  		expiry.After(now.Add(networkConfig.CacheExtensionVerifiedTTL)) {
   593  		return errors.TraceNew("unexpected CacheExtensionInitialTTL state")
   594  	}
   595  
   596  	resolver.VerifyCacheExtension(exampleDomain)
   597  
   598  	entry, expiry, ok = resolver.cache.GetWithExpiration(exampleDomain)
   599  	if !ok ||
   600  		!reflect.DeepEqual(entry, IPs) ||
   601  		expiry.Before(now.Add(networkConfig.CacheExtensionVerifiedTTL)) {
   602  		return errors.TraceNew("unexpected CacheExtensionInitialTTL state")
   603  	}
   604  
   605  	// Set cache flush condition, which should be ignored
   606  	networkID = "networkID-5"
   607  
   608  	resolver.updateNetworkState(networkID)
   609  
   610  	entry, expiry, ok = resolver.cache.GetWithExpiration(exampleDomain)
   611  	if !ok ||
   612  		!reflect.DeepEqual(entry, IPs) ||
   613  		expiry.Before(now.Add(networkConfig.CacheExtensionVerifiedTTL)) {
   614  		return errors.TraceNew("unexpected CacheExtensionInitialTTL state")
   615  	}
   616  
   617  	// Test: cancel context
   618  
   619  	resolver.cache.Flush()
   620  
   621  	cancelFunc()
   622  
   623  	IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
   624  	if err == nil {
   625  		return errors.TraceNew("unexpected success")
   626  	}
   627  
   628  	// Test: cancel context while resolving
   629  
   630  	// This test exercises the additional answers and await cases in
   631  	// ResolveIP. The test is timing dependent, and so imperfect, but this
   632  	// configuration can reproduce panics in those cases before bugs were
   633  	// fixed, where DNS responses need to be received just as the context is
   634  	// cancelled.
   635  
   636  	networkConfig.GetDNSServers = func() []string { return []string{okServer.getAddr()} }
   637  	networkID = "networkID-6"
   638  
   639  	for i := 0; i < 500; i++ {
   640  		resolver.cache.Flush()
   641  
   642  		ctx, cancelFunc := context.WithTimeout(
   643  			context.Background(), time.Duration((i%10+1)*20)*time.Microsecond)
   644  		defer cancelFunc()
   645  
   646  		_, _ = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
   647  	}
   648  
   649  	return nil
   650  }
   651  
   652  func runTestPublicDNSServers() ([]net.IP, string, error) {
   653  
   654  	networkConfig := &NetworkConfig{
   655  		GetDNSServers: getPublicDNSServers,
   656  	}
   657  
   658  	networkID := "networkID-1"
   659  
   660  	resolver := NewResolver(networkConfig, networkID)
   661  	defer resolver.Stop()
   662  
   663  	params := &ResolveParameters{
   664  		AttemptsPerServer: 1,
   665  		RequestTimeout:    5 * time.Second,
   666  		AwaitTimeout:      1 * time.Second,
   667  		IncludeEDNS0:      true,
   668  	}
   669  
   670  	IPs, err := resolver.ResolveIP(
   671  		context.Background(), networkID, params, exampleDomain)
   672  	if err != nil {
   673  		return nil, "", errors.Trace(err)
   674  	}
   675  
   676  	gotIPv4 := false
   677  	gotIPv6 := false
   678  	for _, IP := range IPs {
   679  		if IP.To4() != nil {
   680  			gotIPv4 = true
   681  		} else {
   682  			gotIPv6 = true
   683  		}
   684  	}
   685  	if !gotIPv4 {
   686  		return nil, "", errors.TraceNew("missing IPv4 response")
   687  	}
   688  	if !gotIPv6 && resolver.hasIPv6Route {
   689  		return nil, "", errors.TraceNew("missing IPv6 response")
   690  	}
   691  
   692  	return IPs, resolver.GetMetrics(), nil
   693  }
   694  
   695  func getPublicDNSServers() []string {
   696  	servers := []string{"1.1.1.1", "8.8.8.8", "9.9.9.9"}
   697  	shuffledServers := make([]string, len(servers))
   698  	for i, j := range prng.Perm(len(servers)) {
   699  		shuffledServers[i] = servers[j]
   700  	}
   701  	return shuffledServers
   702  }
   703  
   704  const (
   705  	exampleDomain     = "example.com"
   706  	exampleIPv4       = "93.184.216.34"
   707  	exampleIPv4CIDR   = "93.184.216.0/24"
   708  	exampleIPv6       = "2606:2800:220:1:248:1893:25c8:1946"
   709  	exampleIPv6CIDR   = "2606:2800:220::/48"
   710  	exampleTTLSeconds = 60
   711  )
   712  
   713  // Set the reserved Z flag
   714  var exampleTransform = transforms.Spec{[2]string{"^([a-f0-9]{4})0100", "\\$\\{1\\}0140"}}
   715  
   716  type testDNSServer struct {
   717  	respond         bool
   718  	validResponse   bool
   719  	expectTransform bool
   720  	addr            string
   721  	requestCount    int32
   722  	server          *dns.Server
   723  }
   724  
   725  func newTestDNSServer(respond, validResponse, expectTransform bool) (*testDNSServer, error) {
   726  
   727  	udpAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
   728  	if err != nil {
   729  		return nil, errors.Trace(err)
   730  	}
   731  
   732  	udpConn, err := net.ListenUDP("udp", udpAddr)
   733  	if err != nil {
   734  		return nil, errors.Trace(err)
   735  	}
   736  
   737  	s := &testDNSServer{
   738  		respond:         respond,
   739  		validResponse:   validResponse,
   740  		expectTransform: expectTransform,
   741  		addr:            udpConn.LocalAddr().String(),
   742  	}
   743  
   744  	server := &dns.Server{
   745  		PacketConn: udpConn,
   746  		Handler:    s,
   747  	}
   748  
   749  	s.server = server
   750  
   751  	go server.ActivateAndServe()
   752  
   753  	return s, nil
   754  }
   755  
   756  func (s *testDNSServer) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
   757  	atomic.AddInt32(&s.requestCount, 1)
   758  
   759  	if !s.respond {
   760  		return
   761  	}
   762  
   763  	// Check the reserved Z flag
   764  	if s.expectTransform != r.MsgHdr.Zero {
   765  		return
   766  	}
   767  
   768  	if len(r.Question) != 1 || r.Question[0].Name != dns.Fqdn(exampleDomain) {
   769  		return
   770  	}
   771  
   772  	m := new(dns.Msg)
   773  	m.SetReply(r)
   774  	m.Answer = make([]dns.RR, 1)
   775  	if r.Question[0].Qtype == dns.TypeA {
   776  		IP := net.ParseIP(exampleIPv4)
   777  		if !s.validResponse {
   778  			IP = net.ParseIP("127.0.0.1")
   779  		}
   780  		m.Answer[0] = &dns.A{
   781  			Hdr: dns.RR_Header{
   782  				Name:   r.Question[0].Name,
   783  				Rrtype: dns.TypeA,
   784  				Class:  dns.ClassINET,
   785  				Ttl:    exampleTTLSeconds},
   786  			A: IP,
   787  		}
   788  	} else {
   789  		IP := net.ParseIP(exampleIPv6)
   790  		if !s.validResponse {
   791  			IP = net.ParseIP("::1")
   792  		}
   793  		m.Answer[0] = &dns.AAAA{
   794  			Hdr: dns.RR_Header{
   795  				Name:   r.Question[0].Name,
   796  				Rrtype: dns.TypeAAAA,
   797  				Class:  dns.ClassINET,
   798  				Ttl:    exampleTTLSeconds},
   799  			AAAA: IP,
   800  		}
   801  	}
   802  
   803  	w.WriteMsg(m)
   804  }
   805  
   806  func (s *testDNSServer) getAddr() string {
   807  	return s.addr
   808  }
   809  
   810  func (s *testDNSServer) getRequestCount() int {
   811  	return int(atomic.LoadInt32(&s.requestCount))
   812  }
   813  
   814  func (s *testDNSServer) stop() {
   815  	s.server.PacketConn.Close()
   816  	s.server.Shutdown()
   817  }