github.com/letsencrypt/boulder@v0.20251208.0/bdns/servers.go (about)

     1  package bdns
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"math/rand/v2"
     8  	"net"
     9  	"net/netip"
    10  	"strconv"
    11  	"sync"
    12  	"time"
    13  
    14  	"github.com/miekg/dns"
    15  	"github.com/prometheus/client_golang/prometheus"
    16  
    17  	"github.com/letsencrypt/boulder/cmd"
    18  )
    19  
    20  // ServerProvider represents a type which can provide a list of addresses for
    21  // the bdns to use as DNS resolvers. Different implementations may provide
    22  // different strategies for providing addresses, and may provide different kinds
    23  // of addresses (e.g. host:port combos vs IP addresses).
    24  type ServerProvider interface {
    25  	Addrs() ([]string, error)
    26  	Stop()
    27  }
    28  
    29  // staticProvider stores a list of host:port combos, and provides that whole
    30  // list in randomized order when asked for addresses. This replicates the old
    31  // behavior of the bdns.impl's servers field.
    32  type staticProvider struct {
    33  	servers []string
    34  }
    35  
    36  var _ ServerProvider = &staticProvider{}
    37  
    38  // validateServerAddress ensures that a given server address is formatted in
    39  // such a way that it can be dialed. The provided server address must include a
    40  // host/IP and port separated by colon. Additionally, if the host is a literal
    41  // IPv6 address, it must be enclosed in square brackets.
    42  // (https://golang.org/src/net/dial.go?s=9833:9881#L281)
    43  func validateServerAddress(address string) error {
    44  	// Ensure the host and port portions of `address` can be split.
    45  	host, port, err := net.SplitHostPort(address)
    46  	if err != nil {
    47  		return err
    48  	}
    49  
    50  	// Ensure `address` contains both a `host` and `port` portion.
    51  	if host == "" || port == "" {
    52  		return errors.New("port cannot be missing")
    53  	}
    54  
    55  	// Ensure the `port` portion of `address` is a valid port.
    56  	portNum, err := strconv.Atoi(port)
    57  	if err != nil {
    58  		return fmt.Errorf("parsing port number: %s", err)
    59  	}
    60  	if portNum <= 0 || portNum > 65535 {
    61  		return errors.New("port must be an integer between 0 - 65535")
    62  	}
    63  
    64  	// Ensure the `host` portion of `address` is a valid FQDN or IP address.
    65  	_, err = netip.ParseAddr(host)
    66  	FQDN := dns.IsFqdn(dns.Fqdn(host))
    67  	if err != nil && !FQDN {
    68  		return errors.New("host is not an FQDN or IP address")
    69  	}
    70  	return nil
    71  }
    72  
    73  func NewStaticProvider(servers []string) (*staticProvider, error) {
    74  	var serverAddrs []string
    75  	for _, server := range servers {
    76  		err := validateServerAddress(server)
    77  		if err != nil {
    78  			return nil, fmt.Errorf("server address %q invalid: %s", server, err)
    79  		}
    80  		serverAddrs = append(serverAddrs, server)
    81  	}
    82  	return &staticProvider{servers: serverAddrs}, nil
    83  }
    84  
    85  func (sp *staticProvider) Addrs() ([]string, error) {
    86  	if len(sp.servers) == 0 {
    87  		return nil, fmt.Errorf("no servers configured")
    88  	}
    89  	r := make([]string, len(sp.servers))
    90  	perm := rand.Perm(len(sp.servers))
    91  	for i, v := range perm {
    92  		r[i] = sp.servers[v]
    93  	}
    94  	return r, nil
    95  }
    96  
    97  func (sp *staticProvider) Stop() {}
    98  
    99  // dynamicProvider uses DNS to look up the set of IP addresses which correspond
   100  // to its single host. It returns this list in random order when asked for
   101  // addresses, and refreshes it regularly using a goroutine started by its
   102  // constructor.
   103  type dynamicProvider struct {
   104  	// dnsAuthority is the single <hostname|IPv4|[IPv6]>:<port> of the DNS
   105  	// server to be used for resolution of DNS backends. If the address contains
   106  	// a hostname it will be resolved via the system DNS. If the port is left
   107  	// unspecified it will default to '53'. If this field is left unspecified
   108  	// the system DNS will be used for resolution of DNS backends.
   109  	dnsAuthority string
   110  	// service is the service name to look up SRV records for within the domain.
   111  	// If this field is left unspecified 'dns' will be used as the service name.
   112  	service string
   113  	// proto is the IP protocol (tcp or udp) to look up SRV records for.
   114  	proto string
   115  	// domain is the name to look up SRV records within.
   116  	domain string
   117  	// A map of IP addresses (results of A record lookups for SRV Targets) to
   118  	// ports (Port fields in SRV records) associated with those addresses.
   119  	addrs map[string][]uint16
   120  	// Other internal bookkeeping state.
   121  	cancel        chan any
   122  	mu            sync.RWMutex
   123  	refresh       time.Duration
   124  	updateCounter *prometheus.CounterVec
   125  }
   126  
   127  // ParseTarget takes the user input target string and default port, returns
   128  // formatted host and port info. If target doesn't specify a port, set the port
   129  // to be the defaultPort. If target is in IPv6 format and host-name is enclosed
   130  // in square brackets, brackets are stripped when setting the host.
   131  //
   132  // Examples:
   133  //   - target: "www.google.com" defaultPort: "443" returns host: "www.google.com", port: "443"
   134  //   - target: "ipv4-host:80" defaultPort: "443" returns host: "ipv4-host", port: "80"
   135  //   - target: "[ipv6-host]" defaultPort: "443" returns host: "ipv6-host", port: "443"
   136  //   - target: ":80" defaultPort: "443" returns host: "localhost", port: "80"
   137  //
   138  // This function is copied from:
   139  // https://github.com/grpc/grpc-go/blob/master/internal/resolver/dns/dns_resolver.go
   140  // It has been minimally modified to fit our code style.
   141  func ParseTarget(target, defaultPort string) (host, port string, err error) {
   142  	if target == "" {
   143  		return "", "", errors.New("missing address")
   144  	}
   145  	ip := net.ParseIP(target)
   146  	if ip != nil {
   147  		// Target is an IPv4 or IPv6(without brackets) address.
   148  		return target, defaultPort, nil
   149  	}
   150  	host, port, err = net.SplitHostPort(target)
   151  	if err == nil {
   152  		if port == "" {
   153  			// If the port field is empty (target ends with colon), e.g.
   154  			// "[::1]:", this is an error.
   155  			return "", "", errors.New("missing port after port-separator colon")
   156  		}
   157  		// target has port, i.e ipv4-host:port, [ipv6-host]:port, host-name:port
   158  		if host == "" {
   159  			// Keep consistent with net.Dial(): If the host is empty, as in
   160  			// ":80", the local system is assumed.
   161  			host = "localhost"
   162  		}
   163  		return host, port, nil
   164  	}
   165  	host, port, err = net.SplitHostPort(target + ":" + defaultPort)
   166  	if err == nil {
   167  		// Target doesn't have port.
   168  		return host, port, nil
   169  	}
   170  	return "", "", fmt.Errorf("invalid target address %v, error info: %v", target, err)
   171  }
   172  
   173  var _ ServerProvider = &dynamicProvider{}
   174  
   175  // StartDynamicProvider constructs a new dynamicProvider and starts its
   176  // auto-update goroutine. The auto-update process queries DNS for SRV records
   177  // at refresh intervals and uses the resulting IP/port combos to populate the
   178  // list returned by Addrs. The update process ignores the Priority and Weight
   179  // attributes of the SRV records.
   180  //
   181  // `proto` is the IP protocol (tcp or udp) to look up SRV records for.
   182  func StartDynamicProvider(c *cmd.DNSProvider, refresh time.Duration, proto string) (*dynamicProvider, error) {
   183  	if c.SRVLookup.Domain == "" {
   184  		return nil, fmt.Errorf("'domain' cannot be empty")
   185  	}
   186  
   187  	service := c.SRVLookup.Service
   188  	if service == "" {
   189  		// Default to "dns" if no service is specified. This is the default
   190  		// service name for DNS servers.
   191  		service = "dns"
   192  	}
   193  
   194  	host, port, err := ParseTarget(c.DNSAuthority, "53")
   195  	if err != nil {
   196  		return nil, err
   197  	}
   198  
   199  	dnsAuthority := net.JoinHostPort(host, port)
   200  	err = validateServerAddress(dnsAuthority)
   201  	if err != nil {
   202  		return nil, err
   203  	}
   204  
   205  	dp := dynamicProvider{
   206  		dnsAuthority: dnsAuthority,
   207  		service:      service,
   208  		proto:        proto,
   209  		domain:       c.SRVLookup.Domain,
   210  		addrs:        make(map[string][]uint16),
   211  		cancel:       make(chan any),
   212  		refresh:      refresh,
   213  		updateCounter: prometheus.NewCounterVec(
   214  			prometheus.CounterOpts{
   215  				Name: "dns_update",
   216  				Help: "Counter of attempts to update a dynamic provider",
   217  			},
   218  			[]string{"success"},
   219  		),
   220  	}
   221  
   222  	// Update once immediately, so we can know whether that was successful, then
   223  	// kick off the long-running update goroutine.
   224  	err = dp.update()
   225  	if err != nil {
   226  		return nil, fmt.Errorf("failed to start dynamic provider: %w", err)
   227  	}
   228  	go dp.run()
   229  
   230  	return &dp, nil
   231  }
   232  
   233  // run loops forever, calling dp.update() every dp.refresh interval. Does not
   234  // halt until the dp.cancel channel is closed, so should be run in a goroutine.
   235  func (dp *dynamicProvider) run() {
   236  	t := time.NewTicker(dp.refresh)
   237  	for {
   238  		select {
   239  		case <-t.C:
   240  			err := dp.update()
   241  			if err != nil {
   242  				dp.updateCounter.With(prometheus.Labels{
   243  					"success": "false",
   244  				}).Inc()
   245  				continue
   246  			}
   247  			dp.updateCounter.With(prometheus.Labels{
   248  				"success": "true",
   249  			}).Inc()
   250  		case <-dp.cancel:
   251  			return
   252  		}
   253  	}
   254  }
   255  
   256  // update performs the SRV and A record queries necessary to map the given DNS
   257  // domain name to a set of cacheable IP addresses and ports, and stores the
   258  // results in dp.addrs.
   259  func (dp *dynamicProvider) update() error {
   260  	ctx, cancel := context.WithTimeout(context.Background(), dp.refresh/2)
   261  	defer cancel()
   262  
   263  	resolver := &net.Resolver{
   264  		PreferGo: true,
   265  		Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
   266  			d := &net.Dialer{}
   267  			return d.DialContext(ctx, network, dp.dnsAuthority)
   268  		},
   269  	}
   270  
   271  	// RFC 2782 formatted SRV record being queried e.g. "_service._proto.name."
   272  	record := fmt.Sprintf("_%s._%s.%s.", dp.service, dp.proto, dp.domain)
   273  
   274  	_, srvs, err := resolver.LookupSRV(ctx, dp.service, dp.proto, dp.domain)
   275  	if err != nil {
   276  		return fmt.Errorf("during SRV lookup of %q: %w", record, err)
   277  	}
   278  	if len(srvs) == 0 {
   279  		return fmt.Errorf("SRV lookup of %q returned 0 results", record)
   280  	}
   281  
   282  	addrPorts := make(map[string][]uint16)
   283  	for _, srv := range srvs {
   284  		addrs, err := resolver.LookupHost(ctx, srv.Target)
   285  		if err != nil {
   286  			return fmt.Errorf("during A/AAAA lookup of target %q from SRV record %q: %w", srv.Target, record, err)
   287  		}
   288  		for _, addr := range addrs {
   289  			joinedHostPort := net.JoinHostPort(addr, fmt.Sprint(srv.Port))
   290  			err := validateServerAddress(joinedHostPort)
   291  			if err != nil {
   292  				return fmt.Errorf("invalid addr %q from SRV record %q: %w", joinedHostPort, record, err)
   293  			}
   294  			addrPorts[addr] = append(addrPorts[addr], srv.Port)
   295  		}
   296  	}
   297  
   298  	dp.mu.Lock()
   299  	dp.addrs = addrPorts
   300  	dp.mu.Unlock()
   301  	return nil
   302  }
   303  
   304  // Addrs returns a shuffled list of IP/port pairs, with the guarantee that no
   305  // two IP/port pairs will share the same IP.
   306  func (dp *dynamicProvider) Addrs() ([]string, error) {
   307  	var r []string
   308  	dp.mu.RLock()
   309  	for ip, ports := range dp.addrs {
   310  		port := fmt.Sprint(ports[rand.IntN(len(ports))])
   311  		addr := net.JoinHostPort(ip, port)
   312  		r = append(r, addr)
   313  	}
   314  	dp.mu.RUnlock()
   315  	rand.Shuffle(len(r), func(i, j int) {
   316  		r[i], r[j] = r[j], r[i]
   317  	})
   318  	return r, nil
   319  }
   320  
   321  // Stop tells the background update goroutine to cease. It does not wait for
   322  // confirmation that it has done so.
   323  func (dp *dynamicProvider) Stop() {
   324  	close(dp.cancel)
   325  }