github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/grpc/internal/resolver/dns/dns_resolver.go (about)

     1  /*
     2   *
     3   * Copyright 2018 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  // Package dns implements a dns resolver to be installed as the default resolver
    20  // in grpc.
    21  package dns
    22  
    23  import (
    24  	"context"
    25  	"encoding/json"
    26  	"errors"
    27  	"fmt"
    28  	"net"
    29  	"os"
    30  	"strconv"
    31  	"strings"
    32  	"sync"
    33  	"time"
    34  
    35  	grpclbstate "github.com/hxx258456/ccgo/grpc/balancer/grpclb/state"
    36  	"github.com/hxx258456/ccgo/grpc/grpclog"
    37  	"github.com/hxx258456/ccgo/grpc/internal/backoff"
    38  	"github.com/hxx258456/ccgo/grpc/internal/envconfig"
    39  	"github.com/hxx258456/ccgo/grpc/internal/grpcrand"
    40  	"github.com/hxx258456/ccgo/grpc/resolver"
    41  	"github.com/hxx258456/ccgo/grpc/serviceconfig"
    42  )
    43  
    44  // EnableSRVLookups controls whether the DNS resolver attempts to fetch gRPCLB
    45  // addresses from SRV records.  Must not be changed after init time.
    46  var EnableSRVLookups = false
    47  
    48  var logger = grpclog.Component("dns")
    49  
    50  // Globals to stub out in tests. TODO: Perhaps these two can be combined into a
    51  // single variable for testing the resolver?
    52  var (
    53  	newTimer           = time.NewTimer
    54  	newTimerDNSResRate = time.NewTimer
    55  )
    56  
    57  func init() {
    58  	resolver.Register(NewBuilder())
    59  }
    60  
    61  const (
    62  	defaultPort       = "443"
    63  	defaultDNSSvrPort = "53"
    64  	golang            = "GO"
    65  	// txtPrefix is the prefix string to be prepended to the host name for txt record lookup.
    66  	txtPrefix = "_grpc_config."
    67  	// In DNS, service config is encoded in a TXT record via the mechanism
    68  	// described in RFC-1464 using the attribute name grpc_config.
    69  	txtAttribute = "grpc_config="
    70  )
    71  
    72  var (
    73  	errMissingAddr = errors.New("dns resolver: missing address")
    74  
    75  	// Addresses ending with a colon that is supposed to be the separator
    76  	// between host and port is not allowed.  E.g. "::" is a valid address as
    77  	// it is an IPv6 address (host only) and "[::]:" is invalid as it ends with
    78  	// a colon as the host and port separator
    79  	errEndsWithColon = errors.New("dns resolver: missing port after port-separator colon")
    80  )
    81  
    82  var (
    83  	defaultResolver netResolver = net.DefaultResolver
    84  	// To prevent excessive re-resolution, we enforce a rate limit on DNS
    85  	// resolution requests.
    86  	minDNSResRate = 30 * time.Second
    87  )
    88  
    89  var customAuthorityDialler = func(authority string) func(ctx context.Context, network, address string) (net.Conn, error) {
    90  	return func(ctx context.Context, network, address string) (net.Conn, error) {
    91  		var dialer net.Dialer
    92  		return dialer.DialContext(ctx, network, authority)
    93  	}
    94  }
    95  
    96  var customAuthorityResolver = func(authority string) (netResolver, error) {
    97  	host, port, err := parseTarget(authority, defaultDNSSvrPort)
    98  	if err != nil {
    99  		return nil, err
   100  	}
   101  
   102  	authorityWithPort := net.JoinHostPort(host, port)
   103  
   104  	return &net.Resolver{
   105  		PreferGo: true,
   106  		Dial:     customAuthorityDialler(authorityWithPort),
   107  	}, nil
   108  }
   109  
   110  // NewBuilder creates a dnsBuilder which is used to factory DNS resolvers.
   111  func NewBuilder() resolver.Builder {
   112  	return &dnsBuilder{}
   113  }
   114  
   115  type dnsBuilder struct{}
   116  
   117  // Build creates and starts a DNS resolver that watches the name resolution of the target.
   118  func (b *dnsBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (resolver.Resolver, error) {
   119  	host, port, err := parseTarget(target.Endpoint, defaultPort)
   120  	if err != nil {
   121  		return nil, err
   122  	}
   123  
   124  	// IP address.
   125  	if ipAddr, ok := formatIP(host); ok {
   126  		addr := []resolver.Address{{Addr: ipAddr + ":" + port}}
   127  		cc.UpdateState(resolver.State{Addresses: addr})
   128  		return deadResolver{}, nil
   129  	}
   130  
   131  	// DNS address (non-IP).
   132  	ctx, cancel := context.WithCancel(context.Background())
   133  	d := &dnsResolver{
   134  		host:                 host,
   135  		port:                 port,
   136  		ctx:                  ctx,
   137  		cancel:               cancel,
   138  		cc:                   cc,
   139  		rn:                   make(chan struct{}, 1),
   140  		disableServiceConfig: opts.DisableServiceConfig,
   141  	}
   142  
   143  	if target.Authority == "" {
   144  		d.resolver = defaultResolver
   145  	} else {
   146  		d.resolver, err = customAuthorityResolver(target.Authority)
   147  		if err != nil {
   148  			return nil, err
   149  		}
   150  	}
   151  
   152  	d.wg.Add(1)
   153  	go d.watcher()
   154  	return d, nil
   155  }
   156  
   157  // Scheme returns the naming scheme of this resolver builder, which is "dns".
   158  func (b *dnsBuilder) Scheme() string {
   159  	return "dns"
   160  }
   161  
   162  type netResolver interface {
   163  	LookupHost(ctx context.Context, host string) (addrs []string, err error)
   164  	LookupSRV(ctx context.Context, service, proto, name string) (cname string, addrs []*net.SRV, err error)
   165  	LookupTXT(ctx context.Context, name string) (txts []string, err error)
   166  }
   167  
   168  // deadResolver is a resolver that does nothing.
   169  type deadResolver struct{}
   170  
   171  func (deadResolver) ResolveNow(resolver.ResolveNowOptions) {}
   172  
   173  func (deadResolver) Close() {}
   174  
   175  // dnsResolver watches for the name resolution update for a non-IP target.
   176  type dnsResolver struct {
   177  	host     string
   178  	port     string
   179  	resolver netResolver
   180  	ctx      context.Context
   181  	cancel   context.CancelFunc
   182  	cc       resolver.ClientConn
   183  	// rn channel is used by ResolveNow() to force an immediate resolution of the target.
   184  	rn chan struct{}
   185  	// wg is used to enforce Close() to return after the watcher() goroutine has finished.
   186  	// Otherwise, data race will be possible. [Race Example] in dns_resolver_test we
   187  	// replace the real lookup functions with mocked ones to facilitate testing.
   188  	// If Close() doesn't wait for watcher() goroutine finishes, race detector sometimes
   189  	// will warns lookup (READ the lookup function pointers) inside watcher() goroutine
   190  	// has data race with replaceNetFunc (WRITE the lookup function pointers).
   191  	wg                   sync.WaitGroup
   192  	disableServiceConfig bool
   193  }
   194  
   195  // ResolveNow invoke an immediate resolution of the target that this dnsResolver watches.
   196  func (d *dnsResolver) ResolveNow(resolver.ResolveNowOptions) {
   197  	select {
   198  	case d.rn <- struct{}{}:
   199  	default:
   200  	}
   201  }
   202  
   203  // Close closes the dnsResolver.
   204  func (d *dnsResolver) Close() {
   205  	d.cancel()
   206  	d.wg.Wait()
   207  }
   208  
   209  func (d *dnsResolver) watcher() {
   210  	defer d.wg.Done()
   211  	backoffIndex := 1
   212  	for {
   213  		state, err := d.lookup()
   214  		if err != nil {
   215  			// Report error to the underlying grpc.ClientConn.
   216  			d.cc.ReportError(err)
   217  		} else {
   218  			err = d.cc.UpdateState(*state)
   219  		}
   220  
   221  		var timer *time.Timer
   222  		if err == nil {
   223  			// Success resolving, wait for the next ResolveNow. However, also wait 30 seconds at the very least
   224  			// to prevent constantly re-resolving.
   225  			backoffIndex = 1
   226  			timer = newTimerDNSResRate(minDNSResRate)
   227  			select {
   228  			case <-d.ctx.Done():
   229  				timer.Stop()
   230  				return
   231  			case <-d.rn:
   232  			}
   233  		} else {
   234  			// Poll on an error found in DNS Resolver or an error received from ClientConn.
   235  			timer = newTimer(backoff.DefaultExponential.Backoff(backoffIndex))
   236  			backoffIndex++
   237  		}
   238  		select {
   239  		case <-d.ctx.Done():
   240  			timer.Stop()
   241  			return
   242  		case <-timer.C:
   243  		}
   244  	}
   245  }
   246  
   247  func (d *dnsResolver) lookupSRV() ([]resolver.Address, error) {
   248  	if !EnableSRVLookups {
   249  		return nil, nil
   250  	}
   251  	var newAddrs []resolver.Address
   252  	_, srvs, err := d.resolver.LookupSRV(d.ctx, "grpclb", "tcp", d.host)
   253  	if err != nil {
   254  		err = handleDNSError(err, "SRV") // may become nil
   255  		return nil, err
   256  	}
   257  	for _, s := range srvs {
   258  		lbAddrs, err := d.resolver.LookupHost(d.ctx, s.Target)
   259  		if err != nil {
   260  			err = handleDNSError(err, "A") // may become nil
   261  			if err == nil {
   262  				// If there are other SRV records, look them up and ignore this
   263  				// one that does not exist.
   264  				continue
   265  			}
   266  			return nil, err
   267  		}
   268  		for _, a := range lbAddrs {
   269  			ip, ok := formatIP(a)
   270  			if !ok {
   271  				return nil, fmt.Errorf("dns: error parsing A record IP address %v", a)
   272  			}
   273  			addr := ip + ":" + strconv.Itoa(int(s.Port))
   274  			newAddrs = append(newAddrs, resolver.Address{Addr: addr, ServerName: s.Target})
   275  		}
   276  	}
   277  	return newAddrs, nil
   278  }
   279  
   280  func handleDNSError(err error, lookupType string) error {
   281  	if dnsErr, ok := err.(*net.DNSError); ok && !dnsErr.IsTimeout && !dnsErr.IsTemporary {
   282  		// Timeouts and temporary errors should be communicated to gRPC to
   283  		// attempt another DNS query (with backoff).  Other errors should be
   284  		// suppressed (they may represent the absence of a TXT record).
   285  		return nil
   286  	}
   287  	if err != nil {
   288  		err = fmt.Errorf("dns: %v record lookup error: %v", lookupType, err)
   289  		logger.Info(err)
   290  	}
   291  	return err
   292  }
   293  
   294  func (d *dnsResolver) lookupTXT() *serviceconfig.ParseResult {
   295  	ss, err := d.resolver.LookupTXT(d.ctx, txtPrefix+d.host)
   296  	if err != nil {
   297  		if envconfig.TXTErrIgnore {
   298  			return nil
   299  		}
   300  		if err = handleDNSError(err, "TXT"); err != nil {
   301  			return &serviceconfig.ParseResult{Err: err}
   302  		}
   303  		return nil
   304  	}
   305  	var res string
   306  	for _, s := range ss {
   307  		res += s
   308  	}
   309  
   310  	// TXT record must have "grpc_config=" attribute in order to be used as service config.
   311  	if !strings.HasPrefix(res, txtAttribute) {
   312  		logger.Warningf("dns: TXT record %v missing %v attribute", res, txtAttribute)
   313  		// This is not an error; it is the equivalent of not having a service config.
   314  		return nil
   315  	}
   316  	sc := canaryingSC(strings.TrimPrefix(res, txtAttribute))
   317  	return d.cc.ParseServiceConfig(sc)
   318  }
   319  
   320  func (d *dnsResolver) lookupHost() ([]resolver.Address, error) {
   321  	addrs, err := d.resolver.LookupHost(d.ctx, d.host)
   322  	if err != nil {
   323  		err = handleDNSError(err, "A")
   324  		return nil, err
   325  	}
   326  	newAddrs := make([]resolver.Address, 0, len(addrs))
   327  	for _, a := range addrs {
   328  		ip, ok := formatIP(a)
   329  		if !ok {
   330  			return nil, fmt.Errorf("dns: error parsing A record IP address %v", a)
   331  		}
   332  		addr := ip + ":" + d.port
   333  		newAddrs = append(newAddrs, resolver.Address{Addr: addr})
   334  	}
   335  	return newAddrs, nil
   336  }
   337  
   338  func (d *dnsResolver) lookup() (*resolver.State, error) {
   339  	srv, srvErr := d.lookupSRV()
   340  	addrs, hostErr := d.lookupHost()
   341  	if hostErr != nil && (srvErr != nil || len(srv) == 0) {
   342  		return nil, hostErr
   343  	}
   344  
   345  	state := resolver.State{Addresses: addrs}
   346  	if len(srv) > 0 {
   347  		state = grpclbstate.Set(state, &grpclbstate.State{BalancerAddresses: srv})
   348  	}
   349  	if !d.disableServiceConfig {
   350  		state.ServiceConfig = d.lookupTXT()
   351  	}
   352  	return &state, nil
   353  }
   354  
   355  // formatIP returns ok = false if addr is not a valid textual representation of an IP address.
   356  // If addr is an IPv4 address, return the addr and ok = true.
   357  // If addr is an IPv6 address, return the addr enclosed in square brackets and ok = true.
   358  func formatIP(addr string) (addrIP string, ok bool) {
   359  	ip := net.ParseIP(addr)
   360  	if ip == nil {
   361  		return "", false
   362  	}
   363  	if ip.To4() != nil {
   364  		return addr, true
   365  	}
   366  	return "[" + addr + "]", true
   367  }
   368  
   369  // parseTarget takes the user input target string and default port, returns formatted host and port info.
   370  // If target doesn't specify a port, set the port to be the defaultPort.
   371  // If target is in IPv6 format and host-name is enclosed in square brackets, brackets
   372  // are stripped when setting the host.
   373  // examples:
   374  // target: "www.google.com" defaultPort: "443" returns host: "www.google.com", port: "443"
   375  // target: "ipv4-host:80" defaultPort: "443" returns host: "ipv4-host", port: "80"
   376  // target: "[ipv6-host]" defaultPort: "443" returns host: "ipv6-host", port: "443"
   377  // target: ":80" defaultPort: "443" returns host: "localhost", port: "80"
   378  func parseTarget(target, defaultPort string) (host, port string, err error) {
   379  	if target == "" {
   380  		return "", "", errMissingAddr
   381  	}
   382  	if ip := net.ParseIP(target); ip != nil {
   383  		// target is an IPv4 or IPv6(without brackets) address
   384  		return target, defaultPort, nil
   385  	}
   386  	if host, port, err = net.SplitHostPort(target); err == nil {
   387  		if port == "" {
   388  			// If the port field is empty (target ends with colon), e.g. "[::1]:", this is an error.
   389  			return "", "", errEndsWithColon
   390  		}
   391  		// target has port, i.e ipv4-host:port, [ipv6-host]:port, host-name:port
   392  		if host == "" {
   393  			// Keep consistent with net.Dial(): If the host is empty, as in ":80", the local system is assumed.
   394  			host = "localhost"
   395  		}
   396  		return host, port, nil
   397  	}
   398  	if host, port, err = net.SplitHostPort(target + ":" + defaultPort); err == nil {
   399  		// target doesn't have port
   400  		return host, port, nil
   401  	}
   402  	return "", "", fmt.Errorf("invalid target address %v, error info: %v", target, err)
   403  }
   404  
   405  type rawChoice struct {
   406  	ClientLanguage *[]string        `json:"clientLanguage,omitempty"`
   407  	Percentage     *int             `json:"percentage,omitempty"`
   408  	ClientHostName *[]string        `json:"clientHostName,omitempty"`
   409  	ServiceConfig  *json.RawMessage `json:"serviceConfig,omitempty"`
   410  }
   411  
   412  func containsString(a *[]string, b string) bool {
   413  	if a == nil {
   414  		return true
   415  	}
   416  	for _, c := range *a {
   417  		if c == b {
   418  			return true
   419  		}
   420  	}
   421  	return false
   422  }
   423  
   424  func chosenByPercentage(a *int) bool {
   425  	if a == nil {
   426  		return true
   427  	}
   428  	return grpcrand.Intn(100)+1 <= *a
   429  }
   430  
   431  func canaryingSC(js string) string {
   432  	if js == "" {
   433  		return ""
   434  	}
   435  	var rcs []rawChoice
   436  	err := json.Unmarshal([]byte(js), &rcs)
   437  	if err != nil {
   438  		logger.Warningf("dns: error parsing service config json: %v", err)
   439  		return ""
   440  	}
   441  	cliHostname, err := os.Hostname()
   442  	if err != nil {
   443  		logger.Warningf("dns: error getting client hostname: %v", err)
   444  		return ""
   445  	}
   446  	var sc string
   447  	for _, c := range rcs {
   448  		if !containsString(c.ClientLanguage, golang) ||
   449  			!chosenByPercentage(c.Percentage) ||
   450  			!containsString(c.ClientHostName, cliHostname) ||
   451  			c.ServiceConfig == nil {
   452  			continue
   453  		}
   454  		sc = string(*c.ServiceConfig)
   455  		break
   456  	}
   457  	return sc
   458  }