google.golang.org/grpc@v1.72.2/internal/resolver/dns/dns_resolver.go (about)

     1  /*
     2   *
     3   * Copyright 2018 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  // Package dns implements a dns resolver to be installed as the default resolver
    20  // in grpc.
    21  package dns
    22  
    23  import (
    24  	"context"
    25  	"encoding/json"
    26  	"fmt"
    27  	rand "math/rand/v2"
    28  	"net"
    29  	"net/netip"
    30  	"os"
    31  	"strconv"
    32  	"strings"
    33  	"sync"
    34  	"time"
    35  
    36  	grpclbstate "google.golang.org/grpc/balancer/grpclb/state"
    37  	"google.golang.org/grpc/grpclog"
    38  	"google.golang.org/grpc/internal/backoff"
    39  	"google.golang.org/grpc/internal/envconfig"
    40  	"google.golang.org/grpc/internal/resolver/dns/internal"
    41  	"google.golang.org/grpc/resolver"
    42  	"google.golang.org/grpc/serviceconfig"
    43  )
    44  
    45  var (
    46  	// EnableSRVLookups controls whether the DNS resolver attempts to fetch gRPCLB
    47  	// addresses from SRV records.  Must not be changed after init time.
    48  	EnableSRVLookups = false
    49  
    50  	// MinResolutionInterval is the minimum interval at which re-resolutions are
    51  	// allowed. This helps to prevent excessive re-resolution.
    52  	MinResolutionInterval = 30 * time.Second
    53  
    54  	// ResolvingTimeout specifies the maximum duration for a DNS resolution request.
    55  	// If the timeout expires before a response is received, the request will be canceled.
    56  	//
    57  	// It is recommended to set this value at application startup. Avoid modifying this variable
    58  	// after initialization as it's not thread-safe for concurrent modification.
    59  	ResolvingTimeout = 30 * time.Second
    60  
    61  	logger = grpclog.Component("dns")
    62  )
    63  
    64  func init() {
    65  	resolver.Register(NewBuilder())
    66  	internal.TimeAfterFunc = time.After
    67  	internal.TimeNowFunc = time.Now
    68  	internal.TimeUntilFunc = time.Until
    69  	internal.NewNetResolver = newNetResolver
    70  	internal.AddressDialer = addressDialer
    71  }
    72  
    73  const (
    74  	defaultPort       = "443"
    75  	defaultDNSSvrPort = "53"
    76  	golang            = "GO"
    77  	// txtPrefix is the prefix string to be prepended to the host name for txt
    78  	// record lookup.
    79  	txtPrefix = "_grpc_config."
    80  	// In DNS, service config is encoded in a TXT record via the mechanism
    81  	// described in RFC-1464 using the attribute name grpc_config.
    82  	txtAttribute = "grpc_config="
    83  )
    84  
    85  var addressDialer = func(address string) func(context.Context, string, string) (net.Conn, error) {
    86  	return func(ctx context.Context, network, _ string) (net.Conn, error) {
    87  		var dialer net.Dialer
    88  		return dialer.DialContext(ctx, network, address)
    89  	}
    90  }
    91  
    92  var newNetResolver = func(authority string) (internal.NetResolver, error) {
    93  	if authority == "" {
    94  		return net.DefaultResolver, nil
    95  	}
    96  
    97  	host, port, err := parseTarget(authority, defaultDNSSvrPort)
    98  	if err != nil {
    99  		return nil, err
   100  	}
   101  
   102  	authorityWithPort := net.JoinHostPort(host, port)
   103  
   104  	return &net.Resolver{
   105  		PreferGo: true,
   106  		Dial:     internal.AddressDialer(authorityWithPort),
   107  	}, nil
   108  }
   109  
   110  // NewBuilder creates a dnsBuilder which is used to factory DNS resolvers.
   111  func NewBuilder() resolver.Builder {
   112  	return &dnsBuilder{}
   113  }
   114  
   115  type dnsBuilder struct{}
   116  
   117  // Build creates and starts a DNS resolver that watches the name resolution of
   118  // the target.
   119  func (b *dnsBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (resolver.Resolver, error) {
   120  	host, port, err := parseTarget(target.Endpoint(), defaultPort)
   121  	if err != nil {
   122  		return nil, err
   123  	}
   124  
   125  	// IP address.
   126  	if ipAddr, err := formatIP(host); err == nil {
   127  		addr := []resolver.Address{{Addr: ipAddr + ":" + port}}
   128  		cc.UpdateState(resolver.State{Addresses: addr})
   129  		return deadResolver{}, nil
   130  	}
   131  
   132  	// DNS address (non-IP).
   133  	ctx, cancel := context.WithCancel(context.Background())
   134  	d := &dnsResolver{
   135  		host:                 host,
   136  		port:                 port,
   137  		ctx:                  ctx,
   138  		cancel:               cancel,
   139  		cc:                   cc,
   140  		rn:                   make(chan struct{}, 1),
   141  		disableServiceConfig: opts.DisableServiceConfig,
   142  	}
   143  
   144  	d.resolver, err = internal.NewNetResolver(target.URL.Host)
   145  	if err != nil {
   146  		return nil, err
   147  	}
   148  
   149  	d.wg.Add(1)
   150  	go d.watcher()
   151  	return d, nil
   152  }
   153  
   154  // Scheme returns the naming scheme of this resolver builder, which is "dns".
   155  func (b *dnsBuilder) Scheme() string {
   156  	return "dns"
   157  }
   158  
   159  // deadResolver is a resolver that does nothing.
   160  type deadResolver struct{}
   161  
   162  func (deadResolver) ResolveNow(resolver.ResolveNowOptions) {}
   163  
   164  func (deadResolver) Close() {}
   165  
   166  // dnsResolver watches for the name resolution update for a non-IP target.
   167  type dnsResolver struct {
   168  	host     string
   169  	port     string
   170  	resolver internal.NetResolver
   171  	ctx      context.Context
   172  	cancel   context.CancelFunc
   173  	cc       resolver.ClientConn
   174  	// rn channel is used by ResolveNow() to force an immediate resolution of the
   175  	// target.
   176  	rn chan struct{}
   177  	// wg is used to enforce Close() to return after the watcher() goroutine has
   178  	// finished. Otherwise, data race will be possible. [Race Example] in
   179  	// dns_resolver_test we replace the real lookup functions with mocked ones to
   180  	// facilitate testing. If Close() doesn't wait for watcher() goroutine
   181  	// finishes, race detector sometimes will warn lookup (READ the lookup
   182  	// function pointers) inside watcher() goroutine has data race with
   183  	// replaceNetFunc (WRITE the lookup function pointers).
   184  	wg                   sync.WaitGroup
   185  	disableServiceConfig bool
   186  }
   187  
   188  // ResolveNow invoke an immediate resolution of the target that this
   189  // dnsResolver watches.
   190  func (d *dnsResolver) ResolveNow(resolver.ResolveNowOptions) {
   191  	select {
   192  	case d.rn <- struct{}{}:
   193  	default:
   194  	}
   195  }
   196  
   197  // Close closes the dnsResolver.
   198  func (d *dnsResolver) Close() {
   199  	d.cancel()
   200  	d.wg.Wait()
   201  }
   202  
   203  func (d *dnsResolver) watcher() {
   204  	defer d.wg.Done()
   205  	backoffIndex := 1
   206  	for {
   207  		state, err := d.lookup()
   208  		if err != nil {
   209  			// Report error to the underlying grpc.ClientConn.
   210  			d.cc.ReportError(err)
   211  		} else {
   212  			err = d.cc.UpdateState(*state)
   213  		}
   214  
   215  		var nextResolutionTime time.Time
   216  		if err == nil {
   217  			// Success resolving, wait for the next ResolveNow. However, also wait 30
   218  			// seconds at the very least to prevent constantly re-resolving.
   219  			backoffIndex = 1
   220  			nextResolutionTime = internal.TimeNowFunc().Add(MinResolutionInterval)
   221  			select {
   222  			case <-d.ctx.Done():
   223  				return
   224  			case <-d.rn:
   225  			}
   226  		} else {
   227  			// Poll on an error found in DNS Resolver or an error received from
   228  			// ClientConn.
   229  			nextResolutionTime = internal.TimeNowFunc().Add(backoff.DefaultExponential.Backoff(backoffIndex))
   230  			backoffIndex++
   231  		}
   232  		select {
   233  		case <-d.ctx.Done():
   234  			return
   235  		case <-internal.TimeAfterFunc(internal.TimeUntilFunc(nextResolutionTime)):
   236  		}
   237  	}
   238  }
   239  
   240  func (d *dnsResolver) lookupSRV(ctx context.Context) ([]resolver.Address, error) {
   241  	// Skip this particular host to avoid timeouts with some versions of
   242  	// systemd-resolved.
   243  	if !EnableSRVLookups || d.host == "metadata.google.internal." {
   244  		return nil, nil
   245  	}
   246  	var newAddrs []resolver.Address
   247  	_, srvs, err := d.resolver.LookupSRV(ctx, "grpclb", "tcp", d.host)
   248  	if err != nil {
   249  		err = handleDNSError(err, "SRV") // may become nil
   250  		return nil, err
   251  	}
   252  	for _, s := range srvs {
   253  		lbAddrs, err := d.resolver.LookupHost(ctx, s.Target)
   254  		if err != nil {
   255  			err = handleDNSError(err, "A") // may become nil
   256  			if err == nil {
   257  				// If there are other SRV records, look them up and ignore this
   258  				// one that does not exist.
   259  				continue
   260  			}
   261  			return nil, err
   262  		}
   263  		for _, a := range lbAddrs {
   264  			ip, err := formatIP(a)
   265  			if err != nil {
   266  				return nil, fmt.Errorf("dns: error parsing A record IP address %v: %v", a, err)
   267  			}
   268  			addr := ip + ":" + strconv.Itoa(int(s.Port))
   269  			newAddrs = append(newAddrs, resolver.Address{Addr: addr, ServerName: s.Target})
   270  		}
   271  	}
   272  	return newAddrs, nil
   273  }
   274  
   275  func handleDNSError(err error, lookupType string) error {
   276  	dnsErr, ok := err.(*net.DNSError)
   277  	if ok && !dnsErr.IsTimeout && !dnsErr.IsTemporary {
   278  		// Timeouts and temporary errors should be communicated to gRPC to
   279  		// attempt another DNS query (with backoff).  Other errors should be
   280  		// suppressed (they may represent the absence of a TXT record).
   281  		return nil
   282  	}
   283  	if err != nil {
   284  		err = fmt.Errorf("dns: %v record lookup error: %v", lookupType, err)
   285  		logger.Info(err)
   286  	}
   287  	return err
   288  }
   289  
   290  func (d *dnsResolver) lookupTXT(ctx context.Context) *serviceconfig.ParseResult {
   291  	ss, err := d.resolver.LookupTXT(ctx, txtPrefix+d.host)
   292  	if err != nil {
   293  		if envconfig.TXTErrIgnore {
   294  			return nil
   295  		}
   296  		if err = handleDNSError(err, "TXT"); err != nil {
   297  			return &serviceconfig.ParseResult{Err: err}
   298  		}
   299  		return nil
   300  	}
   301  	var res string
   302  	for _, s := range ss {
   303  		res += s
   304  	}
   305  
   306  	// TXT record must have "grpc_config=" attribute in order to be used as
   307  	// service config.
   308  	if !strings.HasPrefix(res, txtAttribute) {
   309  		logger.Warningf("dns: TXT record %v missing %v attribute", res, txtAttribute)
   310  		// This is not an error; it is the equivalent of not having a service
   311  		// config.
   312  		return nil
   313  	}
   314  	sc := canaryingSC(strings.TrimPrefix(res, txtAttribute))
   315  	return d.cc.ParseServiceConfig(sc)
   316  }
   317  
   318  func (d *dnsResolver) lookupHost(ctx context.Context) ([]resolver.Address, error) {
   319  	addrs, err := d.resolver.LookupHost(ctx, d.host)
   320  	if err != nil {
   321  		err = handleDNSError(err, "A")
   322  		return nil, err
   323  	}
   324  	newAddrs := make([]resolver.Address, 0, len(addrs))
   325  	for _, a := range addrs {
   326  		ip, err := formatIP(a)
   327  		if err != nil {
   328  			return nil, fmt.Errorf("dns: error parsing A record IP address %v: %v", a, err)
   329  		}
   330  		addr := ip + ":" + d.port
   331  		newAddrs = append(newAddrs, resolver.Address{Addr: addr})
   332  	}
   333  	return newAddrs, nil
   334  }
   335  
   336  func (d *dnsResolver) lookup() (*resolver.State, error) {
   337  	ctx, cancel := context.WithTimeout(d.ctx, ResolvingTimeout)
   338  	defer cancel()
   339  	srv, srvErr := d.lookupSRV(ctx)
   340  	addrs, hostErr := d.lookupHost(ctx)
   341  	if hostErr != nil && (srvErr != nil || len(srv) == 0) {
   342  		return nil, hostErr
   343  	}
   344  
   345  	state := resolver.State{Addresses: addrs}
   346  	if len(srv) > 0 {
   347  		state = grpclbstate.Set(state, &grpclbstate.State{BalancerAddresses: srv})
   348  	}
   349  	if !d.disableServiceConfig {
   350  		state.ServiceConfig = d.lookupTXT(ctx)
   351  	}
   352  	return &state, nil
   353  }
   354  
   355  // formatIP returns an error if addr is not a valid textual representation of
   356  // an IP address. If addr is an IPv4 address, return the addr and error = nil.
   357  // If addr is an IPv6 address, return the addr enclosed in square brackets and
   358  // error = nil.
   359  func formatIP(addr string) (string, error) {
   360  	ip, err := netip.ParseAddr(addr)
   361  	if err != nil {
   362  		return "", err
   363  	}
   364  	if ip.Is4() {
   365  		return addr, nil
   366  	}
   367  	return "[" + addr + "]", nil
   368  }
   369  
   370  // parseTarget takes the user input target string and default port, returns
   371  // formatted host and port info. If target doesn't specify a port, set the port
   372  // to be the defaultPort. If target is in IPv6 format and host-name is enclosed
   373  // in square brackets, brackets are stripped when setting the host.
   374  // examples:
   375  // target: "www.google.com" defaultPort: "443" returns host: "www.google.com", port: "443"
   376  // target: "ipv4-host:80" defaultPort: "443" returns host: "ipv4-host", port: "80"
   377  // target: "[ipv6-host]" defaultPort: "443" returns host: "ipv6-host", port: "443"
   378  // target: ":80" defaultPort: "443" returns host: "localhost", port: "80"
   379  func parseTarget(target, defaultPort string) (host, port string, err error) {
   380  	if target == "" {
   381  		return "", "", internal.ErrMissingAddr
   382  	}
   383  	if _, err := netip.ParseAddr(target); err == nil {
   384  		// target is an IPv4 or IPv6(without brackets) address
   385  		return target, defaultPort, nil
   386  	}
   387  	if host, port, err = net.SplitHostPort(target); err == nil {
   388  		if port == "" {
   389  			// If the port field is empty (target ends with colon), e.g. "[::1]:",
   390  			// this is an error.
   391  			return "", "", internal.ErrEndsWithColon
   392  		}
   393  		// target has port, i.e ipv4-host:port, [ipv6-host]:port, host-name:port
   394  		if host == "" {
   395  			// Keep consistent with net.Dial(): If the host is empty, as in ":80",
   396  			// the local system is assumed.
   397  			host = "localhost"
   398  		}
   399  		return host, port, nil
   400  	}
   401  	if host, port, err = net.SplitHostPort(target + ":" + defaultPort); err == nil {
   402  		// target doesn't have port
   403  		return host, port, nil
   404  	}
   405  	return "", "", fmt.Errorf("invalid target address %v, error info: %v", target, err)
   406  }
   407  
   408  type rawChoice struct {
   409  	ClientLanguage *[]string        `json:"clientLanguage,omitempty"`
   410  	Percentage     *int             `json:"percentage,omitempty"`
   411  	ClientHostName *[]string        `json:"clientHostName,omitempty"`
   412  	ServiceConfig  *json.RawMessage `json:"serviceConfig,omitempty"`
   413  }
   414  
   415  func containsString(a *[]string, b string) bool {
   416  	if a == nil {
   417  		return true
   418  	}
   419  	for _, c := range *a {
   420  		if c == b {
   421  			return true
   422  		}
   423  	}
   424  	return false
   425  }
   426  
   427  func chosenByPercentage(a *int) bool {
   428  	if a == nil {
   429  		return true
   430  	}
   431  	return rand.IntN(100)+1 <= *a
   432  }
   433  
   434  func canaryingSC(js string) string {
   435  	if js == "" {
   436  		return ""
   437  	}
   438  	var rcs []rawChoice
   439  	err := json.Unmarshal([]byte(js), &rcs)
   440  	if err != nil {
   441  		logger.Warningf("dns: error parsing service config json: %v", err)
   442  		return ""
   443  	}
   444  	cliHostname, err := os.Hostname()
   445  	if err != nil {
   446  		logger.Warningf("dns: error getting client hostname: %v", err)
   447  		return ""
   448  	}
   449  	var sc string
   450  	for _, c := range rcs {
   451  		if !containsString(c.ClientLanguage, golang) ||
   452  			!chosenByPercentage(c.Percentage) ||
   453  			!containsString(c.ClientHostName, cliHostname) ||
   454  			c.ServiceConfig == nil {
   455  			continue
   456  		}
   457  		sc = string(*c.ServiceConfig)
   458  		break
   459  	}
   460  	return sc
   461  }