github.com/letsencrypt/boulder@v0.20251208.0/grpc/internal/resolver/dns/dns_resolver.go (about)

     1  /*
     2   *
     3   * Copyright 2018 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  // Forked from the default internal DNS resolver in the grpc-go package. The
    20  // original source can be found at:
    21  // https://github.com/grpc/grpc-go/blob/v1.49.0/internal/resolver/dns/dns_resolver.go
    22  
    23  package dns
    24  
    25  import (
    26  	"context"
    27  	"errors"
    28  	"fmt"
    29  	"net"
    30  	"net/netip"
    31  	"strconv"
    32  	"strings"
    33  	"sync"
    34  	"time"
    35  
    36  	"google.golang.org/grpc/grpclog"
    37  	"google.golang.org/grpc/resolver"
    38  	"google.golang.org/grpc/serviceconfig"
    39  
    40  	"github.com/letsencrypt/boulder/bdns"
    41  	"github.com/letsencrypt/boulder/grpc/internal/backoff"
    42  	"github.com/letsencrypt/boulder/grpc/noncebalancer"
    43  )
    44  
    45  var logger = grpclog.Component("srv")
    46  
    47  // Globals to stub out in tests. TODO: Perhaps these two can be combined into a
    48  // single variable for testing the resolver?
    49  var (
    50  	newTimer           = time.NewTimer
    51  	newTimerDNSResRate = time.NewTimer
    52  )
    53  
    54  func init() {
    55  	resolver.Register(NewDefaultSRVBuilder())
    56  	resolver.Register(NewNonceSRVBuilder())
    57  }
    58  
    59  const defaultDNSSvrPort = "53"
    60  
    61  var defaultResolver netResolver = net.DefaultResolver
    62  
    63  var (
    64  	// To prevent excessive re-resolution, we enforce a rate limit on DNS
    65  	// resolution requests.
    66  	minDNSResRate = 30 * time.Second
    67  )
    68  
    69  var customAuthorityDialer = func(authority string) func(ctx context.Context, network, address string) (net.Conn, error) {
    70  	return func(ctx context.Context, network, address string) (net.Conn, error) {
    71  		var dialer net.Dialer
    72  		return dialer.DialContext(ctx, network, authority)
    73  	}
    74  }
    75  
    76  var customAuthorityResolver = func(authority string) (*net.Resolver, error) {
    77  	host, port, err := bdns.ParseTarget(authority, defaultDNSSvrPort)
    78  	if err != nil {
    79  		return nil, err
    80  	}
    81  	return &net.Resolver{
    82  		PreferGo: true,
    83  		Dial:     customAuthorityDialer(net.JoinHostPort(host, port)),
    84  	}, nil
    85  }
    86  
    87  // NewDefaultSRVBuilder creates a srvBuilder which is used to factory SRV DNS
    88  // resolvers.
    89  func NewDefaultSRVBuilder() resolver.Builder {
    90  	return &srvBuilder{scheme: "srv"}
    91  }
    92  
    93  // NewNonceSRVBuilder creates a srvBuilder which is used to factory SRV DNS
    94  // resolvers with a custom grpc.Balancer used by nonce-service clients.
    95  func NewNonceSRVBuilder() resolver.Builder {
    96  	return &srvBuilder{scheme: noncebalancer.SRVResolverScheme, balancer: noncebalancer.Name}
    97  }
    98  
    99  type srvBuilder struct {
   100  	scheme   string
   101  	balancer string
   102  }
   103  
   104  // Build creates and starts a DNS resolver that watches the name resolution of the target.
   105  func (b *srvBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (resolver.Resolver, error) {
   106  	var names []name
   107  	for i := range strings.SplitSeq(target.Endpoint(), ",") {
   108  		service, domain, err := parseServiceDomain(i)
   109  		if err != nil {
   110  			return nil, err
   111  		}
   112  		names = append(names, name{service: service, domain: domain})
   113  	}
   114  
   115  	ctx, cancel := context.WithCancel(context.Background())
   116  	d := &dnsResolver{
   117  		names:  names,
   118  		ctx:    ctx,
   119  		cancel: cancel,
   120  		cc:     cc,
   121  		rn:     make(chan struct{}, 1),
   122  	}
   123  
   124  	if target.URL.Host == "" {
   125  		d.resolver = defaultResolver
   126  	} else {
   127  		var err error
   128  		d.resolver, err = customAuthorityResolver(target.URL.Host)
   129  		if err != nil {
   130  			return nil, err
   131  		}
   132  	}
   133  
   134  	if b.balancer != "" {
   135  		d.serviceConfig = cc.ParseServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, b.balancer))
   136  	}
   137  
   138  	d.wg.Add(1)
   139  	go d.watcher()
   140  	return d, nil
   141  }
   142  
   143  // Scheme returns the naming scheme of this resolver builder.
   144  func (b *srvBuilder) Scheme() string {
   145  	return b.scheme
   146  }
   147  
   148  type netResolver interface {
   149  	LookupHost(ctx context.Context, host string) (addrs []string, err error)
   150  	LookupSRV(ctx context.Context, service, proto, name string) (cname string, addrs []*net.SRV, err error)
   151  }
   152  
   153  type name struct {
   154  	service string
   155  	domain  string
   156  }
   157  
   158  // dnsResolver watches for the name resolution update for a non-IP target.
   159  type dnsResolver struct {
   160  	names    []name
   161  	resolver netResolver
   162  	ctx      context.Context
   163  	cancel   context.CancelFunc
   164  	cc       resolver.ClientConn
   165  	// rn channel is used by ResolveNow() to force an immediate resolution of the target.
   166  	rn chan struct{}
   167  	// wg is used to enforce Close() to return after the watcher() goroutine has finished.
   168  	// Otherwise, data race will be possible. [Race Example] in dns_resolver_test we
   169  	// replace the real lookup functions with mocked ones to facilitate testing.
   170  	// If Close() doesn't wait for watcher() goroutine finishes, race detector sometimes
   171  	// will warns lookup (READ the lookup function pointers) inside watcher() goroutine
   172  	// has data race with replaceNetFunc (WRITE the lookup function pointers).
   173  	wg            sync.WaitGroup
   174  	serviceConfig *serviceconfig.ParseResult
   175  }
   176  
   177  // ResolveNow invoke an immediate resolution of the target that this dnsResolver watches.
   178  func (d *dnsResolver) ResolveNow(resolver.ResolveNowOptions) {
   179  	select {
   180  	case d.rn <- struct{}{}:
   181  	default:
   182  	}
   183  }
   184  
   185  // Close closes the dnsResolver.
   186  func (d *dnsResolver) Close() {
   187  	d.cancel()
   188  	d.wg.Wait()
   189  }
   190  
   191  func (d *dnsResolver) watcher() {
   192  	defer d.wg.Done()
   193  	backoffIndex := 1
   194  	for {
   195  		state, err := d.lookup()
   196  		if err != nil {
   197  			// Report error to the underlying grpc.ClientConn.
   198  			d.cc.ReportError(err)
   199  		} else {
   200  			if d.serviceConfig != nil {
   201  				state.ServiceConfig = d.serviceConfig
   202  			}
   203  			err = d.cc.UpdateState(*state)
   204  		}
   205  
   206  		var timer *time.Timer
   207  		if err == nil {
   208  			// Success resolving, wait for the next ResolveNow. However, also wait 30 seconds at the very least
   209  			// to prevent constantly re-resolving.
   210  			backoffIndex = 1
   211  			timer = newTimerDNSResRate(minDNSResRate)
   212  			select {
   213  			case <-d.ctx.Done():
   214  				timer.Stop()
   215  				return
   216  			case <-d.rn:
   217  			}
   218  		} else {
   219  			// Poll on an error found in DNS Resolver or an error received from ClientConn.
   220  			timer = newTimer(backoff.DefaultExponential.Backoff(backoffIndex))
   221  			backoffIndex++
   222  		}
   223  		select {
   224  		case <-d.ctx.Done():
   225  			timer.Stop()
   226  			return
   227  		case <-timer.C:
   228  		}
   229  	}
   230  }
   231  
   232  func (d *dnsResolver) lookupSRV() ([]resolver.Address, error) {
   233  	var newAddrs []resolver.Address
   234  	var errs []error
   235  	for _, n := range d.names {
   236  		_, srvs, err := d.resolver.LookupSRV(d.ctx, n.service, "tcp", n.domain)
   237  		if err != nil {
   238  			err = handleDNSError(err, "SRV") // may become nil
   239  			if err != nil {
   240  				errs = append(errs, err)
   241  				continue
   242  			}
   243  		}
   244  		for _, s := range srvs {
   245  			backendAddrs, err := d.resolver.LookupHost(d.ctx, s.Target)
   246  			if err != nil {
   247  				err = handleDNSError(err, "A") // may become nil
   248  				if err != nil {
   249  					errs = append(errs, err)
   250  					continue
   251  				}
   252  			}
   253  			for _, a := range backendAddrs {
   254  				ip, ok := formatIP(a)
   255  				if !ok {
   256  					errs = append(errs, fmt.Errorf("srv: error parsing A record IP address %v", a))
   257  					continue
   258  				}
   259  				addr := ip + ":" + strconv.Itoa(int(s.Port))
   260  				newAddrs = append(newAddrs, resolver.Address{Addr: addr, ServerName: s.Target})
   261  			}
   262  		}
   263  	}
   264  	// Only return an error if all lookups failed.
   265  	if len(errs) > 0 && len(newAddrs) == 0 {
   266  		return nil, errors.Join(errs...)
   267  	}
   268  	return newAddrs, nil
   269  }
   270  
   271  func handleDNSError(err error, lookupType string) error {
   272  	if dnsErr, ok := err.(*net.DNSError); ok && !dnsErr.IsTimeout && !dnsErr.IsTemporary {
   273  		// Timeouts and temporary errors should be communicated to gRPC to
   274  		// attempt another DNS query (with backoff).  Other errors should be
   275  		// suppressed (they may represent the absence of a TXT record).
   276  		return nil
   277  	}
   278  	if err != nil {
   279  		err = fmt.Errorf("srv: %v record lookup error: %v", lookupType, err)
   280  		logger.Info(err)
   281  	}
   282  	return err
   283  }
   284  
   285  func (d *dnsResolver) lookup() (*resolver.State, error) {
   286  	addrs, err := d.lookupSRV()
   287  	if err != nil {
   288  		return nil, err
   289  	}
   290  	return &resolver.State{Addresses: addrs}, nil
   291  }
   292  
   293  // formatIP returns ok = false if addr is not a valid textual representation of an IP address.
   294  // If addr is an IPv4 address, return the addr and ok = true.
   295  // If addr is an IPv6 address, return the addr enclosed in square brackets and ok = true.
   296  func formatIP(addr string) (addrIP string, ok bool) {
   297  	ip, err := netip.ParseAddr(addr)
   298  	if err != nil {
   299  		return "", false
   300  	}
   301  	if ip.Is4() {
   302  		return addr, true
   303  	}
   304  	return "[" + addr + "]", true
   305  }
   306  
   307  // parseServiceDomain takes the user input target string and parses the service domain
   308  // names for SRV lookup. Input is expected to be a hostname containing at least
   309  // two labels (e.g. "foo.bar", "foo.bar.baz"). The first label is the service
   310  // name and the rest is the domain name. If the target is not in the expected
   311  // format, an error is returned.
   312  func parseServiceDomain(target string) (string, string, error) {
   313  	sd := strings.SplitN(target, ".", 2)
   314  	if len(sd) < 2 || sd[0] == "" || sd[1] == "" {
   315  		return "", "", fmt.Errorf("srv: hostname %q contains < 2 labels", target)
   316  	}
   317  	return sd[0], sd[1], nil
   318  }