github.com/letsencrypt/boulder@v0.20251208.0/bdns/dns.go (about)

     1  package bdns
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"net"
    10  	"net/http"
    11  	"net/netip"
    12  	"slices"
    13  	"strconv"
    14  	"strings"
    15  	"sync"
    16  	"time"
    17  
    18  	"github.com/jmhodges/clock"
    19  	"github.com/miekg/dns"
    20  	"github.com/prometheus/client_golang/prometheus"
    21  	"github.com/prometheus/client_golang/prometheus/promauto"
    22  
    23  	"github.com/letsencrypt/boulder/iana"
    24  	blog "github.com/letsencrypt/boulder/log"
    25  	"github.com/letsencrypt/boulder/metrics"
    26  )
    27  
    28  // ResolverAddrs contains DNS resolver(s) that were chosen to perform a
    29  // validation request or CAA recheck. A ResolverAddr will be in the form of
    30  // host:port, A:host:port, or AAAA:host:port depending on which type of lookup
    31  // was done.
    32  type ResolverAddrs []string
    33  
    34  // Client queries for DNS records
    35  type Client interface {
    36  	LookupTXT(context.Context, string) (txts []string, resolver ResolverAddrs, err error)
    37  	LookupHost(context.Context, string) ([]netip.Addr, ResolverAddrs, error)
    38  	LookupCAA(context.Context, string) ([]*dns.CAA, string, ResolverAddrs, error)
    39  }
    40  
    41  // impl represents a client that talks to an external resolver
    42  type impl struct {
    43  	dnsClient                exchanger
    44  	servers                  ServerProvider
    45  	allowRestrictedAddresses bool
    46  	maxTries                 int
    47  	clk                      clock.Clock
    48  	log                      blog.Logger
    49  
    50  	queryTime       *prometheus.HistogramVec
    51  	totalLookupTime *prometheus.HistogramVec
    52  	timeoutCounter  *prometheus.CounterVec
    53  }
    54  
    55  var _ Client = &impl{}
    56  
    57  type exchanger interface {
    58  	Exchange(m *dns.Msg, a string) (*dns.Msg, time.Duration, error)
    59  }
    60  
    61  // New constructs a new DNS resolver object that utilizes the
    62  // provided list of DNS servers for resolution.
    63  //
    64  // `tlsConfig` is the configuration used for outbound DoH queries,
    65  // if applicable.
    66  func New(
    67  	readTimeout time.Duration,
    68  	servers ServerProvider,
    69  	stats prometheus.Registerer,
    70  	clk clock.Clock,
    71  	maxTries int,
    72  	userAgent string,
    73  	log blog.Logger,
    74  	tlsConfig *tls.Config,
    75  ) Client {
    76  	var client exchanger
    77  
    78  	// Clone the default transport because it comes with various settings
    79  	// that we like, which are different from the zero value of an
    80  	// `http.Transport`.
    81  	transport := http.DefaultTransport.(*http.Transport).Clone()
    82  	transport.TLSClientConfig = tlsConfig
    83  	// The default transport already sets this field, but it isn't
    84  	// documented that it will always be set. Set it again to be sure,
    85  	// because Unbound will reject non-HTTP/2 DoH requests.
    86  	transport.ForceAttemptHTTP2 = true
    87  	client = &dohExchanger{
    88  		clk: clk,
    89  		hc: http.Client{
    90  			Timeout:   readTimeout,
    91  			Transport: transport,
    92  		},
    93  		userAgent: userAgent,
    94  	}
    95  
    96  	queryTime := promauto.With(stats).NewHistogramVec(
    97  		prometheus.HistogramOpts{
    98  			Name:    "dns_query_time",
    99  			Help:    "Time taken to perform a DNS query",
   100  			Buckets: metrics.InternetFacingBuckets,
   101  		},
   102  		[]string{"qtype", "result", "resolver"},
   103  	)
   104  	totalLookupTime := promauto.With(stats).NewHistogramVec(
   105  		prometheus.HistogramOpts{
   106  			Name:    "dns_total_lookup_time",
   107  			Help:    "Time taken to perform a DNS lookup, including all retried queries",
   108  			Buckets: metrics.InternetFacingBuckets,
   109  		},
   110  		[]string{"qtype", "result", "retries", "resolver"},
   111  	)
   112  	timeoutCounter := promauto.With(stats).NewCounterVec(
   113  		prometheus.CounterOpts{
   114  			Name: "dns_timeout",
   115  			Help: "Counter of various types of DNS query timeouts",
   116  		},
   117  		[]string{"qtype", "type", "resolver", "isTLD"},
   118  	)
   119  	return &impl{
   120  		dnsClient:                client,
   121  		servers:                  servers,
   122  		allowRestrictedAddresses: false,
   123  		maxTries:                 maxTries,
   124  		clk:                      clk,
   125  		queryTime:                queryTime,
   126  		totalLookupTime:          totalLookupTime,
   127  		timeoutCounter:           timeoutCounter,
   128  		log:                      log,
   129  	}
   130  }
   131  
   132  // NewTest constructs a new DNS resolver object that utilizes the
   133  // provided list of DNS servers for resolution and will allow loopback addresses.
   134  // This constructor should *only* be called from tests (unit or integration).
   135  func NewTest(
   136  	readTimeout time.Duration,
   137  	servers ServerProvider,
   138  	stats prometheus.Registerer,
   139  	clk clock.Clock,
   140  	maxTries int,
   141  	userAgent string,
   142  	log blog.Logger,
   143  	tlsConfig *tls.Config,
   144  ) Client {
   145  	resolver := New(readTimeout, servers, stats, clk, maxTries, userAgent, log, tlsConfig)
   146  	resolver.(*impl).allowRestrictedAddresses = true
   147  	return resolver
   148  }
   149  
   150  // exchangeOne performs a single DNS exchange with a randomly chosen server
   151  // out of the server list, returning the response, time, and error (if any).
   152  // We assume that the upstream resolver requests and validates DNSSEC records
   153  // itself.
   154  func (dnsClient *impl) exchangeOne(ctx context.Context, hostname string, qtype uint16) (resp *dns.Msg, resolver string, err error) {
   155  	m := new(dns.Msg)
   156  	// Set question type
   157  	m.SetQuestion(dns.Fqdn(hostname), qtype)
   158  	// Set the AD bit in the query header so that the resolver knows that
   159  	// we are interested in this bit in the response header. If this isn't
   160  	// set the AD bit in the response is useless (RFC 6840 Section 5.7).
   161  	// This has no security implications, it simply allows us to gather
   162  	// metrics about the percentage of responses that are secured with
   163  	// DNSSEC.
   164  	m.AuthenticatedData = true
   165  	// Tell the resolver that we're willing to receive responses up to 4096 bytes.
   166  	// This happens sometimes when there are a very large number of CAA records
   167  	// present.
   168  	m.SetEdns0(4096, false)
   169  
   170  	servers, err := dnsClient.servers.Addrs()
   171  	if err != nil {
   172  		return nil, "", fmt.Errorf("failed to list DNS servers: %w", err)
   173  	}
   174  	chosenServerIndex := 0
   175  	chosenServer := servers[chosenServerIndex]
   176  	resolver = chosenServer
   177  
   178  	// Strip off the IP address part of the server address because
   179  	// we talk to the same server on multiple ports, and don't want
   180  	// to blow up the cardinality.
   181  	chosenServerIP, _, err := net.SplitHostPort(chosenServer)
   182  	if err != nil {
   183  		return
   184  	}
   185  
   186  	start := dnsClient.clk.Now()
   187  	client := dnsClient.dnsClient
   188  	qtypeStr := dns.TypeToString[qtype]
   189  	tries := 1
   190  	defer func() {
   191  		result := "failed"
   192  		if resp != nil {
   193  			result = dns.RcodeToString[resp.Rcode]
   194  		}
   195  		dnsClient.totalLookupTime.With(prometheus.Labels{
   196  			"qtype":    qtypeStr,
   197  			"result":   result,
   198  			"retries":  strconv.Itoa(tries),
   199  			"resolver": chosenServerIP,
   200  		}).Observe(dnsClient.clk.Since(start).Seconds())
   201  	}()
   202  	for {
   203  		ch := make(chan dnsResp, 1)
   204  
   205  		// Strip off the IP address part of the server address because
   206  		// we talk to the same server on multiple ports, and don't want
   207  		// to blow up the cardinality.
   208  		// Note: validateServerAddress() has already checked net.SplitHostPort()
   209  		// and ensures that chosenServer can't be a bare port, e.g. ":1337"
   210  		chosenServerIP, _, err = net.SplitHostPort(chosenServer)
   211  		if err != nil {
   212  			return
   213  		}
   214  
   215  		go func() {
   216  			rsp, rtt, err := client.Exchange(m, chosenServer)
   217  			result := "failed"
   218  			if rsp != nil {
   219  				result = dns.RcodeToString[rsp.Rcode]
   220  			}
   221  			if err != nil {
   222  				dnsClient.log.Infof("logDNSError chosenServer=[%s] hostname=[%s] queryType=[%s] err=[%s]",
   223  					chosenServer,
   224  					hostname,
   225  					qtypeStr,
   226  					err)
   227  			}
   228  			dnsClient.queryTime.With(prometheus.Labels{
   229  				"qtype":    qtypeStr,
   230  				"result":   result,
   231  				"resolver": chosenServerIP,
   232  			}).Observe(rtt.Seconds())
   233  			ch <- dnsResp{m: rsp, err: err}
   234  		}()
   235  		select {
   236  		case <-ctx.Done():
   237  			if ctx.Err() == context.DeadlineExceeded {
   238  				dnsClient.timeoutCounter.With(prometheus.Labels{
   239  					"qtype":    qtypeStr,
   240  					"type":     "deadline exceeded",
   241  					"resolver": chosenServerIP,
   242  					"isTLD":    isTLD(hostname),
   243  				}).Inc()
   244  			} else if ctx.Err() == context.Canceled {
   245  				dnsClient.timeoutCounter.With(prometheus.Labels{
   246  					"qtype":    qtypeStr,
   247  					"type":     "canceled",
   248  					"resolver": chosenServerIP,
   249  					"isTLD":    isTLD(hostname),
   250  				}).Inc()
   251  			} else {
   252  				dnsClient.timeoutCounter.With(prometheus.Labels{
   253  					"qtype":    qtypeStr,
   254  					"type":     "unknown",
   255  					"resolver": chosenServerIP,
   256  				}).Inc()
   257  			}
   258  			err = ctx.Err()
   259  			return
   260  		case r := <-ch:
   261  			if r.err != nil {
   262  				var isRetryable bool
   263  				// Check if the error is a timeout error. Network errors
   264  				// that can timeout implement the net.Error interface.
   265  				var netErr net.Error
   266  				isRetryable = errors.As(r.err, &netErr) && netErr.Timeout()
   267  				hasRetriesLeft := tries < dnsClient.maxTries
   268  				if isRetryable && hasRetriesLeft {
   269  					tries++
   270  					// Chose a new server to retry the query with by incrementing the
   271  					// chosen server index modulo the number of servers. This ensures that
   272  					// if one dns server isn't available we retry with the next in the
   273  					// list.
   274  					chosenServerIndex = (chosenServerIndex + 1) % len(servers)
   275  					chosenServer = servers[chosenServerIndex]
   276  					resolver = chosenServer
   277  					continue
   278  				} else if isRetryable && !hasRetriesLeft {
   279  					dnsClient.timeoutCounter.With(prometheus.Labels{
   280  						"qtype":    qtypeStr,
   281  						"type":     "out of retries",
   282  						"resolver": chosenServerIP,
   283  						"isTLD":    isTLD(hostname),
   284  					}).Inc()
   285  				}
   286  			}
   287  			resp, err = r.m, r.err
   288  			return
   289  		}
   290  	}
   291  }
   292  
   293  // isTLD returns a simplified view of whether something is a TLD: does it have
   294  // any dots in it? This returns true or false as a string, and is meant solely
   295  // for Prometheus metrics.
   296  func isTLD(hostname string) string {
   297  	if strings.Contains(hostname, ".") {
   298  		return "false"
   299  	} else {
   300  		return "true"
   301  	}
   302  }
   303  
   304  type dnsResp struct {
   305  	m   *dns.Msg
   306  	err error
   307  }
   308  
   309  // LookupTXT sends a DNS query to find all TXT records associated with
   310  // the provided hostname which it returns along with the returned
   311  // DNS authority section.
   312  func (dnsClient *impl) LookupTXT(ctx context.Context, hostname string) ([]string, ResolverAddrs, error) {
   313  	var txt []string
   314  	dnsType := dns.TypeTXT
   315  	r, resolver, err := dnsClient.exchangeOne(ctx, hostname, dnsType)
   316  	errWrap := wrapErr(dnsType, hostname, r, err)
   317  	if errWrap != nil {
   318  		return nil, ResolverAddrs{resolver}, errWrap
   319  	}
   320  
   321  	for _, answer := range r.Answer {
   322  		if answer.Header().Rrtype == dnsType {
   323  			if txtRec, ok := answer.(*dns.TXT); ok {
   324  				txt = append(txt, strings.Join(txtRec.Txt, ""))
   325  			}
   326  		}
   327  	}
   328  
   329  	return txt, ResolverAddrs{resolver}, err
   330  }
   331  
   332  func (dnsClient *impl) lookupIP(ctx context.Context, hostname string, ipType uint16) ([]dns.RR, string, error) {
   333  	resp, resolver, err := dnsClient.exchangeOne(ctx, hostname, ipType)
   334  	switch ipType {
   335  	case dns.TypeA:
   336  		if resolver != "" {
   337  			resolver = "A:" + resolver
   338  		}
   339  	case dns.TypeAAAA:
   340  		if resolver != "" {
   341  			resolver = "AAAA:" + resolver
   342  		}
   343  	}
   344  	errWrap := wrapErr(ipType, hostname, resp, err)
   345  	if errWrap != nil {
   346  		return nil, resolver, errWrap
   347  	}
   348  	return resp.Answer, resolver, nil
   349  }
   350  
   351  // LookupHost sends a DNS query to find all A and AAAA records associated with
   352  // the provided hostname. This method assumes that the external resolver will
   353  // chase CNAME/DNAME aliases and return relevant records. It will retry
   354  // requests in the case of temporary network errors. It returns an error if
   355  // both the A and AAAA lookups fail or are empty, but succeeds otherwise.
   356  func (dnsClient *impl) LookupHost(ctx context.Context, hostname string) ([]netip.Addr, ResolverAddrs, error) {
   357  	var recordsA, recordsAAAA []dns.RR
   358  	var errA, errAAAA error
   359  	var resolverA, resolverAAAA string
   360  	var wg sync.WaitGroup
   361  
   362  	wg.Go(func() {
   363  		recordsA, resolverA, errA = dnsClient.lookupIP(ctx, hostname, dns.TypeA)
   364  	})
   365  	wg.Go(func() {
   366  		recordsAAAA, resolverAAAA, errAAAA = dnsClient.lookupIP(ctx, hostname, dns.TypeAAAA)
   367  	})
   368  	wg.Wait()
   369  
   370  	resolvers := ResolverAddrs{resolverA, resolverAAAA}
   371  	resolvers = slices.DeleteFunc(resolvers, func(a string) bool {
   372  		return a == ""
   373  	})
   374  
   375  	var addrsA []netip.Addr
   376  	if errA == nil {
   377  		for _, answer := range recordsA {
   378  			if answer.Header().Rrtype == dns.TypeA {
   379  				a, ok := answer.(*dns.A)
   380  				if ok && a.A.To4() != nil {
   381  					netIP, ok := netip.AddrFromSlice(a.A)
   382  					if ok && (iana.IsReservedAddr(netIP) == nil || dnsClient.allowRestrictedAddresses) {
   383  						addrsA = append(addrsA, netIP)
   384  					}
   385  				}
   386  			}
   387  		}
   388  		if len(addrsA) == 0 {
   389  			errA = fmt.Errorf("no valid A records found for %s", hostname)
   390  		}
   391  	}
   392  
   393  	var addrsAAAA []netip.Addr
   394  	if errAAAA == nil {
   395  		for _, answer := range recordsAAAA {
   396  			if answer.Header().Rrtype == dns.TypeAAAA {
   397  				aaaa, ok := answer.(*dns.AAAA)
   398  				if ok && aaaa.AAAA.To16() != nil {
   399  					netIP, ok := netip.AddrFromSlice(aaaa.AAAA)
   400  					if ok && (iana.IsReservedAddr(netIP) == nil || dnsClient.allowRestrictedAddresses) {
   401  						addrsAAAA = append(addrsAAAA, netIP)
   402  					}
   403  				}
   404  			}
   405  		}
   406  		if len(addrsAAAA) == 0 {
   407  			errAAAA = fmt.Errorf("no valid AAAA records found for %s", hostname)
   408  		}
   409  	}
   410  
   411  	if errA != nil && errAAAA != nil {
   412  		// Construct a new error from both underlying errors. We can only use %w for
   413  		// one of them, because the go error unwrapping protocol doesn't support
   414  		// branching. We don't use ProblemDetails and SubProblemDetails here, because
   415  		// this error will get wrapped in a DNSError and further munged by higher
   416  		// layers in the stack.
   417  		return nil, resolvers, fmt.Errorf("%w; %s", errA, errAAAA)
   418  	}
   419  
   420  	return append(addrsA, addrsAAAA...), resolvers, nil
   421  }
   422  
   423  // LookupCAA sends a DNS query to find all CAA records associated with
   424  // the provided hostname and the complete dig-style RR `response`. This
   425  // response is quite verbose, however it's only populated when the CAA
   426  // response is non-empty.
   427  func (dnsClient *impl) LookupCAA(ctx context.Context, hostname string) ([]*dns.CAA, string, ResolverAddrs, error) {
   428  	dnsType := dns.TypeCAA
   429  	r, resolver, err := dnsClient.exchangeOne(ctx, hostname, dnsType)
   430  
   431  	// Special case: when checking CAA for non-TLD names, treat NXDOMAIN as a
   432  	// successful response containing an empty set of records. This can come up in
   433  	// situations where records were provisioned for validation (e.g. TXT records
   434  	// for DNS-01 challenge) and then removed after validation but before CAA
   435  	// rechecking. But allow NXDOMAIN for TLDs to fall through to the error code
   436  	// below, so we don't issue for gTLDs that have been removed by ICANN.
   437  	if err == nil && r.Rcode == dns.RcodeNameError && strings.Contains(hostname, ".") {
   438  		return nil, "", ResolverAddrs{resolver}, nil
   439  	}
   440  
   441  	errWrap := wrapErr(dnsType, hostname, r, err)
   442  	if errWrap != nil {
   443  		return nil, "", ResolverAddrs{resolver}, errWrap
   444  	}
   445  
   446  	var CAAs []*dns.CAA
   447  	for _, answer := range r.Answer {
   448  		if caaR, ok := answer.(*dns.CAA); ok {
   449  			CAAs = append(CAAs, caaR)
   450  		}
   451  	}
   452  	var response string
   453  	if len(CAAs) > 0 {
   454  		response = r.String()
   455  	}
   456  	return CAAs, response, ResolverAddrs{resolver}, nil
   457  }
   458  
   459  type dohExchanger struct {
   460  	clk       clock.Clock
   461  	hc        http.Client
   462  	userAgent string
   463  }
   464  
   465  // Exchange sends a DoH query to the provided DoH server and returns the response.
   466  func (d *dohExchanger) Exchange(query *dns.Msg, server string) (*dns.Msg, time.Duration, error) {
   467  	q, err := query.Pack()
   468  	if err != nil {
   469  		return nil, 0, err
   470  	}
   471  
   472  	// The default Unbound URL template
   473  	url := fmt.Sprintf("https://%s/dns-query", server)
   474  	req, err := http.NewRequest("POST", url, strings.NewReader(string(q)))
   475  	if err != nil {
   476  		return nil, 0, err
   477  	}
   478  	req.Header.Set("Content-Type", "application/dns-message")
   479  	req.Header.Set("Accept", "application/dns-message")
   480  	if len(d.userAgent) > 0 {
   481  		req.Header.Set("User-Agent", d.userAgent)
   482  	}
   483  
   484  	start := d.clk.Now()
   485  	resp, err := d.hc.Do(req)
   486  	if err != nil {
   487  		return nil, d.clk.Since(start), err
   488  	}
   489  	defer resp.Body.Close()
   490  
   491  	if resp.StatusCode != http.StatusOK {
   492  		return nil, d.clk.Since(start), fmt.Errorf("doh: http status %d", resp.StatusCode)
   493  	}
   494  
   495  	b, err := io.ReadAll(resp.Body)
   496  	if err != nil {
   497  		return nil, d.clk.Since(start), fmt.Errorf("doh: reading response body: %w", err)
   498  	}
   499  
   500  	response := new(dns.Msg)
   501  	err = response.Unpack(b)
   502  	if err != nil {
   503  		return nil, d.clk.Since(start), fmt.Errorf("doh: unpacking response: %w", err)
   504  	}
   505  
   506  	return response, d.clk.Since(start), nil
   507  }