google.golang.org/grpc@v1.62.1/internal/resolver/dns/dns_resolver.go (about)

     1  /*
     2   *
     3   * Copyright 2018 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  // Package dns implements a dns resolver to be installed as the default resolver
    20  // in grpc.
    21  package dns
    22  
    23  import (
    24  	"context"
    25  	"encoding/json"
    26  	"fmt"
    27  	"net"
    28  	"os"
    29  	"strconv"
    30  	"strings"
    31  	"sync"
    32  	"time"
    33  
    34  	grpclbstate "google.golang.org/grpc/balancer/grpclb/state"
    35  	"google.golang.org/grpc/grpclog"
    36  	"google.golang.org/grpc/internal/backoff"
    37  	"google.golang.org/grpc/internal/envconfig"
    38  	"google.golang.org/grpc/internal/grpcrand"
    39  	"google.golang.org/grpc/internal/resolver/dns/internal"
    40  	"google.golang.org/grpc/resolver"
    41  	"google.golang.org/grpc/serviceconfig"
    42  )
    43  
    44  // EnableSRVLookups controls whether the DNS resolver attempts to fetch gRPCLB
    45  // addresses from SRV records.  Must not be changed after init time.
    46  var EnableSRVLookups = false
    47  
    48  var logger = grpclog.Component("dns")
    49  
    50  func init() {
    51  	resolver.Register(NewBuilder())
    52  	internal.TimeAfterFunc = time.After
    53  	internal.NewNetResolver = newNetResolver
    54  	internal.AddressDialer = addressDialer
    55  }
    56  
    57  const (
    58  	defaultPort       = "443"
    59  	defaultDNSSvrPort = "53"
    60  	golang            = "GO"
    61  	// txtPrefix is the prefix string to be prepended to the host name for txt
    62  	// record lookup.
    63  	txtPrefix = "_grpc_config."
    64  	// In DNS, service config is encoded in a TXT record via the mechanism
    65  	// described in RFC-1464 using the attribute name grpc_config.
    66  	txtAttribute = "grpc_config="
    67  )
    68  
    69  var addressDialer = func(address string) func(context.Context, string, string) (net.Conn, error) {
    70  	return func(ctx context.Context, network, _ string) (net.Conn, error) {
    71  		var dialer net.Dialer
    72  		return dialer.DialContext(ctx, network, address)
    73  	}
    74  }
    75  
    76  var newNetResolver = func(authority string) (internal.NetResolver, error) {
    77  	if authority == "" {
    78  		return net.DefaultResolver, nil
    79  	}
    80  
    81  	host, port, err := parseTarget(authority, defaultDNSSvrPort)
    82  	if err != nil {
    83  		return nil, err
    84  	}
    85  
    86  	authorityWithPort := net.JoinHostPort(host, port)
    87  
    88  	return &net.Resolver{
    89  		PreferGo: true,
    90  		Dial:     internal.AddressDialer(authorityWithPort),
    91  	}, nil
    92  }
    93  
    94  // NewBuilder creates a dnsBuilder which is used to factory DNS resolvers.
    95  func NewBuilder() resolver.Builder {
    96  	return &dnsBuilder{}
    97  }
    98  
    99  type dnsBuilder struct{}
   100  
   101  // Build creates and starts a DNS resolver that watches the name resolution of
   102  // the target.
   103  func (b *dnsBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (resolver.Resolver, error) {
   104  	host, port, err := parseTarget(target.Endpoint(), defaultPort)
   105  	if err != nil {
   106  		return nil, err
   107  	}
   108  
   109  	// IP address.
   110  	if ipAddr, ok := formatIP(host); ok {
   111  		addr := []resolver.Address{{Addr: ipAddr + ":" + port}}
   112  		cc.UpdateState(resolver.State{Addresses: addr})
   113  		return deadResolver{}, nil
   114  	}
   115  
   116  	// DNS address (non-IP).
   117  	ctx, cancel := context.WithCancel(context.Background())
   118  	d := &dnsResolver{
   119  		host:                 host,
   120  		port:                 port,
   121  		ctx:                  ctx,
   122  		cancel:               cancel,
   123  		cc:                   cc,
   124  		rn:                   make(chan struct{}, 1),
   125  		disableServiceConfig: opts.DisableServiceConfig,
   126  	}
   127  
   128  	d.resolver, err = internal.NewNetResolver(target.URL.Host)
   129  	if err != nil {
   130  		return nil, err
   131  	}
   132  
   133  	d.wg.Add(1)
   134  	go d.watcher()
   135  	return d, nil
   136  }
   137  
   138  // Scheme returns the naming scheme of this resolver builder, which is "dns".
   139  func (b *dnsBuilder) Scheme() string {
   140  	return "dns"
   141  }
   142  
   143  // deadResolver is a resolver that does nothing.
   144  type deadResolver struct{}
   145  
   146  func (deadResolver) ResolveNow(resolver.ResolveNowOptions) {}
   147  
   148  func (deadResolver) Close() {}
   149  
   150  // dnsResolver watches for the name resolution update for a non-IP target.
   151  type dnsResolver struct {
   152  	host     string
   153  	port     string
   154  	resolver internal.NetResolver
   155  	ctx      context.Context
   156  	cancel   context.CancelFunc
   157  	cc       resolver.ClientConn
   158  	// rn channel is used by ResolveNow() to force an immediate resolution of the
   159  	// target.
   160  	rn chan struct{}
   161  	// wg is used to enforce Close() to return after the watcher() goroutine has
   162  	// finished. Otherwise, data race will be possible. [Race Example] in
   163  	// dns_resolver_test we replace the real lookup functions with mocked ones to
   164  	// facilitate testing. If Close() doesn't wait for watcher() goroutine
   165  	// finishes, race detector sometimes will warns lookup (READ the lookup
   166  	// function pointers) inside watcher() goroutine has data race with
   167  	// replaceNetFunc (WRITE the lookup function pointers).
   168  	wg                   sync.WaitGroup
   169  	disableServiceConfig bool
   170  }
   171  
   172  // ResolveNow invoke an immediate resolution of the target that this
   173  // dnsResolver watches.
   174  func (d *dnsResolver) ResolveNow(resolver.ResolveNowOptions) {
   175  	select {
   176  	case d.rn <- struct{}{}:
   177  	default:
   178  	}
   179  }
   180  
   181  // Close closes the dnsResolver.
   182  func (d *dnsResolver) Close() {
   183  	d.cancel()
   184  	d.wg.Wait()
   185  }
   186  
   187  func (d *dnsResolver) watcher() {
   188  	defer d.wg.Done()
   189  	backoffIndex := 1
   190  	for {
   191  		state, err := d.lookup()
   192  		if err != nil {
   193  			// Report error to the underlying grpc.ClientConn.
   194  			d.cc.ReportError(err)
   195  		} else {
   196  			err = d.cc.UpdateState(*state)
   197  		}
   198  
   199  		var waitTime time.Duration
   200  		if err == nil {
   201  			// Success resolving, wait for the next ResolveNow. However, also wait 30
   202  			// seconds at the very least to prevent constantly re-resolving.
   203  			backoffIndex = 1
   204  			waitTime = internal.MinResolutionRate
   205  			select {
   206  			case <-d.ctx.Done():
   207  				return
   208  			case <-d.rn:
   209  			}
   210  		} else {
   211  			// Poll on an error found in DNS Resolver or an error received from
   212  			// ClientConn.
   213  			waitTime = backoff.DefaultExponential.Backoff(backoffIndex)
   214  			backoffIndex++
   215  		}
   216  		select {
   217  		case <-d.ctx.Done():
   218  			return
   219  		case <-internal.TimeAfterFunc(waitTime):
   220  		}
   221  	}
   222  }
   223  
   224  func (d *dnsResolver) lookupSRV() ([]resolver.Address, error) {
   225  	if !EnableSRVLookups {
   226  		return nil, nil
   227  	}
   228  	var newAddrs []resolver.Address
   229  	_, srvs, err := d.resolver.LookupSRV(d.ctx, "grpclb", "tcp", d.host)
   230  	if err != nil {
   231  		err = handleDNSError(err, "SRV") // may become nil
   232  		return nil, err
   233  	}
   234  	for _, s := range srvs {
   235  		lbAddrs, err := d.resolver.LookupHost(d.ctx, s.Target)
   236  		if err != nil {
   237  			err = handleDNSError(err, "A") // may become nil
   238  			if err == nil {
   239  				// If there are other SRV records, look them up and ignore this
   240  				// one that does not exist.
   241  				continue
   242  			}
   243  			return nil, err
   244  		}
   245  		for _, a := range lbAddrs {
   246  			ip, ok := formatIP(a)
   247  			if !ok {
   248  				return nil, fmt.Errorf("dns: error parsing A record IP address %v", a)
   249  			}
   250  			addr := ip + ":" + strconv.Itoa(int(s.Port))
   251  			newAddrs = append(newAddrs, resolver.Address{Addr: addr, ServerName: s.Target})
   252  		}
   253  	}
   254  	return newAddrs, nil
   255  }
   256  
   257  func handleDNSError(err error, lookupType string) error {
   258  	dnsErr, ok := err.(*net.DNSError)
   259  	if ok && !dnsErr.IsTimeout && !dnsErr.IsTemporary {
   260  		// Timeouts and temporary errors should be communicated to gRPC to
   261  		// attempt another DNS query (with backoff).  Other errors should be
   262  		// suppressed (they may represent the absence of a TXT record).
   263  		return nil
   264  	}
   265  	if err != nil {
   266  		err = fmt.Errorf("dns: %v record lookup error: %v", lookupType, err)
   267  		logger.Info(err)
   268  	}
   269  	return err
   270  }
   271  
   272  func (d *dnsResolver) lookupTXT() *serviceconfig.ParseResult {
   273  	ss, err := d.resolver.LookupTXT(d.ctx, txtPrefix+d.host)
   274  	if err != nil {
   275  		if envconfig.TXTErrIgnore {
   276  			return nil
   277  		}
   278  		if err = handleDNSError(err, "TXT"); err != nil {
   279  			return &serviceconfig.ParseResult{Err: err}
   280  		}
   281  		return nil
   282  	}
   283  	var res string
   284  	for _, s := range ss {
   285  		res += s
   286  	}
   287  
   288  	// TXT record must have "grpc_config=" attribute in order to be used as
   289  	// service config.
   290  	if !strings.HasPrefix(res, txtAttribute) {
   291  		logger.Warningf("dns: TXT record %v missing %v attribute", res, txtAttribute)
   292  		// This is not an error; it is the equivalent of not having a service
   293  		// config.
   294  		return nil
   295  	}
   296  	sc := canaryingSC(strings.TrimPrefix(res, txtAttribute))
   297  	return d.cc.ParseServiceConfig(sc)
   298  }
   299  
   300  func (d *dnsResolver) lookupHost() ([]resolver.Address, error) {
   301  	addrs, err := d.resolver.LookupHost(d.ctx, d.host)
   302  	if err != nil {
   303  		err = handleDNSError(err, "A")
   304  		return nil, err
   305  	}
   306  	newAddrs := make([]resolver.Address, 0, len(addrs))
   307  	for _, a := range addrs {
   308  		ip, ok := formatIP(a)
   309  		if !ok {
   310  			return nil, fmt.Errorf("dns: error parsing A record IP address %v", a)
   311  		}
   312  		addr := ip + ":" + d.port
   313  		newAddrs = append(newAddrs, resolver.Address{Addr: addr})
   314  	}
   315  	return newAddrs, nil
   316  }
   317  
   318  func (d *dnsResolver) lookup() (*resolver.State, error) {
   319  	srv, srvErr := d.lookupSRV()
   320  	addrs, hostErr := d.lookupHost()
   321  	if hostErr != nil && (srvErr != nil || len(srv) == 0) {
   322  		return nil, hostErr
   323  	}
   324  
   325  	state := resolver.State{Addresses: addrs}
   326  	if len(srv) > 0 {
   327  		state = grpclbstate.Set(state, &grpclbstate.State{BalancerAddresses: srv})
   328  	}
   329  	if !d.disableServiceConfig {
   330  		state.ServiceConfig = d.lookupTXT()
   331  	}
   332  	return &state, nil
   333  }
   334  
   335  // formatIP returns ok = false if addr is not a valid textual representation of
   336  // an IP address. If addr is an IPv4 address, return the addr and ok = true.
   337  // If addr is an IPv6 address, return the addr enclosed in square brackets and
   338  // ok = true.
   339  func formatIP(addr string) (addrIP string, ok bool) {
   340  	ip := net.ParseIP(addr)
   341  	if ip == nil {
   342  		return "", false
   343  	}
   344  	if ip.To4() != nil {
   345  		return addr, true
   346  	}
   347  	return "[" + addr + "]", true
   348  }
   349  
   350  // parseTarget takes the user input target string and default port, returns
   351  // formatted host and port info. If target doesn't specify a port, set the port
   352  // to be the defaultPort. If target is in IPv6 format and host-name is enclosed
   353  // in square brackets, brackets are stripped when setting the host.
   354  // examples:
   355  // target: "www.google.com" defaultPort: "443" returns host: "www.google.com", port: "443"
   356  // target: "ipv4-host:80" defaultPort: "443" returns host: "ipv4-host", port: "80"
   357  // target: "[ipv6-host]" defaultPort: "443" returns host: "ipv6-host", port: "443"
   358  // target: ":80" defaultPort: "443" returns host: "localhost", port: "80"
   359  func parseTarget(target, defaultPort string) (host, port string, err error) {
   360  	if target == "" {
   361  		return "", "", internal.ErrMissingAddr
   362  	}
   363  	if ip := net.ParseIP(target); ip != nil {
   364  		// target is an IPv4 or IPv6(without brackets) address
   365  		return target, defaultPort, nil
   366  	}
   367  	if host, port, err = net.SplitHostPort(target); err == nil {
   368  		if port == "" {
   369  			// If the port field is empty (target ends with colon), e.g. "[::1]:",
   370  			// this is an error.
   371  			return "", "", internal.ErrEndsWithColon
   372  		}
   373  		// target has port, i.e ipv4-host:port, [ipv6-host]:port, host-name:port
   374  		if host == "" {
   375  			// Keep consistent with net.Dial(): If the host is empty, as in ":80",
   376  			// the local system is assumed.
   377  			host = "localhost"
   378  		}
   379  		return host, port, nil
   380  	}
   381  	if host, port, err = net.SplitHostPort(target + ":" + defaultPort); err == nil {
   382  		// target doesn't have port
   383  		return host, port, nil
   384  	}
   385  	return "", "", fmt.Errorf("invalid target address %v, error info: %v", target, err)
   386  }
   387  
   388  type rawChoice struct {
   389  	ClientLanguage *[]string        `json:"clientLanguage,omitempty"`
   390  	Percentage     *int             `json:"percentage,omitempty"`
   391  	ClientHostName *[]string        `json:"clientHostName,omitempty"`
   392  	ServiceConfig  *json.RawMessage `json:"serviceConfig,omitempty"`
   393  }
   394  
   395  func containsString(a *[]string, b string) bool {
   396  	if a == nil {
   397  		return true
   398  	}
   399  	for _, c := range *a {
   400  		if c == b {
   401  			return true
   402  		}
   403  	}
   404  	return false
   405  }
   406  
   407  func chosenByPercentage(a *int) bool {
   408  	if a == nil {
   409  		return true
   410  	}
   411  	return grpcrand.Intn(100)+1 <= *a
   412  }
   413  
   414  func canaryingSC(js string) string {
   415  	if js == "" {
   416  		return ""
   417  	}
   418  	var rcs []rawChoice
   419  	err := json.Unmarshal([]byte(js), &rcs)
   420  	if err != nil {
   421  		logger.Warningf("dns: error parsing service config json: %v", err)
   422  		return ""
   423  	}
   424  	cliHostname, err := os.Hostname()
   425  	if err != nil {
   426  		logger.Warningf("dns: error getting client hostname: %v", err)
   427  		return ""
   428  	}
   429  	var sc string
   430  	for _, c := range rcs {
   431  		if !containsString(c.ClientLanguage, golang) ||
   432  			!chosenByPercentage(c.Percentage) ||
   433  			!containsString(c.ClientHostName, cliHostname) ||
   434  			c.ServiceConfig == nil {
   435  			continue
   436  		}
   437  		sc = string(*c.ServiceConfig)
   438  		break
   439  	}
   440  	return sc
   441  }