github.com/letsencrypt/boulder@v0.20251208.0/bdns/dns_test.go (about) 1 package bdns 2 3 import ( 4 "context" 5 "crypto/tls" 6 "crypto/x509" 7 "errors" 8 "fmt" 9 "io" 10 "log" 11 "net" 12 "net/http" 13 "net/netip" 14 "net/url" 15 "os" 16 "regexp" 17 "slices" 18 "strings" 19 "sync" 20 "testing" 21 "time" 22 23 "github.com/jmhodges/clock" 24 "github.com/miekg/dns" 25 "github.com/prometheus/client_golang/prometheus" 26 27 blog "github.com/letsencrypt/boulder/log" 28 "github.com/letsencrypt/boulder/metrics" 29 "github.com/letsencrypt/boulder/test" 30 ) 31 32 const dnsLoopbackAddr = "127.0.0.1:4053" 33 34 func mockDNSQuery(w http.ResponseWriter, httpReq *http.Request) { 35 if httpReq.Header.Get("Content-Type") != "application/dns-message" { 36 w.WriteHeader(http.StatusBadRequest) 37 fmt.Fprintf(w, "client didn't send Content-Type: application/dns-message") 38 } 39 if httpReq.Header.Get("Accept") != "application/dns-message" { 40 w.WriteHeader(http.StatusBadRequest) 41 fmt.Fprintf(w, "client didn't accept Content-Type: application/dns-message") 42 } 43 44 requestBody, err := io.ReadAll(httpReq.Body) 45 if err != nil { 46 w.WriteHeader(http.StatusBadRequest) 47 fmt.Fprintf(w, "reading body: %s", err) 48 } 49 httpReq.Body.Close() 50 51 r := new(dns.Msg) 52 err = r.Unpack(requestBody) 53 if err != nil { 54 w.WriteHeader(http.StatusBadRequest) 55 fmt.Fprintf(w, "unpacking request: %s", err) 56 } 57 58 m := new(dns.Msg) 59 m.SetReply(r) 60 m.Compress = false 61 62 appendAnswer := func(rr dns.RR) { 63 m.Answer = append(m.Answer, rr) 64 } 65 for _, q := range r.Question { 66 q.Name = strings.ToLower(q.Name) 67 if q.Name == "servfail.com." || q.Name == "servfailexception.example.com" { 68 m.Rcode = dns.RcodeServerFailure 69 break 70 } 71 switch q.Qtype { 72 case dns.TypeSOA: 73 record := new(dns.SOA) 74 record.Hdr = dns.RR_Header{Name: "letsencrypt.org.", Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: 0} 75 record.Ns = "ns.letsencrypt.org." 76 record.Mbox = "master.letsencrypt.org." 77 record.Serial = 1 78 record.Refresh = 1 79 record.Retry = 1 80 record.Expire = 1 81 record.Minttl = 1 82 appendAnswer(record) 83 case dns.TypeAAAA: 84 if q.Name == "v6.letsencrypt.org." { 85 record := new(dns.AAAA) 86 record.Hdr = dns.RR_Header{Name: "v6.letsencrypt.org.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 0} 87 record.AAAA = net.ParseIP("2602:80a:6000:abad:cafe::1") 88 appendAnswer(record) 89 } 90 if q.Name == "dualstack.letsencrypt.org." { 91 record := new(dns.AAAA) 92 record.Hdr = dns.RR_Header{Name: "dualstack.letsencrypt.org.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 0} 93 record.AAAA = net.ParseIP("2602:80a:6000:abad:cafe::1") 94 appendAnswer(record) 95 } 96 if q.Name == "v4error.letsencrypt.org." { 97 record := new(dns.AAAA) 98 record.Hdr = dns.RR_Header{Name: "v4error.letsencrypt.org.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 0} 99 record.AAAA = net.ParseIP("2602:80a:6000:abad:cafe::1") 100 appendAnswer(record) 101 } 102 if q.Name == "v6error.letsencrypt.org." { 103 m.SetRcode(r, dns.RcodeNotImplemented) 104 } 105 if q.Name == "nxdomain.letsencrypt.org." { 106 m.SetRcode(r, dns.RcodeNameError) 107 } 108 if q.Name == "dualstackerror.letsencrypt.org." { 109 m.SetRcode(r, dns.RcodeNotImplemented) 110 } 111 case dns.TypeA: 112 if q.Name == "cps.letsencrypt.org." { 113 record := new(dns.A) 114 record.Hdr = dns.RR_Header{Name: "cps.letsencrypt.org.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 0} 115 record.A = net.ParseIP("64.112.117.1") 116 appendAnswer(record) 117 } 118 if q.Name == "dualstack.letsencrypt.org." { 119 record := new(dns.A) 120 record.Hdr = dns.RR_Header{Name: "dualstack.letsencrypt.org.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 0} 121 record.A = net.ParseIP("64.112.117.1") 122 appendAnswer(record) 123 } 124 if q.Name == "v6error.letsencrypt.org." { 125 record := new(dns.A) 126 record.Hdr = dns.RR_Header{Name: "dualstack.letsencrypt.org.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 0} 127 record.A = net.ParseIP("64.112.117.1") 128 appendAnswer(record) 129 } 130 if q.Name == "v4error.letsencrypt.org." { 131 m.SetRcode(r, dns.RcodeNotImplemented) 132 } 133 if q.Name == "nxdomain.letsencrypt.org." { 134 m.SetRcode(r, dns.RcodeNameError) 135 } 136 if q.Name == "dualstackerror.letsencrypt.org." { 137 m.SetRcode(r, dns.RcodeRefused) 138 } 139 case dns.TypeCNAME: 140 if q.Name == "cname.letsencrypt.org." { 141 record := new(dns.CNAME) 142 record.Hdr = dns.RR_Header{Name: "cname.letsencrypt.org.", Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: 30} 143 record.Target = "cps.letsencrypt.org." 144 appendAnswer(record) 145 } 146 if q.Name == "cname.example.com." { 147 record := new(dns.CNAME) 148 record.Hdr = dns.RR_Header{Name: "cname.example.com.", Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: 30} 149 record.Target = "CAA.example.com." 150 appendAnswer(record) 151 } 152 case dns.TypeDNAME: 153 if q.Name == "dname.letsencrypt.org." { 154 record := new(dns.DNAME) 155 record.Hdr = dns.RR_Header{Name: "dname.letsencrypt.org.", Rrtype: dns.TypeDNAME, Class: dns.ClassINET, Ttl: 30} 156 record.Target = "cps.letsencrypt.org." 157 appendAnswer(record) 158 } 159 case dns.TypeCAA: 160 if q.Name == "bracewel.net." || q.Name == "caa.example.com." { 161 record := new(dns.CAA) 162 record.Hdr = dns.RR_Header{Name: q.Name, Rrtype: dns.TypeCAA, Class: dns.ClassINET, Ttl: 0} 163 record.Tag = "issue" 164 record.Value = "letsencrypt.org" 165 record.Flag = 1 166 appendAnswer(record) 167 } 168 if q.Name == "cname.example.com." { 169 record := new(dns.CAA) 170 record.Hdr = dns.RR_Header{Name: "caa.example.com.", Rrtype: dns.TypeCAA, Class: dns.ClassINET, Ttl: 0} 171 record.Tag = "issue" 172 record.Value = "letsencrypt.org" 173 record.Flag = 1 174 appendAnswer(record) 175 } 176 if q.Name == "gonetld." { 177 m.SetRcode(r, dns.RcodeNameError) 178 } 179 case dns.TypeTXT: 180 if q.Name == "split-txt.letsencrypt.org." { 181 record := new(dns.TXT) 182 record.Hdr = dns.RR_Header{Name: "split-txt.letsencrypt.org.", Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 0} 183 record.Txt = []string{"a", "b", "c"} 184 appendAnswer(record) 185 } else { 186 auth := new(dns.SOA) 187 auth.Hdr = dns.RR_Header{Name: "letsencrypt.org.", Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: 0} 188 auth.Ns = "ns.letsencrypt.org." 189 auth.Mbox = "master.letsencrypt.org." 190 auth.Serial = 1 191 auth.Refresh = 1 192 auth.Retry = 1 193 auth.Expire = 1 194 auth.Minttl = 1 195 m.Ns = append(m.Ns, auth) 196 } 197 if q.Name == "nxdomain.letsencrypt.org." { 198 m.SetRcode(r, dns.RcodeNameError) 199 } 200 } 201 } 202 203 body, err := m.Pack() 204 if err != nil { 205 fmt.Fprintf(os.Stderr, "packing reply: %s\n", err) 206 } 207 w.Header().Set("Content-Type", "application/dns-message") 208 _, err = w.Write(body) 209 if err != nil { 210 panic(err) // running tests, so panic is OK 211 } 212 } 213 214 func serveLoopResolver(stopChan chan bool) { 215 m := http.NewServeMux() 216 m.HandleFunc("/dns-query", mockDNSQuery) 217 httpServer := &http.Server{ 218 Addr: dnsLoopbackAddr, 219 Handler: m, 220 ReadTimeout: time.Second, 221 WriteTimeout: time.Second, 222 } 223 go func() { 224 cert := "../test/certs/ipki/localhost/cert.pem" 225 key := "../test/certs/ipki/localhost/key.pem" 226 err := httpServer.ListenAndServeTLS(cert, key) 227 if err != nil { 228 fmt.Println(err) 229 } 230 }() 231 go func() { 232 <-stopChan 233 err := httpServer.Shutdown(context.Background()) 234 if err != nil { 235 log.Fatal(err) 236 } 237 }() 238 } 239 240 func pollServer() { 241 backoff := 200 * time.Millisecond 242 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 243 defer cancel() 244 ticker := time.NewTicker(backoff) 245 246 for { 247 select { 248 case <-ctx.Done(): 249 fmt.Fprintln(os.Stderr, "Timeout reached while testing for the dns server to come up") 250 os.Exit(1) 251 case <-ticker.C: 252 conn, _ := dns.DialTimeout("udp", dnsLoopbackAddr, backoff) 253 if conn != nil { 254 _ = conn.Close() 255 return 256 } 257 } 258 } 259 } 260 261 // tlsConfig is used for the TLS config of client instances that talk to the 262 // DoH server set up in TestMain. 263 var tlsConfig *tls.Config 264 265 func TestMain(m *testing.M) { 266 root, err := os.ReadFile("../test/certs/ipki/minica.pem") 267 if err != nil { 268 log.Fatal(err) 269 } 270 pool := x509.NewCertPool() 271 pool.AppendCertsFromPEM(root) 272 tlsConfig = &tls.Config{ 273 RootCAs: pool, 274 } 275 276 stop := make(chan bool, 1) 277 serveLoopResolver(stop) 278 pollServer() 279 ret := m.Run() 280 stop <- true 281 os.Exit(ret) 282 } 283 284 func TestDNSNoServers(t *testing.T) { 285 staticProvider, err := NewStaticProvider([]string{}) 286 test.AssertNotError(t, err, "Got error creating StaticProvider") 287 288 obj := New(time.Hour, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, "", blog.UseMock(), tlsConfig) 289 290 _, resolvers, err := obj.LookupHost(context.Background(), "letsencrypt.org") 291 test.AssertEquals(t, len(resolvers), 0) 292 test.AssertError(t, err, "No servers") 293 294 _, _, err = obj.LookupTXT(context.Background(), "letsencrypt.org") 295 test.AssertError(t, err, "No servers") 296 297 _, _, _, err = obj.LookupCAA(context.Background(), "letsencrypt.org") 298 test.AssertError(t, err, "No servers") 299 } 300 301 func TestDNSOneServer(t *testing.T) { 302 staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr}) 303 test.AssertNotError(t, err, "Got error creating StaticProvider") 304 305 obj := New(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, "", blog.UseMock(), tlsConfig) 306 307 _, resolvers, err := obj.LookupHost(context.Background(), "cps.letsencrypt.org") 308 test.AssertEquals(t, len(resolvers), 2) 309 slices.Sort(resolvers) 310 test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"}) 311 test.AssertNotError(t, err, "No message") 312 } 313 314 func TestDNSDuplicateServers(t *testing.T) { 315 staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr, dnsLoopbackAddr}) 316 test.AssertNotError(t, err, "Got error creating StaticProvider") 317 318 obj := New(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, "", blog.UseMock(), tlsConfig) 319 320 _, resolvers, err := obj.LookupHost(context.Background(), "cps.letsencrypt.org") 321 test.AssertEquals(t, len(resolvers), 2) 322 slices.Sort(resolvers) 323 test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"}) 324 test.AssertNotError(t, err, "No message") 325 } 326 327 func TestDNSServFail(t *testing.T) { 328 staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr}) 329 test.AssertNotError(t, err, "Got error creating StaticProvider") 330 331 obj := New(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, "", blog.UseMock(), tlsConfig) 332 bad := "servfail.com" 333 334 _, _, err = obj.LookupTXT(context.Background(), bad) 335 test.AssertError(t, err, "LookupTXT didn't return an error") 336 337 _, _, err = obj.LookupHost(context.Background(), bad) 338 test.AssertError(t, err, "LookupHost didn't return an error") 339 340 emptyCaa, _, _, err := obj.LookupCAA(context.Background(), bad) 341 test.Assert(t, len(emptyCaa) == 0, "Query returned non-empty list of CAA records") 342 test.AssertError(t, err, "LookupCAA should have returned an error") 343 } 344 345 func TestDNSLookupTXT(t *testing.T) { 346 staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr}) 347 test.AssertNotError(t, err, "Got error creating StaticProvider") 348 349 obj := New(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, "", blog.UseMock(), tlsConfig) 350 351 a, _, err := obj.LookupTXT(context.Background(), "letsencrypt.org") 352 t.Logf("A: %v", a) 353 test.AssertNotError(t, err, "No message") 354 355 a, _, err = obj.LookupTXT(context.Background(), "split-txt.letsencrypt.org") 356 t.Logf("A: %v ", a) 357 test.AssertNotError(t, err, "No message") 358 test.AssertEquals(t, len(a), 1) 359 test.AssertEquals(t, a[0], "abc") 360 } 361 362 // TODO(#8213): Convert this to a table test. 363 func TestDNSLookupHost(t *testing.T) { 364 staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr}) 365 test.AssertNotError(t, err, "Got error creating StaticProvider") 366 367 obj := New(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, "", blog.UseMock(), tlsConfig) 368 369 ip, resolvers, err := obj.LookupHost(context.Background(), "servfail.com") 370 t.Logf("servfail.com - IP: %s, Err: %s", ip, err) 371 test.AssertError(t, err, "Server failure") 372 test.Assert(t, len(ip) == 0, "Should not have IPs") 373 slices.Sort(resolvers) 374 test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"}) 375 376 ip, resolvers, err = obj.LookupHost(context.Background(), "nonexistent.letsencrypt.org") 377 t.Logf("nonexistent.letsencrypt.org - IP: %s, Err: %s", ip, err) 378 test.AssertError(t, err, "No valid A or AAAA records should error") 379 test.Assert(t, len(ip) == 0, "Should not have IPs") 380 slices.Sort(resolvers) 381 test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"}) 382 383 // Single IPv4 address 384 ip, resolvers, err = obj.LookupHost(context.Background(), "cps.letsencrypt.org") 385 t.Logf("cps.letsencrypt.org - IP: %s, Err: %s", ip, err) 386 test.AssertNotError(t, err, "Not an error to exist") 387 test.Assert(t, len(ip) == 1, "Should have IP") 388 slices.Sort(resolvers) 389 test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"}) 390 ip, resolvers, err = obj.LookupHost(context.Background(), "cps.letsencrypt.org") 391 t.Logf("cps.letsencrypt.org - IP: %s, Err: %s", ip, err) 392 test.AssertNotError(t, err, "Not an error to exist") 393 test.Assert(t, len(ip) == 1, "Should have IP") 394 slices.Sort(resolvers) 395 test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"}) 396 397 // Single IPv6 address 398 ip, resolvers, err = obj.LookupHost(context.Background(), "v6.letsencrypt.org") 399 t.Logf("v6.letsencrypt.org - IP: %s, Err: %s", ip, err) 400 test.AssertNotError(t, err, "Not an error to exist") 401 test.Assert(t, len(ip) == 1, "Should not have IPs") 402 slices.Sort(resolvers) 403 test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"}) 404 405 // Both IPv6 and IPv4 address 406 ip, resolvers, err = obj.LookupHost(context.Background(), "dualstack.letsencrypt.org") 407 t.Logf("dualstack.letsencrypt.org - IP: %s, Err: %s", ip, err) 408 test.AssertNotError(t, err, "Not an error to exist") 409 test.Assert(t, len(ip) == 2, "Should have 2 IPs") 410 expected := netip.MustParseAddr("64.112.117.1") 411 test.Assert(t, ip[0] == expected, "wrong ipv4 address") 412 expected = netip.MustParseAddr("2602:80a:6000:abad:cafe::1") 413 test.Assert(t, ip[1] == expected, "wrong ipv6 address") 414 slices.Sort(resolvers) 415 test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"}) 416 417 // IPv6 error, IPv4 success 418 ip, resolvers, err = obj.LookupHost(context.Background(), "v6error.letsencrypt.org") 419 t.Logf("v6error.letsencrypt.org - IP: %s, Err: %s", ip, err) 420 test.AssertNotError(t, err, "Not an error to exist") 421 test.Assert(t, len(ip) == 1, "Should have 1 IP") 422 expected = netip.MustParseAddr("64.112.117.1") 423 test.Assert(t, ip[0] == expected, "wrong ipv4 address") 424 slices.Sort(resolvers) 425 test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"}) 426 427 // IPv6 success, IPv4 error 428 ip, resolvers, err = obj.LookupHost(context.Background(), "v4error.letsencrypt.org") 429 t.Logf("v4error.letsencrypt.org - IP: %s, Err: %s", ip, err) 430 test.AssertNotError(t, err, "Not an error to exist") 431 test.Assert(t, len(ip) == 1, "Should have 1 IP") 432 expected = netip.MustParseAddr("2602:80a:6000:abad:cafe::1") 433 test.Assert(t, ip[0] == expected, "wrong ipv6 address") 434 slices.Sort(resolvers) 435 test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"}) 436 437 // IPv6 error, IPv4 error 438 // Should return both the IPv4 error (Refused) and the IPv6 error (NotImplemented) 439 hostname := "dualstackerror.letsencrypt.org" 440 ip, resolvers, err = obj.LookupHost(context.Background(), hostname) 441 t.Logf("%s - IP: %s, Err: %s", hostname, ip, err) 442 test.AssertError(t, err, "Should be an error") 443 test.AssertContains(t, err.Error(), "REFUSED looking up A for") 444 test.AssertContains(t, err.Error(), "NOTIMP looking up AAAA for") 445 slices.Sort(resolvers) 446 test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"}) 447 } 448 449 func TestDNSNXDOMAIN(t *testing.T) { 450 staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr}) 451 test.AssertNotError(t, err, "Got error creating StaticProvider") 452 453 obj := New(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, "", blog.UseMock(), tlsConfig) 454 455 hostname := "nxdomain.letsencrypt.org" 456 _, _, err = obj.LookupHost(context.Background(), hostname) 457 test.AssertContains(t, err.Error(), "NXDOMAIN looking up A for") 458 test.AssertContains(t, err.Error(), "NXDOMAIN looking up AAAA for") 459 460 _, _, err = obj.LookupTXT(context.Background(), hostname) 461 expected := Error{dns.TypeTXT, hostname, nil, dns.RcodeNameError, nil} 462 test.AssertDeepEquals(t, err, expected) 463 } 464 465 func TestDNSLookupCAA(t *testing.T) { 466 staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr}) 467 test.AssertNotError(t, err, "Got error creating StaticProvider") 468 469 obj := New(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, "", blog.UseMock(), tlsConfig) 470 removeIDExp := regexp.MustCompile(" id: [[:digit:]]+") 471 472 caas, resp, resolvers, err := obj.LookupCAA(context.Background(), "bracewel.net") 473 test.AssertNotError(t, err, "CAA lookup failed") 474 test.Assert(t, len(caas) > 0, "Should have CAA records") 475 test.AssertEquals(t, len(resolvers), 1) 476 test.AssertDeepEquals(t, resolvers, ResolverAddrs{"127.0.0.1:4053"}) 477 expectedResp := `;; opcode: QUERY, status: NOERROR, id: XXXX 478 ;; flags: qr rd; QUERY: 1, ANSWER: 1, AUTHORITY: 0, ADDITIONAL: 0 479 480 ;; QUESTION SECTION: 481 ;bracewel.net. IN CAA 482 483 ;; ANSWER SECTION: 484 bracewel.net. 0 IN CAA 1 issue "letsencrypt.org" 485 ` 486 test.AssertEquals(t, removeIDExp.ReplaceAllString(resp, " id: XXXX"), expectedResp) 487 488 caas, resp, resolvers, err = obj.LookupCAA(context.Background(), "nonexistent.letsencrypt.org") 489 test.AssertNotError(t, err, "CAA lookup failed") 490 test.Assert(t, len(caas) == 0, "Shouldn't have CAA records") 491 test.AssertEquals(t, resolvers[0], "127.0.0.1:4053") 492 expectedResp = "" 493 test.AssertEquals(t, resp, expectedResp) 494 495 caas, resp, resolvers, err = obj.LookupCAA(context.Background(), "nxdomain.letsencrypt.org") 496 slices.Sort(resolvers) 497 test.AssertNotError(t, err, "CAA lookup failed") 498 test.Assert(t, len(caas) == 0, "Shouldn't have CAA records") 499 test.AssertEquals(t, resolvers[0], "127.0.0.1:4053") 500 expectedResp = "" 501 test.AssertEquals(t, resp, expectedResp) 502 503 caas, resp, resolvers, err = obj.LookupCAA(context.Background(), "cname.example.com") 504 test.AssertNotError(t, err, "CAA lookup failed") 505 test.Assert(t, len(caas) > 0, "Should follow CNAME to find CAA") 506 test.AssertEquals(t, resolvers[0], "127.0.0.1:4053") 507 expectedResp = `;; opcode: QUERY, status: NOERROR, id: XXXX 508 ;; flags: qr rd; QUERY: 1, ANSWER: 1, AUTHORITY: 0, ADDITIONAL: 0 509 510 ;; QUESTION SECTION: 511 ;cname.example.com. IN CAA 512 513 ;; ANSWER SECTION: 514 caa.example.com. 0 IN CAA 1 issue "letsencrypt.org" 515 ` 516 test.AssertEquals(t, removeIDExp.ReplaceAllString(resp, " id: XXXX"), expectedResp) 517 518 _, _, resolvers, err = obj.LookupCAA(context.Background(), "gonetld") 519 test.AssertError(t, err, "should fail for TLD NXDOMAIN") 520 test.AssertContains(t, err.Error(), "NXDOMAIN") 521 test.AssertEquals(t, resolvers[0], "127.0.0.1:4053") 522 } 523 524 type testExchanger struct { 525 sync.Mutex 526 count int 527 errs []error 528 } 529 530 var errTooManyRequests = errors.New("too many requests") 531 532 func (te *testExchanger) Exchange(m *dns.Msg, a string) (*dns.Msg, time.Duration, error) { 533 te.Lock() 534 defer te.Unlock() 535 msg := &dns.Msg{ 536 MsgHdr: dns.MsgHdr{Rcode: dns.RcodeSuccess}, 537 } 538 if len(te.errs) <= te.count { 539 return nil, 0, errTooManyRequests 540 } 541 err := te.errs[te.count] 542 te.count++ 543 544 return msg, 2 * time.Millisecond, err 545 } 546 547 func TestRetry(t *testing.T) { 548 isTimeoutErr := &url.Error{Op: "read", Err: testTimeoutError(true)} 549 nonTimeoutErr := &url.Error{Op: "read", Err: testTimeoutError(false)} 550 servFailError := errors.New("DNS problem: server failure at resolver looking up TXT for example.com") 551 timeoutFailError := errors.New("DNS problem: query timed out looking up TXT for example.com") 552 type testCase struct { 553 name string 554 maxTries int 555 te *testExchanger 556 expected error 557 expectedCount int 558 metricsAllRetries float64 559 } 560 tests := []*testCase{ 561 // The success on first try case 562 { 563 name: "success", 564 maxTries: 3, 565 te: &testExchanger{ 566 errs: []error{nil}, 567 }, 568 expected: nil, 569 expectedCount: 1, 570 }, 571 // Immediate non-OpError, error returns immediately 572 { 573 name: "non-operror", 574 maxTries: 3, 575 te: &testExchanger{ 576 errs: []error{errors.New("nope")}, 577 }, 578 expected: servFailError, 579 expectedCount: 1, 580 }, 581 // Timeout err, then non-OpError stops at two tries 582 { 583 name: "err-then-non-operror", 584 maxTries: 3, 585 te: &testExchanger{ 586 errs: []error{isTimeoutErr, errors.New("nope")}, 587 }, 588 expected: servFailError, 589 expectedCount: 2, 590 }, 591 // Timeout error given always 592 { 593 name: "persistent-timeout-error", 594 maxTries: 3, 595 te: &testExchanger{ 596 errs: []error{ 597 isTimeoutErr, 598 isTimeoutErr, 599 isTimeoutErr, 600 }, 601 }, 602 expected: timeoutFailError, 603 expectedCount: 3, 604 metricsAllRetries: 1, 605 }, 606 // Even with maxTries at 0, we should still let a single request go 607 // through 608 { 609 name: "zero-maxtries", 610 maxTries: 0, 611 te: &testExchanger{ 612 errs: []error{nil}, 613 }, 614 expected: nil, 615 expectedCount: 1, 616 }, 617 // Timeout error given just once causes two tries 618 { 619 name: "single-timeout-error", 620 maxTries: 3, 621 te: &testExchanger{ 622 errs: []error{ 623 isTimeoutErr, 624 nil, 625 }, 626 }, 627 expected: nil, 628 expectedCount: 2, 629 }, 630 // Timeout error given twice causes three tries 631 { 632 name: "double-timeout-error", 633 maxTries: 3, 634 te: &testExchanger{ 635 errs: []error{ 636 isTimeoutErr, 637 isTimeoutErr, 638 nil, 639 }, 640 }, 641 expected: nil, 642 expectedCount: 3, 643 }, 644 // Timeout error given thrice causes three tries and fails 645 { 646 name: "triple-timeout-error", 647 maxTries: 3, 648 te: &testExchanger{ 649 errs: []error{ 650 isTimeoutErr, 651 isTimeoutErr, 652 isTimeoutErr, 653 }, 654 }, 655 expected: timeoutFailError, 656 expectedCount: 3, 657 metricsAllRetries: 1, 658 }, 659 // timeout then non-timeout error causes two retries 660 { 661 name: "timeout-nontimeout-error", 662 maxTries: 3, 663 te: &testExchanger{ 664 errs: []error{ 665 isTimeoutErr, 666 nonTimeoutErr, 667 }, 668 }, 669 expected: servFailError, 670 expectedCount: 2, 671 }, 672 } 673 674 for i, tc := range tests { 675 t.Run(tc.name, func(t *testing.T) { 676 staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr}) 677 test.AssertNotError(t, err, "Got error creating StaticProvider") 678 679 testClient := New(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), tc.maxTries, "", blog.UseMock(), tlsConfig) 680 dr := testClient.(*impl) 681 dr.dnsClient = tc.te 682 _, _, err = dr.LookupTXT(context.Background(), "example.com") 683 if err == errTooManyRequests { 684 t.Errorf("#%d, sent more requests than the test case handles", i) 685 } 686 expectedErr := tc.expected 687 if (expectedErr == nil && err != nil) || 688 (expectedErr != nil && err == nil) || 689 (expectedErr != nil && expectedErr.Error() != err.Error()) { 690 t.Errorf("#%d, error, expected %v, got %v", i, expectedErr, err) 691 } 692 if tc.expectedCount != tc.te.count { 693 t.Errorf("#%d, error, expectedCount %v, got %v", i, tc.expectedCount, tc.te.count) 694 } 695 if tc.metricsAllRetries > 0 { 696 test.AssertMetricWithLabelsEquals( 697 t, dr.timeoutCounter, prometheus.Labels{ 698 "qtype": "TXT", 699 "type": "out of retries", 700 "resolver": "127.0.0.1", 701 "isTLD": "false", 702 }, tc.metricsAllRetries) 703 } 704 }) 705 } 706 707 staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr}) 708 test.AssertNotError(t, err, "Got error creating StaticProvider") 709 710 testClient := New(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 3, "", blog.UseMock(), tlsConfig) 711 dr := testClient.(*impl) 712 dr.dnsClient = &testExchanger{errs: []error{isTimeoutErr, isTimeoutErr, nil}} 713 ctx, cancel := context.WithCancel(context.Background()) 714 cancel() 715 _, _, err = dr.LookupTXT(ctx, "example.com") 716 if err == nil || 717 err.Error() != "DNS problem: query timed out (and was canceled) looking up TXT for example.com" { 718 t.Errorf("expected %s, got %s", context.Canceled, err) 719 } 720 721 dr.dnsClient = &testExchanger{errs: []error{isTimeoutErr, isTimeoutErr, nil}} 722 ctx, cancel = context.WithTimeout(context.Background(), -10*time.Hour) 723 defer cancel() 724 _, _, err = dr.LookupTXT(ctx, "example.com") 725 if err == nil || 726 err.Error() != "DNS problem: query timed out looking up TXT for example.com" { 727 t.Errorf("expected %s, got %s", context.DeadlineExceeded, err) 728 } 729 730 dr.dnsClient = &testExchanger{errs: []error{isTimeoutErr, isTimeoutErr, nil}} 731 ctx, deadlineCancel := context.WithTimeout(context.Background(), -10*time.Hour) 732 deadlineCancel() 733 _, _, err = dr.LookupTXT(ctx, "example.com") 734 if err == nil || 735 err.Error() != "DNS problem: query timed out looking up TXT for example.com" { 736 t.Errorf("expected %s, got %s", context.DeadlineExceeded, err) 737 } 738 739 test.AssertMetricWithLabelsEquals( 740 t, dr.timeoutCounter, prometheus.Labels{ 741 "qtype": "TXT", 742 "type": "canceled", 743 "resolver": "127.0.0.1", 744 }, 1) 745 746 test.AssertMetricWithLabelsEquals( 747 t, dr.timeoutCounter, prometheus.Labels{ 748 "qtype": "TXT", 749 "type": "deadline exceeded", 750 "resolver": "127.0.0.1", 751 }, 2) 752 } 753 754 func TestIsTLD(t *testing.T) { 755 if isTLD("com") != "true" { 756 t.Errorf("expected 'com' to be a TLD, got %q", isTLD("com")) 757 } 758 if isTLD("example.com") != "false" { 759 t.Errorf("expected 'example.com' to not a TLD, got %q", isTLD("example.com")) 760 } 761 } 762 763 type testTimeoutError bool 764 765 func (t testTimeoutError) Timeout() bool { return bool(t) } 766 func (t testTimeoutError) Error() string { return fmt.Sprintf("Timeout: %t", t) } 767 768 // rotateFailureExchanger is a dns.Exchange implementation that tracks a count 769 // of the number of calls to `Exchange` for a given address in the `lookups` 770 // map. For all addresses in the `brokenAddresses` map, a retryable error is 771 // returned from `Exchange`. This mock is used by `TestRotateServerOnErr`. 772 type rotateFailureExchanger struct { 773 sync.Mutex 774 lookups map[string]int 775 brokenAddresses map[string]bool 776 } 777 778 // Exchange for rotateFailureExchanger tracks the `a` argument in `lookups` and 779 // if present in `brokenAddresses`, returns a timeout error. 780 func (e *rotateFailureExchanger) Exchange(m *dns.Msg, a string) (*dns.Msg, time.Duration, error) { 781 e.Lock() 782 defer e.Unlock() 783 784 // Track that exchange was called for the given server 785 e.lookups[a]++ 786 787 // If its a broken server, return a retryable error 788 if e.brokenAddresses[a] { 789 isTimeoutErr := &url.Error{Op: "read", Err: testTimeoutError(true)} 790 return nil, 2 * time.Millisecond, isTimeoutErr 791 } 792 793 return m, 2 * time.Millisecond, nil 794 } 795 796 // TestRotateServerOnErr ensures that a retryable error returned from a DNS 797 // server will result in the retry being performed against the next server in 798 // the list. 799 func TestRotateServerOnErr(t *testing.T) { 800 // Configure three DNS servers 801 dnsServers := []string{ 802 "a:53", "b:53", "[2606:4700:4700::1111]:53", 803 } 804 805 // Set up a DNS client using these servers that will retry queries up to 806 // a maximum of 5 times. It's important to choose a maxTries value >= the 807 // number of dnsServers to ensure we always get around to trying the one 808 // working server 809 staticProvider, err := NewStaticProvider(dnsServers) 810 test.AssertNotError(t, err, "Got error creating StaticProvider") 811 812 maxTries := 5 813 client := New(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), maxTries, "", blog.UseMock(), tlsConfig) 814 815 // Configure a mock exchanger that will always return a retryable error for 816 // servers A and B. This will force server "[2606:4700:4700::1111]:53" to do 817 // all the work once retries reach it. 818 mock := &rotateFailureExchanger{ 819 brokenAddresses: map[string]bool{ 820 "a:53": true, 821 "b:53": true, 822 }, 823 lookups: make(map[string]int), 824 } 825 client.(*impl).dnsClient = mock 826 827 // Perform a bunch of lookups. We choose the initial server randomly. Any time 828 // A or B is chosen there should be an error and a retry using the next server 829 // in the list. Since we configured maxTries to be larger than the number of 830 // servers *all* queries should eventually succeed by being retried against 831 // server "[2606:4700:4700::1111]:53". 832 for range maxTries * 2 { 833 _, resolvers, err := client.LookupTXT(context.Background(), "example.com") 834 test.AssertEquals(t, len(resolvers), 1) 835 test.AssertEquals(t, resolvers[0], "[2606:4700:4700::1111]:53") 836 // Any errors are unexpected - server "[2606:4700:4700::1111]:53" should 837 // have responded without error. 838 test.AssertNotError(t, err, "Expected no error from eventual retry with functional server") 839 } 840 841 // We expect that the A and B servers had a non-zero number of lookups 842 // attempted. 843 test.Assert(t, mock.lookups["a:53"] > 0, "Expected A server to have non-zero lookup attempts") 844 test.Assert(t, mock.lookups["b:53"] > 0, "Expected B server to have non-zero lookup attempts") 845 846 // We expect that the server "[2606:4700:4700::1111]:53" eventually served 847 // all of the lookups attempted. 848 test.AssertEquals(t, mock.lookups["[2606:4700:4700::1111]:53"], maxTries*2) 849 850 } 851 852 type mockTimeoutURLError struct{} 853 854 func (m *mockTimeoutURLError) Error() string { return "whoops, oh gosh" } 855 func (m *mockTimeoutURLError) Timeout() bool { return true } 856 857 type dohAlwaysRetryExchanger struct { 858 sync.Mutex 859 err error 860 } 861 862 func (dohE *dohAlwaysRetryExchanger) Exchange(m *dns.Msg, a string) (*dns.Msg, time.Duration, error) { 863 dohE.Lock() 864 defer dohE.Unlock() 865 866 timeoutURLerror := &url.Error{ 867 Op: "GET", 868 URL: "https://example.com", 869 Err: &mockTimeoutURLError{}, 870 } 871 872 return nil, time.Second, timeoutURLerror 873 } 874 875 func TestDOHMetric(t *testing.T) { 876 staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr}) 877 test.AssertNotError(t, err, "Got error creating StaticProvider") 878 879 testClient := New(time.Second*11, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 0, "", blog.UseMock(), tlsConfig) 880 resolver := testClient.(*impl) 881 resolver.dnsClient = &dohAlwaysRetryExchanger{err: &url.Error{Op: "read", Err: testTimeoutError(true)}} 882 883 // Starting out, we should count 0 "out of retries" errors. 884 test.AssertMetricWithLabelsEquals(t, resolver.timeoutCounter, prometheus.Labels{"qtype": "None", "type": "out of retries", "resolver": "127.0.0.1", "isTLD": "false"}, 0) 885 886 // Trigger the error. 887 _, _, _ = resolver.exchangeOne(context.Background(), "example.com", 0) 888 889 // Now, we should count 1 "out of retries" errors. 890 test.AssertMetricWithLabelsEquals(t, resolver.timeoutCounter, prometheus.Labels{"qtype": "None", "type": "out of retries", "resolver": "127.0.0.1", "isTLD": "false"}, 1) 891 }