github.com/letsencrypt/boulder@v0.20251208.0/bdns/dns_test.go (about)

     1  package bdns
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"crypto/x509"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"log"
    11  	"net"
    12  	"net/http"
    13  	"net/netip"
    14  	"net/url"
    15  	"os"
    16  	"regexp"
    17  	"slices"
    18  	"strings"
    19  	"sync"
    20  	"testing"
    21  	"time"
    22  
    23  	"github.com/jmhodges/clock"
    24  	"github.com/miekg/dns"
    25  	"github.com/prometheus/client_golang/prometheus"
    26  
    27  	blog "github.com/letsencrypt/boulder/log"
    28  	"github.com/letsencrypt/boulder/metrics"
    29  	"github.com/letsencrypt/boulder/test"
    30  )
    31  
    32  const dnsLoopbackAddr = "127.0.0.1:4053"
    33  
    34  func mockDNSQuery(w http.ResponseWriter, httpReq *http.Request) {
    35  	if httpReq.Header.Get("Content-Type") != "application/dns-message" {
    36  		w.WriteHeader(http.StatusBadRequest)
    37  		fmt.Fprintf(w, "client didn't send Content-Type: application/dns-message")
    38  	}
    39  	if httpReq.Header.Get("Accept") != "application/dns-message" {
    40  		w.WriteHeader(http.StatusBadRequest)
    41  		fmt.Fprintf(w, "client didn't accept Content-Type: application/dns-message")
    42  	}
    43  
    44  	requestBody, err := io.ReadAll(httpReq.Body)
    45  	if err != nil {
    46  		w.WriteHeader(http.StatusBadRequest)
    47  		fmt.Fprintf(w, "reading body: %s", err)
    48  	}
    49  	httpReq.Body.Close()
    50  
    51  	r := new(dns.Msg)
    52  	err = r.Unpack(requestBody)
    53  	if err != nil {
    54  		w.WriteHeader(http.StatusBadRequest)
    55  		fmt.Fprintf(w, "unpacking request: %s", err)
    56  	}
    57  
    58  	m := new(dns.Msg)
    59  	m.SetReply(r)
    60  	m.Compress = false
    61  
    62  	appendAnswer := func(rr dns.RR) {
    63  		m.Answer = append(m.Answer, rr)
    64  	}
    65  	for _, q := range r.Question {
    66  		q.Name = strings.ToLower(q.Name)
    67  		if q.Name == "servfail.com." || q.Name == "servfailexception.example.com" {
    68  			m.Rcode = dns.RcodeServerFailure
    69  			break
    70  		}
    71  		switch q.Qtype {
    72  		case dns.TypeSOA:
    73  			record := new(dns.SOA)
    74  			record.Hdr = dns.RR_Header{Name: "letsencrypt.org.", Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: 0}
    75  			record.Ns = "ns.letsencrypt.org."
    76  			record.Mbox = "master.letsencrypt.org."
    77  			record.Serial = 1
    78  			record.Refresh = 1
    79  			record.Retry = 1
    80  			record.Expire = 1
    81  			record.Minttl = 1
    82  			appendAnswer(record)
    83  		case dns.TypeAAAA:
    84  			if q.Name == "v6.letsencrypt.org." {
    85  				record := new(dns.AAAA)
    86  				record.Hdr = dns.RR_Header{Name: "v6.letsencrypt.org.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 0}
    87  				record.AAAA = net.ParseIP("2602:80a:6000:abad:cafe::1")
    88  				appendAnswer(record)
    89  			}
    90  			if q.Name == "dualstack.letsencrypt.org." {
    91  				record := new(dns.AAAA)
    92  				record.Hdr = dns.RR_Header{Name: "dualstack.letsencrypt.org.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 0}
    93  				record.AAAA = net.ParseIP("2602:80a:6000:abad:cafe::1")
    94  				appendAnswer(record)
    95  			}
    96  			if q.Name == "v4error.letsencrypt.org." {
    97  				record := new(dns.AAAA)
    98  				record.Hdr = dns.RR_Header{Name: "v4error.letsencrypt.org.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 0}
    99  				record.AAAA = net.ParseIP("2602:80a:6000:abad:cafe::1")
   100  				appendAnswer(record)
   101  			}
   102  			if q.Name == "v6error.letsencrypt.org." {
   103  				m.SetRcode(r, dns.RcodeNotImplemented)
   104  			}
   105  			if q.Name == "nxdomain.letsencrypt.org." {
   106  				m.SetRcode(r, dns.RcodeNameError)
   107  			}
   108  			if q.Name == "dualstackerror.letsencrypt.org." {
   109  				m.SetRcode(r, dns.RcodeNotImplemented)
   110  			}
   111  		case dns.TypeA:
   112  			if q.Name == "cps.letsencrypt.org." {
   113  				record := new(dns.A)
   114  				record.Hdr = dns.RR_Header{Name: "cps.letsencrypt.org.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 0}
   115  				record.A = net.ParseIP("64.112.117.1")
   116  				appendAnswer(record)
   117  			}
   118  			if q.Name == "dualstack.letsencrypt.org." {
   119  				record := new(dns.A)
   120  				record.Hdr = dns.RR_Header{Name: "dualstack.letsencrypt.org.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 0}
   121  				record.A = net.ParseIP("64.112.117.1")
   122  				appendAnswer(record)
   123  			}
   124  			if q.Name == "v6error.letsencrypt.org." {
   125  				record := new(dns.A)
   126  				record.Hdr = dns.RR_Header{Name: "dualstack.letsencrypt.org.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 0}
   127  				record.A = net.ParseIP("64.112.117.1")
   128  				appendAnswer(record)
   129  			}
   130  			if q.Name == "v4error.letsencrypt.org." {
   131  				m.SetRcode(r, dns.RcodeNotImplemented)
   132  			}
   133  			if q.Name == "nxdomain.letsencrypt.org." {
   134  				m.SetRcode(r, dns.RcodeNameError)
   135  			}
   136  			if q.Name == "dualstackerror.letsencrypt.org." {
   137  				m.SetRcode(r, dns.RcodeRefused)
   138  			}
   139  		case dns.TypeCNAME:
   140  			if q.Name == "cname.letsencrypt.org." {
   141  				record := new(dns.CNAME)
   142  				record.Hdr = dns.RR_Header{Name: "cname.letsencrypt.org.", Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: 30}
   143  				record.Target = "cps.letsencrypt.org."
   144  				appendAnswer(record)
   145  			}
   146  			if q.Name == "cname.example.com." {
   147  				record := new(dns.CNAME)
   148  				record.Hdr = dns.RR_Header{Name: "cname.example.com.", Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: 30}
   149  				record.Target = "CAA.example.com."
   150  				appendAnswer(record)
   151  			}
   152  		case dns.TypeDNAME:
   153  			if q.Name == "dname.letsencrypt.org." {
   154  				record := new(dns.DNAME)
   155  				record.Hdr = dns.RR_Header{Name: "dname.letsencrypt.org.", Rrtype: dns.TypeDNAME, Class: dns.ClassINET, Ttl: 30}
   156  				record.Target = "cps.letsencrypt.org."
   157  				appendAnswer(record)
   158  			}
   159  		case dns.TypeCAA:
   160  			if q.Name == "bracewel.net." || q.Name == "caa.example.com." {
   161  				record := new(dns.CAA)
   162  				record.Hdr = dns.RR_Header{Name: q.Name, Rrtype: dns.TypeCAA, Class: dns.ClassINET, Ttl: 0}
   163  				record.Tag = "issue"
   164  				record.Value = "letsencrypt.org"
   165  				record.Flag = 1
   166  				appendAnswer(record)
   167  			}
   168  			if q.Name == "cname.example.com." {
   169  				record := new(dns.CAA)
   170  				record.Hdr = dns.RR_Header{Name: "caa.example.com.", Rrtype: dns.TypeCAA, Class: dns.ClassINET, Ttl: 0}
   171  				record.Tag = "issue"
   172  				record.Value = "letsencrypt.org"
   173  				record.Flag = 1
   174  				appendAnswer(record)
   175  			}
   176  			if q.Name == "gonetld." {
   177  				m.SetRcode(r, dns.RcodeNameError)
   178  			}
   179  		case dns.TypeTXT:
   180  			if q.Name == "split-txt.letsencrypt.org." {
   181  				record := new(dns.TXT)
   182  				record.Hdr = dns.RR_Header{Name: "split-txt.letsencrypt.org.", Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 0}
   183  				record.Txt = []string{"a", "b", "c"}
   184  				appendAnswer(record)
   185  			} else {
   186  				auth := new(dns.SOA)
   187  				auth.Hdr = dns.RR_Header{Name: "letsencrypt.org.", Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: 0}
   188  				auth.Ns = "ns.letsencrypt.org."
   189  				auth.Mbox = "master.letsencrypt.org."
   190  				auth.Serial = 1
   191  				auth.Refresh = 1
   192  				auth.Retry = 1
   193  				auth.Expire = 1
   194  				auth.Minttl = 1
   195  				m.Ns = append(m.Ns, auth)
   196  			}
   197  			if q.Name == "nxdomain.letsencrypt.org." {
   198  				m.SetRcode(r, dns.RcodeNameError)
   199  			}
   200  		}
   201  	}
   202  
   203  	body, err := m.Pack()
   204  	if err != nil {
   205  		fmt.Fprintf(os.Stderr, "packing reply: %s\n", err)
   206  	}
   207  	w.Header().Set("Content-Type", "application/dns-message")
   208  	_, err = w.Write(body)
   209  	if err != nil {
   210  		panic(err) // running tests, so panic is OK
   211  	}
   212  }
   213  
   214  func serveLoopResolver(stopChan chan bool) {
   215  	m := http.NewServeMux()
   216  	m.HandleFunc("/dns-query", mockDNSQuery)
   217  	httpServer := &http.Server{
   218  		Addr:         dnsLoopbackAddr,
   219  		Handler:      m,
   220  		ReadTimeout:  time.Second,
   221  		WriteTimeout: time.Second,
   222  	}
   223  	go func() {
   224  		cert := "../test/certs/ipki/localhost/cert.pem"
   225  		key := "../test/certs/ipki/localhost/key.pem"
   226  		err := httpServer.ListenAndServeTLS(cert, key)
   227  		if err != nil {
   228  			fmt.Println(err)
   229  		}
   230  	}()
   231  	go func() {
   232  		<-stopChan
   233  		err := httpServer.Shutdown(context.Background())
   234  		if err != nil {
   235  			log.Fatal(err)
   236  		}
   237  	}()
   238  }
   239  
   240  func pollServer() {
   241  	backoff := 200 * time.Millisecond
   242  	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   243  	defer cancel()
   244  	ticker := time.NewTicker(backoff)
   245  
   246  	for {
   247  		select {
   248  		case <-ctx.Done():
   249  			fmt.Fprintln(os.Stderr, "Timeout reached while testing for the dns server to come up")
   250  			os.Exit(1)
   251  		case <-ticker.C:
   252  			conn, _ := dns.DialTimeout("udp", dnsLoopbackAddr, backoff)
   253  			if conn != nil {
   254  				_ = conn.Close()
   255  				return
   256  			}
   257  		}
   258  	}
   259  }
   260  
   261  // tlsConfig is used for the TLS config of client instances that talk to the
   262  // DoH server set up in TestMain.
   263  var tlsConfig *tls.Config
   264  
   265  func TestMain(m *testing.M) {
   266  	root, err := os.ReadFile("../test/certs/ipki/minica.pem")
   267  	if err != nil {
   268  		log.Fatal(err)
   269  	}
   270  	pool := x509.NewCertPool()
   271  	pool.AppendCertsFromPEM(root)
   272  	tlsConfig = &tls.Config{
   273  		RootCAs: pool,
   274  	}
   275  
   276  	stop := make(chan bool, 1)
   277  	serveLoopResolver(stop)
   278  	pollServer()
   279  	ret := m.Run()
   280  	stop <- true
   281  	os.Exit(ret)
   282  }
   283  
   284  func TestDNSNoServers(t *testing.T) {
   285  	staticProvider, err := NewStaticProvider([]string{})
   286  	test.AssertNotError(t, err, "Got error creating StaticProvider")
   287  
   288  	obj := New(time.Hour, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, "", blog.UseMock(), tlsConfig)
   289  
   290  	_, resolvers, err := obj.LookupHost(context.Background(), "letsencrypt.org")
   291  	test.AssertEquals(t, len(resolvers), 0)
   292  	test.AssertError(t, err, "No servers")
   293  
   294  	_, _, err = obj.LookupTXT(context.Background(), "letsencrypt.org")
   295  	test.AssertError(t, err, "No servers")
   296  
   297  	_, _, _, err = obj.LookupCAA(context.Background(), "letsencrypt.org")
   298  	test.AssertError(t, err, "No servers")
   299  }
   300  
   301  func TestDNSOneServer(t *testing.T) {
   302  	staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr})
   303  	test.AssertNotError(t, err, "Got error creating StaticProvider")
   304  
   305  	obj := New(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, "", blog.UseMock(), tlsConfig)
   306  
   307  	_, resolvers, err := obj.LookupHost(context.Background(), "cps.letsencrypt.org")
   308  	test.AssertEquals(t, len(resolvers), 2)
   309  	slices.Sort(resolvers)
   310  	test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"})
   311  	test.AssertNotError(t, err, "No message")
   312  }
   313  
   314  func TestDNSDuplicateServers(t *testing.T) {
   315  	staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr, dnsLoopbackAddr})
   316  	test.AssertNotError(t, err, "Got error creating StaticProvider")
   317  
   318  	obj := New(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, "", blog.UseMock(), tlsConfig)
   319  
   320  	_, resolvers, err := obj.LookupHost(context.Background(), "cps.letsencrypt.org")
   321  	test.AssertEquals(t, len(resolvers), 2)
   322  	slices.Sort(resolvers)
   323  	test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"})
   324  	test.AssertNotError(t, err, "No message")
   325  }
   326  
   327  func TestDNSServFail(t *testing.T) {
   328  	staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr})
   329  	test.AssertNotError(t, err, "Got error creating StaticProvider")
   330  
   331  	obj := New(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, "", blog.UseMock(), tlsConfig)
   332  	bad := "servfail.com"
   333  
   334  	_, _, err = obj.LookupTXT(context.Background(), bad)
   335  	test.AssertError(t, err, "LookupTXT didn't return an error")
   336  
   337  	_, _, err = obj.LookupHost(context.Background(), bad)
   338  	test.AssertError(t, err, "LookupHost didn't return an error")
   339  
   340  	emptyCaa, _, _, err := obj.LookupCAA(context.Background(), bad)
   341  	test.Assert(t, len(emptyCaa) == 0, "Query returned non-empty list of CAA records")
   342  	test.AssertError(t, err, "LookupCAA should have returned an error")
   343  }
   344  
   345  func TestDNSLookupTXT(t *testing.T) {
   346  	staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr})
   347  	test.AssertNotError(t, err, "Got error creating StaticProvider")
   348  
   349  	obj := New(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, "", blog.UseMock(), tlsConfig)
   350  
   351  	a, _, err := obj.LookupTXT(context.Background(), "letsencrypt.org")
   352  	t.Logf("A: %v", a)
   353  	test.AssertNotError(t, err, "No message")
   354  
   355  	a, _, err = obj.LookupTXT(context.Background(), "split-txt.letsencrypt.org")
   356  	t.Logf("A: %v ", a)
   357  	test.AssertNotError(t, err, "No message")
   358  	test.AssertEquals(t, len(a), 1)
   359  	test.AssertEquals(t, a[0], "abc")
   360  }
   361  
   362  // TODO(#8213): Convert this to a table test.
   363  func TestDNSLookupHost(t *testing.T) {
   364  	staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr})
   365  	test.AssertNotError(t, err, "Got error creating StaticProvider")
   366  
   367  	obj := New(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, "", blog.UseMock(), tlsConfig)
   368  
   369  	ip, resolvers, err := obj.LookupHost(context.Background(), "servfail.com")
   370  	t.Logf("servfail.com - IP: %s, Err: %s", ip, err)
   371  	test.AssertError(t, err, "Server failure")
   372  	test.Assert(t, len(ip) == 0, "Should not have IPs")
   373  	slices.Sort(resolvers)
   374  	test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"})
   375  
   376  	ip, resolvers, err = obj.LookupHost(context.Background(), "nonexistent.letsencrypt.org")
   377  	t.Logf("nonexistent.letsencrypt.org - IP: %s, Err: %s", ip, err)
   378  	test.AssertError(t, err, "No valid A or AAAA records should error")
   379  	test.Assert(t, len(ip) == 0, "Should not have IPs")
   380  	slices.Sort(resolvers)
   381  	test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"})
   382  
   383  	// Single IPv4 address
   384  	ip, resolvers, err = obj.LookupHost(context.Background(), "cps.letsencrypt.org")
   385  	t.Logf("cps.letsencrypt.org - IP: %s, Err: %s", ip, err)
   386  	test.AssertNotError(t, err, "Not an error to exist")
   387  	test.Assert(t, len(ip) == 1, "Should have IP")
   388  	slices.Sort(resolvers)
   389  	test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"})
   390  	ip, resolvers, err = obj.LookupHost(context.Background(), "cps.letsencrypt.org")
   391  	t.Logf("cps.letsencrypt.org - IP: %s, Err: %s", ip, err)
   392  	test.AssertNotError(t, err, "Not an error to exist")
   393  	test.Assert(t, len(ip) == 1, "Should have IP")
   394  	slices.Sort(resolvers)
   395  	test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"})
   396  
   397  	// Single IPv6 address
   398  	ip, resolvers, err = obj.LookupHost(context.Background(), "v6.letsencrypt.org")
   399  	t.Logf("v6.letsencrypt.org - IP: %s, Err: %s", ip, err)
   400  	test.AssertNotError(t, err, "Not an error to exist")
   401  	test.Assert(t, len(ip) == 1, "Should not have IPs")
   402  	slices.Sort(resolvers)
   403  	test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"})
   404  
   405  	// Both IPv6 and IPv4 address
   406  	ip, resolvers, err = obj.LookupHost(context.Background(), "dualstack.letsencrypt.org")
   407  	t.Logf("dualstack.letsencrypt.org - IP: %s, Err: %s", ip, err)
   408  	test.AssertNotError(t, err, "Not an error to exist")
   409  	test.Assert(t, len(ip) == 2, "Should have 2 IPs")
   410  	expected := netip.MustParseAddr("64.112.117.1")
   411  	test.Assert(t, ip[0] == expected, "wrong ipv4 address")
   412  	expected = netip.MustParseAddr("2602:80a:6000:abad:cafe::1")
   413  	test.Assert(t, ip[1] == expected, "wrong ipv6 address")
   414  	slices.Sort(resolvers)
   415  	test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"})
   416  
   417  	// IPv6 error, IPv4 success
   418  	ip, resolvers, err = obj.LookupHost(context.Background(), "v6error.letsencrypt.org")
   419  	t.Logf("v6error.letsencrypt.org - IP: %s, Err: %s", ip, err)
   420  	test.AssertNotError(t, err, "Not an error to exist")
   421  	test.Assert(t, len(ip) == 1, "Should have 1 IP")
   422  	expected = netip.MustParseAddr("64.112.117.1")
   423  	test.Assert(t, ip[0] == expected, "wrong ipv4 address")
   424  	slices.Sort(resolvers)
   425  	test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"})
   426  
   427  	// IPv6 success, IPv4 error
   428  	ip, resolvers, err = obj.LookupHost(context.Background(), "v4error.letsencrypt.org")
   429  	t.Logf("v4error.letsencrypt.org - IP: %s, Err: %s", ip, err)
   430  	test.AssertNotError(t, err, "Not an error to exist")
   431  	test.Assert(t, len(ip) == 1, "Should have 1 IP")
   432  	expected = netip.MustParseAddr("2602:80a:6000:abad:cafe::1")
   433  	test.Assert(t, ip[0] == expected, "wrong ipv6 address")
   434  	slices.Sort(resolvers)
   435  	test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"})
   436  
   437  	// IPv6 error, IPv4 error
   438  	// Should return both the IPv4 error (Refused) and the IPv6 error (NotImplemented)
   439  	hostname := "dualstackerror.letsencrypt.org"
   440  	ip, resolvers, err = obj.LookupHost(context.Background(), hostname)
   441  	t.Logf("%s - IP: %s, Err: %s", hostname, ip, err)
   442  	test.AssertError(t, err, "Should be an error")
   443  	test.AssertContains(t, err.Error(), "REFUSED looking up A for")
   444  	test.AssertContains(t, err.Error(), "NOTIMP looking up AAAA for")
   445  	slices.Sort(resolvers)
   446  	test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"})
   447  }
   448  
   449  func TestDNSNXDOMAIN(t *testing.T) {
   450  	staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr})
   451  	test.AssertNotError(t, err, "Got error creating StaticProvider")
   452  
   453  	obj := New(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, "", blog.UseMock(), tlsConfig)
   454  
   455  	hostname := "nxdomain.letsencrypt.org"
   456  	_, _, err = obj.LookupHost(context.Background(), hostname)
   457  	test.AssertContains(t, err.Error(), "NXDOMAIN looking up A for")
   458  	test.AssertContains(t, err.Error(), "NXDOMAIN looking up AAAA for")
   459  
   460  	_, _, err = obj.LookupTXT(context.Background(), hostname)
   461  	expected := Error{dns.TypeTXT, hostname, nil, dns.RcodeNameError, nil}
   462  	test.AssertDeepEquals(t, err, expected)
   463  }
   464  
   465  func TestDNSLookupCAA(t *testing.T) {
   466  	staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr})
   467  	test.AssertNotError(t, err, "Got error creating StaticProvider")
   468  
   469  	obj := New(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, "", blog.UseMock(), tlsConfig)
   470  	removeIDExp := regexp.MustCompile(" id: [[:digit:]]+")
   471  
   472  	caas, resp, resolvers, err := obj.LookupCAA(context.Background(), "bracewel.net")
   473  	test.AssertNotError(t, err, "CAA lookup failed")
   474  	test.Assert(t, len(caas) > 0, "Should have CAA records")
   475  	test.AssertEquals(t, len(resolvers), 1)
   476  	test.AssertDeepEquals(t, resolvers, ResolverAddrs{"127.0.0.1:4053"})
   477  	expectedResp := `;; opcode: QUERY, status: NOERROR, id: XXXX
   478  ;; flags: qr rd; QUERY: 1, ANSWER: 1, AUTHORITY: 0, ADDITIONAL: 0
   479  
   480  ;; QUESTION SECTION:
   481  ;bracewel.net.	IN	 CAA
   482  
   483  ;; ANSWER SECTION:
   484  bracewel.net.	0	IN	CAA	1 issue "letsencrypt.org"
   485  `
   486  	test.AssertEquals(t, removeIDExp.ReplaceAllString(resp, " id: XXXX"), expectedResp)
   487  
   488  	caas, resp, resolvers, err = obj.LookupCAA(context.Background(), "nonexistent.letsencrypt.org")
   489  	test.AssertNotError(t, err, "CAA lookup failed")
   490  	test.Assert(t, len(caas) == 0, "Shouldn't have CAA records")
   491  	test.AssertEquals(t, resolvers[0], "127.0.0.1:4053")
   492  	expectedResp = ""
   493  	test.AssertEquals(t, resp, expectedResp)
   494  
   495  	caas, resp, resolvers, err = obj.LookupCAA(context.Background(), "nxdomain.letsencrypt.org")
   496  	slices.Sort(resolvers)
   497  	test.AssertNotError(t, err, "CAA lookup failed")
   498  	test.Assert(t, len(caas) == 0, "Shouldn't have CAA records")
   499  	test.AssertEquals(t, resolvers[0], "127.0.0.1:4053")
   500  	expectedResp = ""
   501  	test.AssertEquals(t, resp, expectedResp)
   502  
   503  	caas, resp, resolvers, err = obj.LookupCAA(context.Background(), "cname.example.com")
   504  	test.AssertNotError(t, err, "CAA lookup failed")
   505  	test.Assert(t, len(caas) > 0, "Should follow CNAME to find CAA")
   506  	test.AssertEquals(t, resolvers[0], "127.0.0.1:4053")
   507  	expectedResp = `;; opcode: QUERY, status: NOERROR, id: XXXX
   508  ;; flags: qr rd; QUERY: 1, ANSWER: 1, AUTHORITY: 0, ADDITIONAL: 0
   509  
   510  ;; QUESTION SECTION:
   511  ;cname.example.com.	IN	 CAA
   512  
   513  ;; ANSWER SECTION:
   514  caa.example.com.	0	IN	CAA	1 issue "letsencrypt.org"
   515  `
   516  	test.AssertEquals(t, removeIDExp.ReplaceAllString(resp, " id: XXXX"), expectedResp)
   517  
   518  	_, _, resolvers, err = obj.LookupCAA(context.Background(), "gonetld")
   519  	test.AssertError(t, err, "should fail for TLD NXDOMAIN")
   520  	test.AssertContains(t, err.Error(), "NXDOMAIN")
   521  	test.AssertEquals(t, resolvers[0], "127.0.0.1:4053")
   522  }
   523  
   524  type testExchanger struct {
   525  	sync.Mutex
   526  	count int
   527  	errs  []error
   528  }
   529  
   530  var errTooManyRequests = errors.New("too many requests")
   531  
   532  func (te *testExchanger) Exchange(m *dns.Msg, a string) (*dns.Msg, time.Duration, error) {
   533  	te.Lock()
   534  	defer te.Unlock()
   535  	msg := &dns.Msg{
   536  		MsgHdr: dns.MsgHdr{Rcode: dns.RcodeSuccess},
   537  	}
   538  	if len(te.errs) <= te.count {
   539  		return nil, 0, errTooManyRequests
   540  	}
   541  	err := te.errs[te.count]
   542  	te.count++
   543  
   544  	return msg, 2 * time.Millisecond, err
   545  }
   546  
   547  func TestRetry(t *testing.T) {
   548  	isTimeoutErr := &url.Error{Op: "read", Err: testTimeoutError(true)}
   549  	nonTimeoutErr := &url.Error{Op: "read", Err: testTimeoutError(false)}
   550  	servFailError := errors.New("DNS problem: server failure at resolver looking up TXT for example.com")
   551  	timeoutFailError := errors.New("DNS problem: query timed out looking up TXT for example.com")
   552  	type testCase struct {
   553  		name              string
   554  		maxTries          int
   555  		te                *testExchanger
   556  		expected          error
   557  		expectedCount     int
   558  		metricsAllRetries float64
   559  	}
   560  	tests := []*testCase{
   561  		// The success on first try case
   562  		{
   563  			name:     "success",
   564  			maxTries: 3,
   565  			te: &testExchanger{
   566  				errs: []error{nil},
   567  			},
   568  			expected:      nil,
   569  			expectedCount: 1,
   570  		},
   571  		// Immediate non-OpError, error returns immediately
   572  		{
   573  			name:     "non-operror",
   574  			maxTries: 3,
   575  			te: &testExchanger{
   576  				errs: []error{errors.New("nope")},
   577  			},
   578  			expected:      servFailError,
   579  			expectedCount: 1,
   580  		},
   581  		// Timeout err, then non-OpError stops at two tries
   582  		{
   583  			name:     "err-then-non-operror",
   584  			maxTries: 3,
   585  			te: &testExchanger{
   586  				errs: []error{isTimeoutErr, errors.New("nope")},
   587  			},
   588  			expected:      servFailError,
   589  			expectedCount: 2,
   590  		},
   591  		// Timeout error given always
   592  		{
   593  			name:     "persistent-timeout-error",
   594  			maxTries: 3,
   595  			te: &testExchanger{
   596  				errs: []error{
   597  					isTimeoutErr,
   598  					isTimeoutErr,
   599  					isTimeoutErr,
   600  				},
   601  			},
   602  			expected:          timeoutFailError,
   603  			expectedCount:     3,
   604  			metricsAllRetries: 1,
   605  		},
   606  		// Even with maxTries at 0, we should still let a single request go
   607  		// through
   608  		{
   609  			name:     "zero-maxtries",
   610  			maxTries: 0,
   611  			te: &testExchanger{
   612  				errs: []error{nil},
   613  			},
   614  			expected:      nil,
   615  			expectedCount: 1,
   616  		},
   617  		// Timeout error given just once causes two tries
   618  		{
   619  			name:     "single-timeout-error",
   620  			maxTries: 3,
   621  			te: &testExchanger{
   622  				errs: []error{
   623  					isTimeoutErr,
   624  					nil,
   625  				},
   626  			},
   627  			expected:      nil,
   628  			expectedCount: 2,
   629  		},
   630  		// Timeout error given twice causes three tries
   631  		{
   632  			name:     "double-timeout-error",
   633  			maxTries: 3,
   634  			te: &testExchanger{
   635  				errs: []error{
   636  					isTimeoutErr,
   637  					isTimeoutErr,
   638  					nil,
   639  				},
   640  			},
   641  			expected:      nil,
   642  			expectedCount: 3,
   643  		},
   644  		// Timeout error given thrice causes three tries and fails
   645  		{
   646  			name:     "triple-timeout-error",
   647  			maxTries: 3,
   648  			te: &testExchanger{
   649  				errs: []error{
   650  					isTimeoutErr,
   651  					isTimeoutErr,
   652  					isTimeoutErr,
   653  				},
   654  			},
   655  			expected:          timeoutFailError,
   656  			expectedCount:     3,
   657  			metricsAllRetries: 1,
   658  		},
   659  		// timeout then non-timeout error causes two retries
   660  		{
   661  			name:     "timeout-nontimeout-error",
   662  			maxTries: 3,
   663  			te: &testExchanger{
   664  				errs: []error{
   665  					isTimeoutErr,
   666  					nonTimeoutErr,
   667  				},
   668  			},
   669  			expected:      servFailError,
   670  			expectedCount: 2,
   671  		},
   672  	}
   673  
   674  	for i, tc := range tests {
   675  		t.Run(tc.name, func(t *testing.T) {
   676  			staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr})
   677  			test.AssertNotError(t, err, "Got error creating StaticProvider")
   678  
   679  			testClient := New(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), tc.maxTries, "", blog.UseMock(), tlsConfig)
   680  			dr := testClient.(*impl)
   681  			dr.dnsClient = tc.te
   682  			_, _, err = dr.LookupTXT(context.Background(), "example.com")
   683  			if err == errTooManyRequests {
   684  				t.Errorf("#%d, sent more requests than the test case handles", i)
   685  			}
   686  			expectedErr := tc.expected
   687  			if (expectedErr == nil && err != nil) ||
   688  				(expectedErr != nil && err == nil) ||
   689  				(expectedErr != nil && expectedErr.Error() != err.Error()) {
   690  				t.Errorf("#%d, error, expected %v, got %v", i, expectedErr, err)
   691  			}
   692  			if tc.expectedCount != tc.te.count {
   693  				t.Errorf("#%d, error, expectedCount %v, got %v", i, tc.expectedCount, tc.te.count)
   694  			}
   695  			if tc.metricsAllRetries > 0 {
   696  				test.AssertMetricWithLabelsEquals(
   697  					t, dr.timeoutCounter, prometheus.Labels{
   698  						"qtype":    "TXT",
   699  						"type":     "out of retries",
   700  						"resolver": "127.0.0.1",
   701  						"isTLD":    "false",
   702  					}, tc.metricsAllRetries)
   703  			}
   704  		})
   705  	}
   706  
   707  	staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr})
   708  	test.AssertNotError(t, err, "Got error creating StaticProvider")
   709  
   710  	testClient := New(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 3, "", blog.UseMock(), tlsConfig)
   711  	dr := testClient.(*impl)
   712  	dr.dnsClient = &testExchanger{errs: []error{isTimeoutErr, isTimeoutErr, nil}}
   713  	ctx, cancel := context.WithCancel(context.Background())
   714  	cancel()
   715  	_, _, err = dr.LookupTXT(ctx, "example.com")
   716  	if err == nil ||
   717  		err.Error() != "DNS problem: query timed out (and was canceled) looking up TXT for example.com" {
   718  		t.Errorf("expected %s, got %s", context.Canceled, err)
   719  	}
   720  
   721  	dr.dnsClient = &testExchanger{errs: []error{isTimeoutErr, isTimeoutErr, nil}}
   722  	ctx, cancel = context.WithTimeout(context.Background(), -10*time.Hour)
   723  	defer cancel()
   724  	_, _, err = dr.LookupTXT(ctx, "example.com")
   725  	if err == nil ||
   726  		err.Error() != "DNS problem: query timed out looking up TXT for example.com" {
   727  		t.Errorf("expected %s, got %s", context.DeadlineExceeded, err)
   728  	}
   729  
   730  	dr.dnsClient = &testExchanger{errs: []error{isTimeoutErr, isTimeoutErr, nil}}
   731  	ctx, deadlineCancel := context.WithTimeout(context.Background(), -10*time.Hour)
   732  	deadlineCancel()
   733  	_, _, err = dr.LookupTXT(ctx, "example.com")
   734  	if err == nil ||
   735  		err.Error() != "DNS problem: query timed out looking up TXT for example.com" {
   736  		t.Errorf("expected %s, got %s", context.DeadlineExceeded, err)
   737  	}
   738  
   739  	test.AssertMetricWithLabelsEquals(
   740  		t, dr.timeoutCounter, prometheus.Labels{
   741  			"qtype":    "TXT",
   742  			"type":     "canceled",
   743  			"resolver": "127.0.0.1",
   744  		}, 1)
   745  
   746  	test.AssertMetricWithLabelsEquals(
   747  		t, dr.timeoutCounter, prometheus.Labels{
   748  			"qtype":    "TXT",
   749  			"type":     "deadline exceeded",
   750  			"resolver": "127.0.0.1",
   751  		}, 2)
   752  }
   753  
   754  func TestIsTLD(t *testing.T) {
   755  	if isTLD("com") != "true" {
   756  		t.Errorf("expected 'com' to be a TLD, got %q", isTLD("com"))
   757  	}
   758  	if isTLD("example.com") != "false" {
   759  		t.Errorf("expected 'example.com' to not a TLD, got %q", isTLD("example.com"))
   760  	}
   761  }
   762  
   763  type testTimeoutError bool
   764  
   765  func (t testTimeoutError) Timeout() bool { return bool(t) }
   766  func (t testTimeoutError) Error() string { return fmt.Sprintf("Timeout: %t", t) }
   767  
   768  // rotateFailureExchanger is a dns.Exchange implementation that tracks a count
   769  // of the number of calls to `Exchange` for a given address in the `lookups`
   770  // map. For all addresses in the `brokenAddresses` map, a retryable error is
   771  // returned from `Exchange`. This mock is used by `TestRotateServerOnErr`.
   772  type rotateFailureExchanger struct {
   773  	sync.Mutex
   774  	lookups         map[string]int
   775  	brokenAddresses map[string]bool
   776  }
   777  
   778  // Exchange for rotateFailureExchanger tracks the `a` argument in `lookups` and
   779  // if present in `brokenAddresses`, returns a timeout error.
   780  func (e *rotateFailureExchanger) Exchange(m *dns.Msg, a string) (*dns.Msg, time.Duration, error) {
   781  	e.Lock()
   782  	defer e.Unlock()
   783  
   784  	// Track that exchange was called for the given server
   785  	e.lookups[a]++
   786  
   787  	// If its a broken server, return a retryable error
   788  	if e.brokenAddresses[a] {
   789  		isTimeoutErr := &url.Error{Op: "read", Err: testTimeoutError(true)}
   790  		return nil, 2 * time.Millisecond, isTimeoutErr
   791  	}
   792  
   793  	return m, 2 * time.Millisecond, nil
   794  }
   795  
   796  // TestRotateServerOnErr ensures that a retryable error returned from a DNS
   797  // server will result in the retry being performed against the next server in
   798  // the list.
   799  func TestRotateServerOnErr(t *testing.T) {
   800  	// Configure three DNS servers
   801  	dnsServers := []string{
   802  		"a:53", "b:53", "[2606:4700:4700::1111]:53",
   803  	}
   804  
   805  	// Set up a DNS client using these servers that will retry queries up to
   806  	// a maximum of 5 times. It's important to choose a maxTries value >= the
   807  	// number of dnsServers to ensure we always get around to trying the one
   808  	// working server
   809  	staticProvider, err := NewStaticProvider(dnsServers)
   810  	test.AssertNotError(t, err, "Got error creating StaticProvider")
   811  
   812  	maxTries := 5
   813  	client := New(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), maxTries, "", blog.UseMock(), tlsConfig)
   814  
   815  	// Configure a mock exchanger that will always return a retryable error for
   816  	// servers A and B. This will force server "[2606:4700:4700::1111]:53" to do
   817  	// all the work once retries reach it.
   818  	mock := &rotateFailureExchanger{
   819  		brokenAddresses: map[string]bool{
   820  			"a:53": true,
   821  			"b:53": true,
   822  		},
   823  		lookups: make(map[string]int),
   824  	}
   825  	client.(*impl).dnsClient = mock
   826  
   827  	// Perform a bunch of lookups. We choose the initial server randomly. Any time
   828  	// A or B is chosen there should be an error and a retry using the next server
   829  	// in the list. Since we configured maxTries to be larger than the number of
   830  	// servers *all* queries should eventually succeed by being retried against
   831  	// server "[2606:4700:4700::1111]:53".
   832  	for range maxTries * 2 {
   833  		_, resolvers, err := client.LookupTXT(context.Background(), "example.com")
   834  		test.AssertEquals(t, len(resolvers), 1)
   835  		test.AssertEquals(t, resolvers[0], "[2606:4700:4700::1111]:53")
   836  		// Any errors are unexpected - server "[2606:4700:4700::1111]:53" should
   837  		// have responded without error.
   838  		test.AssertNotError(t, err, "Expected no error from eventual retry with functional server")
   839  	}
   840  
   841  	// We expect that the A and B servers had a non-zero number of lookups
   842  	// attempted.
   843  	test.Assert(t, mock.lookups["a:53"] > 0, "Expected A server to have non-zero lookup attempts")
   844  	test.Assert(t, mock.lookups["b:53"] > 0, "Expected B server to have non-zero lookup attempts")
   845  
   846  	// We expect that the server "[2606:4700:4700::1111]:53" eventually served
   847  	// all of the lookups attempted.
   848  	test.AssertEquals(t, mock.lookups["[2606:4700:4700::1111]:53"], maxTries*2)
   849  
   850  }
   851  
   852  type mockTimeoutURLError struct{}
   853  
   854  func (m *mockTimeoutURLError) Error() string { return "whoops, oh gosh" }
   855  func (m *mockTimeoutURLError) Timeout() bool { return true }
   856  
   857  type dohAlwaysRetryExchanger struct {
   858  	sync.Mutex
   859  	err error
   860  }
   861  
   862  func (dohE *dohAlwaysRetryExchanger) Exchange(m *dns.Msg, a string) (*dns.Msg, time.Duration, error) {
   863  	dohE.Lock()
   864  	defer dohE.Unlock()
   865  
   866  	timeoutURLerror := &url.Error{
   867  		Op:  "GET",
   868  		URL: "https://example.com",
   869  		Err: &mockTimeoutURLError{},
   870  	}
   871  
   872  	return nil, time.Second, timeoutURLerror
   873  }
   874  
   875  func TestDOHMetric(t *testing.T) {
   876  	staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr})
   877  	test.AssertNotError(t, err, "Got error creating StaticProvider")
   878  
   879  	testClient := New(time.Second*11, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 0, "", blog.UseMock(), tlsConfig)
   880  	resolver := testClient.(*impl)
   881  	resolver.dnsClient = &dohAlwaysRetryExchanger{err: &url.Error{Op: "read", Err: testTimeoutError(true)}}
   882  
   883  	// Starting out, we should count 0 "out of retries" errors.
   884  	test.AssertMetricWithLabelsEquals(t, resolver.timeoutCounter, prometheus.Labels{"qtype": "None", "type": "out of retries", "resolver": "127.0.0.1", "isTLD": "false"}, 0)
   885  
   886  	// Trigger the error.
   887  	_, _, _ = resolver.exchangeOne(context.Background(), "example.com", 0)
   888  
   889  	// Now, we should count 1 "out of retries" errors.
   890  	test.AssertMetricWithLabelsEquals(t, resolver.timeoutCounter, prometheus.Labels{"qtype": "None", "type": "out of retries", "resolver": "127.0.0.1", "isTLD": "false"}, 1)
   891  }